Spaces:
Runtime error
Runtime error
| import fastapi | |
| import uvicorn | |
| from fastapi import File, UploadFile, Form, HTTPException | |
| from fastapi.responses import JSONResponse, FileResponse | |
| from load_models import get_nllb_model_and_tokenizer, get_xtts_model | |
| from inference_functions import translate, just_inference | |
| import os | |
| import torch | |
| # Set GPU memory fraction | |
| torch.cuda.set_per_process_memory_fraction(0.75, 0) | |
| # Load models | |
| model_nllb, tokenizer_nllb = get_nllb_model_and_tokenizer() | |
| model_xtts = get_xtts_model() | |
| app = fastapi.FastAPI() | |
| def health_check(): | |
| return {"status": "ok"} | |
| def translate_text(text: str = Form(...), target_lang: str = Form(...)): | |
| translation = translate(model_nllb, tokenizer_nllb, text, target_lang) | |
| return {"translation": translation} | |
| def inference_audio(original_path: UploadFile = File(...), text: str = Form(...), lang: str = Form(...)): | |
| # Save the uploaded file | |
| file_location = f"/tmp/{original_path.filename}" | |
| with open(file_location, "wb") as file: | |
| file.write(original_path.file.read()) | |
| output_dir = f"/tmp/generated_audio_{os.path.basename(file_location)}.wav" | |
| torch.cuda.empty_cache() | |
| generated_audio = just_inference(model_xtts, file_location, output_dir, text, lang) | |
| return {"path_to_save": output_dir} | |
| async def process_audio(original_path: UploadFile = File(...), text: str = Form(...), lang: str = Form(...), target_lang: str = Form(...)): | |
| print(f"original_path: {original_path.filename}") | |
| print(f"text: {text}") | |
| print(f"lang: {lang}") | |
| print(f"target_lang: {target_lang}") | |
| # Validate target language | |
| if target_lang not in ["es", "en"]: # Use 'es' and 'en' to match the example values | |
| print("Unsupported language") | |
| raise HTTPException(status_code=400, detail="Unsupported language. Use 'spanish' or 'english'.") | |
| try: | |
| # Translate the text first | |
| translated_text = translate(model_nllb, tokenizer_nllb, text, target_lang) | |
| print(f"translated_text: {translated_text}") | |
| # Save the uploaded file | |
| file_location = f"/tmp/{original_path.filename}" | |
| with open(file_location, "wb") as file: | |
| file.write(original_path.file.read()) | |
| output_dir = f"/tmp/generated_audio_{os.path.basename(file_location)}.wav" | |
| torch.cuda.empty_cache() | |
| generated_audio = just_inference(model_xtts, file_location, output_dir, translated_text, target_lang) | |
| return JSONResponse(content={"audio_path": output_dir, "translation": translated_text}) | |
| except Exception as e: | |
| print(f"Error during processing: {e}") | |
| raise HTTPException(status_code=500, detail="Error during processing") | |
| def download_audio(file_path: str): | |
| if not os.path.exists(file_path): | |
| raise HTTPException(status_code=404, detail="File not found") | |
| return FileResponse(file_path, media_type='audio/wav', filename=os.path.basename(file_path)) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |