LogicGoInfotechSpaces commited on
Commit
57338fa
·
1 Parent(s): c2c4f84

fix(space): lazy-load model to avoid import-time errors; pin huggingface-hub 0.14.1 (cached_download)

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. server.py +5 -3
requirements.txt CHANGED
@@ -8,7 +8,7 @@ opencv-python
8
  pillow
9
  numpy
10
  scipy
11
- huggingface-hub==0.17.3
12
  controlnet-aux==0.0.10
13
  safetensors
14
  einops
 
8
  pillow
9
  numpy
10
  scipy
11
+ huggingface-hub==0.14.1
12
  controlnet-aux==0.0.10
13
  safetensors
14
  einops
server.py CHANGED
@@ -12,7 +12,7 @@ import torch
12
  import numpy as np
13
  from PIL import Image
14
 
15
- from infer_full import StableHair
16
 
17
 
18
  LOGGER = logging.getLogger("hair_server")
@@ -53,15 +53,17 @@ class HairSwapRequest(BaseModel):
53
 
54
 
55
  # Initialize model lazily on first request
56
- _model: Optional[StableHair] = None
57
 
58
 
59
- def get_model() -> StableHair:
60
  global _model
61
  if _model is None:
62
  LOGGER.info("Loading StableHair model ...")
63
  device = "cuda" if torch.cuda.is_available() else "cpu"
64
  dtype = torch.float16 if device == "cuda" else torch.float32
 
 
65
  _model = StableHair(config="./configs/hair_transfer.yaml", device=device, weight_dtype=dtype)
66
  LOGGER.info("Model loaded")
67
  return _model
 
12
  import numpy as np
13
  from PIL import Image
14
 
15
+ # Lazy import performed in get_model() to avoid import-time failures on Space
16
 
17
 
18
  LOGGER = logging.getLogger("hair_server")
 
53
 
54
 
55
  # Initialize model lazily on first request
56
+ _model = None # type: ignore[assignment]
57
 
58
 
59
+ def get_model():
60
  global _model
61
  if _model is None:
62
  LOGGER.info("Loading StableHair model ...")
63
  device = "cuda" if torch.cuda.is_available() else "cpu"
64
  dtype = torch.float16 if device == "cuda" else torch.float32
65
+ # Import here to defer importing diffusers/transformers until needed
66
+ from infer_full import StableHair # noqa: WPS433
67
  _model = StableHair(config="./configs/hair_transfer.yaml", device=device, weight_dtype=dtype)
68
  LOGGER.info("Model loaded")
69
  return _model