|
|
import io |
|
|
import os |
|
|
import uuid |
|
|
import logging |
|
|
from typing import Optional |
|
|
|
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, Header |
|
|
from fastapi.responses import FileResponse, JSONResponse |
|
|
from pydantic import BaseModel |
|
|
import torch |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from mongodb_logging import setup_mongodb_logging, get_logs_from_mongodb, clear_logs_from_mongodb |
|
|
|
|
|
EXPECTED_BEARER = "logicgo@123" |
|
|
|
|
|
|
|
|
from pymongo import MongoClient |
|
|
MONGO_URI = os.environ.get("MONGO_URI", "") |
|
|
mongo_client = MongoClient(MONGO_URI) if MONGO_URI else None |
|
|
mongo_db = mongo_client.get_database("HairSwapDB") if mongo_client is not None else None |
|
|
uploads_col = mongo_db.get_collection("uploads") if mongo_db is not None else None |
|
|
results_col = mongo_db.get_collection("results") if mongo_db is not None else None |
|
|
logs_col = mongo_db.get_collection("logs") if mongo_db is not None else None |
|
|
|
|
|
|
|
|
if MONGO_URI: |
|
|
setup_mongodb_logging(MONGO_URI, "HairSwapDB", "logs") |
|
|
|
|
|
LOGGER = logging.getLogger("hair_server") |
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(name)s - %(message)s") |
|
|
|
|
|
|
|
|
def verify_bearer(authorization: Optional[str] = Header(None)): |
|
|
if not authorization: |
|
|
raise HTTPException(status_code=401, detail="Missing Authorization header") |
|
|
try: |
|
|
scheme, token = authorization.split(" ", 1) |
|
|
except ValueError: |
|
|
raise HTTPException(status_code=401, detail="Invalid Authorization header format") |
|
|
if scheme.lower() != "bearer": |
|
|
raise HTTPException(status_code=401, detail="Invalid auth scheme") |
|
|
if token != EXPECTED_BEARER: |
|
|
raise HTTPException(status_code=401, detail="Invalid token") |
|
|
return True |
|
|
|
|
|
|
|
|
app = FastAPI(title="Hair Swap API", version="1.0.0") |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
def health(): |
|
|
return {"status": "healthy"} |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
return {"status": "ok"} |
|
|
|
|
|
|
|
|
class HairSwapRequest(BaseModel): |
|
|
source_id: str |
|
|
reference_id: str |
|
|
converter_scale: float = 1.0 |
|
|
scale: float = 1.0 |
|
|
guidance_scale: float = 1.5 |
|
|
controlnet_conditioning_scale: float = 1.0 |
|
|
|
|
|
|
|
|
|
|
|
_model = None |
|
|
|
|
|
|
|
|
def get_model(): |
|
|
global _model |
|
|
if _model is None: |
|
|
try: |
|
|
LOGGER.info("Loading StableHair model ...") |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
LOGGER.info(f"Using device: {device}, dtype: {dtype}") |
|
|
|
|
|
|
|
|
if os.environ.get("HUGGINGFACEHUB_API_TOKEN") and not os.environ.get("HUGGINGFACE_HUB_TOKEN"): |
|
|
os.environ["HUGGINGFACE_HUB_TOKEN"] = os.environ["HUGGINGFACEHUB_API_TOKEN"] |
|
|
|
|
|
|
|
|
try: |
|
|
import huggingface_hub as _hfh |
|
|
|
|
|
if not hasattr(_hfh, "split_torch_state_dict_into_shards"): |
|
|
def _split_torch_state_dict_into_shards(state_dict, max_shard_size="10GB"): |
|
|
|
|
|
return {"pytorch_model.bin": state_dict} |
|
|
|
|
|
_hfh.split_torch_state_dict_into_shards = _split_torch_state_dict_into_shards |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
from infer_full import StableHair |
|
|
_model = StableHair(config="./configs/hair_transfer.yaml", device=device, weight_dtype=dtype) |
|
|
LOGGER.info("Model loaded successfully") |
|
|
except Exception as e: |
|
|
LOGGER.error(f"Failed to load model: {str(e)}") |
|
|
raise Exception(f"Model loading failed: {str(e)}") |
|
|
return _model |
|
|
|
|
|
|
|
|
|
|
|
BASE_DATA_DIR = os.environ.get("DATA_DIR", "/data") |
|
|
UPLOAD_DIR = os.path.join(BASE_DATA_DIR, "uploads") |
|
|
RESULTS_DIR = os.path.join(BASE_DATA_DIR, "results") |
|
|
LOGS_DIR = os.path.join(BASE_DATA_DIR, "logs") |
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
os.makedirs(RESULTS_DIR, exist_ok=True) |
|
|
os.makedirs(LOGS_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
@app.post("/upload") |
|
|
async def upload_image(image: UploadFile = File(...)): |
|
|
if not image.filename: |
|
|
raise HTTPException(status_code=400, detail="No file name provided") |
|
|
contents = await image.read() |
|
|
try: |
|
|
Image.open(io.BytesIO(contents)).convert("RGB") |
|
|
except Exception: |
|
|
raise HTTPException(status_code=400, detail="Invalid image file") |
|
|
|
|
|
image_id = str(uuid.uuid4()) |
|
|
ext = os.path.splitext(image.filename)[1] or ".png" |
|
|
path = os.path.join(UPLOAD_DIR, image_id + ext) |
|
|
with open(path, "wb") as f: |
|
|
f.write(contents) |
|
|
|
|
|
if uploads_col is not None: |
|
|
try: |
|
|
uploads_col.insert_one({"_id": image_id, "filename": os.path.basename(path), "path": path}) |
|
|
except Exception: |
|
|
pass |
|
|
return {"id": image_id, "filename": os.path.basename(path)} |
|
|
|
|
|
|
|
|
@app.post("/get-hairswap") |
|
|
def get_hairswap(req: HairSwapRequest): |
|
|
try: |
|
|
|
|
|
def find_file(image_id: str) -> str: |
|
|
for name in os.listdir(UPLOAD_DIR): |
|
|
if name.startswith(image_id): |
|
|
return os.path.join(UPLOAD_DIR, name) |
|
|
raise HTTPException(status_code=404, detail=f"Image id not found: {image_id}") |
|
|
|
|
|
source_path = find_file(req.source_id) |
|
|
reference_path = find_file(req.reference_id) |
|
|
|
|
|
LOGGER.info(f"Found source: {source_path}, reference: {reference_path}") |
|
|
|
|
|
|
|
|
try: |
|
|
model = get_model() |
|
|
LOGGER.info("Model loaded successfully") |
|
|
except Exception as e: |
|
|
LOGGER.error(f"Model loading failed: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}") |
|
|
|
|
|
|
|
|
try: |
|
|
LOGGER.info("Starting hair transfer...") |
|
|
id_np, out_np, bald_np, ref_np = model.Hair_Transfer( |
|
|
source_image=source_path, |
|
|
reference_image=reference_path, |
|
|
random_seed=-1, |
|
|
step=30, |
|
|
guidance_scale=req.guidance_scale, |
|
|
scale=req.scale, |
|
|
controlnet_conditioning_scale=req.controlnet_conditioning_scale, |
|
|
size=512, |
|
|
) |
|
|
LOGGER.info("Hair transfer completed successfully") |
|
|
except Exception as e: |
|
|
import traceback |
|
|
tb = traceback.format_exc() |
|
|
LOGGER.error(f"Hair transfer failed: {str(e)} | device={model.device if hasattr(model, 'device') else 'n/a'} cuda_available={torch.cuda.is_available()}\n{tb}") |
|
|
raise HTTPException(status_code=500, detail=f"Hair transfer failed: {str(e)}") |
|
|
|
|
|
|
|
|
try: |
|
|
result_id = str(uuid.uuid4()) |
|
|
out_img = Image.fromarray((out_np * 255.).astype(np.uint8)) |
|
|
filename = f"{result_id}.png" |
|
|
out_path = os.path.join(RESULTS_DIR, filename) |
|
|
out_img.save(out_path) |
|
|
LOGGER.info(f"Result saved: {out_path}") |
|
|
|
|
|
if results_col is not None: |
|
|
try: |
|
|
results_col.insert_one({ |
|
|
"_id": result_id, |
|
|
"filename": filename, |
|
|
"path": out_path, |
|
|
"source_id": req.source_id, |
|
|
"reference_id": req.reference_id, |
|
|
}) |
|
|
except Exception as e: |
|
|
LOGGER.warning(f"MongoDB save failed: {str(e)}") |
|
|
|
|
|
return {"result": filename} |
|
|
except Exception as e: |
|
|
LOGGER.error(f"Result saving failed: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Result saving failed: {str(e)}") |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
LOGGER.error(f"Unexpected error in get_hairswap: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") |
|
|
|
|
|
|
|
|
@app.get("/download/{filename}") |
|
|
def download(filename: str): |
|
|
path = os.path.join(RESULTS_DIR, filename) |
|
|
if not os.path.exists(path): |
|
|
raise HTTPException(status_code=404, detail="File not found") |
|
|
return FileResponse(path, media_type="image/png", filename=filename) |
|
|
|
|
|
|
|
|
@app.get("/logs") |
|
|
def logs(limit: int = 50, level: str = None, logger_name: str = None): |
|
|
"""Get logs from MongoDB including both metadata and application logs""" |
|
|
response_data = {} |
|
|
|
|
|
|
|
|
if uploads_col is not None and results_col is not None: |
|
|
uploads = list(uploads_col.find({}, {"_id": 1, "filename": 1}).limit(20)) |
|
|
results = list(results_col.find({}, {"_id": 1, "filename": 1, "source_id": 1, "reference_id": 1}).limit(20)) |
|
|
response_data["metadata"] = {"uploads": uploads, "results": results} |
|
|
else: |
|
|
response_data["metadata"] = {"uploads": [], "results": []} |
|
|
|
|
|
|
|
|
if MONGO_URI: |
|
|
try: |
|
|
app_logs = get_logs_from_mongodb(MONGO_URI, "HairSwapDB", "logs", limit, level, logger_name) |
|
|
response_data["application_logs"] = app_logs |
|
|
response_data["mongodb_status"] = "connected" |
|
|
except Exception as e: |
|
|
response_data["application_logs"] = [] |
|
|
response_data["mongodb_status"] = f"error: {str(e)}" |
|
|
else: |
|
|
response_data["application_logs"] = [] |
|
|
response_data["mongodb_status"] = "not_configured" |
|
|
|
|
|
return JSONResponse(response_data) |
|
|
|
|
|
|
|
|
@app.get("/logs/clear") |
|
|
def clear_logs(days_older_than: int = None): |
|
|
"""Clear old logs from MongoDB""" |
|
|
if not MONGO_URI: |
|
|
raise HTTPException(status_code=400, detail="MongoDB not configured") |
|
|
|
|
|
try: |
|
|
deleted_count = clear_logs_from_mongodb(MONGO_URI, "HairSwapDB", "logs", days_older_than) |
|
|
return JSONResponse({ |
|
|
"message": f"Cleared {deleted_count} logs", |
|
|
"days_older_than": days_older_than |
|
|
}) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Failed to clear logs: {str(e)}") |
|
|
|
|
|
|
|
|
@app.get("/logs/stats") |
|
|
def logs_stats(): |
|
|
"""Get logging statistics""" |
|
|
if not MONGO_URI: |
|
|
return JSONResponse({"mongodb_status": "not_configured"}) |
|
|
|
|
|
try: |
|
|
client = MongoClient(MONGO_URI) |
|
|
db = client.get_database("HairSwapDB") |
|
|
logs_collection = db.get_collection("logs") |
|
|
|
|
|
|
|
|
total_logs = logs_collection.count_documents({}) |
|
|
|
|
|
|
|
|
pipeline = [ |
|
|
{"$group": {"_id": "$level", "count": {"$sum": 1}}}, |
|
|
{"$sort": {"count": -1}} |
|
|
] |
|
|
logs_by_level = list(logs_collection.aggregate(pipeline)) |
|
|
|
|
|
|
|
|
pipeline = [ |
|
|
{"$group": {"_id": "$logger", "count": {"$sum": 1}}}, |
|
|
{"$sort": {"count": -1}}, |
|
|
{"$limit": 10} |
|
|
] |
|
|
logs_by_logger = list(logs_collection.aggregate(pipeline)) |
|
|
|
|
|
return JSONResponse({ |
|
|
"total_logs": total_logs, |
|
|
"logs_by_level": logs_by_level, |
|
|
"top_loggers": logs_by_logger, |
|
|
"mongodb_status": "connected" |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
return JSONResponse({ |
|
|
"mongodb_status": f"error: {str(e)}", |
|
|
"total_logs": 0 |
|
|
}) |
|
|
|
|
|
|
|
|
|