Spaces:
Running
Running
| from fastapi import APIRouter, Depends, HTTPException, status | |
| from fastapi.responses import FileResponse | |
| from sqlalchemy.orm import Session | |
| import os | |
| from backend.db.db_instance import get_db_session | |
| from backend.db.task.dao import ( | |
| get_task_status_from_db, | |
| get_all_tasks_status_from_db, | |
| delete_task_from_db, | |
| ) | |
| from backend.db.task.models import ( | |
| TasksResult, | |
| Task, | |
| TaskStatusResponse, | |
| TaskType | |
| ) | |
| from backend.common.models import ( | |
| Response, | |
| ) | |
| from backend.common.compresser import compress_files, find_file_by_hash | |
| from modules.utils.paths import BACKEND_CACHE_DIR | |
| task_router = APIRouter(prefix="/task", tags=["Tasks"]) | |
| async def get_task( | |
| identifier: str, | |
| session: Session = Depends(get_db_session), | |
| ) -> TaskStatusResponse: | |
| """ | |
| Retrieve the specific task by its identifier. | |
| """ | |
| task = get_task_status_from_db(identifier=identifier, session=session) | |
| if task is not None: | |
| return task.to_response() | |
| else: | |
| raise HTTPException(status_code=404, detail="Identifier not found") | |
| async def get_file_task( | |
| identifier: str, | |
| session: Session = Depends(get_db_session), | |
| ) -> FileResponse: | |
| """ | |
| Retrieve the downloadable file response of a specific task by its identifier. | |
| Compressed by ZIP basically. | |
| """ | |
| task = get_task_status_from_db(identifier=identifier, session=session) | |
| if task is not None: | |
| if task.task_type == TaskType.BGM_SEPARATION: | |
| output_zip_path = os.path.join(BACKEND_CACHE_DIR, f"{identifier}_bgm_separation.zip") | |
| instrumental_path = find_file_by_hash( | |
| os.path.join(BACKEND_CACHE_DIR, "UVR", "instrumental"), | |
| task.result["instrumental_hash"] | |
| ) | |
| vocal_path = find_file_by_hash( | |
| os.path.join(BACKEND_CACHE_DIR, "UVR", "vocals"), | |
| task.result["vocal_hash"] | |
| ) | |
| output_zip_path = compress_files( | |
| [instrumental_path, vocal_path], | |
| output_zip_path | |
| ) | |
| return FileResponse( | |
| path=output_zip_path, | |
| status_code=200, | |
| filename=output_zip_path, | |
| media_type="application/zip" | |
| ) | |
| else: | |
| raise HTTPException(status_code=404, detail=f"File download is only supported for bgm separation." | |
| f" The given type is {task.task_type}") | |
| else: | |
| raise HTTPException(status_code=404, detail="Identifier not found") | |
| # Delete method, commented by default because this endpoint is likely to require special permissions | |
| # @task_router.delete( | |
| # "/{identifier}", | |
| # response_model=Response, | |
| # status_code=status.HTTP_200_OK, | |
| # summary="Delete Task by Identifier", | |
| # description="Delete a task from the system using its identifier.", | |
| # ) | |
| async def delete_task( | |
| identifier: str, | |
| session: Session = Depends(get_db_session), | |
| ) -> Response: | |
| """ | |
| Delete a task by its identifier. | |
| """ | |
| if delete_task_from_db(identifier, session): | |
| return Response(identifier=identifier, message="Task deleted") | |
| else: | |
| raise HTTPException(status_code=404, detail="Task not found") | |
| # Get All method, commented by default because this endpoint is likely to require special permissions | |
| # @task_router.get( | |
| # "/all", | |
| # response_model=TasksResult, | |
| # status_code=status.HTTP_200_OK, | |
| # summary="Retrieve All Task Statuses", | |
| # description="Retrieve the statuses of all tasks available in the system.", | |
| # ) | |
| async def get_all_tasks_status( | |
| session: Session = Depends(get_db_session), | |
| ) -> TasksResult: | |
| """ | |
| Retrieve all tasks. | |
| """ | |
| return get_all_tasks_status_from_db(session=session) |