Spaces:
Paused
Paused
File size: 5,776 Bytes
0084610 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import os
from typing import Optional, Union
import torch
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from huggingface_hub import hf_hub_download, snapshot_download
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from .models.dit import get_dit
from .models.text_embedders import get_text_embedder
from .models.vae import build_vae
from .models.parallelize import parallelize_dit
from .t2v_pipeline import Kandinsky5T2VPipeline
from .magcache_utils import set_magcache_params
from safetensors.torch import load_file
torch._dynamo.config.suppress_errors = True
def get_T2V_pipeline(
device_map: Union[str, torch.device, dict],
resolution: int = 512,
cache_dir: str = "./weights/",
dit_path: str = None,
text_encoder_path: str = None,
text_encoder2_path: str = None,
vae_path: str = None,
conf_path: str = None,
offload: bool = False,
magcache: bool = False,
) -> Kandinsky5T2VPipeline:
assert resolution in [512]
if not isinstance(device_map, dict):
device_map = {"dit": device_map, "vae": device_map, "text_embedder": device_map}
try:
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(
os.environ["WORLD_SIZE"]
)
except:
local_rank, world_size = 0, 1
assert not (world_size > 1 and offload), "Offloading available only with not parallel inference"
if world_size > 1:
device_mesh = init_device_mesh(
"cuda", (world_size,), mesh_dim_names=("tensor_parallel",)
)
device_map["dit"] = torch.device(f"cuda:{local_rank}")
device_map["vae"] = torch.device(f"cuda:{local_rank}")
device_map["text_embedder"] = torch.device(f"cuda:{local_rank}")
os.makedirs(cache_dir, exist_ok=True)
if dit_path is None and conf_path is None:
dit_path = snapshot_download(
repo_id="ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s",
allow_patterns="model/*",
local_dir=cache_dir,
)
dit_path = os.path.join(cache_dir, "model/kandinsky5lite_t2v_sft_5s.safetensors")
if vae_path is None and conf_path is None:
vae_path = snapshot_download(
repo_id="hunyuanvideo-community/HunyuanVideo",
allow_patterns="vae/*",
local_dir=cache_dir,
)
vae_path = os.path.join(cache_dir, "vae/")
if text_encoder_path is None and conf_path is None:
text_encoder_path = snapshot_download(
repo_id="Qwen/Qwen2.5-VL-7B-Instruct",
local_dir=os.path.join(cache_dir, "text_encoder/"),
)
text_encoder_path = os.path.join(cache_dir, "text_encoder/")
if text_encoder2_path is None and conf_path is None:
text_encoder2_path = snapshot_download(
repo_id="openai/clip-vit-large-patch14",
local_dir=os.path.join(cache_dir, "text_encoder2/"),
)
text_encoder2_path = os.path.join(cache_dir, "text_encoder2/")
if conf_path is None:
conf = get_default_conf(
dit_path, vae_path, text_encoder_path, text_encoder2_path
)
else:
conf = OmegaConf.load(conf_path)
text_embedder = get_text_embedder(conf.model.text_embedder)
if not offload:
text_embedder = text_embedder.to( device=device_map["text_embedder"])
vae = build_vae(conf.model.vae)
vae = vae.eval()
if not offload:
vae = vae.to(device=device_map["vae"])
dit = get_dit(conf.model.dit_params)
if magcache:
mag_ratios = conf.magcache.mag_ratios
num_steps = conf.model.num_steps
no_cfg = False
if conf.model.guidance_weight == 1.0:
no_cfg = True
set_magcache_params(dit, mag_ratios, num_steps, no_cfg)
state_dict = load_file(conf.model.checkpoint_path)
dit.load_state_dict(state_dict, assign=True)
if not offload:
dit = dit.to(device_map["dit"])
if world_size > 1:
dit = parallelize_dit(dit, device_mesh["tensor_parallel"])
return Kandinsky5T2VPipeline(
device_map=device_map,
dit=dit,
text_embedder=text_embedder,
vae=vae,
resolution=resolution,
local_dit_rank=local_rank,
world_size=world_size,
conf=conf,
offload=offload,
)
def get_default_conf(
dit_path,
vae_path,
text_encoder_path,
text_encoder2_path,
) -> DictConfig:
dit_params = {
"in_visual_dim": 16,
"out_visual_dim": 16,
"time_dim": 512,
"patch_size": [1, 2, 2],
"model_dim": 1792,
"ff_dim": 7168,
"num_text_blocks": 2,
"num_visual_blocks": 32,
"axes_dims": [16, 24, 24],
"visual_cond": True,
"in_text_dim": 3584,
"in_text_dim2": 768,
}
attention = {
"type": "flash",
"causal": False,
"local": False,
"glob": False,
"window": 3,
}
vae = {
"checkpoint_path": vae_path,
"name": "hunyuan",
}
text_embedder = {
"qwen": {
"emb_size": 3584,
"checkpoint_path": text_encoder_path,
"max_length": 256,
},
"clip": {
"checkpoint_path": text_encoder2_path,
"emb_size": 768,
"max_length": 77,
},
}
conf = {
"model": {
"checkpoint_path": dit_path,
"vae": vae,
"text_embedder": text_embedder,
"dit_params": dit_params,
"attention": attention,
"num_steps": 50,
"guidance_weight": 5.0,
},
"metrics": {"scale_factor": (1, 2, 2)},
"resolution": 512,
}
return DictConfig(conf)
|