Kandinsky / kandinsky /generation_utils.py
rahul7star's picture
Migrated from GitHub
0084610 verified
import os
os.environ["TOKENIZERS_PARALLELISM"] = "False"
import torch
from tqdm import tqdm
from .models.utils import fast_sta_nabla
def get_sparse_params(conf, batch_embeds, device):
assert conf.model.dit_params.patch_size[0] == 1
T, H, W, _ = batch_embeds["visual"].shape
T, H, W = (
T // conf.model.dit_params.patch_size[0],
H // conf.model.dit_params.patch_size[1],
W // conf.model.dit_params.patch_size[2],
)
if conf.model.attention.type == "nabla":
sta_mask = fast_sta_nabla(T, H // 8, W // 8, conf.model.attention.wT,
conf.model.attention.wH, conf.model.attention.wW, device=device)
sparse_params = {
"sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0),
"attention_type": conf.model.attention.type,
"to_fractal": True,
"P": conf.model.attention.P,
"wT": conf.model.attention.wT,
"wW": conf.model.attention.wW,
"wH": conf.model.attention.wH,
"add_sta": conf.model.attention.add_sta,
"visual_shape": (T, H, W),
"method": getattr(conf.model.attention, "method", "topcdf"),
}
else:
sparse_params = None
return sparse_params
@torch.no_grad()
def get_velocity(
dit,
x,
t,
text_embeds,
null_text_embeds,
visual_rope_pos,
text_rope_pos,
null_text_rope_pos,
guidance_weight,
conf,
sparse_params=None,
):
pred_velocity = dit(
x,
text_embeds["text_embeds"],
text_embeds["pooled_embed"],
t * 1000,
visual_rope_pos,
text_rope_pos,
scale_factor=conf.metrics.scale_factor,
sparse_params=sparse_params,
)
if abs(guidance_weight - 1.0) > 1e-6:
uncond_pred_velocity = dit(
x,
null_text_embeds["text_embeds"],
null_text_embeds["pooled_embed"],
t * 1000,
visual_rope_pos,
null_text_rope_pos,
scale_factor=conf.metrics.scale_factor,
sparse_params=sparse_params,
)
pred_velocity = uncond_pred_velocity + guidance_weight * (
pred_velocity - uncond_pred_velocity
)
return pred_velocity
@torch.no_grad()
def generate(
model,
device,
shape,
num_steps,
text_embeds,
null_text_embeds,
visual_rope_pos,
text_rope_pos,
null_text_rope_pos,
guidance_weight,
scheduler_scale,
conf,
progress=False,
seed=6554,
):
g = torch.Generator(device="cuda")
g.manual_seed(seed)
img = torch.randn(*shape, device=device, generator=g)
sparse_params = get_sparse_params(conf, {"visual": img}, device)
timesteps = torch.linspace(1, 0, num_steps + 1, device=device)
timesteps = scheduler_scale * timesteps / (1 + (scheduler_scale - 1) * timesteps)
for timestep, timestep_diff in tqdm(list(zip(timesteps[:-1], torch.diff(timesteps)))):
time = timestep.unsqueeze(0)
if model.visual_cond:
visual_cond = torch.zeros_like(img)
visual_cond_mask = torch.zeros(
[*img.shape[:-1], 1], dtype=img.dtype, device=img.device
)
model_input = torch.cat([img, visual_cond, visual_cond_mask], dim=-1)
else:
model_input = img
pred_velocity = get_velocity(
model,
model_input,
time,
text_embeds,
null_text_embeds,
visual_rope_pos,
text_rope_pos,
null_text_rope_pos,
guidance_weight,
conf,
sparse_params=sparse_params,
)
img = img + timestep_diff * pred_velocity
return img
def generate_sample(
shape,
caption,
dit,
vae,
conf,
text_embedder,
num_steps=25,
guidance_weight=5.0,
scheduler_scale=1,
negative_caption="",
seed=6554,
device="cuda",
vae_device="cuda",
text_embedder_device="cuda",
progress=True,
offload=False,
):
bs, duration, height, width, dim = shape
if duration == 1:
type_of_content = "image"
else:
type_of_content = "video"
with torch.no_grad():
bs_text_embed, text_cu_seqlens = text_embedder.encode(
[caption], type_of_content=type_of_content
)
bs_null_text_embed, null_text_cu_seqlens = text_embedder.encode(
[negative_caption], type_of_content=type_of_content
)
if offload:
text_embedder = text_embedder.to('cpu')
for key in bs_text_embed:
bs_text_embed[key] = bs_text_embed[key].to(device=device)
bs_null_text_embed[key] = bs_null_text_embed[key].to(device=device)
text_cu_seqlens = text_cu_seqlens.to(device=device)[-1].item()
null_text_cu_seqlens = null_text_cu_seqlens.to(device=device)[-1].item()
visual_rope_pos = [
torch.arange(duration),
torch.arange(shape[-3] // conf.model.dit_params.patch_size[1]),
torch.arange(shape[-2] // conf.model.dit_params.patch_size[2]),
]
text_rope_pos = torch.arange(text_cu_seqlens)
null_text_rope_pos = torch.arange(null_text_cu_seqlens)
if offload:
dit.to(device, non_blocking=True)
with torch.no_grad():
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
latent_visual = generate(
dit,
device,
(bs * duration, height, width, dim),
num_steps,
bs_text_embed,
bs_null_text_embed,
visual_rope_pos,
text_rope_pos,
null_text_rope_pos,
guidance_weight,
scheduler_scale,
conf,
seed=seed,
progress=progress,
)
if offload:
dit = dit.to('cpu', non_blocking=True)
torch.cuda.empty_cache()
if offload:
vae = vae.to(vae_device, non_blocking=True)
with torch.no_grad():
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
images = latent_visual.reshape(
bs,
-1,
latent_visual.shape[-3],
latent_visual.shape[-2],
latent_visual.shape[-1],
)
images = images.to(device=vae_device)
images = (images / vae.config.scaling_factor).permute(0, 4, 1, 2, 3)
images = vae.decode(images).sample
images = ((images.clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8)
if offload:
vae = vae.to('cpu', non_blocking=True)
torch.cuda.empty_cache()
return images