Spaces:
Build error
Build error
| import asyncio | |
| import base64 | |
| import os | |
| import time | |
| from io import BytesIO | |
| import gradio as gr | |
| import numpy as np | |
| import websockets | |
| from dotenv import load_dotenv | |
| from fastrtc import ( | |
| AsyncAudioVideoStreamHandler, | |
| Stream, | |
| WebRTC, | |
| get_cloudflare_turn_credentials_async, | |
| wait_for_item, | |
| ) | |
| from google import genai | |
| from gradio.utils import get_space | |
| from PIL import Image | |
| load_dotenv() | |
| def encode_audio(data: np.ndarray) -> dict: | |
| """Encode Audio data to send to the server""" | |
| return { | |
| "mime_type": "audio/pcm", | |
| "data": base64.b64encode(data.tobytes()).decode("UTF-8"), | |
| } | |
| def encode_image(data: np.ndarray) -> dict: | |
| with BytesIO() as output_bytes: | |
| pil_image = Image.fromarray(data) | |
| pil_image.save(output_bytes, "JPEG") | |
| bytes_data = output_bytes.getvalue() | |
| base64_str = str(base64.b64encode(bytes_data), "utf-8") | |
| return {"mime_type": "image/jpeg", "data": base64_str} | |
| class GeminiHandler(AsyncAudioVideoStreamHandler): | |
| def __init__( | |
| self, | |
| ) -> None: | |
| super().__init__( | |
| "mono", | |
| output_sample_rate=24000, | |
| input_sample_rate=16000, | |
| ) | |
| self.audio_queue = asyncio.Queue() | |
| self.video_queue = asyncio.Queue() | |
| self.session = None | |
| self.last_frame_time = 0 | |
| self.quit = asyncio.Event() | |
| def copy(self) -> "GeminiHandler": | |
| return GeminiHandler() | |
| async def start_up(self): | |
| await self.wait_for_args() | |
| api_key = self.latest_args[3] | |
| client = genai.Client( | |
| api_key=api_key, http_options={"api_version": "v1alpha"} | |
| ) | |
| config = {"response_modalities": ["AUDIO"], "system_instruction": "You are an art history teacher that will describe the artwork passed in as an image to the user. Describe the history and significance of the artwork."} | |
| async with client.aio.live.connect( | |
| model="gemini-2.0-flash-exp", | |
| config=config, # type: ignore | |
| ) as session: | |
| self.session = session | |
| while not self.quit.is_set(): | |
| turn = self.session.receive() | |
| try: | |
| async for response in turn: | |
| if data := response.data: | |
| audio = np.frombuffer(data, dtype=np.int16).reshape(1, -1) | |
| self.audio_queue.put_nowait(audio) | |
| except websockets.exceptions.ConnectionClosedOK: | |
| print("connection closed") | |
| break | |
| async def video_receive(self, frame: np.ndarray): | |
| self.video_queue.put_nowait(frame) | |
| if self.session: | |
| # send image every 1 second | |
| print(time.time() - self.last_frame_time) | |
| if time.time() - self.last_frame_time > 1: | |
| self.last_frame_time = time.time() | |
| await self.session.send(input=encode_image(frame)) | |
| if self.latest_args[2] is not None: | |
| await self.session.send(input=encode_image(self.latest_args[2])) | |
| async def video_emit(self): | |
| frame = await wait_for_item(self.video_queue, 0.01) | |
| if frame is not None: | |
| return frame | |
| else: | |
| return np.zeros((100, 100, 3), dtype=np.uint8) | |
| async def receive(self, frame: tuple[int, np.ndarray]) -> None: | |
| _, array = frame | |
| array = array.squeeze() | |
| audio_message = encode_audio(array) | |
| if self.session: | |
| await self.session.send(input=audio_message) | |
| async def emit(self): | |
| array = await wait_for_item(self.audio_queue, 0.01) | |
| if array is not None: | |
| return (self.output_sample_rate, array) | |
| return array | |
| async def shutdown(self) -> None: | |
| if self.session: | |
| self.quit.set() | |
| await self.session.close() | |
| self.quit.clear() | |
| stream = Stream( | |
| handler=GeminiHandler(), | |
| modality="audio-video", | |
| mode="send-receive", | |
| rtc_configuration=get_cloudflare_turn_credentials_async if get_space() else None, | |
| time_limit=180 if get_space() else None, | |
| additional_inputs=[ | |
| gr.Markdown( | |
| "## 🎨 Art History Teacher\n\n" | |
| "Provide an image of the artwork and Gemini will describe it to you." | |
| "To get a Gemini API key, please visit the [Gemini API Key](https://console.cloud.google.com/apis/api/generativelanguage.googleapis.com/credentials) page." | |
| ), | |
| gr.Image(label="Image", type="numpy", sources=["upload", "clipboard"]), | |
| gr.Textbox(label="Gemini API Key"), | |
| ], | |
| ui_args={ | |
| "icon": "https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png", | |
| "pulse_color": "rgb(255, 255, 255)", | |
| "icon_button_color": "rgb(255, 255, 255)", | |
| "title": "Gemini Audio Video Chat", | |
| }, | |
| ) | |
| if __name__ == "__main__": | |
| stream.ui.launch() | |