Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import os | |
| import base64 | |
| from collections.abc import Iterator | |
| import gradio as gr | |
| from cohere import ClientV2 | |
| model_id = "command-a-vision-07-2025" | |
| # Initialize Cohere client | |
| api_key = os.getenv("COHERE_API_KEY") | |
| if not api_key: | |
| raise ValueError("COHERE_API_KEY environment variable is required") | |
| client = ClientV2(api_key=api_key, client_name="hf-command-a-vision-07-2025") | |
| IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp") | |
| def count_files_in_new_message(paths: list[str]) -> int: | |
| image_count = 0 | |
| for path in paths: | |
| if path.endswith(IMAGE_FILE_TYPES): | |
| image_count += 1 | |
| return image_count | |
| def validate_media_constraints(message: dict) -> bool: | |
| image_count = count_files_in_new_message(message["files"]) | |
| if image_count > 10: | |
| gr.Warning("Maximum 10 images are supported.") | |
| return False | |
| return True | |
| def encode_image_to_base64(image_path: str) -> str: | |
| """Encode an image file to base64 data URL format.""" | |
| with open(image_path, "rb") as image_file: | |
| encoded_string = base64.b64encode(image_file.read()).decode('utf-8') | |
| # Determine file extension for MIME type | |
| if image_path.lower().endswith('.png'): | |
| mime_type = "image/png" | |
| elif image_path.lower().endswith('.jpg') or image_path.lower().endswith('.jpeg'): | |
| mime_type = "image/jpeg" | |
| elif image_path.lower().endswith('.webp'): | |
| mime_type = "image/webp" | |
| else: | |
| mime_type = "image/jpeg" # default | |
| return f"data:{mime_type};base64,{encoded_string}" | |
| def generate(message: dict, history: list[dict], max_new_tokens: int = 512) -> Iterator[str]: | |
| if not validate_media_constraints(message): | |
| yield "" | |
| return | |
| # Build messages for Cohere API | |
| messages = [] | |
| # Add conversation history | |
| for item in history: | |
| if item["role"] == "assistant": | |
| messages.append({"role": "assistant", "content": item["content"]}) | |
| else: | |
| content = item["content"] | |
| if isinstance(content, str): | |
| messages.append({"role": "user", "content": [{"type": "text", "text": content}]}) | |
| else: | |
| filepath = content[0] | |
| # For file-only messages, don't include empty text | |
| messages.append({ | |
| "role": "user", | |
| "content": [ | |
| {"type": "image_url", "image_url": {"url": encode_image_to_base64(filepath)}} | |
| ] | |
| }) | |
| # Add current message | |
| current_content = [] | |
| if message["text"]: | |
| current_content.append({"type": "text", "text": message["text"]}) | |
| for file_path in message["files"]: | |
| current_content.append({ | |
| "type": "image_url", | |
| "image_url": {"url": encode_image_to_base64(file_path)} | |
| }) | |
| # Only add the message if there's content | |
| if current_content: | |
| messages.append({"role": "user", "content": current_content}) | |
| try: | |
| # Call Cohere API using the correct event type and delta access | |
| response = client.chat_stream( | |
| model=model_id, | |
| messages=messages, | |
| temperature=0.3, | |
| max_tokens=max_new_tokens, | |
| ) | |
| output = "" | |
| for event in response: | |
| if getattr(event, "type", None) == "content-delta": | |
| # event.delta.message.content.text is the streamed text | |
| text = getattr(event.delta.message.content, "text", "") | |
| output += text | |
| yield output | |
| except Exception as e: | |
| gr.Warning(f"Error calling Cohere API: {str(e)}") | |
| yield "" | |
| examples = [ | |
| [ | |
| { | |
| "text": "Write a COBOL function to reverse a string", | |
| "files": [], | |
| } | |
| ], | |
| [ | |
| { | |
| "text": "Como sair de um helicóptero que caiu na água?", | |
| "files": [], | |
| } | |
| ], | |
| [ | |
| { | |
| "text": "What is the total amount of the invoice with and without tax?", | |
| "files": ["assets/invoice-1.jpg"], | |
| } | |
| ], | |
| [ | |
| { | |
| "text": "¿Contra qué modelo gana más Aya Vision 8B?", | |
| "files": ["assets/aya-vision-win-rates.png"], | |
| } | |
| ], | |
| [ | |
| { | |
| "text": "Erläutern Sie die Ergebnisse in der Tabelle", | |
| "files": ["assets/command-a-longbech-v2.png"], | |
| } | |
| ], | |
| [ | |
| { | |
| "text": "Explain la théorie de la relativité en français", | |
| "files": [], | |
| } | |
| ], | |
| ] | |
| demo = gr.ChatInterface( | |
| fn=generate, | |
| type="messages", | |
| textbox=gr.MultimodalTextbox( | |
| file_types=list(IMAGE_FILE_TYPES), | |
| file_count="multiple", | |
| autofocus=True, | |
| ), | |
| multimodal=True, | |
| additional_inputs=[ | |
| gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700), | |
| ], | |
| stop_btn=False, | |
| title="Command A Vision", | |
| examples=examples, | |
| run_examples_on_click=False, | |
| cache_examples=False, | |
| css_paths="style.css", | |
| delete_cache=(1800, 1800), | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |