Kandinsky / kandinsky /utils.py
rahul7star's picture
Migrated from GitHub
0084610 verified
import os
from typing import Optional, Union
import torch
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from huggingface_hub import hf_hub_download, snapshot_download
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from .models.dit import get_dit
from .models.text_embedders import get_text_embedder
from .models.vae import build_vae
from .models.parallelize import parallelize_dit
from .t2v_pipeline import Kandinsky5T2VPipeline
from .magcache_utils import set_magcache_params
from safetensors.torch import load_file
torch._dynamo.config.suppress_errors = True
def get_T2V_pipeline(
device_map: Union[str, torch.device, dict],
resolution: int = 512,
cache_dir: str = "./weights/",
dit_path: str = None,
text_encoder_path: str = None,
text_encoder2_path: str = None,
vae_path: str = None,
conf_path: str = None,
offload: bool = False,
magcache: bool = False,
) -> Kandinsky5T2VPipeline:
assert resolution in [512]
if not isinstance(device_map, dict):
device_map = {"dit": device_map, "vae": device_map, "text_embedder": device_map}
try:
local_rank, world_size = int(os.environ["LOCAL_RANK"]), int(
os.environ["WORLD_SIZE"]
)
except:
local_rank, world_size = 0, 1
assert not (world_size > 1 and offload), "Offloading available only with not parallel inference"
if world_size > 1:
device_mesh = init_device_mesh(
"cuda", (world_size,), mesh_dim_names=("tensor_parallel",)
)
device_map["dit"] = torch.device(f"cuda:{local_rank}")
device_map["vae"] = torch.device(f"cuda:{local_rank}")
device_map["text_embedder"] = torch.device(f"cuda:{local_rank}")
os.makedirs(cache_dir, exist_ok=True)
if dit_path is None and conf_path is None:
dit_path = snapshot_download(
repo_id="ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s",
allow_patterns="model/*",
local_dir=cache_dir,
)
dit_path = os.path.join(cache_dir, "model/kandinsky5lite_t2v_sft_5s.safetensors")
if vae_path is None and conf_path is None:
vae_path = snapshot_download(
repo_id="hunyuanvideo-community/HunyuanVideo",
allow_patterns="vae/*",
local_dir=cache_dir,
)
vae_path = os.path.join(cache_dir, "vae/")
if text_encoder_path is None and conf_path is None:
text_encoder_path = snapshot_download(
repo_id="Qwen/Qwen2.5-VL-7B-Instruct",
local_dir=os.path.join(cache_dir, "text_encoder/"),
)
text_encoder_path = os.path.join(cache_dir, "text_encoder/")
if text_encoder2_path is None and conf_path is None:
text_encoder2_path = snapshot_download(
repo_id="openai/clip-vit-large-patch14",
local_dir=os.path.join(cache_dir, "text_encoder2/"),
)
text_encoder2_path = os.path.join(cache_dir, "text_encoder2/")
if conf_path is None:
conf = get_default_conf(
dit_path, vae_path, text_encoder_path, text_encoder2_path
)
else:
conf = OmegaConf.load(conf_path)
text_embedder = get_text_embedder(conf.model.text_embedder)
if not offload:
text_embedder = text_embedder.to( device=device_map["text_embedder"])
vae = build_vae(conf.model.vae)
vae = vae.eval()
if not offload:
vae = vae.to(device=device_map["vae"])
dit = get_dit(conf.model.dit_params)
if magcache:
mag_ratios = conf.magcache.mag_ratios
num_steps = conf.model.num_steps
no_cfg = False
if conf.model.guidance_weight == 1.0:
no_cfg = True
set_magcache_params(dit, mag_ratios, num_steps, no_cfg)
state_dict = load_file(conf.model.checkpoint_path)
dit.load_state_dict(state_dict, assign=True)
if not offload:
dit = dit.to(device_map["dit"])
if world_size > 1:
dit = parallelize_dit(dit, device_mesh["tensor_parallel"])
return Kandinsky5T2VPipeline(
device_map=device_map,
dit=dit,
text_embedder=text_embedder,
vae=vae,
resolution=resolution,
local_dit_rank=local_rank,
world_size=world_size,
conf=conf,
offload=offload,
)
def get_default_conf(
dit_path,
vae_path,
text_encoder_path,
text_encoder2_path,
) -> DictConfig:
dit_params = {
"in_visual_dim": 16,
"out_visual_dim": 16,
"time_dim": 512,
"patch_size": [1, 2, 2],
"model_dim": 1792,
"ff_dim": 7168,
"num_text_blocks": 2,
"num_visual_blocks": 32,
"axes_dims": [16, 24, 24],
"visual_cond": True,
"in_text_dim": 3584,
"in_text_dim2": 768,
}
attention = {
"type": "flash",
"causal": False,
"local": False,
"glob": False,
"window": 3,
}
vae = {
"checkpoint_path": vae_path,
"name": "hunyuan",
}
text_embedder = {
"qwen": {
"emb_size": 3584,
"checkpoint_path": text_encoder_path,
"max_length": 256,
},
"clip": {
"checkpoint_path": text_encoder2_path,
"emb_size": 768,
"max_length": 77,
},
}
conf = {
"model": {
"checkpoint_path": dit_path,
"vae": vae,
"text_embedder": text_embedder,
"dit_params": dit_params,
"attention": attention,
"num_steps": 50,
"guidance_weight": 5.0,
},
"metrics": {"scale_factor": (1, 2, 2)},
"resolution": 512,
}
return DictConfig(conf)