import io import os import uuid import logging from typing import Optional from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Header from fastapi.responses import FileResponse, JSONResponse from pydantic import BaseModel import torch import numpy as np from PIL import Image from infer_full import StableHair LOGGER = logging.getLogger("hair_server") logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s - %(message)s") EXPECTED_BEARER = "logicgo@123" def verify_bearer(authorization: Optional[str] = Header(None)): if not authorization: raise HTTPException(status_code=401, detail="Missing Authorization header") try: scheme, token = authorization.split(" ", 1) except ValueError: raise HTTPException(status_code=401, detail="Invalid Authorization header format") if scheme.lower() != "bearer": raise HTTPException(status_code=401, detail="Invalid auth scheme") if token != EXPECTED_BEARER: raise HTTPException(status_code=401, detail="Invalid token") return True app = FastAPI(title="Hair Swap API", version="1.0.0") @app.get("/health") def health(): return {"status": "healthy"} class HairSwapRequest(BaseModel): source_id: str reference_id: str converter_scale: float = 1.0 scale: float = 1.0 guidance_scale: float = 1.5 controlnet_conditioning_scale: float = 1.0 # Initialize model lazily on first request _model: Optional[StableHair] = None def get_model() -> StableHair: global _model if _model is None: LOGGER.info("Loading StableHair model ...") device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 _model = StableHair(config="./configs/hair_transfer.yaml", device=device, weight_dtype=dtype) LOGGER.info("Model loaded") return _model UPLOAD_DIR = os.path.join(os.getcwd(), "uploads") RESULTS_DIR = os.path.join(os.getcwd(), "results") LOGS_DIR = os.path.join(os.getcwd(), "logs") os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(RESULTS_DIR, exist_ok=True) os.makedirs(LOGS_DIR, exist_ok=True) @app.post("/upload") async def upload_image(image: UploadFile = File(...), _=Depends(verify_bearer)): if not image.filename: raise HTTPException(status_code=400, detail="No file name provided") contents = await image.read() try: Image.open(io.BytesIO(contents)).convert("RGB") except Exception: raise HTTPException(status_code=400, detail="Invalid image file") image_id = str(uuid.uuid4()) ext = os.path.splitext(image.filename)[1] or ".png" path = os.path.join(UPLOAD_DIR, image_id + ext) with open(path, "wb") as f: f.write(contents) return {"id": image_id, "filename": os.path.basename(path)} @app.post("/get-hairswap") def get_hairswap(req: HairSwapRequest, _=Depends(verify_bearer)): # Resolve file paths def find_file(image_id: str) -> str: for name in os.listdir(UPLOAD_DIR): if name.startswith(image_id): return os.path.join(UPLOAD_DIR, name) raise HTTPException(status_code=404, detail=f"Image id not found: {image_id}") source_path = find_file(req.source_id) reference_path = find_file(req.reference_id) model = get_model() # Prepare kwargs similar to infer_full id_np, out_np, bald_np, ref_np = model.Hair_Transfer( source_image=source_path, reference_image=reference_path, random_seed=-1, step=30, guidance_scale=req.guidance_scale, scale=req.scale, controlnet_conditioning_scale=req.controlnet_conditioning_scale, size=512, ) # Save result result_id = str(uuid.uuid4()) out_img = Image.fromarray((out_np * 255.).astype(np.uint8)) filename = f"{result_id}.png" out_path = os.path.join(RESULTS_DIR, filename) out_img.save(out_path) return {"result": filename} @app.get("/download/{filename}") def download(filename: str, _=Depends(verify_bearer)): path = os.path.join(RESULTS_DIR, filename) if not os.path.exists(path): raise HTTPException(status_code=404, detail="File not found") return FileResponse(path, media_type="image/png", filename=filename) @app.get("/logs") def logs(_=Depends(verify_bearer)): # Simple in-memory/log-file placeholder return JSONResponse({"logs": ["service running"]})