|
|
import io |
|
|
import os |
|
|
import uuid |
|
|
import logging |
|
|
from typing import Optional |
|
|
|
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Header |
|
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
from pydantic import BaseModel |
|
|
import torch |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
from infer_full import StableHair |
|
|
|
|
|
|
|
|
LOGGER = logging.getLogger("hair_server") |
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s - %(message)s") |
|
|
|
|
|
EXPECTED_BEARER = "logicgo@123" |
|
|
|
|
|
|
|
|
def verify_bearer(authorization: Optional[str] = Header(None)): |
|
|
if not authorization: |
|
|
raise HTTPException(status_code=401, detail="Missing Authorization header") |
|
|
try: |
|
|
scheme, token = authorization.split(" ", 1) |
|
|
except ValueError: |
|
|
raise HTTPException(status_code=401, detail="Invalid Authorization header format") |
|
|
if scheme.lower() != "bearer": |
|
|
raise HTTPException(status_code=401, detail="Invalid auth scheme") |
|
|
if token != EXPECTED_BEARER: |
|
|
raise HTTPException(status_code=401, detail="Invalid token") |
|
|
return True |
|
|
|
|
|
|
|
|
app = FastAPI(title="Hair Swap API", version="1.0.0") |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health(): |
|
|
return {"status": "healthy"} |
|
|
|
|
|
|
|
|
class HairSwapRequest(BaseModel): |
|
|
source_id: str |
|
|
reference_id: str |
|
|
converter_scale: float = 1.0 |
|
|
scale: float = 1.0 |
|
|
guidance_scale: float = 1.5 |
|
|
controlnet_conditioning_scale: float = 1.0 |
|
|
|
|
|
|
|
|
|
|
|
_model: Optional[StableHair] = None |
|
|
|
|
|
|
|
|
def get_model() -> StableHair: |
|
|
global _model |
|
|
if _model is None: |
|
|
LOGGER.info("Loading StableHair model ...") |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
_model = StableHair(config="./configs/hair_transfer.yaml", device=device, weight_dtype=dtype) |
|
|
LOGGER.info("Model loaded") |
|
|
return _model |
|
|
|
|
|
|
|
|
UPLOAD_DIR = os.path.join(os.getcwd(), "uploads") |
|
|
RESULTS_DIR = os.path.join(os.getcwd(), "results") |
|
|
LOGS_DIR = os.path.join(os.getcwd(), "logs") |
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
os.makedirs(RESULTS_DIR, exist_ok=True) |
|
|
os.makedirs(LOGS_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
@app.post("/upload") |
|
|
async def upload_image(image: UploadFile = File(...), _=Depends(verify_bearer)): |
|
|
if not image.filename: |
|
|
raise HTTPException(status_code=400, detail="No file name provided") |
|
|
contents = await image.read() |
|
|
try: |
|
|
Image.open(io.BytesIO(contents)).convert("RGB") |
|
|
except Exception: |
|
|
raise HTTPException(status_code=400, detail="Invalid image file") |
|
|
|
|
|
image_id = str(uuid.uuid4()) |
|
|
ext = os.path.splitext(image.filename)[1] or ".png" |
|
|
path = os.path.join(UPLOAD_DIR, image_id + ext) |
|
|
with open(path, "wb") as f: |
|
|
f.write(contents) |
|
|
return {"id": image_id, "filename": os.path.basename(path)} |
|
|
|
|
|
|
|
|
@app.post("/get-hairswap") |
|
|
def get_hairswap(req: HairSwapRequest, _=Depends(verify_bearer)): |
|
|
|
|
|
def find_file(image_id: str) -> str: |
|
|
for name in os.listdir(UPLOAD_DIR): |
|
|
if name.startswith(image_id): |
|
|
return os.path.join(UPLOAD_DIR, name) |
|
|
raise HTTPException(status_code=404, detail=f"Image id not found: {image_id}") |
|
|
|
|
|
source_path = find_file(req.source_id) |
|
|
reference_path = find_file(req.reference_id) |
|
|
|
|
|
model = get_model() |
|
|
|
|
|
id_np, out_np, bald_np, ref_np = model.Hair_Transfer( |
|
|
source_image=source_path, |
|
|
reference_image=reference_path, |
|
|
random_seed=-1, |
|
|
step=30, |
|
|
guidance_scale=req.guidance_scale, |
|
|
scale=req.scale, |
|
|
controlnet_conditioning_scale=req.controlnet_conditioning_scale, |
|
|
size=512, |
|
|
) |
|
|
|
|
|
|
|
|
result_id = str(uuid.uuid4()) |
|
|
out_img = Image.fromarray((out_np * 255.).astype(np.uint8)) |
|
|
filename = f"{result_id}.png" |
|
|
out_path = os.path.join(RESULTS_DIR, filename) |
|
|
out_img.save(out_path) |
|
|
|
|
|
return {"result": filename} |
|
|
|
|
|
|
|
|
@app.get("/download/{filename}") |
|
|
def download(filename: str, _=Depends(verify_bearer)): |
|
|
path = os.path.join(RESULTS_DIR, filename) |
|
|
if not os.path.exists(path): |
|
|
raise HTTPException(status_code=404, detail="File not found") |
|
|
return FileResponse(path, media_type="image/png", filename=filename) |
|
|
|
|
|
|
|
|
@app.get("/logs") |
|
|
def logs(_=Depends(verify_bearer)): |
|
|
|
|
|
return JSONResponse({"logs": ["service running"]}) |
|
|
|
|
|
|
|
|
|