LogicGoInfotechSpaces commited on
Commit
770b663
·
1 Parent(s): e78b9f1

fix: add runtime shim for split_torch_state_dict_into_shards to unblock diffusers import

Browse files
Files changed (1) hide show
  1. server.py +13 -0
server.py CHANGED
@@ -81,6 +81,19 @@ def get_model():
81
  # Ensure HF token env var is where downstream libs expect it
82
  if os.environ.get("HUGGINGFACEHUB_API_TOKEN") and not os.environ.get("HUGGINGFACE_HUB_TOKEN"):
83
  os.environ["HUGGINGFACE_HUB_TOKEN"] = os.environ["HUGGINGFACEHUB_API_TOKEN"]
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  # Import here to defer importing diffusers/transformers until needed
86
  from infer_full import StableHair # noqa: WPS433
 
81
  # Ensure HF token env var is where downstream libs expect it
82
  if os.environ.get("HUGGINGFACEHUB_API_TOKEN") and not os.environ.get("HUGGINGFACE_HUB_TOKEN"):
83
  os.environ["HUGGINGFACE_HUB_TOKEN"] = os.environ["HUGGINGFACEHUB_API_TOKEN"]
84
+
85
+ # Backward-compat shim: some diffusers versions import a helper only present in newer hub versions.
86
+ try:
87
+ import huggingface_hub as _hfh # type: ignore
88
+
89
+ if not hasattr(_hfh, "split_torch_state_dict_into_shards"):
90
+ def _split_torch_state_dict_into_shards(state_dict, max_shard_size="10GB"):
91
+ # Minimal shim: return a single shard mapping expected by callers
92
+ return {"pytorch_model.bin": state_dict}
93
+
94
+ _hfh.split_torch_state_dict_into_shards = _split_torch_state_dict_into_shards # type: ignore[attr-defined]
95
+ except Exception:
96
+ pass
97
 
98
  # Import here to defer importing diffusers/transformers until needed
99
  from infer_full import StableHair # noqa: WPS433