Spaces:
Running
Running
| import os | |
| import time | |
| import random | |
| import functools | |
| from typing import List, Optional, Tuple, Union | |
| from pathlib import Path | |
| from einops import rearrange | |
| import torch | |
| import torch.distributed as dist | |
| from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE, NEGATIVE_PROMPT_I2V | |
| from hyvideo.vae import load_vae | |
| from hyvideo.modules import load_model | |
| from hyvideo.text_encoder import TextEncoder | |
| from hyvideo.utils.data_utils import align_to, get_closest_ratio, generate_crop_size_list | |
| from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed, get_nd_rotary_pos_embed_new | |
| from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler | |
| from hyvideo.diffusion.pipelines import HunyuanVideoPipeline | |
| from hyvideo.diffusion.pipelines import HunyuanVideoAudioPipeline | |
| from PIL import Image | |
| import numpy as np | |
| import torchvision.transforms as transforms | |
| import cv2 | |
| from wan.utils.utils import resize_lanczos, calculate_new_dimensions | |
| from hyvideo.data_kits.audio_preprocessor import encode_audio, get_facemask | |
| from transformers import WhisperModel | |
| from transformers import AutoFeatureExtractor | |
| from hyvideo.data_kits.face_align import AlignImage | |
| import librosa | |
| def get_audio_feature(feature_extractor, audio_path, duration): | |
| audio_input, sampling_rate = librosa.load(audio_path, duration=duration, sr=16000) | |
| assert sampling_rate == 16000 | |
| audio_features = [] | |
| window = 750*640 | |
| for i in range(0, len(audio_input), window): | |
| audio_feature = feature_extractor(audio_input[i:i+window], | |
| sampling_rate=sampling_rate, | |
| return_tensors="pt", | |
| device="cuda" | |
| ).input_features | |
| audio_features.append(audio_feature) | |
| audio_features = torch.cat(audio_features, dim=-1) | |
| return audio_features, len(audio_input) // 640 | |
| def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1): | |
| crop_h, crop_w = crop_img.shape[:2] | |
| target_w, target_h = size | |
| scale_h, scale_w = target_h / crop_h, target_w / crop_w | |
| if scale_w > scale_h: | |
| resize_h = int(target_h*resize_ratio) | |
| resize_w = int(crop_w / crop_h * resize_h) | |
| else: | |
| resize_w = int(target_w*resize_ratio) | |
| resize_h = int(crop_h / crop_w * resize_w) | |
| crop_img = cv2.resize(crop_img, (resize_w, resize_h)) | |
| pad_left = (target_w - resize_w) // 2 | |
| pad_top = (target_h - resize_h) // 2 | |
| pad_right = target_w - resize_w - pad_left | |
| pad_bottom = target_h - resize_h - pad_top | |
| crop_img = cv2.copyMakeBorder(crop_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=color) | |
| return crop_img | |
| def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): | |
| num_images, num_image_patches, embed_dim = image_features.shape | |
| batch_size, sequence_length = input_ids.shape | |
| left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) | |
| # 1. Create a mask to know where special image tokens are | |
| special_image_token_mask = input_ids == self.config.image_token_index | |
| num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) | |
| # Compute the maximum embed dimension | |
| max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length | |
| batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) | |
| # 2. Compute the positions where text should be written | |
| # Calculate new positions for text tokens in merged image-text sequence. | |
| # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. | |
| # `torch.cumsum` computes how each image token shifts subsequent text token positions. | |
| # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. | |
| new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 | |
| nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] | |
| if left_padding: | |
| new_token_positions += nb_image_pad[:, None] # offset for left padding | |
| text_to_overwrite = new_token_positions[batch_indices, non_image_indices] | |
| # 3. Create the full embedding, already padded to the maximum position | |
| final_embedding = torch.zeros( | |
| batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device | |
| ) | |
| final_attention_mask = torch.zeros( | |
| batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device | |
| ) | |
| if labels is not None: | |
| final_labels = torch.full( | |
| (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device | |
| ) | |
| # In case the Vision model or the Language model has been offloaded to CPU, we need to manually | |
| # set the corresponding tensors into their correct target device. | |
| target_device = inputs_embeds.device | |
| batch_indices, non_image_indices, text_to_overwrite = ( | |
| batch_indices.to(target_device), | |
| non_image_indices.to(target_device), | |
| text_to_overwrite.to(target_device), | |
| ) | |
| attention_mask = attention_mask.to(target_device) | |
| # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"] | |
| # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features | |
| final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] | |
| final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] | |
| if labels is not None: | |
| final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] | |
| # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) | |
| image_to_overwrite = torch.full( | |
| (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device | |
| ) | |
| image_to_overwrite[batch_indices, text_to_overwrite] = False | |
| image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) | |
| if image_to_overwrite.sum() != image_features.shape[:-1].numel(): | |
| raise ValueError( | |
| f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" | |
| f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." | |
| ) | |
| final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) | |
| final_attention_mask |= image_to_overwrite | |
| position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) | |
| # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. | |
| batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) | |
| indices_to_mask = new_token_positions[batch_indices, pad_indices] | |
| final_embedding[batch_indices, indices_to_mask] = 0 | |
| if labels is None: | |
| final_labels = None | |
| return final_embedding, final_attention_mask, final_labels, position_ids | |
| def patched_llava_forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| pixel_values: torch.FloatTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| vision_feature_layer: Optional[int] = None, | |
| vision_feature_select_strategy: Optional[str] = None, | |
| labels: Optional[torch.LongTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| num_logits_to_keep: int = 0, | |
| ): | |
| from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| vision_feature_layer = ( | |
| vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer | |
| ) | |
| vision_feature_select_strategy = ( | |
| vision_feature_select_strategy | |
| if vision_feature_select_strategy is not None | |
| else self.config.vision_feature_select_strategy | |
| ) | |
| if (input_ids is None) ^ (inputs_embeds is not None): | |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") | |
| if pixel_values is not None and inputs_embeds is not None: | |
| raise ValueError( | |
| "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" | |
| ) | |
| if inputs_embeds is None: | |
| inputs_embeds = self.get_input_embeddings()(input_ids) | |
| image_features = None | |
| if pixel_values is not None: | |
| image_features = self.get_image_features( | |
| pixel_values=pixel_values, | |
| vision_feature_layer=vision_feature_layer, | |
| vision_feature_select_strategy=vision_feature_select_strategy, | |
| ) | |
| inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( | |
| image_features, inputs_embeds, input_ids, attention_mask, labels | |
| ) | |
| cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) | |
| outputs = self.language_model( | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| cache_position=cache_position, | |
| num_logits_to_keep=num_logits_to_keep, | |
| ) | |
| logits = outputs[0] | |
| loss = None | |
| if not return_dict: | |
| output = (logits,) + outputs[1:] | |
| return (loss,) + output if loss is not None else output | |
| return LlavaCausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| image_hidden_states=image_features if pixel_values is not None else None, | |
| ) | |
| def adapt_model(model, audio_block_name): | |
| modules_dict= { k: m for k, m in model.named_modules()} | |
| for model_layer, avatar_layer in model.double_stream_map.items(): | |
| module = modules_dict[f"{audio_block_name}.{avatar_layer}"] | |
| target = modules_dict[f"double_blocks.{model_layer}"] | |
| setattr(target, "audio_adapter", module ) | |
| delattr(model, audio_block_name) | |
| class DataPreprocess(object): | |
| def __init__(self): | |
| self.llava_size = (336, 336) | |
| self.llava_transform = transforms.Compose( | |
| [ | |
| transforms.Resize(self.llava_size, interpolation=transforms.InterpolationMode.BILINEAR), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)), | |
| ] | |
| ) | |
| def get_batch(self, image , size, pad = False): | |
| image = np.asarray(image) | |
| if pad: | |
| llava_item_image = pad_image(image.copy(), self.llava_size) | |
| else: | |
| llava_item_image = image.copy() | |
| uncond_llava_item_image = np.ones_like(llava_item_image) * 255 | |
| if pad: | |
| cat_item_image = pad_image(image.copy(), size) | |
| else: | |
| cat_item_image = image.copy() | |
| llava_item_tensor = self.llava_transform(Image.fromarray(llava_item_image.astype(np.uint8))) | |
| uncond_llava_item_tensor = self.llava_transform(Image.fromarray(uncond_llava_item_image)) | |
| cat_item_tensor = torch.from_numpy(cat_item_image.copy()).permute((2, 0, 1)) / 255.0 | |
| # batch = { | |
| # "pixel_value_llava": llava_item_tensor.unsqueeze(0), | |
| # "uncond_pixel_value_llava": uncond_llava_item_tensor.unsqueeze(0), | |
| # 'pixel_value_ref': cat_item_tensor.unsqueeze(0), | |
| # } | |
| return llava_item_tensor.unsqueeze(0), uncond_llava_item_tensor.unsqueeze(0), cat_item_tensor.unsqueeze(0) | |
| class Inference(object): | |
| def __init__( | |
| self, | |
| i2v, | |
| custom, | |
| avatar, | |
| enable_cfg, | |
| vae, | |
| vae_kwargs, | |
| text_encoder, | |
| model, | |
| text_encoder_2=None, | |
| pipeline=None, | |
| feature_extractor=None, | |
| wav2vec=None, | |
| align_instance=None, | |
| device=None, | |
| ): | |
| self.i2v = i2v | |
| self.custom = custom | |
| self.avatar = avatar | |
| self.enable_cfg = enable_cfg | |
| self.vae = vae | |
| self.vae_kwargs = vae_kwargs | |
| self.text_encoder = text_encoder | |
| self.text_encoder_2 = text_encoder_2 | |
| self.model = model | |
| self.pipeline = pipeline | |
| self.feature_extractor=feature_extractor | |
| self.wav2vec=wav2vec | |
| self.align_instance=align_instance | |
| self.device = "cuda" | |
| def from_pretrained(cls, model_filepath, model_type, base_model_type, text_encoder_filepath, dtype = torch.bfloat16, VAE_dtype = torch.float16, mixed_precision_transformer =torch.bfloat16 , quantizeTransformer = False, save_quantized = False, **kwargs): | |
| device = "cuda" | |
| import transformers | |
| transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.forward = patched_llava_forward # force legacy behaviour to be able to use tansformers v>(4.47) | |
| transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features = _merge_input_ids_with_image_features | |
| torch.set_grad_enabled(False) | |
| text_len = 512 | |
| latent_channels = 16 | |
| precision = "bf16" | |
| vae_precision = "fp32" if VAE_dtype == torch.float32 else "bf16" | |
| embedded_cfg_scale = 6 | |
| filepath = model_filepath[0] | |
| i2v_condition_type = None | |
| i2v_mode = False | |
| custom = False | |
| custom_audio = False | |
| avatar = False | |
| if base_model_type == "hunyuan_i2v": | |
| model_id = "HYVideo-T/2" | |
| i2v_condition_type = "token_replace" | |
| i2v_mode = True | |
| elif base_model_type == "hunyuan_custom": | |
| model_id = "HYVideo-T/2-custom" | |
| custom = True | |
| elif base_model_type == "hunyuan_custom_audio": | |
| model_id = "HYVideo-T/2-custom-audio" | |
| custom_audio = True | |
| custom = True | |
| elif base_model_type == "hunyuan_custom_edit": | |
| model_id = "HYVideo-T/2-custom-edit" | |
| custom = True | |
| elif base_model_type == "hunyuan_avatar": | |
| model_id = "HYVideo-T/2-avatar" | |
| text_len = 256 | |
| avatar = True | |
| else: | |
| model_id = "HYVideo-T/2-cfgdistill" | |
| if i2v_mode and i2v_condition_type == "latent_concat": | |
| in_channels = latent_channels * 2 + 1 | |
| image_embed_interleave = 2 | |
| elif i2v_mode and i2v_condition_type == "token_replace": | |
| in_channels = latent_channels | |
| image_embed_interleave = 4 | |
| else: | |
| in_channels = latent_channels | |
| image_embed_interleave = 1 | |
| out_channels = latent_channels | |
| pinToMemory = kwargs.pop("pinToMemory", False) | |
| partialPinning = kwargs.pop("partialPinning", False) | |
| factor_kwargs = kwargs | {"device": "meta", "dtype": PRECISION_TO_TYPE[precision]} | |
| if embedded_cfg_scale and i2v_mode: | |
| factor_kwargs["guidance_embed"] = True | |
| model = load_model( | |
| model = model_id, | |
| i2v_condition_type = i2v_condition_type, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| factor_kwargs=factor_kwargs, | |
| ) | |
| from mmgp import offload | |
| # model = Inference.load_state_dict(args, model, model_filepath) | |
| # model_filepath ="c:/temp/hc/mp_rank_00_model_states_video.pt" | |
| offload.load_model_data(model, model_filepath, do_quantize= quantizeTransformer and not save_quantized, pinToMemory = pinToMemory, partialPinning = partialPinning) | |
| pass | |
| # offload.save_model(model, "hunyuan_video_avatar_edit_720_bf16.safetensors") | |
| # offload.save_model(model, "hunyuan_video_avatar_edit_720_quanto_bf16_int8.safetensors", do_quantize= True) | |
| if save_quantized: | |
| from wgp import save_quantized_model | |
| save_quantized_model(model, model_type, filepath, dtype, None) | |
| model.mixed_precision = mixed_precision_transformer | |
| if model.mixed_precision : | |
| model._lock_dtype = torch.float32 | |
| model.lock_layers_dtypes(torch.float32) | |
| model.eval() | |
| # ============================= Build extra models ======================== | |
| # VAE | |
| if custom or avatar: | |
| vae_configpath = "ckpts/hunyuan_video_custom_VAE_config.json" | |
| vae_filepath = "ckpts/hunyuan_video_custom_VAE_fp32.safetensors" | |
| # elif avatar: | |
| # vae_configpath = "ckpts/config_vae_avatar.json" | |
| # vae_filepath = "ckpts/vae_avatar.pt" | |
| else: | |
| vae_configpath = "ckpts/hunyuan_video_VAE_config.json" | |
| vae_filepath = "ckpts/hunyuan_video_VAE_fp32.safetensors" | |
| # config = AutoencoderKLCausal3D.load_config("ckpts/hunyuan_video_VAE_config.json") | |
| # config = AutoencoderKLCausal3D.load_config("c:/temp/hvae/config_vae.json") | |
| vae, _, s_ratio, t_ratio = load_vae( "884-16c-hy", vae_path= vae_filepath, vae_config_path= vae_configpath, vae_precision= vae_precision, device= "cpu", ) | |
| vae._model_dtype = torch.float32 if VAE_dtype == torch.float32 else (torch.float16 if avatar else torch.bfloat16) | |
| vae._model_dtype = torch.float32 if VAE_dtype == torch.float32 else torch.bfloat16 | |
| vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio} | |
| enable_cfg = False | |
| # Text encoder | |
| if i2v_mode: | |
| text_encoder = "llm-i2v" | |
| tokenizer = "llm-i2v" | |
| prompt_template = "dit-llm-encode-i2v" | |
| prompt_template_video = "dit-llm-encode-video-i2v" | |
| elif custom or avatar : | |
| text_encoder = "llm-i2v" | |
| tokenizer = "llm-i2v" | |
| prompt_template = "dit-llm-encode" | |
| prompt_template_video = "dit-llm-encode-video" | |
| enable_cfg = True | |
| else: | |
| text_encoder = "llm" | |
| tokenizer = "llm" | |
| prompt_template = "dit-llm-encode" | |
| prompt_template_video = "dit-llm-encode-video" | |
| if prompt_template_video is not None: | |
| crop_start = PROMPT_TEMPLATE[prompt_template_video].get( "crop_start", 0 ) | |
| elif prompt_template is not None: | |
| crop_start = PROMPT_TEMPLATE[prompt_template].get("crop_start", 0) | |
| else: | |
| crop_start = 0 | |
| max_length = text_len + crop_start | |
| # prompt_template | |
| prompt_template = PROMPT_TEMPLATE[prompt_template] if prompt_template is not None else None | |
| # prompt_template_video | |
| prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] if prompt_template_video is not None else None | |
| text_encoder = TextEncoder( | |
| text_encoder_type=text_encoder, | |
| max_length=max_length, | |
| text_encoder_precision="fp16", | |
| tokenizer_type=tokenizer, | |
| i2v_mode=i2v_mode, | |
| prompt_template=prompt_template, | |
| prompt_template_video=prompt_template_video, | |
| hidden_state_skip_layer=2, | |
| apply_final_norm=False, | |
| reproduce=True, | |
| device="cpu", | |
| image_embed_interleave=image_embed_interleave, | |
| text_encoder_path = text_encoder_filepath | |
| ) | |
| text_encoder_2 = TextEncoder( | |
| text_encoder_type="clipL", | |
| max_length=77, | |
| text_encoder_precision="fp16", | |
| tokenizer_type="clipL", | |
| reproduce=True, | |
| device="cpu", | |
| ) | |
| feature_extractor = None | |
| wav2vec = None | |
| align_instance = None | |
| if avatar or custom_audio: | |
| feature_extractor = AutoFeatureExtractor.from_pretrained("ckpts/whisper-tiny/") | |
| wav2vec = WhisperModel.from_pretrained("ckpts/whisper-tiny/").to(device="cpu", dtype=torch.float32) | |
| wav2vec._model_dtype = torch.float32 | |
| wav2vec.requires_grad_(False) | |
| if avatar: | |
| align_instance = AlignImage("cuda", det_path="ckpts/det_align/detface.pt") | |
| align_instance.facedet.model.to("cpu") | |
| adapt_model(model, "audio_adapter_blocks") | |
| elif custom_audio: | |
| adapt_model(model, "audio_models") | |
| return cls( | |
| i2v=i2v_mode, | |
| custom=custom, | |
| avatar=avatar, | |
| enable_cfg = enable_cfg, | |
| vae=vae, | |
| vae_kwargs=vae_kwargs, | |
| text_encoder=text_encoder, | |
| text_encoder_2=text_encoder_2, | |
| model=model, | |
| feature_extractor=feature_extractor, | |
| wav2vec=wav2vec, | |
| align_instance=align_instance, | |
| device=device, | |
| ) | |
| class HunyuanVideoSampler(Inference): | |
| def __init__( | |
| self, | |
| i2v, | |
| custom, | |
| avatar, | |
| enable_cfg, | |
| vae, | |
| vae_kwargs, | |
| text_encoder, | |
| model, | |
| text_encoder_2=None, | |
| pipeline=None, | |
| feature_extractor=None, | |
| wav2vec=None, | |
| align_instance=None, | |
| device=0, | |
| ): | |
| super().__init__( | |
| i2v, | |
| custom, | |
| avatar, | |
| enable_cfg, | |
| vae, | |
| vae_kwargs, | |
| text_encoder, | |
| model, | |
| text_encoder_2=text_encoder_2, | |
| pipeline=pipeline, | |
| feature_extractor=feature_extractor, | |
| wav2vec=wav2vec, | |
| align_instance=align_instance, | |
| device=device, | |
| ) | |
| self.i2v_mode = i2v | |
| self.enable_cfg = enable_cfg | |
| self.pipeline = self.load_diffusion_pipeline( | |
| avatar = self.avatar, | |
| vae=self.vae, | |
| text_encoder=self.text_encoder, | |
| text_encoder_2=self.text_encoder_2, | |
| model=self.model, | |
| device=self.device, | |
| ) | |
| if self.i2v_mode: | |
| self.default_negative_prompt = NEGATIVE_PROMPT_I2V | |
| else: | |
| self.default_negative_prompt = NEGATIVE_PROMPT | |
| def _interrupt(self): | |
| return self.pipeline._interrupt | |
| def _interrupt(self, value): | |
| self.pipeline._interrupt =value | |
| def load_diffusion_pipeline( | |
| self, | |
| avatar, | |
| vae, | |
| text_encoder, | |
| text_encoder_2, | |
| model, | |
| scheduler=None, | |
| device=None, | |
| progress_bar_config=None, | |
| #data_type="video", | |
| ): | |
| """Load the denoising scheduler for inference.""" | |
| if scheduler is None: | |
| scheduler = FlowMatchDiscreteScheduler( | |
| shift=6.0, | |
| reverse=True, | |
| solver="euler", | |
| ) | |
| if avatar: | |
| pipeline = HunyuanVideoAudioPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| text_encoder_2=text_encoder_2, | |
| transformer=model, | |
| scheduler=scheduler, | |
| progress_bar_config=progress_bar_config, | |
| ) | |
| else: | |
| pipeline = HunyuanVideoPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| text_encoder_2=text_encoder_2, | |
| transformer=model, | |
| scheduler=scheduler, | |
| progress_bar_config=progress_bar_config, | |
| ) | |
| return pipeline | |
| def get_rotary_pos_embed_new(self, video_length, height, width, concat_dict={}, enable_riflex = False): | |
| target_ndim = 3 | |
| ndim = 5 - 2 | |
| latents_size = [(video_length-1)//4+1 , height//8, width//8] | |
| if isinstance(self.model.patch_size, int): | |
| assert all(s % self.model.patch_size == 0 for s in latents_size), \ | |
| f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \ | |
| f"but got {latents_size}." | |
| rope_sizes = [s // self.model.patch_size for s in latents_size] | |
| elif isinstance(self.model.patch_size, list): | |
| assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \ | |
| f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \ | |
| f"but got {latents_size}." | |
| rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)] | |
| if len(rope_sizes) != target_ndim: | |
| rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis | |
| head_dim = self.model.hidden_size // self.model.heads_num | |
| rope_dim_list = self.model.rope_dim_list | |
| if rope_dim_list is None: | |
| rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] | |
| assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer" | |
| freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list, | |
| rope_sizes, | |
| theta=256, | |
| use_real=True, | |
| theta_rescale_factor=1, | |
| concat_dict=concat_dict, | |
| L_test = (video_length - 1) // 4 + 1, | |
| enable_riflex = enable_riflex | |
| ) | |
| return freqs_cos, freqs_sin | |
| def get_rotary_pos_embed(self, video_length, height, width, enable_riflex = False): | |
| target_ndim = 3 | |
| ndim = 5 - 2 | |
| # 884 | |
| vae = "884-16c-hy" | |
| if "884" in vae: | |
| latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8] | |
| elif "888" in vae: | |
| latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8] | |
| else: | |
| latents_size = [video_length, height // 8, width // 8] | |
| if isinstance(self.model.patch_size, int): | |
| assert all(s % self.model.patch_size == 0 for s in latents_size), ( | |
| f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " | |
| f"but got {latents_size}." | |
| ) | |
| rope_sizes = [s // self.model.patch_size for s in latents_size] | |
| elif isinstance(self.model.patch_size, list): | |
| assert all( | |
| s % self.model.patch_size[idx] == 0 | |
| for idx, s in enumerate(latents_size) | |
| ), ( | |
| f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " | |
| f"but got {latents_size}." | |
| ) | |
| rope_sizes = [ | |
| s // self.model.patch_size[idx] for idx, s in enumerate(latents_size) | |
| ] | |
| if len(rope_sizes) != target_ndim: | |
| rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis | |
| head_dim = self.model.hidden_size // self.model.heads_num | |
| rope_dim_list = self.model.rope_dim_list | |
| if rope_dim_list is None: | |
| rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)] | |
| assert ( | |
| sum(rope_dim_list) == head_dim | |
| ), "sum(rope_dim_list) should equal to head_dim of attention layer" | |
| freqs_cos, freqs_sin = get_nd_rotary_pos_embed( | |
| rope_dim_list, | |
| rope_sizes, | |
| theta=256, | |
| use_real=True, | |
| theta_rescale_factor=1, | |
| L_test = (video_length - 1) // 4 + 1, | |
| enable_riflex = enable_riflex | |
| ) | |
| return freqs_cos, freqs_sin | |
| def generate( | |
| self, | |
| input_prompt, | |
| input_ref_images = None, | |
| audio_guide = None, | |
| input_frames = None, | |
| input_masks = None, | |
| input_video = None, | |
| fps = 24, | |
| height=192, | |
| width=336, | |
| frame_num=129, | |
| seed=None, | |
| n_prompt=None, | |
| sampling_steps=50, | |
| guide_scale=1.0, | |
| shift=5.0, | |
| embedded_guidance_scale=6.0, | |
| batch_size=1, | |
| num_videos_per_prompt=1, | |
| i2v_resolution="720p", | |
| image_start=None, | |
| enable_RIFLEx = False, | |
| i2v_condition_type: str = "token_replace", | |
| i2v_stability=True, | |
| VAE_tile_size = None, | |
| joint_pass = False, | |
| cfg_star_switch = False, | |
| fit_into_canvas = True, | |
| conditioning_latents_size = 0, | |
| **kwargs, | |
| ): | |
| if VAE_tile_size != None: | |
| self.vae.tile_sample_min_tsize = VAE_tile_size["tile_sample_min_tsize"] | |
| self.vae.tile_latent_min_tsize = VAE_tile_size["tile_latent_min_tsize"] | |
| self.vae.tile_sample_min_size = VAE_tile_size["tile_sample_min_size"] | |
| self.vae.tile_latent_min_size = VAE_tile_size["tile_latent_min_size"] | |
| self.vae.tile_overlap_factor = VAE_tile_size["tile_overlap_factor"] | |
| self.vae.enable_tiling() | |
| i2v_mode= self.i2v_mode | |
| if not self.enable_cfg: | |
| guide_scale=1.0 | |
| # ======================================================================== | |
| # Arguments: seed | |
| # ======================================================================== | |
| if isinstance(seed, torch.Tensor): | |
| seed = seed.tolist() | |
| if seed is None: | |
| seeds = [ | |
| random.randint(0, 1_000_000) | |
| for _ in range(batch_size * num_videos_per_prompt) | |
| ] | |
| elif isinstance(seed, int): | |
| seeds = [ | |
| seed + i | |
| for _ in range(batch_size) | |
| for i in range(num_videos_per_prompt) | |
| ] | |
| elif isinstance(seed, (list, tuple)): | |
| if len(seed) == batch_size: | |
| seeds = [ | |
| int(seed[i]) + j | |
| for i in range(batch_size) | |
| for j in range(num_videos_per_prompt) | |
| ] | |
| elif len(seed) == batch_size * num_videos_per_prompt: | |
| seeds = [int(s) for s in seed] | |
| else: | |
| raise ValueError( | |
| f"Length of seed must be equal to number of prompt(batch_size) or " | |
| f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}." | |
| ) | |
| else: | |
| raise ValueError( | |
| f"Seed must be an integer, a list of integers, or None, got {seed}." | |
| ) | |
| from wan.utils.utils import seed_everything | |
| seed_everything(seed) | |
| generator = [torch.Generator("cuda").manual_seed(seed) for seed in seeds] | |
| # generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds] | |
| # ======================================================================== | |
| # Arguments: target_width, target_height, target_frame_num | |
| # ======================================================================== | |
| if width <= 0 or height <= 0 or frame_num <= 0: | |
| raise ValueError( | |
| f"`height` and `width` and `frame_num` must be positive integers, got height={height}, width={width}, frame_num={frame_num}" | |
| ) | |
| if (frame_num - 1) % 4 != 0: | |
| raise ValueError( | |
| f"`frame_num-1` must be a multiple of 4, got {frame_num}" | |
| ) | |
| target_height = align_to(height, 16) | |
| target_width = align_to(width, 16) | |
| target_frame_num = frame_num | |
| audio_strength = 1 | |
| if input_ref_images != None: | |
| # ip_cfg_scale = 3.0 | |
| ip_cfg_scale = 0 | |
| denoise_strength = 1 | |
| # guide_scale=7.5 | |
| # shift=13 | |
| name = "person" | |
| input_ref_images = input_ref_images[0] | |
| # ======================================================================== | |
| # Arguments: prompt, new_prompt, negative_prompt | |
| # ======================================================================== | |
| if not isinstance(input_prompt, str): | |
| raise TypeError(f"`prompt` must be a string, but got {type(input_prompt)}") | |
| input_prompt = [input_prompt.strip()] | |
| # negative prompt | |
| if n_prompt is None or n_prompt == "": | |
| n_prompt = self.default_negative_prompt | |
| if guide_scale == 1.0: | |
| n_prompt = "" | |
| if not isinstance(n_prompt, str): | |
| raise TypeError( | |
| f"`negative_prompt` must be a string, but got {type(n_prompt)}" | |
| ) | |
| n_prompt = [n_prompt.strip()] | |
| # ======================================================================== | |
| # Scheduler | |
| # ======================================================================== | |
| scheduler = FlowMatchDiscreteScheduler( | |
| shift=shift, | |
| reverse=True, | |
| solver="euler" | |
| ) | |
| self.pipeline.scheduler = scheduler | |
| # --------------------------------- | |
| # Reference condition | |
| # --------------------------------- | |
| img_latents = None | |
| semantic_images = None | |
| denoise_strength = 0 | |
| ip_cfg_scale = 0 | |
| if i2v_mode: | |
| if i2v_resolution == "720p": | |
| bucket_hw_base_size = 960 | |
| elif i2v_resolution == "540p": | |
| bucket_hw_base_size = 720 | |
| elif i2v_resolution == "360p": | |
| bucket_hw_base_size = 480 | |
| else: | |
| raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]") | |
| # semantic_images = [Image.open(i2v_image_path).convert('RGB')] | |
| semantic_images = [image_start.convert('RGB')] # | |
| origin_size = semantic_images[0].size | |
| h, w = origin_size | |
| h, w = calculate_new_dimensions(height, width, h, w, fit_into_canvas) | |
| closest_size = (w, h) | |
| # crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32) | |
| # aspect_ratios = np.array([round(float(h)/float(w), 5) for h, w in crop_size_list]) | |
| # closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list) | |
| ref_image_transform = transforms.Compose([ | |
| transforms.Resize(closest_size), | |
| transforms.CenterCrop(closest_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]) | |
| ]) | |
| semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images] | |
| semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device) | |
| with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True): | |
| img_latents = self.pipeline.vae.encode(semantic_image_pixel_values).latent_dist.mode() # B, C, F, H, W | |
| img_latents.mul_(self.pipeline.vae.config.scaling_factor) | |
| target_height, target_width = closest_size | |
| # ======================================================================== | |
| # Build Rope freqs | |
| # ======================================================================== | |
| if input_ref_images == None: | |
| freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_RIFLEx) | |
| else: | |
| if self.avatar: | |
| w, h = input_ref_images.size | |
| target_height, target_width = calculate_new_dimensions(target_height, target_width, h, w, fit_into_canvas) | |
| if target_width != w or target_height != h: | |
| input_ref_images = input_ref_images.resize((target_width,target_height), resample=Image.Resampling.LANCZOS) | |
| concat_dict = {'mode': 'timecat', 'bias': -1} | |
| freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict) | |
| else: | |
| if input_frames != None: | |
| target_height, target_width = input_frames.shape[-3:-1] | |
| elif input_video != None: | |
| target_height, target_width = input_video.shape[-2:] | |
| concat_dict = {'mode': 'timecat-w', 'bias': -1} | |
| freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(target_frame_num, target_height, target_width, concat_dict, enable_RIFLEx) | |
| n_tokens = freqs_cos.shape[0] | |
| callback = kwargs.pop("callback", None) | |
| callback_steps = kwargs.pop("callback_steps", None) | |
| # ======================================================================== | |
| # Pipeline inference | |
| # ======================================================================== | |
| pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref = None, None, None | |
| if input_ref_images == None: | |
| name = None | |
| else: | |
| pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref = DataPreprocess().get_batch(input_ref_images, (target_width, target_height), pad = self.custom) | |
| ref_latents, uncond_audio_prompts, audio_prompts, face_masks, motion_exp, motion_pose = None, None, None, None, None, None | |
| bg_latents = None | |
| if input_video != None: | |
| pixel_value_bg = input_video.unsqueeze(0) | |
| pixel_value_mask = torch.zeros_like(input_video).unsqueeze(0) | |
| if input_frames != None: | |
| pixel_value_video_bg = input_frames.permute(-1,0,1,2).unsqueeze(0).float() | |
| pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float() | |
| pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.) | |
| if input_video != None: | |
| pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2) | |
| pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2) | |
| else: | |
| pixel_value_bg = pixel_value_video_bg | |
| pixel_value_mask = pixel_value_video_mask | |
| pixel_value_video_mask, pixel_value_video_bg = None, None | |
| if input_video != None or input_frames != None: | |
| if pixel_value_bg.shape[2] < frame_num: | |
| padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] + list(pixel_value_bg.shape[3:]) | |
| pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2) | |
| pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2) | |
| bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample() | |
| pixel_value_mask = pixel_value_mask.div_(127.5).add_(-1.) | |
| mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample() | |
| bg_latents = torch.cat([bg_latents, mask_latents], dim=1) | |
| bg_latents.mul_(self.vae.config.scaling_factor) | |
| if self.avatar: | |
| if n_prompt == None or len(n_prompt) == 0: | |
| n_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, Lens changes" | |
| uncond_pixel_value_llava = pixel_value_llava.clone() | |
| pixel_value_ref = pixel_value_ref.unsqueeze(0) | |
| self.align_instance.facedet.model.to("cuda") | |
| face_masks = get_facemask(pixel_value_ref.to("cuda")*255, self.align_instance, area=3.0) | |
| # iii = (face_masks.squeeze(0).squeeze(0).permute(1,2,0).repeat(1,1,3)*255).cpu().numpy().astype(np.uint8) | |
| # image = Image.fromarray(iii) | |
| # image.save("mask.png") | |
| # jjj = (pixel_value_ref.squeeze(0).squeeze(0).permute(1,2,0)*255).cpu().numpy().astype(np.uint8) | |
| self.align_instance.facedet.model.to("cpu") | |
| # pixel_value_ref = pixel_value_ref.clone().repeat(1,129,1,1,1) | |
| pixel_value_ref = pixel_value_ref.repeat(1,1+4*2,1,1,1) | |
| pixel_value_ref = pixel_value_ref * 2 - 1 | |
| pixel_value_ref_for_vae = rearrange(pixel_value_ref, "b f c h w -> b c f h w") | |
| vae_dtype = self.vae.dtype | |
| with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32): | |
| ref_latents = self.vae.encode(pixel_value_ref_for_vae).latent_dist.sample() | |
| ref_latents = torch.cat( [ref_latents[:,:, :1], ref_latents[:,:, 1:2].repeat(1,1,31,1,1), ref_latents[:,:, -1:]], dim=2) | |
| pixel_value_ref, pixel_value_ref_for_vae = None, None | |
| if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor: | |
| ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor) | |
| else: | |
| ref_latents.mul_(self.vae.config.scaling_factor) | |
| # out_latents= ref_latents / self.vae.config.scaling_factor | |
| # image = self.vae.decode(out_latents, return_dict=False, generator=generator)[0] | |
| # image = image.clamp(-1, 1) | |
| # from wan.utils.utils import cache_video | |
| # cache_video( tensor=image, save_file="decode.mp4", fps=25, nrow=1, normalize=True, value_range=(-1, 1)) | |
| motion_pose = np.array([25] * 4) | |
| motion_exp = np.array([30] * 4) | |
| motion_pose = torch.from_numpy(motion_pose).unsqueeze(0) | |
| motion_exp = torch.from_numpy(motion_exp).unsqueeze(0) | |
| face_masks = torch.nn.functional.interpolate(face_masks.float().squeeze(2), | |
| (ref_latents.shape[-2], | |
| ref_latents.shape[-1]), | |
| mode="bilinear").unsqueeze(2).to(dtype=ref_latents.dtype) | |
| if audio_guide != None: | |
| audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_guide, duration = frame_num/fps ) | |
| audio_prompts = audio_input[0] | |
| weight_dtype = audio_prompts.dtype | |
| if self.custom: | |
| audio_len = min(audio_len, frame_num) | |
| audio_input = audio_input[:, :audio_len] | |
| audio_prompts = encode_audio(self.wav2vec, audio_prompts.to(dtype=self.wav2vec.dtype), fps, num_frames=audio_len) | |
| audio_prompts = audio_prompts.to(self.model.dtype) | |
| segment_size = 129 if self.avatar else frame_num | |
| if audio_prompts.shape[1] <= segment_size: | |
| audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1,segment_size-audio_prompts.shape[1], 1, 1, 1)], dim=1) | |
| else: | |
| audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1, 5, 1, 1, 1)], dim=1) | |
| uncond_audio_prompts = torch.zeros_like(audio_prompts[:,:129]) | |
| samples = self.pipeline( | |
| prompt=input_prompt, | |
| height=target_height, | |
| width=target_width, | |
| video_length=target_frame_num, | |
| num_inference_steps=sampling_steps, | |
| guidance_scale=guide_scale, | |
| negative_prompt=n_prompt, | |
| num_videos_per_prompt=num_videos_per_prompt, | |
| generator=generator, | |
| output_type="pil", | |
| name = name, | |
| pixel_value_ref = pixel_value_ref, | |
| ref_latents=ref_latents, # [1, 16, 1, h//8, w//8] | |
| pixel_value_llava=pixel_value_llava, # [1, 3, 336, 336] | |
| uncond_pixel_value_llava=uncond_pixel_value_llava, | |
| face_masks=face_masks, # [b f h w] | |
| audio_prompts=audio_prompts, | |
| uncond_audio_prompts=uncond_audio_prompts, | |
| motion_exp=motion_exp, | |
| motion_pose=motion_pose, | |
| fps= torch.from_numpy(np.array(fps)), | |
| bg_latents = bg_latents, | |
| audio_strength = audio_strength, | |
| denoise_strength=denoise_strength, | |
| ip_cfg_scale=ip_cfg_scale, | |
| freqs_cis=(freqs_cos, freqs_sin), | |
| n_tokens=n_tokens, | |
| embedded_guidance_scale=embedded_guidance_scale, | |
| data_type="video" if target_frame_num > 1 else "image", | |
| is_progress_bar=True, | |
| vae_ver="884-16c-hy", | |
| enable_tiling=True, | |
| i2v_mode=i2v_mode, | |
| i2v_condition_type=i2v_condition_type, | |
| i2v_stability=i2v_stability, | |
| img_latents=img_latents, | |
| semantic_images=semantic_images, | |
| joint_pass = joint_pass, | |
| cfg_star_rescale = cfg_star_switch, | |
| callback = callback, | |
| callback_steps = callback_steps, | |
| )[0] | |
| if samples == None: | |
| return None | |
| samples = samples.squeeze(0) | |
| return samples | |