Spaces:
Running
Running
| from fastapi import HTTPException | |
| from pydantic import BaseModel | |
| from modules import refiner | |
| from modules.api import utils as api_utils | |
| from modules.api.Api import APIManager | |
| from modules.normalization import text_normalize | |
| class RefineTextRequest(BaseModel): | |
| text: str | |
| prompt: str = "[oral_2][laugh_0][break_6]" | |
| seed: int = -1 | |
| top_P: float = 0.7 | |
| top_K: int = 20 | |
| temperature: float = 0.7 | |
| repetition_penalty: float = 1.0 | |
| max_new_token: int = 384 | |
| normalize: bool = True | |
| async def refiner_prompt_post(request: RefineTextRequest): | |
| """ | |
| This endpoint receives a prompt and returns the refined result | |
| """ | |
| try: | |
| text = request.text | |
| if request.normalize: | |
| text = text_normalize(request.text) | |
| refined_text = refiner.refine_text( | |
| text=text, | |
| prompt=request.prompt, | |
| seed=request.seed, | |
| top_P=request.top_P, | |
| top_K=request.top_K, | |
| temperature=request.temperature, | |
| repetition_penalty=request.repetition_penalty, | |
| max_new_token=request.max_new_token, | |
| ) | |
| return {"message": "ok", "data": refined_text} | |
| except Exception as e: | |
| import logging | |
| logging.exception(e) | |
| if isinstance(e, HTTPException): | |
| raise e | |
| else: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def setup(api_manager: APIManager): | |
| api_manager.post("/v1/prompt/refine", response_model=api_utils.BaseResponse)( | |
| refiner_prompt_post | |
| ) | |