Spaces:
Runtime error
Runtime error
| import os | |
| import requests | |
| import gradio as gr | |
| from PIL import Image, ImageDraw, ImageFont | |
| import io | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor | |
| # ===== CONFIGURATION ===== | |
| HF_API_TOKEN = os.environ.get("HF_API_TOKEN") | |
| MODEL_NAME = "stabilityai/stable-diffusion-xl-base-1.0" # Using SDXL | |
| API_URL = f"https://api-inference.huggingface.co/models/{MODEL_NAME}" | |
| headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} | |
| WATERMARK_TEXT = "SelamGPT" | |
| MAX_RETRIES = 3 | |
| TIMEOUT = 60 # Increased for SDXL's longer processing | |
| EXECUTOR = ThreadPoolExecutor(max_workers=2) | |
| # ===== WATERMARK FUNCTION ===== | |
| def add_watermark(image_bytes): | |
| """Convert to PNG with medium quality before watermarking""" | |
| try: | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| # Save as medium-quality PNG to buffer | |
| png_buffer = io.BytesIO() | |
| image.save(png_buffer, format="PNG", optimize=True, quality=85) # Medium quality | |
| png_buffer.seek(0) | |
| # Add watermark to the PNG | |
| watermarked_image = Image.open(png_buffer) | |
| draw = ImageDraw.Draw(watermarked_image) | |
| font_size = 24 | |
| try: | |
| font = ImageFont.truetype("Roboto-Bold.ttf", font_size) | |
| except: | |
| font = ImageFont.load_default(font_size) | |
| text_width = draw.textlength(WATERMARK_TEXT, font=font) | |
| x = watermarked_image.width - text_width - 10 | |
| y = watermarked_image.height - 34 | |
| draw.text((x+1, y+1), WATERMARK_TEXT, font=font, fill=(0, 0, 0, 128)) | |
| draw.text((x, y), WATERMARK_TEXT, font=font, fill=(255, 255, 255)) | |
| # Return as PNG bytes | |
| final_buffer = io.BytesIO() | |
| watermarked_image.save(final_buffer, format="PNG", optimize=True, quality=85) | |
| final_buffer.seek(0) | |
| return Image.open(final_buffer) | |
| except Exception as e: | |
| print(f"Watermark error: {str(e)}") | |
| return Image.open(io.BytesIO(image_bytes)) | |
| # ===== IMAGE GENERATION (SDXL-OPTIMIZED) ===== | |
| def generate_image(prompt): | |
| """Generate image with SDXL-specific parameters""" | |
| if not prompt.strip(): | |
| return None, "⚠️ Please enter a prompt" | |
| def api_call(): | |
| return requests.post( | |
| API_URL, | |
| headers=headers, | |
| json={ | |
| "inputs": prompt, | |
| "parameters": { | |
| "height": 1024, # SDXL's native resolution | |
| "width": 1024, | |
| "num_inference_steps": 30, # Better quality than 25 | |
| "guidance_scale": 7.5 # SDXL's optimal value | |
| }, | |
| "options": {"wait_for_model": True} | |
| }, | |
| timeout=TIMEOUT | |
| ) | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| future = EXECUTOR.submit(api_call) | |
| response = future.result() | |
| if response.status_code == 200: | |
| return add_watermark(response.content), "✔️ Generation successful" | |
| elif response.status_code == 503: | |
| wait_time = (attempt + 1) * 15 # Longer wait for SDXL | |
| print(f"Model loading, waiting {wait_time}s...") | |
| time.sleep(wait_time) | |
| continue | |
| else: | |
| return None, f"⚠️ API Error: {response.text[:200]}" | |
| except requests.Timeout: | |
| return None, f"⚠️ Timeout: Model took >{TIMEOUT}s to respond" | |
| except Exception as e: | |
| return None, f"⚠️ Unexpected error: {str(e)[:200]}" | |
| return None, "⚠️ Failed after multiple attempts. Please try later." | |
| # ===== GRADIO INTERFACE ===== | |
| with gr.Blocks() as demo: | |
| output_image = gr.Image( | |
| label="Generated Image", | |
| type="pil", # Force PIL/PNG output | |
| format="png", # Explicit PNG format | |
| height=512 | |
| ) | |
| with gr.Blocks(theme=theme, title="SelamGPT Image Generator") as demo: | |
| gr.Markdown(""" | |
| # 🎨 SelamGPT Image Generator | |
| *Now powered by Stable Diffusion XL (1024x1024 resolution)* | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| prompt_input = gr.Textbox( | |
| label="Describe your image", | |
| placeholder="A futuristic Ethiopian city with flying cars...", | |
| lines=3, | |
| max_lines=5, | |
| elem_id="prompt-box" | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate Image", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| gr.Examples( | |
| examples=[ | |
| ["An ancient Aksumite warrior in cyberpunk armor, 4k detailed"], | |
| ["Traditional Ethiopian coffee ceremony in zero gravity, photorealistic"], | |
| ["Portrait of a Habesha queen with golden jewelry, studio lighting"] | |
| ], | |
| inputs=prompt_input, | |
| label="Try these SDXL-optimized prompts:" | |
| ) | |
| with gr.Column(scale=2): | |
| output_image = gr.Image( | |
| label="Generated Image (1024x1024)", | |
| height=512, | |
| elem_id="output-image" | |
| ) | |
| status_output = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| elem_id="status-box" | |
| ) | |
| generate_btn.click( | |
| fn=generate_image, | |
| inputs=prompt_input, | |
| outputs=[output_image, status_output], | |
| queue=True, | |
| show_progress="minimal" | |
| ) | |
| clear_btn.click( | |
| fn=lambda: [None, ""], | |
| outputs=[output_image, status_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=2) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |