Spaces:
Running
on
Zero
Running
on
Zero
tight-inversion
commited on
Commit
·
4d0ddc3
1
Parent(s):
10d3d92
Align with pulid demo
Browse files- app.py +1 -0
- flux/util.py +3 -22
app.py
CHANGED
|
@@ -431,6 +431,7 @@ if __name__ == "__main__":
|
|
| 431 |
args.offload = True
|
| 432 |
|
| 433 |
print(f"Using device: {args.device}")
|
|
|
|
| 434 |
print(f"Offload: {args.offload}")
|
| 435 |
|
| 436 |
demo = create_demo(args, args.name, args.device, args.offload, args.aggressive_offload)
|
|
|
|
| 431 |
args.offload = True
|
| 432 |
|
| 433 |
print(f"Using device: {args.device}")
|
| 434 |
+
print(f"fp8: {args.fp8}")
|
| 435 |
print(f"Offload: {args.offload}")
|
| 436 |
|
| 437 |
demo = create_demo(args, args.name, args.device, args.offload, args.aggressive_offload)
|
flux/util.py
CHANGED
|
@@ -123,36 +123,17 @@ def load_flow_model(name: str, device: str = "cuda", hf_download: bool = True):
|
|
| 123 |
):
|
| 124 |
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow, local_dir='models')
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
model = Flux(configs[name].params)
|
| 129 |
-
model = model.to_empty(device=device)
|
| 130 |
|
| 131 |
if ckpt_path is not None:
|
| 132 |
print("Loading checkpoint")
|
| 133 |
-
#
|
| 134 |
sd = load_sft(ckpt_path, device=str(device))
|
| 135 |
-
# Load the state dictionary into the model
|
| 136 |
missing, unexpected = model.load_state_dict(sd, strict=False)
|
| 137 |
print_load_warning(missing, unexpected)
|
| 138 |
-
model.to(torch.bfloat16)
|
| 139 |
return model
|
| 140 |
|
| 141 |
-
# from XLabs-AI https://github.com/XLabs-AI/x-flux/blob/1f8ef54972105ad9062be69fe6b7f841bce02a08/src/flux/util.py#L330
|
| 142 |
-
def load_flow_model_quintized(name: str, device: str = "cuda", hf_download: bool = True):
|
| 143 |
-
# Loading Flux
|
| 144 |
-
print("Init model")
|
| 145 |
-
ckpt_path = 'models/flux-dev-fp8.safetensors'
|
| 146 |
-
if (
|
| 147 |
-
not os.path.exists(ckpt_path)
|
| 148 |
-
and hf_download
|
| 149 |
-
):
|
| 150 |
-
print("Downloading model")
|
| 151 |
-
ckpt_path = hf_hub_download("XLabs-AI/flux-dev-fp8", "flux-dev-fp8.safetensors")
|
| 152 |
-
print("Model downloaded to", ckpt_path)
|
| 153 |
-
json_path = hf_hub_download("XLabs-AI/flux-dev-fp8", 'flux_dev_quantization_map.json')
|
| 154 |
-
|
| 155 |
-
model = Flux(configs[name].params).to(torch.bfloat16)
|
| 156 |
def load_flow_model_quintized(
|
| 157 |
name: str,
|
| 158 |
device: str = "cuda",
|
|
|
|
| 123 |
):
|
| 124 |
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow, local_dir='models')
|
| 125 |
|
| 126 |
+
with torch.device(device):
|
| 127 |
+
model = Flux(configs[name].params).to(torch.bfloat16)
|
|
|
|
|
|
|
| 128 |
|
| 129 |
if ckpt_path is not None:
|
| 130 |
print("Loading checkpoint")
|
| 131 |
+
# load_sft doesn't support torch.device
|
| 132 |
sd = load_sft(ckpt_path, device=str(device))
|
|
|
|
| 133 |
missing, unexpected = model.load_state_dict(sd, strict=False)
|
| 134 |
print_load_warning(missing, unexpected)
|
|
|
|
| 135 |
return model
|
| 136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
def load_flow_model_quintized(
|
| 138 |
name: str,
|
| 139 |
device: str = "cuda",
|