Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import time | |
| import os | |
| import torch | |
| from PIL import Image | |
| from threading import Thread | |
| from transformers import TextIteratorStreamer, AutoConfig, AutoModelForCausalLM | |
| from constants import ( | |
| IMAGE_TOKEN_INDEX, | |
| DEFAULT_IMAGE_TOKEN, | |
| DEFAULT_IM_START_TOKEN, | |
| DEFAULT_IM_END_TOKEN, | |
| ) | |
| from conversation import conv_templates | |
| from eval_utils import load_maya_model | |
| from utils import disable_torch_init | |
| from mm_utils import tokenizer_image_token, process_images | |
| from huggingface_hub._login import _login | |
| # Import LLaVA modules to register model types | |
| from model import * | |
| from model.language_model.llava_cohere import LlavaCohereForCausalLM, LlavaCohereConfig | |
| # Register model type and config | |
| AutoConfig.register("llava_cohere", LlavaCohereConfig) | |
| AutoModelForCausalLM.register(LlavaCohereConfig, LlavaCohereForCausalLM) | |
| hf_token = os.getenv("hf_token") | |
| _login(token=hf_token, add_to_git_credential=False) | |
| # Global Variables | |
| MODEL_BASE = "CohereForAI/aya-23-8B" | |
| MODEL_PATH = "maya-multimodal/maya" | |
| MODE = "finetuned" | |
| def load_model(): | |
| """Load the Maya model and required components""" | |
| model, tokenizer, image_processor, _ = load_maya_model( | |
| MODEL_BASE, MODEL_PATH, None, MODE | |
| ) | |
| model = model.cuda() | |
| model.eval() | |
| return model, tokenizer, image_processor | |
| # Load model globally | |
| print("Loading model...") | |
| model, tokenizer, image_processor = load_model() | |
| print("Model loaded successfully!") | |
| def validate_image_file(image_path): | |
| """Validate that the image file exists and is in a supported format.""" | |
| if not os.path.isfile(image_path): | |
| raise gr.Error(f"Error: File {image_path} does not exist.") | |
| try: | |
| with Image.open(image_path) as img: | |
| img.verify() | |
| return True | |
| except (IOError, SyntaxError) as e: | |
| raise gr.Error(f"Error: {image_path} is not a valid image file. {e}") | |
| def process_chat_stream(message, history): | |
| print(message) | |
| print("History:", history) | |
| image = None # Initialize image variable first | |
| # First try to get image from current message | |
| if message.get("files", []): | |
| current_files = message["files"] | |
| if current_files: | |
| last_file = current_files[-1] | |
| image = last_file["path"] if isinstance(last_file, dict) else last_file | |
| # If no image in current message, try to get from history | |
| if image is None and history: | |
| for hist in reversed(history): | |
| print("Processing history item:", hist) | |
| if isinstance(hist["content"], tuple): | |
| image = hist["content"][0] | |
| break | |
| elif isinstance(hist["content"], dict) and hist["content"].get("files"): | |
| hist_files = hist["content"]["files"] | |
| if hist_files: | |
| first_file = hist_files[0] | |
| image = first_file["path"] if isinstance(first_file, dict) else first_file | |
| break | |
| # Check if we found an image | |
| if image is None: | |
| raise gr.Error("Please upload an image to start the conversation.") | |
| # Validate and process image | |
| validate_image_file(image) | |
| image = Image.open(image).convert("RGB") | |
| # Process image for the model | |
| image_tensor = process_images([image], image_processor, model.config) | |
| if image_tensor is None: | |
| raise gr.Error("Failed to process image") | |
| image_tensor = image_tensor.cuda() | |
| # Prepare conversation | |
| conv = conv_templates["aya"].copy() | |
| # Add conversation history | |
| for hist in history: | |
| # Handle user messages | |
| if hist["role"] == "user": | |
| # Extract text content based on format | |
| if isinstance(hist["content"], str): | |
| human_text = hist["content"] | |
| elif isinstance(hist["content"], tuple): | |
| human_text = hist["content"][1] if len(hist["content"]) > 1 else "" | |
| else: | |
| human_text = hist["content"] | |
| conv.append_message(conv.roles[0], human_text) | |
| # Handle assistant messages | |
| elif hist["role"] == "assistant": | |
| conv.append_message(conv.roles[1], hist["content"]) | |
| # Format current message with proper image token placement | |
| current_message = message["text"] | |
| if not history: | |
| if model.config.mm_use_im_start_end: | |
| current_message = f"{DEFAULT_IM_START_TOKEN}{DEFAULT_IMAGE_TOKEN}{DEFAULT_IM_END_TOKEN}\n{current_message}" | |
| else: | |
| current_message = f"{DEFAULT_IMAGE_TOKEN}\n{current_message}" | |
| # Add current message to conversation | |
| conv.append_message(conv.roles[0], current_message) | |
| conv.append_message(conv.roles[1], None) | |
| # Get prompt and ensure input_ids are properly created | |
| prompt = conv.get_prompt() | |
| # print("PROMPT: ", prompt) | |
| try: | |
| input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') | |
| if input_ids is None: | |
| raise ValueError("Tokenization returned None") | |
| # Ensure input_ids is 2D tensor | |
| if len(input_ids.shape) == 1: | |
| input_ids = input_ids.unsqueeze(0) | |
| input_ids = input_ids.cuda() | |
| # Validate vision tower and image tensor before starting generation | |
| if not hasattr(model, 'get_vision_tower') or model.get_vision_tower() is None: | |
| raise ValueError("Model's vision tower is not properly initialized") | |
| if image_tensor is None: | |
| raise ValueError("Image tensor is None") | |
| # Setup streamer and generation | |
| streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) | |
| generation_kwargs = { | |
| "inputs": input_ids, | |
| "images": image_tensor, | |
| "image_sizes": [image.size], | |
| "streamer": streamer, | |
| "temperature": 0.3, | |
| "do_sample": True, | |
| "top_p": 0.9, | |
| "num_beams": 1, | |
| "max_new_tokens": 4096, | |
| "use_cache": True | |
| } | |
| def generate_with_error_handling(): | |
| try: | |
| model.generate(**generation_kwargs) | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Generation error: {str(e)}\nTraceback:\n{''.join(traceback.format_exc())}" | |
| raise gr.Error(error_msg) | |
| thread = Thread(target=generate_with_error_handling) | |
| thread.start() | |
| except Exception as e: | |
| error_msg = f"Setup error: {str(e)}" | |
| import traceback | |
| error_msg += f"\nTraceback:\n{''.join(traceback.format_exc())}" | |
| raise gr.Error(error_msg) | |
| partial_message = "" | |
| for new_token in streamer: | |
| partial_message += new_token | |
| time.sleep(0.1) | |
| yield {"role": "assistant", "content": partial_message} | |
| # Create Gradio interface | |
| chatbot = gr.Chatbot( | |
| show_label=False, | |
| height=450, | |
| show_share_button=False, | |
| show_copy_button=False, | |
| avatar_images=None, | |
| container=True, | |
| render_markdown=True, | |
| scale=1, | |
| type="messages" | |
| ) | |
| chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) | |
| with gr.Blocks(fill_height=True, ) as demo: | |
| gr.ChatInterface( | |
| fn=process_chat_stream, | |
| title="Maya: Multilingual Multimodal Model", | |
| examples=[{"text": "Describe this photo in detail.", "files": ["./asian_food.jpg"]}, | |
| {"text": "What is the name of this famous sight in the photo?", "files": ["./hawaii.jpg"]}], | |
| description="Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error. [Read the research paper](https://huggingface.co/papers/2412.07112)\n\nTeam π Maya", | |
| stop_btn="Stop Generation", | |
| multimodal=True, | |
| textbox=chat_input, | |
| chatbot=chatbot, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(api_open=False) | |
| demo.launch(show_api=False, share=False) |