LogicGoInfotechSpaces commited on
Commit
d7bae69
·
1 Parent(s): f6e7388

feat: add detailed error handling and logging to get-hairswap endpoint

Browse files
Files changed (1) hide show
  1. server.py +81 -47
server.py CHANGED
@@ -72,13 +72,19 @@ _model = None # type: ignore[assignment]
72
  def get_model():
73
  global _model
74
  if _model is None:
75
- LOGGER.info("Loading StableHair model ...")
76
- device = "cuda" if torch.cuda.is_available() else "cpu"
77
- dtype = torch.float16 if device == "cuda" else torch.float32
78
- # Import here to defer importing diffusers/transformers until needed
79
- from infer_full import StableHair # noqa: WPS433
80
- _model = StableHair(config="./configs/hair_transfer.yaml", device=device, weight_dtype=dtype)
81
- LOGGER.info("Model loaded")
 
 
 
 
 
 
82
  return _model
83
 
84
 
@@ -118,48 +124,76 @@ async def upload_image(image: UploadFile = File(...), _=Depends(verify_bearer)):
118
 
119
  @app.post("/get-hairswap")
120
  def get_hairswap(req: HairSwapRequest, _=Depends(verify_bearer)):
121
- # Resolve file paths
122
- def find_file(image_id: str) -> str:
123
- for name in os.listdir(UPLOAD_DIR):
124
- if name.startswith(image_id):
125
- return os.path.join(UPLOAD_DIR, name)
126
- raise HTTPException(status_code=404, detail=f"Image id not found: {image_id}")
127
-
128
- source_path = find_file(req.source_id)
129
- reference_path = find_file(req.reference_id)
130
-
131
- model = get_model()
132
- # Prepare kwargs similar to infer_full
133
- id_np, out_np, bald_np, ref_np = model.Hair_Transfer(
134
- source_image=source_path,
135
- reference_image=reference_path,
136
- random_seed=-1,
137
- step=30,
138
- guidance_scale=req.guidance_scale,
139
- scale=req.scale,
140
- controlnet_conditioning_scale=req.controlnet_conditioning_scale,
141
- size=512,
142
- )
143
-
144
- # Save result
145
- result_id = str(uuid.uuid4())
146
- out_img = Image.fromarray((out_np * 255.).astype(np.uint8))
147
- filename = f"{result_id}.png"
148
- out_path = os.path.join(RESULTS_DIR, filename)
149
- out_img.save(out_path)
150
- if results_col:
151
  try:
152
- results_col.insert_one({
153
- "_id": result_id,
154
- "filename": filename,
155
- "path": out_path,
156
- "source_id": req.source_id,
157
- "reference_id": req.reference_id,
158
- })
159
- except Exception:
160
- pass
161
 
162
- return {"result": filename}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
 
165
  @app.get("/download/{filename}")
 
72
  def get_model():
73
  global _model
74
  if _model is None:
75
+ try:
76
+ LOGGER.info("Loading StableHair model ...")
77
+ device = "cuda" if torch.cuda.is_available() else "cpu"
78
+ dtype = torch.float16 if device == "cuda" else torch.float32
79
+ LOGGER.info(f"Using device: {device}, dtype: {dtype}")
80
+
81
+ # Import here to defer importing diffusers/transformers until needed
82
+ from infer_full import StableHair # noqa: WPS433
83
+ _model = StableHair(config="./configs/hair_transfer.yaml", device=device, weight_dtype=dtype)
84
+ LOGGER.info("Model loaded successfully")
85
+ except Exception as e:
86
+ LOGGER.error(f"Failed to load model: {str(e)}")
87
+ raise Exception(f"Model loading failed: {str(e)}")
88
  return _model
89
 
90
 
 
124
 
125
  @app.post("/get-hairswap")
126
  def get_hairswap(req: HairSwapRequest, _=Depends(verify_bearer)):
127
+ try:
128
+ # Resolve file paths
129
+ def find_file(image_id: str) -> str:
130
+ for name in os.listdir(UPLOAD_DIR):
131
+ if name.startswith(image_id):
132
+ return os.path.join(UPLOAD_DIR, name)
133
+ raise HTTPException(status_code=404, detail=f"Image id not found: {image_id}")
134
+
135
+ source_path = find_file(req.source_id)
136
+ reference_path = find_file(req.reference_id)
137
+
138
+ LOGGER.info(f"Found source: {source_path}, reference: {reference_path}")
139
+
140
+ # Load model with error handling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  try:
142
+ model = get_model()
143
+ LOGGER.info("Model loaded successfully")
144
+ except Exception as e:
145
+ LOGGER.error(f"Model loading failed: {str(e)}")
146
+ raise HTTPException(status_code=500, detail=f"Model loading failed: {str(e)}")
 
 
 
 
147
 
148
+ # Perform hair transfer with error handling
149
+ try:
150
+ LOGGER.info("Starting hair transfer...")
151
+ id_np, out_np, bald_np, ref_np = model.Hair_Transfer(
152
+ source_image=source_path,
153
+ reference_image=reference_path,
154
+ random_seed=-1,
155
+ step=30,
156
+ guidance_scale=req.guidance_scale,
157
+ scale=req.scale,
158
+ controlnet_conditioning_scale=req.controlnet_conditioning_scale,
159
+ size=512,
160
+ )
161
+ LOGGER.info("Hair transfer completed successfully")
162
+ except Exception as e:
163
+ LOGGER.error(f"Hair transfer failed: {str(e)}")
164
+ raise HTTPException(status_code=500, detail=f"Hair transfer failed: {str(e)}")
165
+
166
+ # Save result
167
+ try:
168
+ result_id = str(uuid.uuid4())
169
+ out_img = Image.fromarray((out_np * 255.).astype(np.uint8))
170
+ filename = f"{result_id}.png"
171
+ out_path = os.path.join(RESULTS_DIR, filename)
172
+ out_img.save(out_path)
173
+ LOGGER.info(f"Result saved: {out_path}")
174
+
175
+ if results_col:
176
+ try:
177
+ results_col.insert_one({
178
+ "_id": result_id,
179
+ "filename": filename,
180
+ "path": out_path,
181
+ "source_id": req.source_id,
182
+ "reference_id": req.reference_id,
183
+ })
184
+ except Exception as e:
185
+ LOGGER.warning(f"MongoDB save failed: {str(e)}")
186
+
187
+ return {"result": filename}
188
+ except Exception as e:
189
+ LOGGER.error(f"Result saving failed: {str(e)}")
190
+ raise HTTPException(status_code=500, detail=f"Result saving failed: {str(e)}")
191
+
192
+ except HTTPException:
193
+ raise
194
+ except Exception as e:
195
+ LOGGER.error(f"Unexpected error in get_hairswap: {str(e)}")
196
+ raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
197
 
198
 
199
  @app.get("/download/{filename}")