File size: 2,724 Bytes
0dfc187
 
 
7a13824
c248483
6298bb9
c20305d
7b9aa9d
13a2046
7b9aa9d
13a2046
33179dc
0dfc187
 
 
 
6298bb9
 
 
0dfc187
 
6298bb9
 
 
 
0dfc187
 
 
 
 
 
 
7b9aa9d
0dfc187
 
6298bb9
 
0dfc187
 
 
 
 
 
 
 
 
 
 
 
6298bb9
 
 
0dfc187
 
 
 
 
 
 
 
 
 
d7f02ec
 
 
 
7b9aa9d
d7f02ec
 
 
 
 
7b9aa9d
7e241d8
5469b0c
 
 
 
 
 
0dfc187
7b9aa9d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
from gradio_client import Client
from PIL import Image
import os
import time
import traceback

# Create Client instances for the repositories
clients = [
    Client("hsuwill000/LCM_SoteMix_OpenVINO_CPU_Space_TAESD"),
    Client("HelloSun/LCM_Dreamshaper_v7-int8-ov")
]

# Counter for image filenames to avoid overwriting
count = 0

# Global counter for selecting clients in order
client_index = 0

# Gradio Interface Function to handle image generation
def infer_gradio(prompt: str):
    global count, client_index
    # Select the current client based on the client_index
    client = clients[client_index]
    
    # Prepare the inputs for the prediction
    inputs = {
        "prompt": prompt,
        "num_inference_steps": 10  # Number of inference steps for the model
    }

    try:
        # Send the request to the model and receive the result (image URL or file path)
        result = client.predict(inputs, api_name="/infer")
        
        # Open the resulting image
        image = Image.open(result)
        
        # Create a unique filename to save the image
        filename = f"img_{count:08d}.jpg"
        while os.path.exists(filename):
            count += 1
            filename = f"img_{count:08d}.jpg"
        
        # Save the image locally
        image.save(filename)
        print(f"Saved image as {filename}")
        
        # Return the image to be displayed in Gradio
        # Update the client_index to use the next client in the next call
        client_index = (client_index + 1) % len(clients)  # Cycle through clients
        
        return image
    
    except Exception as e:
        # Handle any errors that occur
        print(f"An exception occurred: {str(e)}")
        print("Stack trace:")
        traceback.print_exc()  # Print stack trace for debugging
        return None  # Return nothing if an error occurs

# Define Gradio Interface
with gr.Blocks() as demo:
    with gr.Row():  # Use a Row to place the prompt input and the button side by side
        prompt_input = gr.Textbox(
            label="Enter Your Prompt", 
            show_label=False,
            placeholder="Type your prompt for image generation here",
            lines=1,  # Set the input to be only one line tall
            interactive=True  # Allow user to interact with the textbox
        )
        
        # Change the button text to "RUN" and align it with the prompt input
        run_button = gr.Button("RUN")
    
    # Output image display area
    output_image = gr.Image(label="Generated Image")
    
    # Connecting the button click to the image generation function
    run_button.click(infer_gradio, inputs=prompt_input, outputs=output_image)

demo.launch()