| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| # Keep these if you use them elsewhere in your app (HTML, static files) | |
| # from fastapi.responses import HTMLResponse | |
| # from fastapi.staticfiles import StaticFiles | |
| # from fastapi.templating import Jinja2Templates | |
| # from fastapi.responses import FileResponse | |
| # Removed 'requests' as we are using gradio_client | |
| # import requests | |
| import base64 # Keep if needed elsewhere (not strictly needed for this version) | |
| import os | |
| import random | |
| # Removed unused IO import | |
| # from typing import IO | |
| # Import necessary classes from transformers (Keeping only AutoTokenizer) | |
| from transformers import AutoTokenizer | |
| # Import necessary modules for llama-cpp-python and downloading from Hub | |
| from llama_cpp import Llama # The core Llama class | |
| from huggingface_hub import hf_hub_download # For downloading GGUF files | |
| # Import the Gradio Client and handle_file | |
| from gradio_client import Client, handle_file | |
| # Import necessary modules for temporary file handling | |
| import tempfile | |
| # shutil is not strictly necessary for this version, os.remove is sufficient | |
| # import shutil | |
| from deep_translator import GoogleTranslator | |
| from deep_translator.exceptions import InvalidSourceOrTargetLanguage | |
| app = FastAPI() | |
| # --- Llama.cpp Language Model Setup (Local CPU Inference) --- | |
| # Repository on Hugging Face Hub containing the Qwen1.5 0.5B GGUF file | |
| # Using the OFFICIAL Qwen 0.5B repository shown in the user's image: | |
| LLM_MODEL_REPO = "Qwen/Qwen1.5-0.5B-Chat-GGUF" # Updated to official 0.5B repo | |
| # Specify the filename for a Q4_K_M quantized version (good balance of speed/quality on CPU) | |
| # Based on DIRECT VERIFICATION from the user's IMAGE of the 0.5B repo: | |
| LLM_MODEL_FILE = "qwen1_5-0_5b-chat-q4_k_m.gguf" # Exact filename from the 0.5B repo image | |
| # Original model name for the tokenizer (needed by transformers) | |
| # This points to the base model repository for the tokenizer files. | |
| ORIGINAL_MODEL_NAME = "Qwen/Qwen1.5-0.5B-Chat" # Updated to the 0.5B Chat model | |
| tokenizer = None # Using transformers tokenizer for chat templating | |
| llm_model = None # This will hold the llama_cpp.Llama instance | |
| # --- Hugging Face Gradio Space Client Setup (For External Image Captioning) --- | |
| # Global Gradio Client for Captioning | |
| caption_client = None | |
| # The URL of the external Gradio Space for image captioning | |
| CAPTION_SPACE_URL = "Makhinur/Image-to-Text-Salesforce-blip-image-captioning-base" | |
| # Function to load the language model (GGUF via llama.cpp) and its tokenizer (from transformers) | |
| def load_language_model(): | |
| global tokenizer, llm_model | |
| print(f"Loading language model: {LLM_MODEL_FILE} from {LLM_MODEL_REPO}...") | |
| try: | |
| # --- Load Tokenizer (using transformers) --- | |
| # Load the tokenizer from the original model repo | |
| print(f"Loading tokenizer from original model repo: {ORIGINAL_MODEL_NAME}...") | |
| tokenizer = AutoTokenizer.from_pretrained(ORIGINAL_MODEL_NAME) | |
| # Set pad_token if not already defined, often necessary for correct batching/generation behavior | |
| # Qwen tokenizers should have pad_token, but this check is robust | |
| if tokenizer.pad_token is None: | |
| if tokenizer.eos_token is not None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| elif tokenizer.unk_token is not None: | |
| tokenizer.pad_token = tokenizer.unk_token | |
| else: | |
| # Fallback if neither exists (very rare) | |
| print("Warning: Neither EOS nor UNK token found for tokenizer. Setting pad_token to None.") | |
| tokenizer.pad_token = None | |
| # --- Download GGUF model file (using huggingface_hub) --- | |
| print(f"Downloading GGUF model file: {LLM_MODEL_FILE} from {LLM_MODEL_REPO}...") | |
| model_path = hf_hub_download( | |
| repo_id=LLM_MODEL_REPO, | |
| filename=LLM_MODEL_FILE, | |
| # cache_dir="/tmp/hf_cache" # Optional: specify a custom cache directory | |
| ) | |
| print(f"GGUF model downloaded to: {model_path}") | |
| # --- Load the GGUF model (using llama-cpp-python) --- | |
| print(f"Loading GGUF model into llama_cpp...") | |
| # Instantiate the Llama model from the downloaded GGUF file | |
| # n_gpu_layers=0: Crucial for forcing CPU-only inference | |
| # n_ctx: Context window size (tokens model can consider), match model's spec if possible (Qwen1.5 0.5B has a smaller context than 1.8B, maybe 4096 or 8192 is standard) | |
| # n_threads: Number of CPU threads to use. Set to your vCPU count (2) for better performance. | |
| llm_model = Llama( | |
| model_path=model_path, | |
| n_gpu_layers=0, # Explicitly use CPU | |
| n_ctx=4096, # Context window size (4096 is a common safe value) | |
| n_threads=2 # Use 2 CPU threads | |
| ) | |
| print("Llama.cpp model loaded successfully.") | |
| except Exception as e: | |
| print(f"Error loading language model {LLM_MODEL_REPO}/{LLM_MODEL_FILE}: {e}") | |
| tokenizer = None | |
| llm_model = None # Ensure the model is None if loading fails | |
| # Function to initialize the Gradio Client for the captioning Space | |
| def initialize_caption_client(): | |
| global caption_client | |
| print(f"Initializing Gradio client for {CAPTION_SPACE_URL}...") | |
| try: | |
| # If the target Gradio Space requires authentication (e.g., private) | |
| # store HF_TOKEN as a Space Secret and uncomment these lines. | |
| # HF_TOKEN = os.environ.get("HF_TOKEN") | |
| # if HF_TOKEN: | |
| # print("Using HF_TOKEN for Gradio client.") | |
| # caption_client = Client(CAPTION_SPACE_URL, hf_token=HF_TOKEN) | |
| # else: | |
| # print("HF_TOKEN not found. Initializing public Gradio client.") | |
| # caption_client = Client(CAPTION_SPACE_URL) | |
| # Assuming the caption space is public | |
| caption_client = Client(CAPTION_SPACE_URL) | |
| print("Gradio client initialized successfully.") | |
| except Exception as e: | |
| print(f"Error initializing Gradio client for {CAPTION_SPACE_URL}: {e}") | |
| # Set client to None so the endpoint can check and return an error | |
| caption_client = None | |
| # Load models and initialize clients when the app starts | |
| async def startup_event(): | |
| # Load the language model (Qwen1.5 0.5B GGUF via llama.cpp) | |
| load_language_model() | |
| # Initialize the client for the external captioning Space | |
| initialize_caption_client() | |
| # --- Image Captioning Function (Using gradio_client and temporary file) --- | |
| def generate_image_caption(image_file: UploadFile): | |
| """ | |
| Generates a caption for the uploaded image using the external Gradio Space API. | |
| Reads the uploaded file's content, saves it to a temporary file, | |
| and uses the temporary file's path with handle_file for the API call. | |
| """ | |
| if caption_client is None: | |
| # If the client failed to initialize at startup | |
| error_msg = "Gradio caption client not initialized. Cannot generate caption." | |
| print(error_msg) | |
| return f"Error: {error_msg}" | |
| temp_file_path = None # Variable to store the path of the temporary file | |
| try: | |
| print(f"Attempting to generate caption for file: {image_file.filename}") | |
| # Read the content of the uploaded file | |
| # Seek to the beginning just in case the file-like object's pointer was moved | |
| image_file.file.seek(0) | |
| image_bytes = image_file.file.read() | |
| # Create a temporary file on the local filesystem | |
| # delete=False ensures the file persists after closing the handle | |
| # suffix helps hint at the file type for the Gradio API | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(image_file.filename)[1] or '.jpg') | |
| temp_file.write(image_bytes) | |
| temp_file.close() # Close the file handle so gradio_client can access the file | |
| temp_file_path = temp_file.name # Get the full path to the temporary file | |
| print(f"Saved uploaded file temporarily to: {temp_file_path}") | |
| # Use handle_file() with the path string to the temporary file. | |
| # This correctly prepares the file for the Gradio API input. | |
| prepared_input = handle_file(temp_file_path) | |
| # Call the predict method on the initialized gradio_client | |
| # api_name="/predict" matches the endpoint specified in the Gradio API docs | |
| caption = caption_client.predict(img=prepared_input, api_name="/predict") | |
| print(f"Caption generated successfully.") | |
| # Return the caption string received from the API | |
| return caption | |
| except Exception as e: | |
| # Catch any exceptions that occur during reading, writing, or the API call | |
| print(f"Error during caption generation API call: {e}") # Log the error details server-side | |
| # Return a structured error string including the exception type and message | |
| return f"Error: Unable to generate caption from API. Details: {type(e).__name__}: {e}" | |
| finally: | |
| # Clean up the temporary file regardless of whether the process succeeded or failed | |
| if temp_file_path and os.path.exists(temp_file_path): | |
| print(f"Cleaning up temporary file: {temp_file_path}") | |
| try: | |
| os.remove(temp_file_path) # Delete the file using its path | |
| except OSError as e: | |
| print(f"Error removing temporary file {temp_file_path}: {e}") # Log cleanup errors | |
| # --- Language Model Story Generation Function (Qwen1.5 0.5B via llama.cpp) --- | |
| # Renamed function to reflect the model being used | |
| def generate_story_qwen_0_5b(prompt_text: str, max_new_tokens: int = 300, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50) -> str: | |
| """ | |
| Generates text using the loaded Qwen1.5 0.5B model via llama.cpp. | |
| Uses the tokenizer to apply the chat template and calls llama.cpp's chat completion. | |
| """ | |
| # Check if the language model was loaded successfully at startup | |
| # Check for both tokenizer and llm_model (llama.cpp instance) | |
| if tokenizer is None or llm_model is None: | |
| # Raise a RuntimeError which is caught by the calling endpoint | |
| raise RuntimeError("Language model (llama.cpp) or tokenizer not loaded.") | |
| # Construct the messages list following the chat format for Qwen1.5 Chat | |
| # Qwen models use a standard ChatML-like format. | |
| messages = [ | |
| # System message is optional but can help guide the model's persona/style | |
| # {"role": "system", "content": "You are a helpful and creative assistant."} | |
| {"role": "user", "content": prompt_text} | |
| ] | |
| try: | |
| print("Calling llama.cpp create_chat_completion for Qwen 0.5B...") | |
| # Call the create_chat_completion method from llama_cpp.Llama instance | |
| # This method handles the chat templating internally for models like Qwen. | |
| # max_tokens is the max number of tokens to generate | |
| # temperature, top_p control sampling. top_k might not be a direct parameter. | |
| response = llm_model.create_chat_completion( | |
| messages=messages, | |
| max_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| # top_k is sometimes supported, but check llama-cpp-python docs if needed | |
| # top_k=top_k, | |
| stream=False # We want the full response at once | |
| ) | |
| print("Llama.cpp completion received for Qwen 0.5B.") | |
| # Parse the response to get the generated text content | |
| # The response structure is typically like OpenAI's chat API response | |
| if response and response.get('choices') and len(response['choices']) > 0: | |
| story = response['choices'][0].get('message', {}).get('content', '') | |
| else: | |
| # Handle cases where the response is empty or has an unexpected structure | |
| print("Warning: Llama.cpp Qwen 0.5B response structure unexpected or content missing.") | |
| story = "" # Return an empty string if content is not found | |
| except Exception as e: | |
| # Catch any exception that occurs during the llama.cpp inference process | |
| print(f"Llama.cpp Qwen 0.5B inference failed: {e}") # Log the error server-side | |
| # Re-raise as a RuntimeError to indicate failure to the endpoint | |
| raise RuntimeError(f"Llama.cpp inference failed: {type(e).__name__}: {e}") | |
| # Return the generated story text, removing leading/trailing whitespace | |
| return story.strip() | |
| # --- FastAPI Endpoint for Story Generation --- | |
| async def generate_story_endpoint(image_file: UploadFile = File(...), language: str = Form(...)): | |
| # Choose a random theme for the story prompt | |
| story_theme = random.choice([ | |
| 'an adventurous journey', 'a mysterious encounter', 'a heroic quest', | |
| 'a magical adventure', 'a thrilling escape', 'an unexpected discovery', | |
| 'a dangerous mission', 'a romantic escapade', 'an epic battle', | |
| 'a journey into the unknown' | |
| ]) | |
| # Step 1: Get image caption using the external Gradio API via gradio_client | |
| # Pass the UploadFile object directly to the captioning function | |
| caption = generate_image_caption(image_file) | |
| # Check if caption generation returned an error string | |
| if caption.startswith("Error:"): | |
| print(f"Caption generation failed: {caption}") # Log the error detail server-side | |
| # Raise an HTTPException with a 500 status code and the error message | |
| raise HTTPException(status_code=500, detail=caption) | |
| # Step 2: Construct the prompt text for the language model | |
| # This prompt instructs the model on what to write and incorporates the caption. | |
| prompt_text = f"Write a detailed story that is approximately 300 words long. Ensure the story has a clear beginning, middle, and end about {story_theme}. Incorporate the following details from an image description into the story: {caption}\n\nStory:" | |
| # Step 3: Generate the story using the local language model (Qwen 0.5B via llama.cpp) | |
| try: | |
| # Call the Qwen 0.5B story generation function | |
| story = generate_story_qwen_0_5b( # <--- Use the updated function name | |
| prompt_text, | |
| max_new_tokens=350, # Request ~300 new tokens | |
| temperature=0.7, # Sampling parameters | |
| top_p=0.9, | |
| top_k=50 # Note: top_k may not be directly used by llama_cpp.create_chat_completion | |
| ) | |
| story = story.strip() # Basic cleanup of generated story text | |
| except RuntimeError as e: | |
| # Catch specific RuntimeError raised by generate_story_qwen_0_5b if LLM loading or inference fails | |
| print(f"Language model generation error: {e}") # Log the error server-side | |
| # Return a 503 Service Unavailable error if the LLM is not available or failed | |
| raise HTTPException(status_code=503, detail=f"Story generation failed (LLM): {e}") | |
| except Exception as e: | |
| # Catch any other unexpected errors during story generation | |
| print(f"An unexpected error occurred during story generation: {e}") # Log server-side | |
| raise HTTPException(status_code=500, detail=f"An unexpected error occurred during story generation: {type(e).__name__}: {e}") | |
| # Step 4: Translate the generated story if the target language is not English | |
| # Check if language is provided and not English (case-insensitive) | |
| if language and language.lower() != "english": | |
| try: | |
| # Initialize GoogleTranslator with English source and requested target language | |
| translator = GoogleTranslator(source='english', target=language.lower()) | |
| # Perform the translation | |
| translated_story = translator.translate(story) | |
| # Check if translation returned None or an empty string (indicates failure) | |
| if translated_story is None or translated_story == "": | |
| print(f"Translation returned None or empty string for language: {language}") # Log failure | |
| # If translation fails, return the original English story with a warning | |
| return {"story": story + "\n\n(Note: Automatic translation to your requested language failed.)"} | |
| # If translation was successful, use the translated text | |
| story = translated_story | |
| except InvalidSourceOrTargetLanguage: | |
| print(f"Invalid target language requested: {language}") # Log invalid language | |
| raise HTTPException(status_code=400, detail=f"Invalid target language: {language}") | |
| except Exception as e: | |
| # Catch any other errors during translation (e.g., network issues, API problems) | |
| print(f"Translation failed for language {language}: {e}") # Log server-side | |
| raise HTTPException(status_code=500, detail=f"Translation failed: {type(e).__name__}: {e}") | |
| # Step 5: Return the final generated (and potentially translated) story as a JSON response | |
| return {"story": story} | |
| # --- Optional: Serve a simple HTML form for testing --- | |
| # To use this, uncomment the imports related to HTMLResponse, StaticFiles, Jinja2Templates, Request | |
| # at the top of the file, and create a 'templates' directory with an 'index.html' file. | |
| # from fastapi import Request | |
| # from fastapi.templating import Jinja2Templates | |
| # from fastapi.staticfiles import StaticFiles | |
| # templates = Jinja2Templates(directory="templates") | |
| # app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # @app.get("/", response_class=HTMLResponse) | |
| # async def read_root(request: Request): | |
| # # Simple HTML form to upload an image and specify language | |
| # html_content = """ | |
| # <!DOCTYPE html> | |
| # <html> | |
| # <head><title>Story Generator</title></head> | |
| # <body> | |
| # <h1>Generate a Story from an Image</h1> | |
| # <form action="/generate-story/" method="post" enctype="multipart/form-data"> | |
| # <input type="file" name="image_file" accept="image/*" required><br><br> | |
| # Target Language (e.g., english, french, spanish): <input type="text" name="language" value="english"><br><br> | |
| # <button type="submit">Generate Story</button> | |
| # </form> | |
| # </body> | |
| # </html> | |
| # """ | |
| # # If using templates: return templates.TemplateResponse("index.html", {"request": request}) | |
| # return HTMLResponse(content=html_content) # Using direct HTML for simplicity if templates not set up |