app-cjsgck-39 / models.py
Gertie01's picture
Deploy Gradio app with multiple files
89a570a verified
import os
import requests
from io import BytesIO
from PIL import Image
import google.generativeai as genai
from openai import OpenAI
from huggingface_hub import InferenceClient
import gradio as gr
from config import OPENAI_API_KEY, GOOGLE_API_KEY, HF_API_TOKEN, GEMINI_VISION_MODEL_ID, DALL_E_MODEL_ID, SDXL_HF_MODEL_ID
from utils import images_to_gemini_parts
# --- Initialize API Clients ---
openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
if GOOGLE_API_KEY:
genai.configure(api_key=GOOGLE_API_KEY)
gemini_vision_model = genai.GenerativeModel(GEMINI_VISION_MODEL_ID)
else:
gemini_vision_model = None
hf_client = InferenceClient(HF_API_TOKEN) if HF_API_TOKEN else None
def _generate_detailed_prompt_from_images(images: list[Image.Image], user_prompt: str, target_model_desc: str) -> str:
"""
Uses Gemini Pro Vision to generate a detailed prompt for a text-to-image model
based on input images and a user's textual request.
"""
if not gemini_vision_model:
raise gr.Error("Google API Key not configured. Cannot use Gemini for prompt generation.")
gemini_parts = images_to_gemini_parts(images)
prompt_instruction = (
"Analyze these images. Describe their key visual elements, styles, and themes. "
f"Then, combine these descriptions with the following user request to create a single, highly detailed, "
f"creative, and coherent text-to-image prompt suitable for a modern generative AI model like {target_model_desc}. "
"The goal is to 'remix' the visual ideas from the input images with the user's intent to produce a new, imaginative image. "
"Focus on generating a creative and specific prompt, not just a literal summary. "
"Include artistic styles, lighting, composition, and mood where appropriate. "
"User request: " + user_prompt
)
gemini_content = gemini_parts + [{"text": prompt_instruction}]
try:
response = gemini_vision_model.generate_content(gemini_content)
return response.text
except Exception as e:
print(f"Error generating prompt with Gemini: {e}")
# Fallback to a simpler prompt if Gemini fails
image_desc = f"a creative remix of {len(images)} input images" if images else "an imaginative scene"
return f"A vibrant, imaginative image ({image_desc}) based on the theme '{user_prompt}'."
def remix_with_gpt_image_1(img1: Image.Image | None, img2: Image.Image | None, img3: Image.Image | None, prompt: str) -> Image.Image:
"""
Remixes images and prompt using DALL-E 3 (GPT Image-1).
Leverages Gemini Pro Vision to generate a detailed DALL-E prompt from input images.
"""
input_images = [img for img in [img1, img2, img3] if img is not None]
if not openai_client:
raise gr.Error("OpenAI API Key not configured. Cannot use GPT Image-1 (DALL-E 3).")
if input_images and not gemini_vision_model:
raise gr.Error("Google API Key not configured. Gemini Pro Vision is required to describe input images for DALL-E 3 remixing.")
if input_images:
print("Generating DALL-E prompt using Gemini-Pro-Vision...")
dalle_prompt = _generate_detailed_prompt_from_images(input_images, prompt, "DALL-E 3")
else:
dalle_prompt = prompt # If no images, just use the text prompt directly
print(f"Final DALL-E prompt: {dalle_prompt}")
try:
response = openai_client.images.generate(
model=DALL_E_MODEL_ID,
prompt=dalle_prompt,
size="1024x1024",
quality="standard",
n=1,
)
image_url = response.data[0].url
# Download the image
img_data = Image.open(BytesIO(requests.get(image_url).content))
return img_data
except Exception as e:
print(f"Error generating image with DALL-E: {e}")
raise gr.Error(f"DALL-E image generation failed: {e}")
def remix_with_gemini_2(img1: Image.Image | None, img2: Image.Image | None, img3: Image.Image | None, prompt: str) -> Image.Image:
"""
Remixes images and prompt using Gemini Pro Vision to generate a detailed prompt,
then uses a Hugging Face SDXL model for final image generation.
"""
input_images = [img for img in [img1, img2, img3] if img is not None]
if not gemini_vision_model:
raise gr.Error("Google API Key not configured. Cannot use Gemini-2's prompt generation.")
if not hf_client:
raise gr.Error("Hugging Face API Token not configured. Cannot use SDXL for image generation.")
print("Generating image generation prompt using Gemini-Pro-Vision...")
sdxl_prompt = _generate_detailed_prompt_from_images(input_images, prompt, "Stable Diffusion XL")
print(f"Generated SDXL prompt by Gemini: {sdxl_prompt}")
try:
# Use Hugging Face Inference API for text-to-image generation
image_bytes = hf_client.text_to_image(sdxl_prompt, model=SDXL_HF_MODEL_ID)
img_data = Image.open(BytesIO(image_bytes))
return img_data
except Exception as e:
print(f"Error generating image with Hugging Face SDXL: {e}")
raise gr.Error(f"SDXL image generation failed: {e}")
def remix_images(
model_choice: str,
image1: Image.Image | None,
image2: Image.Image | None,
image3: Image.Image | None,
prompt: str,
) -> Image.Image:
"""
Main function to orchestrate image remixing based on model choice.
"""
# At least one image or a prompt is required
if not prompt and not (image1 or image2 or image3):
raise gr.Error("Please provide at least a text prompt or one image to remix.")
# Check if there are images but no prompt, and suggest adding a prompt for better remixing
if (image1 or image2 or image3) and not prompt:
gr.Info("You've provided images but no text prompt. The AI will try to remix based on images alone. Adding a prompt often yields better results!")
try:
if model_choice == "GPT Image-1":
return remix_with_gpt_image_1(image1, image2, image3, prompt)
elif model_choice == "Gemini-2":
return remix_with_gemini_2(image1, image2, image3, prompt)
else:
raise gr.Error("Invalid model choice.")
except Exception as e:
# Re-raise Gradio errors to be displayed to the user
if isinstance(e, gr.Error):
raise
print(f"An unexpected error occurred during remixing: {e}")
raise gr.Error(f"An unexpected error occurred: {e}")