Spaces:
Runtime error
Runtime error
| import sys | |
| import os | |
| import asyncio | |
| import base64 | |
| import logging | |
| from io import BytesIO | |
| from pathlib import Path | |
| import uvicorn | |
| from config import Config | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from PIL import Image | |
| from pydantic import BaseModel | |
| sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "..")) | |
| from utils.wrapper import StreamDiffusionWrapper | |
| logger = logging.getLogger("uvicorn") | |
| PROJECT_DIR = Path(__file__).parent.parent | |
| class PredictInputModel(BaseModel): | |
| """ | |
| The input model for the /predict endpoint. | |
| """ | |
| prompt: str | |
| class PredictResponseModel(BaseModel): | |
| """ | |
| The response model for the /predict endpoint. | |
| """ | |
| base64_image: str | |
| class UpdatePromptResponseModel(BaseModel): | |
| """ | |
| The response model for the /update_prompt endpoint. | |
| """ | |
| prompt: str | |
| class Api: | |
| def __init__(self, config: Config) -> None: | |
| """ | |
| Initialize the API. | |
| Parameters | |
| ---------- | |
| config : Config | |
| The configuration. | |
| """ | |
| self.config = config | |
| self.stream_diffusion = StreamDiffusionWrapper( | |
| mode=config.mode, | |
| model_id_or_path=config.model_id_or_path, | |
| lcm_lora_id=config.lcm_lora_id, | |
| vae_id=config.vae_id, | |
| device=config.device, | |
| dtype=config.dtype, | |
| acceleration=config.acceleration, | |
| t_index_list=config.t_index_list, | |
| warmup=config.warmup, | |
| use_safety_checker=config.use_safety_checker, | |
| cfg_type="none", | |
| ) | |
| self.app = FastAPI() | |
| self.app.add_api_route( | |
| "/api/predict", | |
| self._predict, | |
| methods=["POST"], | |
| response_model=PredictResponseModel, | |
| ) | |
| self.app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| self.app.mount( | |
| "/", StaticFiles(directory="../view/build", html=True), name="public" | |
| ) | |
| self._predict_lock = asyncio.Lock() | |
| self._update_prompt_lock = asyncio.Lock() | |
| async def _predict(self, inp: PredictInputModel) -> PredictResponseModel: | |
| """ | |
| Predict an image and return. | |
| Parameters | |
| ---------- | |
| inp : PredictInputModel | |
| The input. | |
| Returns | |
| ------- | |
| PredictResponseModel | |
| The prediction result. | |
| """ | |
| async with self._predict_lock: | |
| return PredictResponseModel( | |
| base64_image=self._pil_to_base64( | |
| self.stream_diffusion(prompt=inp.prompt) | |
| ) | |
| ) | |
| def _pil_to_base64(self, image: Image.Image, format: str = "JPEG") -> bytes: | |
| """ | |
| Convert a PIL image to base64. | |
| Parameters | |
| ---------- | |
| image : Image.Image | |
| The PIL image. | |
| format : str | |
| The image format, by default "JPEG". | |
| Returns | |
| ------- | |
| bytes | |
| The base64 image. | |
| """ | |
| buffered = BytesIO() | |
| image.convert("RGB").save(buffered, format=format) | |
| return base64.b64encode(buffered.getvalue()).decode("ascii") | |
| def _base64_to_pil(self, base64_image: str) -> Image.Image: | |
| """ | |
| Convert a base64 image to PIL. | |
| Parameters | |
| ---------- | |
| base64_image : str | |
| The base64 image. | |
| Returns | |
| ------- | |
| Image.Image | |
| The PIL image. | |
| """ | |
| if "base64," in base64_image: | |
| base64_image = base64_image.split("base64,")[1] | |
| return Image.open(BytesIO(base64.b64decode(base64_image))).convert("RGB") | |
| if __name__ == "__main__": | |
| from config import Config | |
| config = Config() | |
| uvicorn.run( | |
| Api(config).app, | |
| host=config.host, | |
| port=config.port, | |
| workers=config.workers, | |
| ) | |