Commit
·
d7bae69
1
Parent(s):
f6e7388
feat: add detailed error handling and logging to get-hairswap endpoint
Browse files
server.py
CHANGED
|
@@ -72,13 +72,19 @@ _model = None # type: ignore[assignment]
|
|
| 72 |
def get_model():
|
| 73 |
global _model
|
| 74 |
if _model is None:
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
return _model
|
| 83 |
|
| 84 |
|
|
@@ -118,48 +124,76 @@ async def upload_image(image: UploadFile = File(...), _=Depends(verify_bearer)):
|
|
| 118 |
|
| 119 |
@app.post("/get-hairswap")
|
| 120 |
def get_hairswap(req: HairSwapRequest, _=Depends(verify_bearer)):
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
reference_image=reference_path,
|
| 136 |
-
random_seed=-1,
|
| 137 |
-
step=30,
|
| 138 |
-
guidance_scale=req.guidance_scale,
|
| 139 |
-
scale=req.scale,
|
| 140 |
-
controlnet_conditioning_scale=req.controlnet_conditioning_scale,
|
| 141 |
-
size=512,
|
| 142 |
-
)
|
| 143 |
-
|
| 144 |
-
# Save result
|
| 145 |
-
result_id = str(uuid.uuid4())
|
| 146 |
-
out_img = Image.fromarray((out_np * 255.).astype(np.uint8))
|
| 147 |
-
filename = f"{result_id}.png"
|
| 148 |
-
out_path = os.path.join(RESULTS_DIR, filename)
|
| 149 |
-
out_img.save(out_path)
|
| 150 |
-
if results_col:
|
| 151 |
try:
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
"reference_id": req.reference_id,
|
| 158 |
-
})
|
| 159 |
-
except Exception:
|
| 160 |
-
pass
|
| 161 |
|
| 162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
|
| 165 |
@app.get("/download/{filename}")
|
|
|
|
| 72 |
def get_model():
|
| 73 |
global _model
|
| 74 |
if _model is None:
|
| 75 |
+
try:
|
| 76 |
+
LOGGER.info("Loading StableHair model ...")
|
| 77 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 78 |
+
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 79 |
+
LOGGER.info(f"Using device: {device}, dtype: {dtype}")
|
| 80 |
+
|
| 81 |
+
# Import here to defer importing diffusers/transformers until needed
|
| 82 |
+
from infer_full import StableHair # noqa: WPS433
|
| 83 |
+
_model = StableHair(config="./configs/hair_transfer.yaml", device=device, weight_dtype=dtype)
|
| 84 |
+
LOGGER.info("Model loaded successfully")
|
| 85 |
+
except Exception as e:
|
| 86 |
+
LOGGER.error(f"Failed to load model: {str(e)}")
|
| 87 |
+
raise Exception(f"Model loading failed: {str(e)}")
|
| 88 |
return _model
|
| 89 |
|
| 90 |
|
|
|
|
| 124 |
|
| 125 |
@app.post("/get-hairswap")
|
| 126 |
def get_hairswap(req: HairSwapRequest, _=Depends(verify_bearer)):
|
| 127 |
+
try:
|
| 128 |
+
# Resolve file paths
|
| 129 |
+
def find_file(image_id: str) -> str:
|
| 130 |
+
for name in os.listdir(UPLOAD_DIR):
|
| 131 |
+
if name.startswith(image_id):
|
| 132 |
+
return os.path.join(UPLOAD_DIR, name)
|
| 133 |
+
raise HTTPException(status_code=404, detail=f"Image id not found: {image_id}")
|
| 134 |
+
|
| 135 |
+
source_path = find_file(req.source_id)
|
| 136 |
+
reference_path = find_file(req.reference_id)
|
| 137 |
+
|
| 138 |
+
LOGGER.info(f"Found source: {source_path}, reference: {reference_path}")
|
| 139 |
+
|
| 140 |
+
# Load model with error handling
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
try:
|
| 142 |
+
model = get_model()
|
| 143 |
+
LOGGER.info("Model loaded successfully")
|
| 144 |
+
except Exception as e:
|
| 145 |
+
LOGGER.error(f"Model loading failed: {str(e)}")
|
| 146 |
+
raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
+
# Perform hair transfer with error handling
|
| 149 |
+
try:
|
| 150 |
+
LOGGER.info("Starting hair transfer...")
|
| 151 |
+
id_np, out_np, bald_np, ref_np = model.Hair_Transfer(
|
| 152 |
+
source_image=source_path,
|
| 153 |
+
reference_image=reference_path,
|
| 154 |
+
random_seed=-1,
|
| 155 |
+
step=30,
|
| 156 |
+
guidance_scale=req.guidance_scale,
|
| 157 |
+
scale=req.scale,
|
| 158 |
+
controlnet_conditioning_scale=req.controlnet_conditioning_scale,
|
| 159 |
+
size=512,
|
| 160 |
+
)
|
| 161 |
+
LOGGER.info("Hair transfer completed successfully")
|
| 162 |
+
except Exception as e:
|
| 163 |
+
LOGGER.error(f"Hair transfer failed: {str(e)}")
|
| 164 |
+
raise HTTPException(status_code=500, detail=f"Hair transfer failed: {str(e)}")
|
| 165 |
+
|
| 166 |
+
# Save result
|
| 167 |
+
try:
|
| 168 |
+
result_id = str(uuid.uuid4())
|
| 169 |
+
out_img = Image.fromarray((out_np * 255.).astype(np.uint8))
|
| 170 |
+
filename = f"{result_id}.png"
|
| 171 |
+
out_path = os.path.join(RESULTS_DIR, filename)
|
| 172 |
+
out_img.save(out_path)
|
| 173 |
+
LOGGER.info(f"Result saved: {out_path}")
|
| 174 |
+
|
| 175 |
+
if results_col:
|
| 176 |
+
try:
|
| 177 |
+
results_col.insert_one({
|
| 178 |
+
"_id": result_id,
|
| 179 |
+
"filename": filename,
|
| 180 |
+
"path": out_path,
|
| 181 |
+
"source_id": req.source_id,
|
| 182 |
+
"reference_id": req.reference_id,
|
| 183 |
+
})
|
| 184 |
+
except Exception as e:
|
| 185 |
+
LOGGER.warning(f"MongoDB save failed: {str(e)}")
|
| 186 |
+
|
| 187 |
+
return {"result": filename}
|
| 188 |
+
except Exception as e:
|
| 189 |
+
LOGGER.error(f"Result saving failed: {str(e)}")
|
| 190 |
+
raise HTTPException(status_code=500, detail=f"Result saving failed: {str(e)}")
|
| 191 |
+
|
| 192 |
+
except HTTPException:
|
| 193 |
+
raise
|
| 194 |
+
except Exception as e:
|
| 195 |
+
LOGGER.error(f"Unexpected error in get_hairswap: {str(e)}")
|
| 196 |
+
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
|
| 197 |
|
| 198 |
|
| 199 |
@app.get("/download/{filename}")
|