Spaces:
Runtime error
Runtime error
| # Copyright 2025 Google LLC | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import struct | |
| import re | |
| import logging | |
| import io | |
| from typing import Optional, Dict, Tuple, Union | |
| import google.generativeai as genai | |
| from pydub import AudioSegment | |
| from cache import cache | |
| # --- Constants --- | |
| GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") | |
| GENERATE_SPEECH = os.environ.get("GENERATE_SPEECH", "false").lower() == "true" | |
| TTS_MODEL = "gemini-2.5-flash-preview-tts" | |
| DEFAULT_RAW_AUDIO_MIME = "audio/L16;rate=24000" | |
| # --- Setup --- | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| genai.configure(api_key=GEMINI_API_KEY) | |
| class TTSGenerationError(Exception): | |
| """Raised when Gemini TTS generation fails.""" | |
| pass | |
| def parse_audio_mime_type(mime_type: str) -> Dict[str, int]: | |
| """ | |
| Extracts bits_per_sample and sampling rate from a MIME string. | |
| e.g. "audio/L16;rate=24000" → {"bits_per_sample": 16, "rate": 24000} | |
| """ | |
| bits_per_sample = 16 | |
| rate = 24000 | |
| for param in mime_type.split(";"): | |
| param = param.strip().lower() | |
| if param.startswith("rate="): | |
| try: | |
| rate = int(param.split("=", 1)[1]) | |
| except ValueError: | |
| pass | |
| elif re.match(r"audio/l\d+", param): | |
| try: | |
| bits_per_sample = int(param.split("l", 1)[1]) | |
| except ValueError: | |
| pass | |
| return {"bits_per_sample": bits_per_sample, "rate": rate} | |
| def convert_to_wav(audio_data: bytes, mime_type: str) -> bytes: | |
| """Wrap raw PCM bytes in a WAV header for mono audio.""" | |
| params = parse_audio_mime_type(mime_type) | |
| bits = params["bits_per_sample"] | |
| rate = params["rate"] | |
| num_channels = 1 | |
| bytes_per_sample = bits // 8 | |
| block_align = num_channels * bytes_per_sample | |
| byte_rate = rate * block_align | |
| data_size = len(audio_data) | |
| chunk_size = 36 + data_size | |
| header = struct.pack( | |
| "<4sI4s4sIHHIIHH4sI", | |
| b"RIFF", | |
| chunk_size, | |
| b"WAVE", | |
| b"fmt ", | |
| 16, | |
| 1, | |
| num_channels, | |
| rate, | |
| byte_rate, | |
| block_align, | |
| bits, | |
| b"data", | |
| data_size, | |
| ) | |
| return header + audio_data | |
| def _synthesize_gemini_tts_impl(text: str, gemini_voice_name: str) -> Tuple[bytes, str]: | |
| """Core function to request audio from Gemini TTS (cached).""" | |
| if not GENERATE_SPEECH: | |
| raise TTSGenerationError("GENERATE_SPEECH not enabled in environment.") | |
| try: | |
| model = genai.GenerativeModel(TTS_MODEL) | |
| response = model.generate_content( | |
| contents=[text], | |
| generation_config={ | |
| "response_modalities": ["AUDIO"], | |
| "speech_config": { | |
| "voice_config": { | |
| "prebuilt_voice_config": {"voice_name": gemini_voice_name} | |
| } | |
| } | |
| }, | |
| ) | |
| audio_part = response.candidates[0].content.parts[0] | |
| raw_data = audio_part.inline_data.data | |
| mime = audio_part.inline_data.mime_type | |
| except Exception as e: | |
| logging.error("Gemini TTS API error: %s", e) | |
| raise TTSGenerationError(f"TTS request failed: {e}") | |
| if not raw_data: | |
| raise TTSGenerationError("Empty audio data from Gemini.") | |
| # Convert raw audio to WAV if needed | |
| mime_lower = mime.lower() if mime else "" | |
| if mime_lower and ( | |
| mime_lower.startswith("audio/l") | |
| or not mime_lower.startswith(("audio/wav", "audio/mpeg", "audio/ogg", "audio/opus")) | |
| ): | |
| raw_data = convert_to_wav(raw_data, mime_lower) | |
| mime = "audio/wav" | |
| elif not mime: | |
| logging.warning("MIME missing; defaulting to WAV") | |
| raw_data = convert_to_wav(raw_data, DEFAULT_RAW_AUDIO_MIME) | |
| mime = "audio/wav" | |
| # Attempt MP3 compression | |
| try: | |
| segment = AudioSegment.from_file(io.BytesIO(raw_data), format="wav") | |
| buf = io.BytesIO() | |
| segment.export(buf, format="mp3") | |
| return buf.getvalue(), "audio/mpeg" | |
| except Exception as e: | |
| logging.warning("MP3 conversion failed (%s); returning WAV", e) | |
| return raw_data, mime | |
| # Choose wrapper based on GENERATE_SPEECH flag | |
| if GENERATE_SPEECH: | |
| def synthesize_gemini_tts(text: str, voice: str) -> Tuple[Optional[bytes], Optional[str]]: | |
| try: | |
| return _synthesize_gemini_tts_impl(text, voice) | |
| except TTSGenerationError as e: | |
| logging.error("TTS failed: %s; skipping audio", e) | |
| return None, None | |
| else: | |
| def synthesize_gemini_tts(text: str, voice: str) -> Tuple[Optional[bytes], Optional[str]]: | |
| key = _synthesize_gemini_tts_impl.__cache_key__(text, voice) | |
| result = cache.get(key) | |
| if result is not None: | |
| return result | |
| logging.info("No cached audio; speech disabled.") | |
| return None, None | |