Spaces:
Running
on
T4
Running
on
T4
| from constants import IMAGE_PER_CONVERSATION_LIMIT, DEFAULT_SYSTEM_PREAMBLE_TOKEN_COUNT, VISION_COHERE_MODEL_NAME, VISION_MODEL_TOKEN_LIMIT | |
| from prompt_examples import AYA_VISION_PROMPT_EXAMPLES | |
| import base64 | |
| from io import BytesIO | |
| from PIL import Image | |
| import logging | |
| import cohere | |
| import os | |
| import traceback | |
| import random | |
| import gradio as gr | |
| from google.cloud.sql.connector import Connector, IPTypes | |
| import pg8000 | |
| from datetime import datetime | |
| import sqlalchemy | |
| # from dotenv import load_dotenv | |
| # load_dotenv() | |
| MULTIMODAL_API_KEY = os.getenv("AYA_VISION_API_KEY") | |
| logger = logging.getLogger(__name__) | |
| aya_vision_client = cohere.ClientV2( | |
| api_key=MULTIMODAL_API_KEY, | |
| client_name="c4ai-aya-vision-hf-space" | |
| ) | |
| def cohere_vision_chat(chat_history, model=VISION_COHERE_MODEL_NAME): | |
| response = aya_vision_client.chat( | |
| messages=chat_history, | |
| model=model, | |
| ) | |
| return response.message.content[0].text | |
| def get_aya_vision_prompt_example(language): | |
| example = AYA_VISION_PROMPT_EXAMPLES[language] | |
| return example[0], example[1] | |
| def get_base64_from_local_file(file_path): | |
| try: | |
| print("loading image") | |
| with open(file_path, "rb") as image_file: | |
| base64_image = base64.b64encode(image_file.read()).decode('utf-8') | |
| return base64_image | |
| except Exception as e: | |
| logger.debug(f"Error converting local image to base64 string: {e}") | |
| return None | |
| def get_aya_vision_response(incoming_message, image_filepath, max_size_mb=5): | |
| max_size_bytes = max_size_mb * 1024 * 1024 | |
| image_ext = image_filepath.lower() | |
| if image_ext.endswith(".jpg") or image_ext.endswith('.jpeg'): | |
| image_type="image/jpeg" | |
| elif image_ext.endswith(".png"): | |
| image_type = "image/png" | |
| elif image_ext.endswith(".webp"): | |
| image_type="image/webp" | |
| elif image_ext.endswith(".gif"): | |
| image_type="image/gif" | |
| response="" | |
| chat_history = [] | |
| print("converting image to base 64") | |
| base64_image = get_base64_from_local_file(image_filepath) | |
| image = f"data:{image_type};base64,{base64_image}" | |
| # to prevent Cohere API from throwing error for empty message | |
| if incoming_message=="" or incoming_message is None: | |
| incoming_message="." | |
| chat_history.append( | |
| { | |
| "role": "user", | |
| "content": [{"type": "text", "text": incoming_message}, | |
| {"type": "image_url","image_url": { "url": image}}], | |
| } | |
| ) | |
| image_size_bytes = get_base64_image_size(image) | |
| if image_size_bytes >= max_size_bytes: | |
| gr.Error("Please upload image with size under 5MB") | |
| # response = cohere_vision_chat_stream(chat_history, model=VISION_COHERE_MODEL_NAME) | |
| # return response | |
| res = aya_vision_client.chat_stream(messages=chat_history,model=VISION_COHERE_MODEL_NAME) | |
| output = "" | |
| for event in res: | |
| if event: | |
| if event.type == "content-delta": | |
| output += event.delta.message.content.text | |
| yield output | |
| def get_base64_image_size(base64_string): | |
| if ',' in base64_string: | |
| base64_data = base64_string.split(',', 1)[1] | |
| else: | |
| base64_data = base64_string | |
| base64_data = base64_data.replace('\n', '').replace('\r', '').replace(' ', '') | |
| padding = base64_data.count('=') | |
| size_bytes = (len(base64_data) * 3) // 4 - padding | |
| return size_bytes | |
| def insert_aya_audio(connection, user_prompt, text_response, audio_response_file_path, input_audio_file_path): | |
| with connection.begin(): | |
| connection.execute( | |
| sqlalchemy.text(""" | |
| INSERT INTO aya_audio (user_prompt, text_response, input_audio_file_path, audio_response_file_path, timestamp) | |
| VALUES (:user_prompt, :text_response, :input_audio_file_path, :audio_response_file_path, :timestamp) | |
| """), | |
| { | |
| "user_prompt": user_prompt, | |
| "text_response": text_response, | |
| "input_audio_file_path": input_audio_file_path, | |
| "audio_response_file_path": audio_response_file_path, | |
| "timestamp": datetime.now() | |
| } | |
| ) | |
| def insert_aya_image(connection, user_prompt, generated_img_desc, image_response_file_path): | |
| with connection.begin(): | |
| connection.execute( | |
| sqlalchemy.text(""" | |
| INSERT INTO aya_image (user_prompt, generated_img_desc, image_response_file_path, timestamp) | |
| VALUES (:user_prompt, :generated_img_desc, :image_response_file_path, :timestamp) | |
| """), | |
| {"user_prompt": user_prompt, "generated_img_desc": generated_img_desc, "image_response_file_path": image_response_file_path, "timestamp": datetime.now()} | |
| ) | |
| def connect_with_connector() -> sqlalchemy.engine.base.Engine: | |
| instance_connection_name = os.environ[ | |
| "INSTANCE_CONNECTION_NAME" | |
| ] | |
| db_user = os.environ["DB_USER"] | |
| db_pass = os.environ["DB_PASS"] | |
| db_name = os.environ["DB_NAME"] | |
| ip_type = IPTypes.PRIVATE if os.environ.get("PRIVATE_IP") else IPTypes.PUBLIC | |
| connector = Connector(refresh_strategy="LAZY") | |
| def getconn() -> pg8000.dbapi.Connection: | |
| conn: pg8000.dbapi.Connection = connector.connect( | |
| instance_connection_name, | |
| "pg8000", | |
| user=db_user, | |
| password=db_pass, | |
| db=db_name, | |
| ip_type=ip_type, | |
| ) | |
| return conn | |
| pool = sqlalchemy.create_engine( | |
| "postgresql+pg8000://", | |
| creator=getconn, | |
| ) | |
| connection = pool.connect() | |
| return connection |