File size: 56,868 Bytes
b4feb07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 |
import math
import random
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from diffusers.training_utils import compute_density_for_timestep_sampling
DEFAULT_PROMPT_TEMPLATE = {
"template": (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
),
"crop_start": 95,
}
def get_config_value(args, name):
if hasattr(args, name):
return getattr(args, name)
elif hasattr(args, 'training_config') and hasattr(args.training_config, name):
return getattr(args.training_config, name)
else:
raise AttributeError(f"Neither args nor args.training_config has attribute '{name}'")
# Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline._get_llama_prompt_embeds
def _get_llama_prompt_embeds(
tokenizer,
text_encoder,
prompt: Union[str, List[str]],
prompt_template: Dict[str, Any],
num_videos_per_prompt: int = 1,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 256,
num_hidden_layers_to_skip: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
device = device
dtype = dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
prompt = [prompt_template["template"].format(p) for p in prompt]
crop_start = prompt_template.get("crop_start", None)
if crop_start is None:
prompt_template_input = tokenizer(
prompt_template["template"],
padding="max_length",
return_tensors="pt",
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=False,
)
crop_start = prompt_template_input["input_ids"].shape[-1]
# Remove <|eot_id|> token and placeholder {}
crop_start -= 2
max_sequence_length += crop_start
text_inputs = tokenizer(
prompt,
max_length=max_sequence_length,
padding="max_length",
truncation=True,
return_tensors="pt",
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
)
text_input_ids = text_inputs.input_ids.to(device=device)
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
prompt_embeds = text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_attention_mask,
output_hidden_states=True,
).hidden_states[-(num_hidden_layers_to_skip + 1)]
prompt_embeds = prompt_embeds.to(dtype=dtype)
if crop_start is not None and crop_start > 0:
prompt_embeds = prompt_embeds[:, crop_start:]
prompt_attention_mask = prompt_attention_mask[:, crop_start:]
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
return prompt_embeds, prompt_attention_mask
# Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline._get_clip_prompt_embeds
def _get_clip_prompt_embeds(
tokenizer_2,
text_encoder_2,
prompt: Union[str, List[str]],
num_videos_per_prompt: int = 1,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 77,
) -> torch.Tensor:
device = device
dtype = dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = tokenizer_2(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
_ = tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
prompt_embeds = text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
return prompt_embeds
# Copied from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video.HunyuanVideoPipeline.encode_prompt
def encode_prompt(
tokenizer,
text_encoder,
tokenizer_2,
text_encoder_2,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]] = None,
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 256,
):
if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = _get_llama_prompt_embeds(
tokenizer,
text_encoder,
prompt,
prompt_template,
num_videos_per_prompt,
device=device,
dtype=dtype,
max_sequence_length=max_sequence_length,
)
if pooled_prompt_embeds is None:
if prompt_2 is None:
prompt_2 = prompt
pooled_prompt_embeds = _get_clip_prompt_embeds(
tokenizer_2,
text_encoder_2,
prompt,
num_videos_per_prompt,
device=device,
dtype=dtype,
max_sequence_length=77,
)
return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask
def encode_image(
feature_extractor,
image_encoder,
image: torch.Tensor,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device
image = (image + 1) / 2.0 # [-1, 1] -> [0, 1]
image = feature_extractor(images=image, return_tensors="pt", do_rescale=False).to(
device=device, dtype=image_encoder.dtype
)
image_embeds = image_encoder(**image).last_hidden_state
return image_embeds.to(dtype=dtype)
def get_framepack_input_t2v(
vae,
pixel_values, # [-1, 1], (B, C, F, H, W)
latent_window_size: int = 9,
vanilla_sampling: bool = False,
dtype: Optional[torch.dtype] = None,
is_keep_x0=False,
):
# calculate latent frame count from original frame count (4n+1)
latent_f = (pixel_values.shape[2] - 1) // 4 + 1
# assert latent_f % latent_window_size == 0
# calculate the total number of sections (excluding the first frame, divided by window size)
total_latent_sections = math.floor(latent_f / latent_window_size) # 2.0
if total_latent_sections < 1:
min_frames_needed = latent_window_size * 4 + 1
raise ValueError(
f"Not enough frames for FramePack: {pixel_values.shape[2]} frames ({latent_f} latent frames), minimum required: {min_frames_needed} frames ({latent_window_size + 1} latent frames)"
)
# actual latent frame count (aligned to section boundaries)
latent_f_aligned = total_latent_sections * latent_window_size
# actual video frame count
frame_count_aligned = (latent_f_aligned - 1) * 4 + 1 # 73
if frame_count_aligned != pixel_values.shape[2]: # 73 != 89
print(
f"Frame count mismatch: required={frame_count_aligned} != actual={pixel_values.shape[2]}, trimming to {frame_count_aligned}"
)
pixel_values = pixel_values[
:, :, :frame_count_aligned, :, :
] # torch.Size([1, 3, 89, 480, 832]) -> torch.Size([1, 3, 73, 480, 832])
latent_f = latent_f_aligned # Update to the aligned value
# VAE encode
pixel_values = pixel_values.to(device=vae.device, dtype=vae.dtype)
latents = vae.encode(pixel_values).latent_dist.sample()
latents = latents * vae.config.scaling_factor
latents = latents.to(dtype=dtype)
all_target_latents = []
all_target_latent_indices = []
all_clean_latents = []
all_clean_latent_indices = []
all_clean_latents_2x = []
all_clean_latent_2x_indices = []
all_clean_latents_4x = []
all_clean_latent_4x_indices = []
section_to_video_idx = []
if vanilla_sampling:
# Vanilla Sampling Logic
if is_keep_x0:
for b in range(latents.shape[0]):
video_lat = latents[b : b + 1] # Keep batch dim: 1, C, F_aligned, H, W
for section_index in range(total_latent_sections):
target_start_f = section_index * latent_window_size
target_end_f = target_start_f + latent_window_size
start_latent = video_lat[:, :, 0:1, :, :]
target_latents = video_lat[:, :, target_start_f:target_end_f, :, :]
# Clean latents preparation (Vanilla)
if section_index == 0:
clean_latents_total_count = 2 + 2 + 16
else:
clean_latents_total_count = 1 + 2 + 16
history_latents = torch.zeros(
size=(
1,
16,
clean_latents_total_count,
video_lat.shape[-2],
video_lat.shape[-1],
),
device=video_lat.device,
dtype=video_lat.dtype,
)
history_start_f = 0
video_start_f = target_start_f - clean_latents_total_count
copy_count = clean_latents_total_count
if video_start_f < 0:
history_start_f = -video_start_f
copy_count = clean_latents_total_count - history_start_f
video_start_f = 0
if copy_count > 0:
history_latents[:, :, history_start_f:] = video_lat[
:, :, video_start_f : video_start_f + copy_count, :, :
]
# indices generation (Vanilla): copy from FramePack-F1
if section_index == 0:
indices = torch.arange(0, sum([16, 2, 2, latent_window_size])).unsqueeze(0)
(
clean_latent_4x_indices,
clean_latent_2x_indices,
clean_latent_indices,
latent_indices,
) = indices.split([16, 2, 2, latent_window_size], dim=1)
clean_latents_4x, clean_latents_2x, clean_latents = history_latents.split([16, 2, 2], dim=2)
else:
indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
(
clean_latent_indices_start,
clean_latent_4x_indices,
clean_latent_2x_indices,
clean_latent_1x_indices,
latent_indices,
) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents.split([16, 2, 1], dim=2)
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=2)
all_target_latents.append(target_latents)
all_target_latent_indices.append(latent_indices)
all_clean_latents.append(clean_latents)
all_clean_latent_indices.append(clean_latent_indices)
all_clean_latents_2x.append(clean_latents_2x)
all_clean_latent_2x_indices.append(clean_latent_2x_indices)
all_clean_latents_4x.append(clean_latents_4x)
all_clean_latent_4x_indices.append(clean_latent_4x_indices)
section_to_video_idx.append(b)
else:
for b in range(latents.shape[0]):
video_lat = latents[b : b + 1] # Keep batch dim: 1, C, F_aligned, H, W
for section_index in range(total_latent_sections):
target_start_f = section_index * latent_window_size
target_end_f = target_start_f + latent_window_size
target_latents = video_lat[:, :, target_start_f:target_end_f, :, :]
# Clean latents preparation (Vanilla)
clean_latents_total_count = 2 + 2 + 16
history_latents = torch.zeros(
size=(
1,
16,
clean_latents_total_count,
video_lat.shape[-2],
video_lat.shape[-1],
),
device=video_lat.device,
dtype=video_lat.dtype,
)
history_start_f = 0
video_start_f = target_start_f - clean_latents_total_count
copy_count = clean_latents_total_count
if video_start_f < 0:
history_start_f = -video_start_f
copy_count = clean_latents_total_count - history_start_f
video_start_f = 0
if copy_count > 0:
history_latents[:, :, history_start_f:] = video_lat[
:, :, video_start_f : video_start_f + copy_count, :, :
]
# indices generation (Vanilla): copy from FramePack-F1
indices = torch.arange(0, sum([16, 2, 2, latent_window_size])).unsqueeze(0)
(
clean_latent_4x_indices,
clean_latent_2x_indices,
clean_latent_indices,
latent_indices,
) = indices.split([16, 2, 2, latent_window_size], dim=1)
clean_latents_4x, clean_latents_2x, clean_latents = history_latents.split([16, 2, 2], dim=2)
all_target_latents.append(target_latents)
all_target_latent_indices.append(latent_indices)
all_clean_latents.append(clean_latents)
all_clean_latent_indices.append(clean_latent_indices)
all_clean_latents_2x.append(clean_latents_2x)
all_clean_latent_2x_indices.append(clean_latent_2x_indices)
all_clean_latents_4x.append(clean_latents_4x)
all_clean_latent_4x_indices.append(clean_latent_4x_indices)
section_to_video_idx.append(b)
else:
pass
# Stack all sections into batches
batched_target_latents = torch.cat(all_target_latents, dim=0)
batched_target_latent_indices = torch.cat(all_target_latent_indices, dim=0)
batched_clean_latents = torch.cat(all_clean_latents, dim=0)
batched_clean_latent_indices = torch.cat(all_clean_latent_indices, dim=0)
batched_clean_latents_2x = torch.cat(all_clean_latents_2x, dim=0)
batched_clean_latent_2x_indices = torch.cat(all_clean_latent_2x_indices, dim=0)
batched_clean_latents_4x = torch.cat(all_clean_latents_4x, dim=0)
batched_clean_latent_4x_indices = torch.cat(all_clean_latent_4x_indices, dim=0)
return (
batched_target_latents,
batched_target_latent_indices,
batched_clean_latents,
batched_clean_latent_indices,
batched_clean_latents_2x,
batched_clean_latent_2x_indices,
batched_clean_latents_4x,
batched_clean_latent_4x_indices,
section_to_video_idx,
)
def get_framepack_input_i2v(
vae,
pixel_values, # [-1, 1], (B, C, F, H, W)
latent_window_size: int = 9,
vanilla_sampling: bool = False,
dtype: Optional[torch.dtype] = None,
):
# calculate latent frame count from original frame count (4n+1)
latent_f = (pixel_values.shape[2] - 1) // 4 + 1
# calculate the total number of sections (excluding the first frame, divided by window size)
total_latent_sections = math.floor((latent_f - 1) / latent_window_size) # 2.0
if total_latent_sections < 1:
min_frames_needed = latent_window_size * 4 + 1
raise ValueError(
f"Not enough frames for FramePack: {pixel_values.shape[2]} frames ({latent_f} latent frames), minimum required: {min_frames_needed} frames ({latent_window_size + 1} latent frames)"
)
# actual latent frame count (aligned to section boundaries)
latent_f_aligned = total_latent_sections * latent_window_size + 1
# actual video frame count
frame_count_aligned = (latent_f_aligned - 1) * 4 + 1 # 73
if frame_count_aligned != pixel_values.shape[2]: # 73 != 89
print(
f"Frame count mismatch: required={frame_count_aligned} != actual={pixel_values.shape[2]}, trimming to {frame_count_aligned}"
)
pixel_values = pixel_values[
:, :, :frame_count_aligned, :, :
] # torch.Size([1, 3, 89, 480, 832]) -> torch.Size([1, 3, 73, 480, 832])
latent_f = latent_f_aligned # Update to the aligned value
# VAE encode
pixel_values = pixel_values.to(device=vae.device, dtype=vae.dtype)
latents = vae.encode(pixel_values).latent_dist.sample()
latents = latents * vae.config.scaling_factor
latents = latents.to(dtype=dtype)
all_target_latents = []
all_target_latent_indices = []
all_clean_latents = []
all_clean_latent_indices = []
all_clean_latents_2x = []
all_clean_latent_2x_indices = []
all_clean_latents_4x = []
all_clean_latent_4x_indices = []
section_to_video_idx = []
if vanilla_sampling:
# Vanilla Sampling Logic
for b in range(latents.shape[0]):
video_lat = latents[b : b + 1] # Keep batch dim: 1, C, F_aligned, H, W
for section_index in range(total_latent_sections):
target_start_f = section_index * latent_window_size + 1
target_end_f = target_start_f + latent_window_size
target_latents = video_lat[:, :, target_start_f:target_end_f, :, :]
start_latent = video_lat[:, :, 0:1, :, :]
# Clean latents preparation (Vanilla)
clean_latents_total_count = 1 + 2 + 16
history_latents = torch.zeros(
size=(
1,
16,
clean_latents_total_count,
video_lat.shape[-2],
video_lat.shape[-1],
),
device=video_lat.device,
dtype=video_lat.dtype,
)
history_start_f = 0
video_start_f = target_start_f - clean_latents_total_count
copy_count = clean_latents_total_count
if video_start_f < 0:
history_start_f = -video_start_f
copy_count = clean_latents_total_count - history_start_f
video_start_f = 0
if copy_count > 0:
history_latents[:, :, history_start_f:] = video_lat[
:, :, video_start_f : video_start_f + copy_count, :, :
]
# indices generation (Vanilla): copy from FramePack-F1
indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
(
clean_latent_indices_start,
clean_latent_4x_indices,
clean_latent_2x_indices,
clean_latent_1x_indices,
latent_indices,
) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents.split([16, 2, 1], dim=2)
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=2)
all_target_latents.append(target_latents)
all_target_latent_indices.append(latent_indices)
all_clean_latents.append(clean_latents)
all_clean_latent_indices.append(clean_latent_indices)
all_clean_latents_2x.append(clean_latents_2x)
all_clean_latent_2x_indices.append(clean_latent_2x_indices)
all_clean_latents_4x.append(clean_latents_4x)
all_clean_latent_4x_indices.append(clean_latent_4x_indices)
section_to_video_idx.append(b)
else:
# padding is reversed for inference (future to past)
latent_paddings = list(reversed(range(total_latent_sections))) # [1, 0]
# Note: The padding trick for inference. See the paper for details.
if total_latent_sections > 4:
latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
for b in range(latents.shape[0]):
video_lat = latents[
b : b + 1
] # keep batch dim, (1, C, F, H, W) # torch.Size([1, 16, 19, 60, 104])
# emulate inference step (history latents)
# Note: In inference, history_latents stores *generated* future latents.
# Here, for caching, we just need its shape and type for clean_* tensors.
# The actual content doesn't matter much as clean_* will be overwritten.
history_latents = torch.zeros(
(
1,
video_lat.shape[1],
1 + 2 + 16,
video_lat.shape[3],
video_lat.shape[4],
),
dtype=video_lat.dtype,
).to(video_lat.device) # torch.Size([1, 16, 19, 60, 104])
latent_f_index = latent_f - latent_window_size # Start from the last section # 19 - 9 = 10
section_index = total_latent_sections - 1 # 2 - 1 = 1
for latent_padding in latent_paddings:
is_last_section = (
section_index == 0
) # the last section in inference order == the first section in time
latent_padding_size = latent_padding * latent_window_size
if is_last_section:
assert latent_f_index == 1, "Last section should be starting from frame 1"
# indices generation (same as inference)
indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
(
clean_latent_indices_pre, # Index for start_latent
blank_indices, # Indices for padding (future context in inference)
latent_indices, # Indices for the target latents to predict
clean_latent_indices_post, # Index for the most recent history frame
clean_latent_2x_indices, # Indices for the next 2 history frames
clean_latent_4x_indices, # Indices for the next 16 history frames
) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
# Indices for clean_latents (start + recent history)
clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
# clean latents preparation (emulating inference)
clean_latents_pre = video_lat[:, :, 0:1, :, :] # Always the first frame (start_latent)
clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[
:, :, : 1 + 2 + 16, :, :
].split([1, 2, 16], dim=2)
clean_latents = torch.cat(
[clean_latents_pre, clean_latents_post], dim=2
) # Combine start frame + placeholder
# Target latents for this section (ground truth)
target_latents = video_lat[:, :, latent_f_index : latent_f_index + latent_window_size, :, :]
all_target_latents.append(target_latents)
all_target_latent_indices.append(latent_indices)
all_clean_latents.append(clean_latents)
all_clean_latent_indices.append(clean_latent_indices)
all_clean_latents_2x.append(clean_latents_2x)
all_clean_latent_2x_indices.append(clean_latent_2x_indices)
all_clean_latents_4x.append(clean_latents_4x)
all_clean_latent_4x_indices.append(clean_latent_4x_indices)
section_to_video_idx.append(b)
if is_last_section: # If this was the first section generated in inference (time=0)
# History gets the start frame + the generated first section
generated_latents_for_history = video_lat[:, :, : latent_window_size + 1, :, :]
else:
# History gets the generated current section
generated_latents_for_history = target_latents # Use true latents as stand-in for generated
history_latents = torch.cat([generated_latents_for_history, history_latents], dim=2)
section_index -= 1
latent_f_index -= latent_window_size
# Stack all sections into batches
batched_target_latents = torch.cat(all_target_latents, dim=0)
batched_target_latent_indices = torch.cat(all_target_latent_indices, dim=0)
batched_clean_latents = torch.cat(all_clean_latents, dim=0)
batched_clean_latent_indices = torch.cat(all_clean_latent_indices, dim=0)
batched_clean_latents_2x = torch.cat(all_clean_latents_2x, dim=0)
batched_clean_latent_2x_indices = torch.cat(all_clean_latent_2x_indices, dim=0)
batched_clean_latents_4x = torch.cat(all_clean_latents_4x, dim=0)
batched_clean_latent_4x_indices = torch.cat(all_clean_latent_4x_indices, dim=0)
return (
batched_target_latents,
batched_target_latent_indices,
batched_clean_latents,
batched_clean_latent_indices,
batched_clean_latents_2x,
batched_clean_latent_2x_indices,
batched_clean_latents_4x,
batched_clean_latent_4x_indices,
section_to_video_idx,
)
def get_pyramid_input(
args,
scheduler,
latents, # [b c t h w]
pyramid_stage_num=3,
pyramid_sample_ratios=[1, 2, 1],
pyramid_sample_mode="efficient", # ["efficient", "full", "diffusion_forcing", "stream_sample"]
pyramid_stream_inference_steps=[10, 10, 10],
stream_chunk_size=5,
):
assert pyramid_stage_num == len(pyramid_sample_ratios)
if pyramid_sample_mode not in ["efficient", "full", "diffusion_forcing", "stream_sample"]:
raise ValueError(
f"Invalid pyramid_sample_mode: {pyramid_sample_mode}. Must be one of ['efficient', 'full', 'diffusion_forcing', 'dance_forcing']."
)
# Get clen pyramid latent list
pyramid_latent_list = []
pyramid_latent_list.append(latents)
num_frames, height, width = latents.shape[-3], latents.shape[-2], latents.shape[-1]
for _ in range(pyramid_stage_num - 1):
height //= 2
width //= 2
latents = rearrange(latents, "b c t h w -> (b t) c h w")
latents = torch.nn.functional.interpolate(latents, size=(height, width), mode="bilinear")
latents = rearrange(latents, "(b t) c h w -> b c t h w", t=num_frames)
pyramid_latent_list.append(latents)
pyramid_latent_list = list(reversed(pyramid_latent_list))
# Get pyramid noise list
noise = torch.randn_like(pyramid_latent_list[-1])
device = noise.device
dtype = pyramid_latent_list[-1].dtype
latent_frame_num = noise.shape[2]
input_video_num = noise.shape[0]
height, width = noise.shape[-2], noise.shape[-1]
noise_list = [noise]
cur_noise = noise
for i_s in range(pyramid_stage_num - 1):
height //= 2
width //= 2
cur_noise = rearrange(cur_noise, "b c t h w -> (b t) c h w")
cur_noise = F.interpolate(cur_noise, size=(height, width), mode="bilinear") * 2
cur_noise = rearrange(cur_noise, "(b t) c h w -> b c t h w", t=latent_frame_num)
noise_list.append(cur_noise)
noise_list = list(reversed(noise_list)) # make sure from low res to high res
# Get pyramid target list
if pyramid_sample_mode == "efficient":
assert input_video_num % (int(sum(pyramid_sample_ratios))) == 0
# To calculate the padding batchsize and column size
bsz = input_video_num // int(sum(pyramid_sample_ratios))
column_size = int(sum(pyramid_sample_ratios))
column_to_stage = {}
i_sum = 0
for i_s, column_num in enumerate(pyramid_sample_ratios):
for index in range(i_sum, i_sum + column_num):
column_to_stage[index] = i_s
i_sum += column_num
# from low resolution to high resolution
noisy_latents_list = []
sigmas_list = []
targets_list = []
timesteps_list = []
training_steps = scheduler.config.num_train_timesteps
for index in range(column_size):
i_s = column_to_stage[index]
clean_latent = pyramid_latent_list[i_s][index::column_size] # [bs, c, t, h, w]
last_clean_latent = None if i_s == 0 else pyramid_latent_list[i_s - 1][index::column_size]
start_sigma = scheduler.start_sigmas[i_s]
end_sigma = scheduler.end_sigmas[i_s]
if i_s == 0:
start_point = noise_list[i_s][index::column_size]
else:
# Get the upsampled latent
last_clean_latent = rearrange(last_clean_latent, "b c t h w -> (b t) c h w")
last_clean_latent = F.interpolate(
last_clean_latent,
size=(
last_clean_latent.shape[-2] * 2,
last_clean_latent.shape[-1] * 2,
),
mode="nearest",
)
last_clean_latent = rearrange(last_clean_latent, "(b t) c h w -> b c t h w", t=latent_frame_num)
start_point = start_sigma * noise_list[i_s][index::column_size] + (1 - start_sigma) * last_clean_latent
if i_s == pyramid_stage_num - 1:
end_point = clean_latent
else:
end_point = end_sigma * noise_list[i_s][index::column_size] + (1 - end_sigma) * clean_latent
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=get_config_value(args, 'weighting_scheme'),
batch_size=bsz,
logit_mean=get_config_value(args, 'logit_mean'),
logit_std=get_config_value(args, 'logit_std'),
mode_scale=get_config_value(args, 'mode_scale'),
)
indices = (u * training_steps).long() # Totally 1000 training steps per stage
indices = indices.clamp(0, training_steps - 1)
timesteps = scheduler.timesteps_per_stage[i_s][indices].to(device=device)
# Add noise according to flow matching.
# zt = (1 - texp) * x + texp * z1
sigmas = scheduler.sigmas_per_stage[i_s][indices].to(device=device)
while len(sigmas.shape) < start_point.ndim:
sigmas = sigmas.unsqueeze(-1)
noisy_latents = sigmas * start_point + (1 - sigmas) * end_point
# [stage1_latent, stage2_latent, ..., stagen_latent], which will be concat after patching
noisy_latents_list.append([noisy_latents.to(dtype)])
sigmas_list.append(sigmas.to(dtype))
timesteps_list.append(timesteps.to(dtype))
targets_list.append(start_point - end_point) # The standard rectified flow matching objective
elif pyramid_sample_mode == "full":
# To calculate the batchsize
bsz = input_video_num
# from low resolution to high resolution
noisy_latents_list = []
sigmas_list = []
targets_list = []
timesteps_list = []
training_steps = scheduler.config.num_train_timesteps
for i_s, cur_sample_ratio in zip(range(pyramid_stage_num), pyramid_sample_ratios):
clean_latent = pyramid_latent_list[i_s] # [bs, c, t, h, w]
last_clean_latent = None if i_s == 0 else pyramid_latent_list[i_s - 1]
start_sigma = scheduler.start_sigmas[i_s]
end_sigma = scheduler.end_sigmas[i_s]
if i_s == 0:
start_point = noise_list[i_s]
else:
# Get the upsampled latent
last_clean_latent = rearrange(last_clean_latent, "b c t h w -> (b t) c h w")
last_clean_latent = F.interpolate(
last_clean_latent,
size=(
last_clean_latent.shape[-2] * 2,
last_clean_latent.shape[-1] * 2,
),
mode="nearest",
)
last_clean_latent = rearrange(last_clean_latent, "(b t) c h w -> b c t h w", t=latent_frame_num)
start_point = start_sigma * noise_list[i_s] + (1 - start_sigma) * last_clean_latent
if i_s == pyramid_stage_num - 1:
end_point = clean_latent
else:
end_point = end_sigma * noise_list[i_s] + (1 - end_sigma) * clean_latent
for _ in range(cur_sample_ratio):
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=get_config_value(args, 'weighting_scheme'),
batch_size=bsz,
logit_mean=get_config_value(args, 'logit_mean'),
logit_std=get_config_value(args, 'logit_std'),
mode_scale=get_config_value(args, 'mode_scale'),
)
indices = (u * training_steps).long() # Totally 1000 training steps per stage
indices = indices.clamp(0, training_steps - 1)
timesteps = scheduler.timesteps_per_stage[i_s][indices].to(device=device)
# Add noise according to flow matching.
# zt = (1 - texp) * x + texp * z1
sigmas = scheduler.sigmas_per_stage[i_s][indices].to(device=device)
while len(sigmas.shape) < start_point.ndim:
sigmas = sigmas.unsqueeze(-1)
noisy_latents = sigmas * start_point + (1 - sigmas) * end_point
# [stage1_latent, stage2_latent, ..., stagen_latent]
noisy_latents_list.append(noisy_latents.to(dtype))
sigmas_list.append(sigmas.to(dtype))
timesteps_list.append(timesteps.to(dtype))
targets_list.append(start_point - end_point) # The standard rectified flow matching objective
elif pyramid_sample_mode == "diffusion_forcing":
# To calculate the batchsize
bsz = input_video_num
latent_chunk_num = latent_frame_num // stream_chunk_size
assert latent_frame_num % stream_chunk_size == 0
# from low resolution to high resolution
noisy_latents_list = []
sigmas_list = []
targets_list = []
timesteps_list = []
training_steps = scheduler.config.num_train_timesteps
for i_s, cur_sample_ratio in zip(range(pyramid_stage_num), pyramid_sample_ratios):
clean_latent = pyramid_latent_list[i_s] # [bs, c, t, h, w]
last_clean_latent = None if i_s == 0 else pyramid_latent_list[i_s - 1]
start_sigma = scheduler.start_sigmas[i_s]
end_sigma = scheduler.end_sigmas[i_s]
if i_s == 0:
start_point = noise_list[i_s]
else:
# Get the upsampled latent
last_clean_latent = rearrange(last_clean_latent, "b c t h w -> (b t) c h w")
last_clean_latent = F.interpolate(
last_clean_latent,
size=(
last_clean_latent.shape[-2] * 2,
last_clean_latent.shape[-1] * 2,
),
mode="nearest",
)
last_clean_latent = rearrange(last_clean_latent, "(b t) c h w -> b c t h w", t=latent_frame_num)
start_point = start_sigma * noise_list[i_s] + (1 - start_sigma) * last_clean_latent
if i_s == pyramid_stage_num - 1:
end_point = clean_latent
else:
end_point = end_sigma * noise_list[i_s] + (1 - end_sigma) * clean_latent
for _ in range(cur_sample_ratio):
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=get_config_value(args, 'weighting_scheme'),
batch_size=bsz * latent_chunk_num,
logit_mean=get_config_value(args, 'logit_mean'),
logit_std=get_config_value(args, 'logit_std'),
mode_scale=get_config_value(args, 'mode_scale'),
)
indices = (u * training_steps).long() # Totally 1000 training steps per stage
indices = indices.clamp(0, training_steps - 1)
timesteps = scheduler.timesteps_per_stage[i_s][indices].to(device=device)
timesteps = timesteps.view(bsz, latent_chunk_num) # [bsz, latent_chunk_num]
sigmas = scheduler.sigmas_per_stage[i_s][indices].to(device=device)
sigmas = sigmas.view(bsz, latent_chunk_num) # [bsz, latent_chunk_num]
chunk_index = (
torch.arange(latent_frame_num, device=device).unsqueeze(0).expand(bsz, -1) // stream_chunk_size
)
chunk_index = chunk_index.clamp(max=latent_chunk_num - 1)
sigmas = torch.gather(sigmas, 1, chunk_index) # [bsz, t]
timesteps = torch.gather(timesteps, 1, chunk_index)
# Add noise according to flow matching.
# zt = (1 - texp) * x + texp * z1
sigmas = (
sigmas.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
) # reshape to [bsz, 1, t, 1, 1] for broadcasting
noisy_latents = sigmas * start_point + (1 - sigmas) * end_point
# [stage1_latent, stage2_latent, ..., stagen_latent]
noisy_latents_list.append(noisy_latents.to(dtype)) # torch.Size([2, 16, 10, 12, 20])
sigmas_list.append(sigmas.to(dtype)) # torch.Size([2, 1, 10, 1, 1])
timesteps_list.append(timesteps.to(dtype)) # torch.Size([2, 10])
targets_list.append(start_point - end_point) # The standard rectified flow matching objective
elif pyramid_sample_mode == "stream_sample":
# training_all_progressive_timesteps
# skip 0. (1, max_inference_steps):[1.3850, 44.1200, 86.8550, 129.5900, 172.3250,
# 215.0600, 257.7950, 300.5300, 343.2650, 386.0000,
# 386.3580, 426.0960, 465.8340, 505.5720, 545.3100,
# 585.0480, 624.7860, 664.5240, 704.2620, 744.0000,
# 744.2560, 772.6720, 801.0880, 829.5040, 857.9200,
# 886.3360, 914.7520, 943.1680, 971.5840, 1000.0000]
# progressive_timesteps_stages
# stream_chunk_size=3:
# [ 386., 386., 386., 744., 744., 744., 1000., 1000., 1000.] high, mid, low
# [343.2650, 343.2650, 343.2650, 704.2620, 704.2620, 704.2620, 971.5840, 971.5840, 971.5840] high, mid, low
# [300.5300, 300.5300, 300.5300, 664.5240, 664.5240, 664.5240, 943.1680, 943.1680, 943.1680] high, mid, low
# [257.7950, 257.7950, 257.7950, 624.7860, 624.7860, 624.7860, 914.7520, 914.7520, 914.7520] high, mid, low
# [215.0600, 215.0600, 215.0600, 585.0480, 585.0480, 585.0480, 886.3360, 886.3360, 886.3360] high, mid, low
# [172.3250, 172.3250, 172.3250, 545.3100, 545.3100, 545.3100, 857.9200, 857.9200, 857.9200] high, mid, low
# [129.5900, 129.5900, 129.5900, 505.5720, 505.5720, 505.5720, 829.5040, 829.5040, 829.5040] high, mid, low
# [ 86.8550, 86.8550, 86.8550, 465.8340, 465.8340, 465.8340, 801.0880, 801.0880, 801.0880] high, mid, low
# [ 44.1200, 44.1200, 44.1200, 426.0960, 426.0960, 426.0960, 772.6720, 772.6720, 772.6720] high, mid, low
# [ 1.3850, 1.3850, 1.3850, 386.3580, 386.3580, 386.3580, 744.2560, 744.2560, 744.2560] high, mid, low
# stream_chunk_size=5, shape = (training_num_steps_to_be_saved, latent_frame_num):
# [545.3100, 545.3100, 545.3100, 545.3100, 545.3100, 1000.0000, 1000.0000, 1000.0000, 1000.0000, 1000.0000] mid, low
# [505.5720, 505.5720, 505.5720, 505.5720, 505.5720, 971.5840, 971.5840, 971.5840, 971.5840, 971.5840] mid, low
# [465.8340, 465.8340, 465.8340, 465.8340, 465.8340, 943.1680, 943.1680, 943.1680, 943.1680, 943.1680] mid, low
# [426.0960, 426.0960, 426.0960, 426.0960, 426.0960, 914.7520, 914.7520, 914.7520, 914.7520, 914.7520] mid, low
# [386.3580, 386.3580, 386.3580, 386.3580, 386.3580, 886.3360, 886.3360, 886.3360, 886.3360, 886.3360] mid, low
# [386.0000, 386.0000, 386.0000, 386.0000, 386.0000, 857.9200, 857.9200, 857.9200, 857.9200, 857.9200] high, low
# [343.2650, 343.2650, 343.2650, 343.2650, 343.2650, 829.5040, 829.5040, 829.5040, 829.5040, 829.5040] high, low
# [300.5300, 300.5300, 300.5300, 300.5300, 300.5300, 801.0880, 801.0880, 801.0880, 801.0880, 801.0880] high, low
# [257.7950, 257.7950, 257.7950, 257.7950, 257.7950, 772.6720, 772.6720, 772.6720, 772.6720, 772.6720] high, low
# [215.0600, 215.0600, 215.0600, 215.0600, 215.0600, 744.2560, 744.2560, 744.2560, 744.2560, 744.2560] high, low
# [172.3250, 172.3250, 172.3250, 172.3250, 172.3250, 744.0000, 744.0000, 744.0000, 744.0000, 744.0000] high, mid
# [129.5900, 129.5900, 129.5900, 129.5900, 129.5900, 704.2620, 704.2620, 704.2620, 704.2620, 704.2620] high, mid
# [ 86.8550, 86.8550, 86.8550, 86.8550, 86.8550, 664.5240, 664.5240, 664.5240, 664.5240, 664.5240] high, mid
# [ 44.1200, 44.1200, 44.1200, 44.1200, 44.1200, 624.7860, 624.7860, 624.7860, 624.7860, 624.7860] high, mid
# [ 1.3850, 1.3850, 1.3850, 1.3850, 1.3850, 585.0480, 585.0480, 585.0480, 585.0480, 585.0480] high, mid
# To calculate the batchsize
bsz = input_video_num
# Get multi stage timesteps for streamgen
(
training_num_steps_to_be_saved,
training_all_timesteps_stage_ids,
training_all_progressive_timesteps,
progressive_timesteps_stages,
) = get_stream_sample(
scheduler=scheduler,
max_latent_frame_num=latent_frame_num,
stream_chunk_size=stream_chunk_size,
pyramid_stage_num=pyramid_stage_num,
pyramid_stream_inference_steps=pyramid_stream_inference_steps,
)
timestep_to_stage = {
float(t.item()): int(stage.item())
for t, stage in zip(training_all_progressive_timesteps[0], training_all_timesteps_stage_ids[0])
}
while True:
initialization = random.choice([True, False])
termination = random.choice([True, False])
if not (initialization and termination): # Make sure not both are True
break
stage_i = random.randint(0, training_num_steps_to_be_saved - 1)
timesteps = progressive_timesteps_stages[stage_i].clone().repeat(bsz, 1) # (b, f)
if initialization: # get the ending timesteps, [999]x5 from [91, 192, ..., 999]x5
timesteps = timesteps[:, -latent_frame_num:]
elif termination: # get the starting timesteps, [91]x5 from [91, ..., 999]x5
timesteps = timesteps[:, :latent_frame_num]
# For stage mapping / Get sigmas
sigmas, stage_latent_mapping = get_sigmas_from_pyramid_timesteps(scheduler, timesteps, timestep_to_stage)
# To device
timesteps = timesteps.to(device)
sigmas = sigmas.to(device)
# Get pyramid stage points
stage_point_list = []
for i_s in range(pyramid_stage_num):
clean_latent = pyramid_latent_list[i_s] # [bs, c, t, h, w]
last_clean_latent = None if i_s == 0 else pyramid_latent_list[i_s - 1]
start_sigma = scheduler.start_sigmas[i_s]
end_sigma = scheduler.end_sigmas[i_s]
if i_s == 0:
start_point = noise_list[i_s]
else:
# Get the upsampled latent
last_clean_latent = rearrange(last_clean_latent, "b c t h w -> (b t) c h w")
last_clean_latent = F.interpolate(
last_clean_latent,
size=(
last_clean_latent.shape[-2] * 2,
last_clean_latent.shape[-1] * 2,
),
mode="nearest",
)
last_clean_latent = rearrange(last_clean_latent, "(b t) c h w -> b c t h w", t=latent_frame_num)
start_point = start_sigma * noise_list[i_s] + (1 - start_sigma) * last_clean_latent
if i_s == pyramid_stage_num - 1:
end_point = clean_latent
else:
end_point = end_sigma * noise_list[i_s] + (1 - end_sigma) * clean_latent
stage_point_list.append((start_point, end_point))
noisy_latents_list = [] # torch.Size([2, 16, 10, 12, 20])
targets_list = [] # torch.Size([2, 16, 10, 12, 20])
sigmas_list = [] # torch.Size([2, 1, 10, 1, 1])
timesteps_list = [] # torch.Size([2, 10])
temp_noisy_latents_list = []
temp_targets_list = []
unique_elements = list(map(int, torch.unique(stage_latent_mapping)))
for cur_stage in reversed(unique_elements):
stage_indices = torch.nonzero(stage_latent_mapping == cur_stage, as_tuple=True)
start_index = stage_indices[1][0].item()
end_index = start_index + stream_chunk_size
start_point, end_point = stage_point_list[cur_stage]
start_point_slice = start_point[:, :, start_index:end_index, :, :]
end_point_slice = end_point[:, :, start_index:end_index, :, :]
sigmas_slice = sigmas[:, :, start_index:end_index, :, :]
noisy_latents = sigmas_slice * start_point_slice + (1 - sigmas_slice) * end_point_slice
target = start_point_slice - end_point_slice
temp_noisy_latents_list.append(noisy_latents.to(dtype))
temp_targets_list.append(target)
noisy_latents_list.append(temp_noisy_latents_list)
targets_list.append(temp_targets_list)
sigmas_list.append(sigmas.to(dtype))
timesteps_list.append(timesteps.to(dtype=dtype))
return noisy_latents_list, sigmas_list, timesteps_list, targets_list
def get_sigmas_from_pyramid_timesteps(scheduler, timesteps, timestep_to_stage):
# For stage mapping
flat_timesteps = timesteps.flatten()
stage_latent_mapping = torch.tensor(
[timestep_to_stage.get(float(t.item()), -1) for t in flat_timesteps],
device=timesteps.device,
).view(timesteps.shape)
# Get sigmas
sigmas = torch.full_like(timesteps, -1.0)
for i in range(timesteps.shape[0]):
for j in range(timesteps.shape[1]):
temp_stage_mapping = int(stage_latent_mapping[i, j])
target_value = timesteps[i, j]
temp_indice = (
(
torch.isclose(
scheduler.timesteps_per_stage[temp_stage_mapping],
target_value.clone().detach().to(scheduler.timesteps_per_stage[temp_stage_mapping].dtype),
)
)
.nonzero(as_tuple=True)[0]
.item()
)
sigmas[i, j] = scheduler.sigmas_per_stage[temp_stage_mapping][temp_indice]
sigmas = sigmas.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
return sigmas, stage_latent_mapping
def get_stream_sample(
scheduler,
max_latent_frame_num,
stream_chunk_size,
pyramid_stage_num=3,
pyramid_stream_inference_steps=[10, 10, 10],
):
max_inference_steps = sum(pyramid_stream_inference_steps)
# Set training all progressive timesteps and stage mapping
all_progressive_timesteps_list = []
timestep_stage_list = []
for stage_idx in range(pyramid_stage_num):
scheduler.set_timesteps(pyramid_stream_inference_steps[stage_idx], stage_idx)
temp_timesteps = scheduler.timesteps # shape: (n_i,)
all_progressive_timesteps_list.append(temp_timesteps)
timestep_stage_list.append(
torch.full_like(temp_timesteps, fill_value=stage_idx)
) # same shape, filled with stage_idx
all_progressive_timesteps = torch.cat(all_progressive_timesteps_list).unsqueeze(0).flip(1) # (1, T)
all_timesteps_stage_ids = torch.cat(timestep_stage_list).unsqueeze(0).flip(1)
# Set training progressive timesteps stages
# every stream_chunk_size frames is treated as one, using the same noise level. f' = f / c
assert max_latent_frame_num % stream_chunk_size == 0, (
f"num_frames should be multiple of stream_chunk_size, {max_latent_frame_num} % {stream_chunk_size} != 0"
)
assert max_inference_steps % (max_latent_frame_num // stream_chunk_size) == 0, (
f"max_inference_steps should be multiple of max_latent_frame_num // stream_chunk_size, {max_inference_steps} % {max_latent_frame_num // stream_chunk_size} != 0"
)
num_steps_to_be_saved = max_inference_steps // (
max_latent_frame_num // stream_chunk_size
) # every m steps, save stream_chunk_size frames. m = t / f' = t / (f / c) = c * (t / f)
# (b, t) -> [(b, t / m) in reverse range(m)] -> [(b, f) in reverse range(m)]
progressive_timesteps_stages = [
repeat(
all_progressive_timesteps[:, (num_steps_to_be_saved - 1) - s :: num_steps_to_be_saved],
"b f -> b f c",
c=stream_chunk_size,
).flatten(1, 2)
for s in range(num_steps_to_be_saved)
]
return num_steps_to_be_saved, all_timesteps_stage_ids, all_progressive_timesteps, progressive_timesteps_stages
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--weighting_scheme",
type=str,
default="logit_normal",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
)
parser.add_argument(
"--logit_mean",
type=float,
default=0.0,
help="mean to use when using the `'logit_normal'` weighting scheme.",
)
parser.add_argument(
"--logit_std",
type=float,
default=1.0,
help="std to use when using the `'logit_normal'` weighting scheme.",
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
args = parser.parse_args()
device = "cuda"
import sys
sys.path.append("../")
from scheduler.scheduling_flow_matching_pyramid import PyramidFlowMatchEulerDiscreteScheduler
stages = [1, 2, 4]
timestep_shift = 1.0
stage_range = [0, 1 / 3, 2 / 3, 1]
scheduler_gamma = 1 / 3
scheduler = PyramidFlowMatchEulerDiscreteScheduler(
shift=timestep_shift,
stages=len(stages),
stage_range=stage_range,
gamma=scheduler_gamma,
)
print(
f"The start sigmas and end sigmas of each stage is Start: {scheduler.start_sigmas}, End: {scheduler.end_sigmas}, Ori_start: {scheduler.ori_start_sigmas}"
)
# Test get_framepack_input
from diffusers import AutoencoderKLHunyuanVideo
# 5: (21, 41, 61, 81, 101)
# 6: (25, 49, 73, 97, 121)
# 7: (29, 57, 85, 113, 141)
# 8: (33, 65, 97, 129, 161)
# 9: (37, 73, 109, 145, 181)
# 10: (41, 81, 121, 161, 201)
# 11: (45, 89, 133, 177, 221)
# 12: (49, 97, 145, 193, 241)
pixel_values = torch.randn([2, 3, 241, 384, 640], device=device).clamp(-1, 1)
pixel_values = pixel_values.to(torch.bfloat16)
vae = AutoencoderKLHunyuanVideo.from_pretrained(
"/mnt/workspace/checkpoints/hunyuanvideo-community/HunyuanVideo/",
subfolder="vae",
weight_dtype=torch.bfloat16,
).to(device)
vae.requires_grad_(False)
vae.eval()
(
model_input, # torch.Size([2, 16, 9, 60, 104])
indices_latents, # torch.Size([2, 9])
latents_clean, # torch.Size([2, 16, 2, 60, 104])
indices_clean_latents, # torch.Size([2, 2])
latents_history_2x, # torch.Size([2, 16, 2, 60, 104])
indices_latents_history_2x, # torch.Size([2, 2])
latents_history_4x, # torch.Size([2, 16, 16, 60, 104])
indices_latents_history_4x, # torch.Size([2, 16])
section_to_video_idx,
) = get_framepack_input_i2v(
vae=vae,
pixel_values=pixel_values, # torch.Size([1, 3, 73, 480, 832])
latent_window_size=12,
vanilla_sampling=False,
dtype=torch.bfloat16,
)
print(indices_latents, "\n", indices_clean_latents, "\n", indices_latents_history_2x, "\n", indices_latents_history_4x)
# print(
# indices_latents,
# "\n",
# indices_clean_latents,
# "\n",
# indices_latents_history_2x,
# "\n",
# indices_latents_history_4x,
# )
# Test get_pyramid_input
# model_input = torch.randn([2, 16, 10, 48, 80], device=device)
# noisy_model_input_list, sigmas_list, timesteps_list, targets_list = get_pyramid_input(
# args=args,
# scheduler=scheduler,
# latents=model_input,
# pyramid_stage_num=3,
# pyramid_sample_ratios=[1, 2, 1],
# pyramid_sample_mode="stream_sample",
# stream_chunk_size=3,
# pyramid_stream_inference_steps=[10, 10, 10],
# )
# if isinstance(noisy_model_input_list[0], list):
# total_sample_count = sum(y.shape[0] for x in noisy_model_input_list for y in x)
# else:
# total_sample_count = sum(x.shape[0] for x in noisy_model_input_list)
# batch_size = model_input.shape[0]
|