Hair_stable_new / server.py
LogicGoInfotechSpaces's picture
feat(api): add FastAPI server with hair swap endpoints and auth; update requirements
83f52c2
raw
history blame
4.47 kB
import io
import os
import uuid
import logging
from typing import Optional
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Header
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
import torch
import numpy as np
from PIL import Image
from infer_full import StableHair
LOGGER = logging.getLogger("hair_server")
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s - %(message)s")
EXPECTED_BEARER = "logicgo@123"
def verify_bearer(authorization: Optional[str] = Header(None)):
if not authorization:
raise HTTPException(status_code=401, detail="Missing Authorization header")
try:
scheme, token = authorization.split(" ", 1)
except ValueError:
raise HTTPException(status_code=401, detail="Invalid Authorization header format")
if scheme.lower() != "bearer":
raise HTTPException(status_code=401, detail="Invalid auth scheme")
if token != EXPECTED_BEARER:
raise HTTPException(status_code=401, detail="Invalid token")
return True
app = FastAPI(title="Hair Swap API", version="1.0.0")
@app.get("/health")
def health():
return {"status": "healthy"}
class HairSwapRequest(BaseModel):
source_id: str
reference_id: str
converter_scale: float = 1.0
scale: float = 1.0
guidance_scale: float = 1.5
controlnet_conditioning_scale: float = 1.0
# Initialize model lazily on first request
_model: Optional[StableHair] = None
def get_model() -> StableHair:
global _model
if _model is None:
LOGGER.info("Loading StableHair model ...")
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
_model = StableHair(config="./configs/hair_transfer.yaml", device=device, weight_dtype=dtype)
LOGGER.info("Model loaded")
return _model
UPLOAD_DIR = os.path.join(os.getcwd(), "uploads")
RESULTS_DIR = os.path.join(os.getcwd(), "results")
LOGS_DIR = os.path.join(os.getcwd(), "logs")
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)
os.makedirs(LOGS_DIR, exist_ok=True)
@app.post("/upload")
async def upload_image(image: UploadFile = File(...), _=Depends(verify_bearer)):
if not image.filename:
raise HTTPException(status_code=400, detail="No file name provided")
contents = await image.read()
try:
Image.open(io.BytesIO(contents)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Invalid image file")
image_id = str(uuid.uuid4())
ext = os.path.splitext(image.filename)[1] or ".png"
path = os.path.join(UPLOAD_DIR, image_id + ext)
with open(path, "wb") as f:
f.write(contents)
return {"id": image_id, "filename": os.path.basename(path)}
@app.post("/get-hairswap")
def get_hairswap(req: HairSwapRequest, _=Depends(verify_bearer)):
# Resolve file paths
def find_file(image_id: str) -> str:
for name in os.listdir(UPLOAD_DIR):
if name.startswith(image_id):
return os.path.join(UPLOAD_DIR, name)
raise HTTPException(status_code=404, detail=f"Image id not found: {image_id}")
source_path = find_file(req.source_id)
reference_path = find_file(req.reference_id)
model = get_model()
# Prepare kwargs similar to infer_full
id_np, out_np, bald_np, ref_np = model.Hair_Transfer(
source_image=source_path,
reference_image=reference_path,
random_seed=-1,
step=30,
guidance_scale=req.guidance_scale,
scale=req.scale,
controlnet_conditioning_scale=req.controlnet_conditioning_scale,
size=512,
)
# Save result
result_id = str(uuid.uuid4())
out_img = Image.fromarray((out_np * 255.).astype(np.uint8))
filename = f"{result_id}.png"
out_path = os.path.join(RESULTS_DIR, filename)
out_img.save(out_path)
return {"result": filename}
@app.get("/download/{filename}")
def download(filename: str, _=Depends(verify_bearer)):
path = os.path.join(RESULTS_DIR, filename)
if not os.path.exists(path):
raise HTTPException(status_code=404, detail="File not found")
return FileResponse(path, media_type="image/png", filename=filename)
@app.get("/logs")
def logs(_=Depends(verify_bearer)):
# Simple in-memory/log-file placeholder
return JSONResponse({"logs": ["service running"]})