Spaces:
Paused
Paused
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| import asyncio | |
| import logging | |
| import os | |
| import traceback | |
| import argparse | |
| import uvicorn | |
| import numpy as np | |
| import librosa | |
| import io | |
| import tempfile | |
| from core import WhisperLiveKit | |
| from audio_processor import AudioProcessor | |
| from language_detector import LanguageDetector | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logging.getLogger().setLevel(logging.WARNING) | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.DEBUG) | |
| kit = None | |
| language_detector = None | |
| async def lifespan(app: FastAPI): | |
| global kit, language_detector | |
| kit = WhisperLiveKit() | |
| language_detector = LanguageDetector(model_name="turbo") | |
| yield | |
| app = FastAPI(lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allows all methods | |
| allow_headers=["*"], # Allows all headers | |
| ) | |
| # Mount static files | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| async def read_root(): | |
| return FileResponse("static/index.html") | |
| async def health_check(): | |
| return JSONResponse({"status": "healthy"}) | |
| async def detect_language(file: UploadFile = File(...)): | |
| try: | |
| # Use a temporary directory for saving the uploaded file | |
| with tempfile.NamedTemporaryFile(delete=False) as temp_file: | |
| file_path = temp_file.name | |
| contents = await file.read() | |
| temp_file.write(contents) | |
| # Use the language detector with the saved file | |
| if language_detector: | |
| detected_lang, confidence, probs = language_detector.detect_language_from_file(file_path) | |
| # Clean up - remove the temporary file | |
| os.remove(file_path) | |
| return JSONResponse({ | |
| "language": detected_lang, | |
| "confidence": float(confidence), | |
| "probabilities": {lang: float(prob) for lang, prob in probs.items()} | |
| }) | |
| else: | |
| return JSONResponse( | |
| {"error": "Language detector not initialized"}, | |
| status_code=500 | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in language detection: {e}") | |
| logger.error(f"Traceback: {traceback.format_exc()}") | |
| # Clean up in case of error | |
| if 'file_path' in locals() and os.path.exists(file_path): | |
| os.remove(file_path) | |
| return JSONResponse( | |
| {"error": str(e)}, | |
| status_code=500 | |
| ) | |
| async def handle_websocket_results(websocket, results_generator): | |
| """Consumes results from the audio processor and sends them via WebSocket.""" | |
| try: | |
| async for response in results_generator: | |
| try: | |
| logger.debug(f"Sending response: {response}") | |
| if isinstance(response, dict): | |
| # Ensure the response has a consistent format | |
| if 'buffer_transcription' in response: | |
| await websocket.send_json({ | |
| 'buffer_transcription': response['buffer_transcription'] | |
| }) | |
| elif 'full_transcription' in response: | |
| await websocket.send_json({ | |
| 'full_transcription': response['full_transcription'] | |
| }) | |
| else: | |
| await websocket.send_json(response) | |
| else: | |
| # If response is not a dict, wrap it in a text field | |
| await websocket.send_json({"text": str(response)}) | |
| except Exception as e: | |
| logger.error(f"Error sending message: {e}") | |
| logger.error(f"Traceback: {traceback.format_exc()}") | |
| raise | |
| except Exception as e: | |
| logger.warning(f"Error in WebSocket results handler: {e}") | |
| logger.warning(f"Traceback: {traceback.format_exc()}") | |
| async def websocket_endpoint(websocket: WebSocket): | |
| logger.info("New WebSocket connection request") | |
| audio_processor = None | |
| websocket_task = None | |
| try: | |
| await websocket.accept() | |
| logger.info("WebSocket connection accepted") | |
| audio_processor = AudioProcessor() | |
| results_generator = await audio_processor.create_tasks() | |
| websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator)) | |
| while True: | |
| try: | |
| message = await websocket.receive_bytes() | |
| logger.debug(f"Received audio chunk of size: {len(message)}") | |
| await audio_processor.process_audio(message) | |
| except WebSocketDisconnect: | |
| logger.info("WebSocket connection closed") | |
| break | |
| except Exception as e: | |
| logger.error(f"Error processing WebSocket message: {e}") | |
| logger.error(f"Traceback: {traceback.format_exc()}") | |
| break | |
| except Exception as e: | |
| logger.error(f"Error in WebSocket endpoint: {e}") | |
| logger.error(f"Traceback: {traceback.format_exc()}") | |
| finally: | |
| if audio_processor: | |
| await audio_processor.cleanup() | |
| if websocket_task: | |
| websocket_task.cancel() | |
| try: | |
| await websocket_task | |
| except asyncio.CancelledError: | |
| pass | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on") | |
| parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") | |
| parser.add_argument("--model", type=str, default="base", help="Whisper model to use") | |
| parser.add_argument("--backend", type=str, default="faster-whisper", help="Backend to use") | |
| parser.add_argument("--task", type=str, default="transcribe", help="Task to perform") | |
| args = parser.parse_args() | |
| uvicorn.run(app, host=args.host, port=args.port) |