File size: 6,561 Bytes
89a570a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import requests
from io import BytesIO
from PIL import Image
import google.generativeai as genai
from openai import OpenAI
from huggingface_hub import InferenceClient
import gradio as gr

from config import OPENAI_API_KEY, GOOGLE_API_KEY, HF_API_TOKEN, GEMINI_VISION_MODEL_ID, DALL_E_MODEL_ID, SDXL_HF_MODEL_ID
from utils import images_to_gemini_parts

# --- Initialize API Clients ---
openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None

if GOOGLE_API_KEY:
    genai.configure(api_key=GOOGLE_API_KEY)
    gemini_vision_model = genai.GenerativeModel(GEMINI_VISION_MODEL_ID)
else:
    gemini_vision_model = None

hf_client = InferenceClient(HF_API_TOKEN) if HF_API_TOKEN else None


def _generate_detailed_prompt_from_images(images: list[Image.Image], user_prompt: str, target_model_desc: str) -> str:
    """
    Uses Gemini Pro Vision to generate a detailed prompt for a text-to-image model
    based on input images and a user's textual request.
    """
    if not gemini_vision_model:
        raise gr.Error("Google API Key not configured. Cannot use Gemini for prompt generation.")

    gemini_parts = images_to_gemini_parts(images)
    
    prompt_instruction = (
        "Analyze these images. Describe their key visual elements, styles, and themes. "
        f"Then, combine these descriptions with the following user request to create a single, highly detailed, "
        f"creative, and coherent text-to-image prompt suitable for a modern generative AI model like {target_model_desc}. "
        "The goal is to 'remix' the visual ideas from the input images with the user's intent to produce a new, imaginative image. "
        "Focus on generating a creative and specific prompt, not just a literal summary. "
        "Include artistic styles, lighting, composition, and mood where appropriate. "
        "User request: " + user_prompt
    )
    
    gemini_content = gemini_parts + [{"text": prompt_instruction}]

    try:
        response = gemini_vision_model.generate_content(gemini_content)
        return response.text
    except Exception as e:
        print(f"Error generating prompt with Gemini: {e}")
        # Fallback to a simpler prompt if Gemini fails
        image_desc = f"a creative remix of {len(images)} input images" if images else "an imaginative scene"
        return f"A vibrant, imaginative image ({image_desc}) based on the theme '{user_prompt}'."


def remix_with_gpt_image_1(img1: Image.Image | None, img2: Image.Image | None, img3: Image.Image | None, prompt: str) -> Image.Image:
    """
    Remixes images and prompt using DALL-E 3 (GPT Image-1).
    Leverages Gemini Pro Vision to generate a detailed DALL-E prompt from input images.
    """
    input_images = [img for img in [img1, img2, img3] if img is not None]

    if not openai_client:
        raise gr.Error("OpenAI API Key not configured. Cannot use GPT Image-1 (DALL-E 3).")
    if input_images and not gemini_vision_model:
         raise gr.Error("Google API Key not configured. Gemini Pro Vision is required to describe input images for DALL-E 3 remixing.")

    if input_images:
        print("Generating DALL-E prompt using Gemini-Pro-Vision...")
        dalle_prompt = _generate_detailed_prompt_from_images(input_images, prompt, "DALL-E 3")
    else:
        dalle_prompt = prompt # If no images, just use the text prompt directly

    print(f"Final DALL-E prompt: {dalle_prompt}")

    try:
        response = openai_client.images.generate(
            model=DALL_E_MODEL_ID,
            prompt=dalle_prompt,
            size="1024x1024",
            quality="standard",
            n=1,
        )
        image_url = response.data[0].url
        # Download the image
        img_data = Image.open(BytesIO(requests.get(image_url).content))
        return img_data
    except Exception as e:
        print(f"Error generating image with DALL-E: {e}")
        raise gr.Error(f"DALL-E image generation failed: {e}")


def remix_with_gemini_2(img1: Image.Image | None, img2: Image.Image | None, img3: Image.Image | None, prompt: str) -> Image.Image:
    """
    Remixes images and prompt using Gemini Pro Vision to generate a detailed prompt,
    then uses a Hugging Face SDXL model for final image generation.
    """
    input_images = [img for img in [img1, img2, img3] if img is not None]

    if not gemini_vision_model:
        raise gr.Error("Google API Key not configured. Cannot use Gemini-2's prompt generation.")
    if not hf_client:
        raise gr.Error("Hugging Face API Token not configured. Cannot use SDXL for image generation.")

    print("Generating image generation prompt using Gemini-Pro-Vision...")
    sdxl_prompt = _generate_detailed_prompt_from_images(input_images, prompt, "Stable Diffusion XL")
    
    print(f"Generated SDXL prompt by Gemini: {sdxl_prompt}")
    
    try:
        # Use Hugging Face Inference API for text-to-image generation
        image_bytes = hf_client.text_to_image(sdxl_prompt, model=SDXL_HF_MODEL_ID)
        img_data = Image.open(BytesIO(image_bytes))
        return img_data
    except Exception as e:
        print(f"Error generating image with Hugging Face SDXL: {e}")
        raise gr.Error(f"SDXL image generation failed: {e}")


def remix_images(
    model_choice: str,
    image1: Image.Image | None,
    image2: Image.Image | None,
    image3: Image.Image | None,
    prompt: str,
) -> Image.Image:
    """
    Main function to orchestrate image remixing based on model choice.
    """
    # At least one image or a prompt is required
    if not prompt and not (image1 or image2 or image3):
        raise gr.Error("Please provide at least a text prompt or one image to remix.")
    
    # Check if there are images but no prompt, and suggest adding a prompt for better remixing
    if (image1 or image2 or image3) and not prompt:
        gr.Info("You've provided images but no text prompt. The AI will try to remix based on images alone. Adding a prompt often yields better results!")

    try:
        if model_choice == "GPT Image-1":
            return remix_with_gpt_image_1(image1, image2, image3, prompt)
        elif model_choice == "Gemini-2":
            return remix_with_gemini_2(image1, image2, image3, prompt)
        else:
            raise gr.Error("Invalid model choice.")
    except Exception as e:
        # Re-raise Gradio errors to be displayed to the user
        if isinstance(e, gr.Error):
            raise
        print(f"An unexpected error occurred during remixing: {e}")
        raise gr.Error(f"An unexpected error occurred: {e}")