Commit
·
57338fa
1
Parent(s):
c2c4f84
fix(space): lazy-load model to avoid import-time errors; pin huggingface-hub 0.14.1 (cached_download)
Browse files- requirements.txt +1 -1
- server.py +5 -3
requirements.txt
CHANGED
|
@@ -8,7 +8,7 @@ opencv-python
|
|
| 8 |
pillow
|
| 9 |
numpy
|
| 10 |
scipy
|
| 11 |
-
huggingface-hub==0.
|
| 12 |
controlnet-aux==0.0.10
|
| 13 |
safetensors
|
| 14 |
einops
|
|
|
|
| 8 |
pillow
|
| 9 |
numpy
|
| 10 |
scipy
|
| 11 |
+
huggingface-hub==0.14.1
|
| 12 |
controlnet-aux==0.0.10
|
| 13 |
safetensors
|
| 14 |
einops
|
server.py
CHANGED
|
@@ -12,7 +12,7 @@ import torch
|
|
| 12 |
import numpy as np
|
| 13 |
from PIL import Image
|
| 14 |
|
| 15 |
-
|
| 16 |
|
| 17 |
|
| 18 |
LOGGER = logging.getLogger("hair_server")
|
|
@@ -53,15 +53,17 @@ class HairSwapRequest(BaseModel):
|
|
| 53 |
|
| 54 |
|
| 55 |
# Initialize model lazily on first request
|
| 56 |
-
_model:
|
| 57 |
|
| 58 |
|
| 59 |
-
def get_model()
|
| 60 |
global _model
|
| 61 |
if _model is None:
|
| 62 |
LOGGER.info("Loading StableHair model ...")
|
| 63 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 64 |
dtype = torch.float16 if device == "cuda" else torch.float32
|
|
|
|
|
|
|
| 65 |
_model = StableHair(config="./configs/hair_transfer.yaml", device=device, weight_dtype=dtype)
|
| 66 |
LOGGER.info("Model loaded")
|
| 67 |
return _model
|
|
|
|
| 12 |
import numpy as np
|
| 13 |
from PIL import Image
|
| 14 |
|
| 15 |
+
# Lazy import performed in get_model() to avoid import-time failures on Space
|
| 16 |
|
| 17 |
|
| 18 |
LOGGER = logging.getLogger("hair_server")
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
# Initialize model lazily on first request
|
| 56 |
+
_model = None # type: ignore[assignment]
|
| 57 |
|
| 58 |
|
| 59 |
+
def get_model():
|
| 60 |
global _model
|
| 61 |
if _model is None:
|
| 62 |
LOGGER.info("Loading StableHair model ...")
|
| 63 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 64 |
dtype = torch.float16 if device == "cuda" else torch.float32
|
| 65 |
+
# Import here to defer importing diffusers/transformers until needed
|
| 66 |
+
from infer_full import StableHair # noqa: WPS433
|
| 67 |
_model = StableHair(config="./configs/hair_transfer.yaml", device=device, weight_dtype=dtype)
|
| 68 |
LOGGER.info("Model loaded")
|
| 69 |
return _model
|