Spaces:
Runtime error
Runtime error
| import os | |
| import requests | |
| import gradio as gr | |
| from PIL import Image, ImageDraw, ImageFont | |
| import io | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor | |
| # ===== CONFIGURATION ===== | |
| HF_API_TOKEN = os.environ.get("HF_API_TOKEN") | |
| MODEL_NAME = "stabilityai/stable-diffusion-xl-base-1.0" | |
| API_URL = f"https://api-inference.huggingface.co/models/{MODEL_NAME}" | |
| headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} | |
| WATERMARK_TEXT = "SelamGPT" | |
| MAX_RETRIES = 3 | |
| TIMEOUT = 60 | |
| EXECUTOR = ThreadPoolExecutor(max_workers=2) | |
| # ===== WATERMARK FUNCTION ===== | |
| def add_watermark(image_bytes): | |
| """Add watermark with optimized PNG output""" | |
| try: | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| draw = ImageDraw.Draw(image) | |
| font_size = 24 | |
| try: | |
| font = ImageFont.truetype("Roboto-Bold.ttf", font_size) | |
| except: | |
| font = ImageFont.load_default(font_size) | |
| text_width = draw.textlength(WATERMARK_TEXT, font=font) | |
| x = image.width - text_width - 10 | |
| y = image.height - 34 | |
| draw.text((x+1, y+1), WATERMARK_TEXT, font=font, fill=(0, 0, 0, 128)) | |
| draw.text((x, y), WATERMARK_TEXT, font=font, fill=(255, 255, 255)) | |
| # Convert to optimized PNG | |
| img_byte_arr = io.BytesIO() | |
| image.save(img_byte_arr, format='PNG', optimize=True, quality=85) | |
| img_byte_arr.seek(0) | |
| return Image.open(img_byte_arr) | |
| except Exception as e: | |
| print(f"Watermark error: {str(e)}") | |
| return Image.open(io.BytesIO(image_bytes)) | |
| # ===== IMAGE GENERATION ===== | |
| def generate_image(prompt): | |
| if not prompt.strip(): | |
| return None, "⚠️ Please enter a prompt" | |
| def api_call(): | |
| return requests.post( | |
| API_URL, | |
| headers=headers, | |
| json={ | |
| "inputs": prompt, | |
| "parameters": { | |
| "height": 1024, | |
| "width": 1024, | |
| "num_inference_steps": 30 | |
| }, | |
| "options": {"wait_for_model": True} | |
| }, | |
| timeout=TIMEOUT | |
| ) | |
| for attempt in range(MAX_RETRIES): | |
| try: | |
| future = EXECUTOR.submit(api_call) | |
| response = future.result() | |
| if response.status_code == 200: | |
| return add_watermark(response.content), "✔️ Generation successful" | |
| elif response.status_code == 503: | |
| wait_time = (attempt + 1) * 15 | |
| print(f"Model loading, waiting {wait_time}s...") | |
| time.sleep(wait_time) | |
| continue | |
| else: | |
| return None, f"⚠️ API Error: {response.text[:200]}" | |
| except requests.Timeout: | |
| return None, f"⚠️ Timeout: Model took >{TIMEOUT}s to respond" | |
| except Exception as e: | |
| return None, f"⚠️ Unexpected error: {str(e)[:200]}" | |
| return None, "⚠️ Failed after multiple attempts. Please try later." | |
| # ===== GRADIO THEME ===== | |
| theme = gr.themes.Default( | |
| primary_hue="emerald", | |
| secondary_hue="amber", | |
| font=[gr.themes.GoogleFont("Poppins"), "Arial", "sans-serif"] | |
| ) | |
| # ===== GRADIO INTERFACE ===== | |
| with gr.Blocks(theme=theme, title="SelamGPT Image Generator") as demo: | |
| gr.Markdown(""" | |
| # 🎨 SelamGPT Image Generator | |
| *Powered by Stable Diffusion XL (1024x1024 PNG output)* | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| prompt_input = gr.Textbox( | |
| label="Describe your image", | |
| placeholder="A futuristic Ethiopian city with flying cars...", | |
| lines=3, | |
| max_lines=5 | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate Image", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| gr.Examples( | |
| examples=[ | |
| ["An ancient Aksumite warrior in cyberpunk armor, 4k detailed"], | |
| ["Traditional Ethiopian coffee ceremony in zero gravity"], | |
| ["Portrait of a Habesha queen with golden jewelry"] | |
| ], | |
| inputs=prompt_input | |
| ) | |
| with gr.Column(scale=2): | |
| output_image = gr.Image( | |
| label="Generated Image", | |
| type="pil", | |
| format="png", | |
| height=512 | |
| ) | |
| status_output = gr.Textbox( | |
| label="Status", | |
| interactive=False | |
| ) | |
| generate_btn.click( | |
| fn=generate_image, | |
| inputs=prompt_input, | |
| outputs=[output_image, status_output], | |
| queue=True | |
| ) | |
| clear_btn.click( | |
| fn=lambda: [None, ""], | |
| outputs=[output_image, status_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=2) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |