Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import os | |
| from dataclasses import dataclass | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file as load_sft | |
| from flux.model import Flux, FluxParams | |
| from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams | |
| from flux.modules.conditioner import HFEmbedder | |
| class SamplingOptions: | |
| prompt: str | |
| width: int | |
| height: int | |
| num_steps: int | |
| guidance: float | |
| seed: int | |
| class ModelSpec: | |
| params: FluxParams | |
| ae_params: AutoEncoderParams | |
| ckpt_path: str | |
| ae_path: str | |
| repo_id: str | |
| repo_flow: str | |
| repo_ae: str | |
| configs = { | |
| "flux-dev": ModelSpec( | |
| repo_id="black-forest-labs/FLUX.1-dev", | |
| repo_flow="flux1-dev.safetensors", | |
| repo_ae="ae.safetensors", | |
| ckpt_path='models/flux1-dev.safetensors', | |
| params=FluxParams( | |
| in_channels=64, | |
| vec_in_dim=768, | |
| context_in_dim=4096, | |
| hidden_size=3072, | |
| mlp_ratio=4.0, | |
| num_heads=24, | |
| depth=19, | |
| depth_single_blocks=38, | |
| axes_dim=[16, 56, 56], | |
| theta=10_000, | |
| qkv_bias=True, | |
| guidance_embed=True, | |
| ), | |
| ae_path='models/ae.safetensors', | |
| ae_params=AutoEncoderParams( | |
| resolution=256, | |
| in_channels=3, | |
| ch=128, | |
| out_ch=3, | |
| ch_mult=[1, 2, 4, 4], | |
| num_res_blocks=2, | |
| z_channels=16, | |
| scale_factor=0.3611, | |
| shift_factor=0.1159, | |
| ), | |
| ), | |
| "flux-schnell": ModelSpec( | |
| repo_id="black-forest-labs/FLUX.1-schnell", | |
| repo_flow="flux1-schnell.safetensors", | |
| repo_ae="ae.safetensors", | |
| ckpt_path=os.getenv("FLUX_SCHNELL"), | |
| params=FluxParams( | |
| in_channels=64, | |
| vec_in_dim=768, | |
| context_in_dim=4096, | |
| hidden_size=3072, | |
| mlp_ratio=4.0, | |
| num_heads=24, | |
| depth=19, | |
| depth_single_blocks=38, | |
| axes_dim=[16, 56, 56], | |
| theta=10_000, | |
| qkv_bias=True, | |
| guidance_embed=False, | |
| ), | |
| ae_path=os.getenv("AE"), | |
| ae_params=AutoEncoderParams( | |
| resolution=256, | |
| in_channels=3, | |
| ch=128, | |
| out_ch=3, | |
| ch_mult=[1, 2, 4, 4], | |
| num_res_blocks=2, | |
| z_channels=16, | |
| scale_factor=0.3611, | |
| shift_factor=0.1159, | |
| ), | |
| ), | |
| } | |
| def print_load_warning(missing: list[str], unexpected: list[str]) -> None: | |
| if len(missing) > 0 and len(unexpected) > 0: | |
| print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) | |
| print("\n" + "-" * 79 + "\n") | |
| print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) | |
| elif len(missing) > 0: | |
| print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) | |
| elif len(unexpected) > 0: | |
| print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) | |
| def load_flow_model(name: str, device: str = "cuda", hf_download: bool = True): | |
| # Loading Flux | |
| print("Init model") | |
| ckpt_path = configs[name].ckpt_path | |
| if ( | |
| not os.path.exists(ckpt_path) | |
| and configs[name].repo_id is not None | |
| and configs[name].repo_flow is not None | |
| and hf_download | |
| ): | |
| ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow, local_dir='models') | |
| with torch.device(device): | |
| model = Flux(configs[name].params).to(torch.bfloat16) | |
| if ckpt_path is not None: | |
| print("Loading checkpoint") | |
| # load_sft doesn't support torch.device | |
| sd = load_sft(ckpt_path, device=str(device)) | |
| missing, unexpected = model.load_state_dict(sd, strict=False) | |
| print_load_warning(missing, unexpected) | |
| return model | |
| def load_flow_model_quintized( | |
| name: str, | |
| device: str = "cuda", | |
| hf_download: bool = True, | |
| cache_path: str = None, | |
| ): | |
| """ | |
| Loads (or downloads) a FLUX-fp8 checkpoint, performs quantization once, | |
| and caches the quantized model to disk. Future calls load from cache. | |
| :param name: model name key in configs (e.g. "flux-dev-fp8") | |
| :param device: Torch device string ("cuda" or "cpu") | |
| :param hf_download: Whether to download from HF if local ckpt is missing | |
| :param cache_path: Filepath for cached quantized model | |
| :return: A quantized FLUX model on the specified device. | |
| """ | |
| if cache_path is None: | |
| cache_path = os.path.join(os.path.expanduser("~"), ".cache/flux_dev_fp8_quantized_model.pth") | |
| # 1) Check if we already have a cached, quantized model | |
| if os.path.exists(cache_path): | |
| print(f"Loading cached quantized model from '{cache_path}'...") | |
| model = torch.load(cache_path, map_location=device) | |
| return model.to(device) | |
| # 2) If no cache, build and quantize for the first time. | |
| print("No cached model found. Initializing + quantizing from scratch.") | |
| # (A) Download or specify checkpoint paths | |
| ckpt_path = "models/flux-dev-fp8.safetensors" | |
| if not os.path.exists(ckpt_path) and hf_download: | |
| print("Downloading model checkpoint from HF...") | |
| ckpt_path = hf_hub_download("XLabs-AI/flux-dev-fp8", "flux-dev-fp8.safetensors") | |
| print("Model downloaded to:", ckpt_path) | |
| json_path = hf_hub_download("XLabs-AI/flux-dev-fp8", "flux_dev_quantization_map.json") | |
| # (B) Build the unquantized model | |
| print("Initializing model in bfloat16...") | |
| model = Flux(configs[name].params).to(torch.bfloat16) | |
| # (C) Load the unquantized weights | |
| print("Loading unquantized checkpoint to CPU...") | |
| sd = load_sft(ckpt_path, device="cpu") # CPU load | |
| # (D) Load quantization map | |
| with open(json_path, "r") as f: | |
| quantization_map = json.load(f) | |
| # (E) Quantize | |
| print("Starting quantization process...") | |
| from optimum.quanto import requantize | |
| requantize(model, sd, quantization_map, device=device) | |
| print("Quantization complete.") | |
| # (F) Cache the fully quantized model to disk | |
| print(f"Saving the quantized model to '{cache_path}'...") | |
| torch.save(model, cache_path) | |
| print("Model saved. Future runs will load from cache.") | |
| return model.to(device) | |
| def load_t5(device: str = "cuda", max_length: int = 512) -> HFEmbedder: | |
| # max length 64, 128, 256 and 512 should work (if your sequence is short enough) | |
| return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device) | |
| def load_clip(device: str = "cuda") -> HFEmbedder: | |
| return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device) | |
| def load_ae(name: str, device: str = "cuda", hf_download: bool = True) -> AutoEncoder: | |
| ckpt_path = configs[name].ae_path | |
| if ( | |
| not os.path.exists(ckpt_path) | |
| and configs[name].repo_id is not None | |
| and configs[name].repo_ae is not None | |
| and hf_download | |
| ): | |
| ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae, local_dir='models') | |
| # Loading the autoencoder | |
| print("Init AE") | |
| with torch.device(device): | |
| ae = AutoEncoder(configs[name].ae_params) | |
| if ckpt_path is not None: | |
| sd = load_sft(ckpt_path, device=str(device)) | |
| missing, unexpected = ae.load_state_dict(sd, strict=False) | |
| print_load_warning(missing, unexpected) | |
| return ae | |