Spaces:
Running
Running
| import os | |
| import requests | |
| from io import BytesIO | |
| from PIL import Image | |
| import google.generativeai as genai | |
| from openai import OpenAI | |
| from huggingface_hub import InferenceClient | |
| import gradio as gr | |
| from config import OPENAI_API_KEY, GOOGLE_API_KEY, HF_API_TOKEN, GEMINI_VISION_MODEL_ID, DALL_E_MODEL_ID, SDXL_HF_MODEL_ID | |
| from utils import images_to_gemini_parts | |
| # --- Initialize API Clients --- | |
| openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None | |
| if GOOGLE_API_KEY: | |
| genai.configure(api_key=GOOGLE_API_KEY) | |
| gemini_vision_model = genai.GenerativeModel(GEMINI_VISION_MODEL_ID) | |
| else: | |
| gemini_vision_model = None | |
| hf_client = InferenceClient(HF_API_TOKEN) if HF_API_TOKEN else None | |
| def _generate_detailed_prompt_from_images(images: list[Image.Image], user_prompt: str, target_model_desc: str) -> str: | |
| """ | |
| Uses Gemini Pro Vision to generate a detailed prompt for a text-to-image model | |
| based on input images and a user's textual request. | |
| """ | |
| if not gemini_vision_model: | |
| raise gr.Error("Google API Key not configured. Cannot use Gemini for prompt generation.") | |
| gemini_parts = images_to_gemini_parts(images) | |
| prompt_instruction = ( | |
| "Analyze these images. Describe their key visual elements, styles, and themes. " | |
| f"Then, combine these descriptions with the following user request to create a single, highly detailed, " | |
| f"creative, and coherent text-to-image prompt suitable for a modern generative AI model like {target_model_desc}. " | |
| "The goal is to 'remix' the visual ideas from the input images with the user's intent to produce a new, imaginative image. " | |
| "Focus on generating a creative and specific prompt, not just a literal summary. " | |
| "Include artistic styles, lighting, composition, and mood where appropriate. " | |
| "User request: " + user_prompt | |
| ) | |
| gemini_content = gemini_parts + [{"text": prompt_instruction}] | |
| try: | |
| response = gemini_vision_model.generate_content(gemini_content) | |
| return response.text | |
| except Exception as e: | |
| print(f"Error generating prompt with Gemini: {e}") | |
| # Fallback to a simpler prompt if Gemini fails | |
| image_desc = f"a creative remix of {len(images)} input images" if images else "an imaginative scene" | |
| return f"A vibrant, imaginative image ({image_desc}) based on the theme '{user_prompt}'." | |
| def remix_with_gpt_image_1(img1: Image.Image | None, img2: Image.Image | None, img3: Image.Image | None, prompt: str) -> Image.Image: | |
| """ | |
| Remixes images and prompt using DALL-E 3 (GPT Image-1). | |
| Leverages Gemini Pro Vision to generate a detailed DALL-E prompt from input images. | |
| """ | |
| input_images = [img for img in [img1, img2, img3] if img is not None] | |
| if not openai_client: | |
| raise gr.Error("OpenAI API Key not configured. Cannot use GPT Image-1 (DALL-E 3).") | |
| if input_images and not gemini_vision_model: | |
| raise gr.Error("Google API Key not configured. Gemini Pro Vision is required to describe input images for DALL-E 3 remixing.") | |
| if input_images: | |
| print("Generating DALL-E prompt using Gemini-Pro-Vision...") | |
| dalle_prompt = _generate_detailed_prompt_from_images(input_images, prompt, "DALL-E 3") | |
| else: | |
| dalle_prompt = prompt # If no images, just use the text prompt directly | |
| print(f"Final DALL-E prompt: {dalle_prompt}") | |
| try: | |
| response = openai_client.images.generate( | |
| model=DALL_E_MODEL_ID, | |
| prompt=dalle_prompt, | |
| size="1024x1024", | |
| quality="standard", | |
| n=1, | |
| ) | |
| image_url = response.data[0].url | |
| # Download the image | |
| img_data = Image.open(BytesIO(requests.get(image_url).content)) | |
| return img_data | |
| except Exception as e: | |
| print(f"Error generating image with DALL-E: {e}") | |
| raise gr.Error(f"DALL-E image generation failed: {e}") | |
| def remix_with_gemini_2(img1: Image.Image | None, img2: Image.Image | None, img3: Image.Image | None, prompt: str) -> Image.Image: | |
| """ | |
| Remixes images and prompt using Gemini Pro Vision to generate a detailed prompt, | |
| then uses a Hugging Face SDXL model for final image generation. | |
| """ | |
| input_images = [img for img in [img1, img2, img3] if img is not None] | |
| if not gemini_vision_model: | |
| raise gr.Error("Google API Key not configured. Cannot use Gemini-2's prompt generation.") | |
| if not hf_client: | |
| raise gr.Error("Hugging Face API Token not configured. Cannot use SDXL for image generation.") | |
| print("Generating image generation prompt using Gemini-Pro-Vision...") | |
| sdxl_prompt = _generate_detailed_prompt_from_images(input_images, prompt, "Stable Diffusion XL") | |
| print(f"Generated SDXL prompt by Gemini: {sdxl_prompt}") | |
| try: | |
| # Use Hugging Face Inference API for text-to-image generation | |
| image_bytes = hf_client.text_to_image(sdxl_prompt, model=SDXL_HF_MODEL_ID) | |
| img_data = Image.open(BytesIO(image_bytes)) | |
| return img_data | |
| except Exception as e: | |
| print(f"Error generating image with Hugging Face SDXL: {e}") | |
| raise gr.Error(f"SDXL image generation failed: {e}") | |
| def remix_images( | |
| model_choice: str, | |
| image1: Image.Image | None, | |
| image2: Image.Image | None, | |
| image3: Image.Image | None, | |
| prompt: str, | |
| ) -> Image.Image: | |
| """ | |
| Main function to orchestrate image remixing based on model choice. | |
| """ | |
| # At least one image or a prompt is required | |
| if not prompt and not (image1 or image2 or image3): | |
| raise gr.Error("Please provide at least a text prompt or one image to remix.") | |
| # Check if there are images but no prompt, and suggest adding a prompt for better remixing | |
| if (image1 or image2 or image3) and not prompt: | |
| gr.Info("You've provided images but no text prompt. The AI will try to remix based on images alone. Adding a prompt often yields better results!") | |
| try: | |
| if model_choice == "GPT Image-1": | |
| return remix_with_gpt_image_1(image1, image2, image3, prompt) | |
| elif model_choice == "Gemini-2": | |
| return remix_with_gemini_2(image1, image2, image3, prompt) | |
| else: | |
| raise gr.Error("Invalid model choice.") | |
| except Exception as e: | |
| # Re-raise Gradio errors to be displayed to the user | |
| if isinstance(e, gr.Error): | |
| raise | |
| print(f"An unexpected error occurred during remixing: {e}") | |
| raise gr.Error(f"An unexpected error occurred: {e}") |