diff --git a/app.py b/app.py index 21c4bfbd3669a90bee4bf6b69ab7025343a35dd4..69300337ef661db2214ed1ba05ed2e9b8f0fffc3 100644 --- a/app.py +++ b/app.py @@ -19,6 +19,9 @@ if __name__ == "__main__": # # "sequential_cpu_offload" means that each layer of the model will be moved to the CPU after use, # resulting in slower speeds but saving a large amount of GPU memory. + # + # EasyAnimateV1, V2 and V3 support "model_cpu_offload" "sequential_cpu_offload" + # EasyAnimateV4, V5 and V5.1 support "model_cpu_offload" "model_cpu_offload_and_qfloat8" "sequential_cpu_offload" GPU_memory_mode = "model_cpu_offload_and_qfloat8" # Use torch.float16 if GPU does not support torch.bfloat16 # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 @@ -29,11 +32,11 @@ if __name__ == "__main__": server_port = 7860 # Params below is used when ui_mode = "modelscope" - edition = "v5" + edition = "v5.1" # Config - config_path = "config/easyanimate_video_v5_magvit_multi_text_encoder.yaml" + config_path = "config/easyanimate_video_v5.1_magvit_qwen.yaml" # Model path of the pretrained model - model_name = "models/Diffusion_Transformer/EasyAnimateV5-12b-zh-InP" + model_name = "models/Diffusion_Transformer/EasyAnimateV5.1-12b-zh-InP" # "Inpaint" or "Control" model_type = "Inpaint" # Save dir @@ -46,18 +49,6 @@ if __name__ == "__main__": else: demo, controller = ui(GPU_memory_mode, weight_dtype) - # launch gradio - app, _, _ = demo.queue(status_update_rate=1).launch( - server_name=server_name, - server_port=server_port, - prevent_thread_lock=True - ) - - # launch api - infer_forward_api(None, app, controller) - update_diffusion_transformer_api(None, app, controller) - update_edition_api(None, app, controller) - - # not close the python - while True: - time.sleep(5) \ No newline at end of file + demo.launch( + server_name=server_name, server_port=server_port + ) \ No newline at end of file diff --git a/config/easyanimate_video_v5.1_magvit_qwen.yaml b/config/easyanimate_video_v5.1_magvit_qwen.yaml new file mode 100644 index 0000000000000000000000000000000000000000..126b7d96641c8ca7ea5eece3a22d1cd5c0c06930 --- /dev/null +++ b/config/easyanimate_video_v5.1_magvit_qwen.yaml @@ -0,0 +1,21 @@ +transformer_additional_kwargs: + transformer_type: "EasyAnimateTransformer3DModel" + after_norm: false + time_position_encoding_type: "3d_rope" + resize_inpaint_mask_directly: true + enable_text_attention_mask: true + enable_clip_in_inpaint: false + add_ref_latent_in_control_model: true + +vae_kwargs: + vae_type: "AutoencoderKLMagvit" + mini_batch_encoder: 4 + mini_batch_decoder: 1 + slice_mag_vae: false + slice_compression_vae: false + cache_compression_vae: false + cache_mag_vae: true + +text_encoder_kwargs: + enable_multi_text_encoder: false + replace_t5_to_llm: true \ No newline at end of file diff --git a/easyanimate/api/api.py b/easyanimate/api/api.py index c60415179cb8edc918fc30ea602fda5dd3303b48..f2783529a6b71ab88cb9805e420f0f2c00b54ed1 100644 --- a/easyanimate/api/api.py +++ b/easyanimate/api/api.py @@ -93,7 +93,7 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller): lora_model_path = datas.get('lora_model_path', 'none') lora_alpha_slider = datas.get('lora_alpha_slider', 0.55) prompt_textbox = datas.get('prompt_textbox', None) - negative_prompt_textbox = datas.get('negative_prompt_textbox', 'Blurring, mutation, deformation, distortion, dark and solid, comics.') + negative_prompt_textbox = datas.get('negative_prompt_textbox', 'Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art.') sampler_dropdown = datas.get('sampler_dropdown', 'Euler') sample_step_slider = datas.get('sample_step_slider', 30) resize_method = datas.get('resize_method', "Generate by") diff --git a/easyanimate/api/post_infer.py b/easyanimate/api/post_infer.py index 950e6df9ade98f4edc7d15345b8c63067d8c2ae3..38ef4faefe1308c279677c28f27612db0c0369ee 100644 --- a/easyanimate/api/post_infer.py +++ b/easyanimate/api/post_infer.py @@ -54,14 +54,14 @@ if __name__ == '__main__': # -------------------------- # # Step 1: update edition # -------------------------- # - edition = "v5" + edition = "v5.1" outputs = post_update_edition(edition) print('Output update edition: ', outputs) # -------------------------- # # Step 2: update edition # -------------------------- # - diffusion_transformer_path = "models/Diffusion_Transformer/EasyAnimateV5-12b-zh-InP" + diffusion_transformer_path = "models/Diffusion_Transformer/EasyAnimateV5.1-12b-zh-InP" outputs = post_diffusion_transformer(diffusion_transformer_path) print('Output update edition: ', outputs) diff --git a/easyanimate/data/dataset_image_video.py b/easyanimate/data/dataset_image_video.py index 3065838ee9391b54bc799de902ccb92bd02e9133..829cdad3ef457ff145bac48d7e82b773da6d5ec1 100644 --- a/easyanimate/data/dataset_image_video.py +++ b/easyanimate/data/dataset_image_video.py @@ -12,9 +12,12 @@ import albumentations import cv2 import numpy as np import torch +import torch.nn.functional as F import torchvision.transforms as transforms from decord import VideoReader +from einops import rearrange from func_timeout import FunctionTimedOut, func_timeout +from packaging import version as pver from PIL import Image from torch.utils.data import BatchSampler, Sampler from torch.utils.data.dataset import Dataset @@ -100,6 +103,152 @@ def get_random_mask(shape): else: raise ValueError(f"The mask_index {mask_index} is not define") return mask + +class Camera(object): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + def __init__(self, entry): + fx, fy, cx, cy = entry[1:5] + self.fx = fx + self.fy = fy + self.cx = cx + self.cy = cy + w2c_mat = np.array(entry[7:]).reshape(3, 4) + w2c_mat_4x4 = np.eye(4) + w2c_mat_4x4[:3, :] = w2c_mat + self.w2c_mat = w2c_mat_4x4 + self.c2w_mat = np.linalg.inv(w2c_mat_4x4) + +def custom_meshgrid(*args): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid + if pver.parse(torch.__version__) < pver.parse('1.10'): + return torch.meshgrid(*args) + else: + return torch.meshgrid(*args, indexing='ij') + +def get_relative_pose(cam_params): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] + cam_to_origin = 0 + target_cam_c2w = np.array([ + [1, 0, 0, 0], + [0, 1, 0, -cam_to_origin], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + abs2rel = target_cam_c2w @ abs_w2cs[0] + ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] + ret_poses = np.array(ret_poses, dtype=np.float32) + return ret_poses + +def ray_condition(K, c2w, H, W, device): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + # c2w: B, V, 4, 4 + # K: B, V, 4 + + B = K.shape[0] + + j, i = custom_meshgrid( + torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), + torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), + ) + i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + + fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 + + zs = torch.ones_like(i) # [B, HxW] + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + zs = zs.expand_as(ys) + + directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 + directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 + + rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW + rays_o = c2w[..., :3, 3] # B, V, 3 + rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW + # c2w @ dirctions + rays_dxo = torch.cross(rays_o, rays_d) + plucker = torch.cat([rays_dxo, rays_d], dim=-1) + plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 + # plucker = plucker.permute(0, 1, 4, 2, 3) + return plucker + +def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False): + """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + with open(pose_file_path, 'r') as f: + poses = f.readlines() + + poses = [pose.strip().split(' ') for pose in poses[1:]] + cam_params = [[float(x) for x in pose] for pose in poses] + if return_poses: + return cam_params + else: + cam_params = [Camera(cam_param) for cam_param in cam_params] + + sample_wh_ratio = width / height + pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed + + if pose_wh_ratio > sample_wh_ratio: + resized_ori_w = height * pose_wh_ratio + for cam_param in cam_params: + cam_param.fx = resized_ori_w * cam_param.fx / width + else: + resized_ori_h = width / pose_wh_ratio + for cam_param in cam_params: + cam_param.fy = resized_ori_h * cam_param.fy / height + + intrinsic = np.asarray([[cam_param.fx * width, + cam_param.fy * height, + cam_param.cx * width, + cam_param.cy * height] + for cam_param in cam_params], dtype=np.float32) + + K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] + c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere + c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] + plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W + plucker_embedding = plucker_embedding[None] + plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] + return plucker_embedding + +def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'): + """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + cam_params = [Camera(cam_param) for cam_param in cam_params] + + sample_wh_ratio = width / height + pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed + + if pose_wh_ratio > sample_wh_ratio: + resized_ori_w = height * pose_wh_ratio + for cam_param in cam_params: + cam_param.fx = resized_ori_w * cam_param.fx / width + else: + resized_ori_h = width / pose_wh_ratio + for cam_param in cam_params: + cam_param.fy = resized_ori_h * cam_param.fy / height + + intrinsic = np.asarray([[cam_param.fx * width, + cam_param.fy * height, + cam_param.cx * width, + cam_param.cy * height] + for cam_param in cam_params], dtype=np.float32) + + K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] + c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere + c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] + plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W + plucker_embedding = plucker_embedding[None] + plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] + return plucker_embedding class ImageVideoSampler(BatchSampler): """A sampler wrapper for grouping images with similar aspect ratio into a same batch. @@ -184,7 +333,7 @@ class ImageVideoDataset(Dataset): video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, image_sample_size=512, video_repeat=0, - text_drop_ratio=-1, + text_drop_ratio=0.1, enable_bucket=False, video_length_drop_start=0.1, video_length_drop_end=0.9, @@ -355,7 +504,6 @@ class ImageVideoDataset(Dataset): return sample - class ImageVideoControlDataset(Dataset): def __init__( self, @@ -363,11 +511,12 @@ class ImageVideoControlDataset(Dataset): video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, image_sample_size=512, video_repeat=0, - text_drop_ratio=-1, + text_drop_ratio=0.1, enable_bucket=False, video_length_drop_start=0.1, video_length_drop_end=0.9, enable_inpaint=False, + enable_camera_info=False, ): # Loading annotations from files print(f"loading annotations from {ann_path} ...") @@ -397,6 +546,7 @@ class ImageVideoControlDataset(Dataset): self.enable_bucket = enable_bucket self.text_drop_ratio = text_drop_ratio self.enable_inpaint = enable_inpaint + self.enable_camera_info = enable_camera_info self.video_length_drop_start = video_length_drop_start self.video_length_drop_end = video_length_drop_end @@ -412,6 +562,13 @@ class ImageVideoControlDataset(Dataset): transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ] ) + if self.enable_camera_info: + self.video_transforms_camera = transforms.Compose( + [ + transforms.Resize(min(self.video_sample_size)), + transforms.CenterCrop(self.video_sample_size) + ] + ) # Image params self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) @@ -484,33 +641,59 @@ class ImageVideoControlDataset(Dataset): else: control_video_id = os.path.join(self.data_root, control_video_id) - with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader: - try: - sample_args = (control_video_reader, batch_index) - control_pixel_values = func_timeout( - VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args - ) - resized_frames = [] - for i in range(len(control_pixel_values)): - frame = control_pixel_values[i] - resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) - resized_frames.append(resized_frame) - control_pixel_values = np.array(resized_frames) - except FunctionTimedOut: - raise ValueError(f"Read {idx} timeout.") - except Exception as e: - raise ValueError(f"Failed to extract frames from video. Error is {e}.") - - if not self.enable_bucket: - control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous() - control_pixel_values = control_pixel_values / 255. - del control_video_reader + if self.enable_camera_info: + if control_video_id.lower().endswith('.txt'): + if not self.enable_bucket: + control_pixel_values = torch.zeros_like(pixel_values) + + control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0]) + control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous() + control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True) + control_camera_values = self.video_transforms_camera(control_camera_values) + else: + control_pixel_values = np.zeros_like(pixel_values) + + control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True) + control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0) + control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0] + control_camera_values = np.array([control_camera_values[index] for index in batch_index]) else: - control_pixel_values = control_pixel_values - - if not self.enable_bucket: - control_pixel_values = self.video_transforms(control_pixel_values) - return pixel_values, control_pixel_values, text, "video" + if not self.enable_bucket: + control_pixel_values = torch.zeros_like(pixel_values) + control_camera_values = None + else: + control_pixel_values = np.zeros_like(pixel_values) + control_camera_values = None + else: + with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader: + try: + sample_args = (control_video_reader, batch_index) + control_pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(control_pixel_values)): + frame = control_pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + control_pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous() + control_pixel_values = control_pixel_values / 255. + del control_video_reader + else: + control_pixel_values = control_pixel_values + + if not self.enable_bucket: + control_pixel_values = self.video_transforms(control_pixel_values) + control_camera_values = None + + return pixel_values, control_pixel_values, control_camera_values, text, "video" else: image_path, text = data_info['file_path'], data_info['text'] if self.data_root is not None: @@ -536,7 +719,8 @@ class ImageVideoControlDataset(Dataset): control_image = self.image_transforms(control_image).unsqueeze(0) else: control_image = np.expand_dims(np.array(control_image), 0) - return image, control_image, text, 'image' + + return image, control_image, None, text, 'image' def __len__(self): return self.length @@ -552,13 +736,17 @@ class ImageVideoControlDataset(Dataset): if data_type_local != data_type: raise ValueError("data_type_local != data_type") - pixel_values, control_pixel_values, name, data_type = self.get_batch(idx) + pixel_values, control_pixel_values, control_camera_values, name, data_type = self.get_batch(idx) + sample["pixel_values"] = pixel_values sample["control_pixel_values"] = control_pixel_values sample["text"] = name sample["data_type"] = data_type sample["idx"] = idx - + + if self.enable_camera_info: + sample["control_camera_values"] = control_camera_values + if len(sample) > 0: break except Exception as e: diff --git a/easyanimate/models/__init__.py b/easyanimate/models/__init__.py index 5f2988ded0700ca058add72405e60ef15062eb1c..0b9b722df9f8c90cbb957afc482c1e1c4847bf0d 100644 --- a/easyanimate/models/__init__.py +++ b/easyanimate/models/__init__.py @@ -1,8 +1,7 @@ -from .autoencoder_magvit import (AutoencoderKLCogVideoX, AutoencoderKLMagvit, AutoencoderKL) +from .autoencoder_magvit import (AutoencoderKL, AutoencoderKLCogVideoX, + AutoencoderKLMagvit) from .transformer3d import (EasyAnimateTransformer3DModel, - HunyuanTransformer3DModel, - Transformer3DModel) - + HunyuanTransformer3DModel, Transformer3DModel) name_to_transformer3d = { "Transformer3DModel": Transformer3DModel, diff --git a/easyanimate/models/attention.py b/easyanimate/models/attention.py index 9e62da307dd34d1d4cb084cc226272dc46596d09..43b8ceae3495cff19d5a464038151819bba92455 100644 --- a/easyanimate/models/attention.py +++ b/easyanimate/models/attention.py @@ -29,7 +29,7 @@ from diffusers.models.embeddings import (SinusoidalPositionalEmbedding, get_3d_sincos_pos_embed) from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.normalization import (AdaLayerNorm, AdaLayerNormZero, +from diffusers.models.normalization import (AdaLayerNorm, AdaLayerNormZero, CogVideoXLayerNormZero) from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging from diffusers.utils.import_utils import is_xformers_available @@ -38,12 +38,11 @@ from einops import rearrange, repeat from torch import nn from .motion_module import PositionalEncoding, get_motion_module -from .norm import AdaLayerNormShift, FP32LayerNorm, EasyAnimateLayerNormZero +from .norm import AdaLayerNormShift, EasyAnimateLayerNormZero, FP32LayerNorm from .processor import (EasyAnimateAttnProcessor2_0, + EasyAnimateSWAttnProcessor2_0, LazyKVCompressionProcessor2_0) - - if is_xformers_available(): import xformers import xformers.ops @@ -1042,7 +1041,9 @@ class EasyAnimateDiTBlock(nn.Module): ff_bias: bool = True, qk_norm: bool = True, after_norm: bool = False, - norm_type: str="fp32_layer_norm" + norm_type: str="fp32_layer_norm", + is_mmdit_block: bool = True, + is_swa: bool = False, ): super().__init__() @@ -1051,6 +1052,7 @@ class EasyAnimateDiTBlock(nn.Module): time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True ) + self.is_swa = is_swa self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, @@ -1058,17 +1060,20 @@ class EasyAnimateDiTBlock(nn.Module): qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=True, - processor=EasyAnimateAttnProcessor2_0(), - ) - self.attn2 = Attention( - query_dim=dim, - dim_head=attention_head_dim, - heads=num_attention_heads, - qk_norm="layer_norm" if qk_norm else None, - eps=1e-6, - bias=True, - processor=EasyAnimateAttnProcessor2_0(), + processor=EasyAnimateAttnProcessor2_0() if not is_swa else EasyAnimateSWAttnProcessor2_0(), ) + if is_mmdit_block: + self.attn2 = Attention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + qk_norm="layer_norm" if qk_norm else None, + eps=1e-6, + bias=True, + processor=EasyAnimateAttnProcessor2_0() if not is_swa else EasyAnimateSWAttnProcessor2_0(), + ) + else: + self.attn2 = None # FFN Part self.norm2 = EasyAnimateLayerNormZero( @@ -1082,14 +1087,18 @@ class EasyAnimateDiTBlock(nn.Module): inner_dim=ff_inner_dim, bias=ff_bias, ) - self.txt_ff = FeedForward( - dim, - dropout=dropout, - activation_fn=activation_fn, - final_dropout=final_dropout, - inner_dim=ff_inner_dim, - bias=ff_bias, - ) + if is_mmdit_block: + self.txt_ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + else: + self.txt_ff = None + if after_norm: self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) else: @@ -1101,6 +1110,9 @@ class EasyAnimateDiTBlock(nn.Module): encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + num_frames = None, + height = None, + width = None ) -> torch.Tensor: # Norm norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( @@ -1108,12 +1120,23 @@ class EasyAnimateDiTBlock(nn.Module): ) # Attn - attn_hidden_states, attn_encoder_hidden_states = self.attn1( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, - attn2=self.attn2, - ) + if self.is_swa: + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + attn2=self.attn2, + num_frames=num_frames, + height=height, + width=width, + ) + else: + attn_hidden_states, attn_encoder_hidden_states = self.attn1( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + attn2=self.attn2 + ) hidden_states = hidden_states + gate_msa * attn_hidden_states encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states @@ -1125,10 +1148,16 @@ class EasyAnimateDiTBlock(nn.Module): # FFN if self.norm3 is not None: norm_hidden_states = self.norm3(self.ff(norm_hidden_states)) - norm_encoder_hidden_states = self.norm3(self.txt_ff(norm_encoder_hidden_states)) + if self.txt_ff is not None: + norm_encoder_hidden_states = self.norm3(self.txt_ff(norm_encoder_hidden_states)) + else: + norm_encoder_hidden_states = self.norm3(self.ff(norm_encoder_hidden_states)) else: norm_hidden_states = self.ff(norm_hidden_states) - norm_encoder_hidden_states = self.txt_ff(norm_encoder_hidden_states) + if self.txt_ff is not None: + norm_encoder_hidden_states = self.txt_ff(norm_encoder_hidden_states) + else: + norm_encoder_hidden_states = self.ff(norm_encoder_hidden_states) hidden_states = hidden_states + gate_ff * norm_hidden_states encoder_hidden_states = encoder_hidden_states + enc_gate_ff * norm_encoder_hidden_states return hidden_states, encoder_hidden_states \ No newline at end of file diff --git a/easyanimate/models/autoencoder_magvit.py b/easyanimate/models/autoencoder_magvit.py index 62ee173d01feb4f98a0b1abd26b6d84ac18edbde..1dc1c316a8f664b03dbefb6bb495864a8fd6b969 100644 --- a/easyanimate/models/autoencoder_magvit.py +++ b/easyanimate/models/autoencoder_magvit.py @@ -44,6 +44,7 @@ from ..vae.ldm.models.cogvideox_enc_dec import (CogVideoXCausalConv3d, CogVideoXDecoder3D, CogVideoXEncoder3D, CogVideoXSafeConv3d) +from ..vae.ldm.models.omnigen_enc_dec import CausalConv3d from ..vae.ldm.models.omnigen_enc_dec import Decoder as omnigen_Mag_Decoder from ..vae.ldm.models.omnigen_enc_dec import Encoder as omnigen_Mag_Encoder @@ -96,6 +97,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): out_channels: int = 3, ch = 128, ch_mult = [ 1,2,4,4 ], + block_out_channels = [128, 256, 512, 512], use_gc_blocks = None, down_block_types: tuple = None, up_block_types: tuple = None, @@ -109,6 +111,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): latent_channels: int = 4, norm_num_groups: int = 32, scaling_factor: float = 0.1825, + force_upcast: float = True, slice_mag_vae=True, slice_compression_vae=False, cache_compression_vae=False, @@ -130,8 +133,9 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, - ch = ch, - ch_mult = ch_mult, + ch=ch, + ch_mult=ch_mult, + block_out_channels=block_out_channels, use_gc_blocks=use_gc_blocks, mid_block_type=mid_block_type, mid_block_use_attention=mid_block_use_attention, @@ -154,8 +158,9 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, - ch = ch, - ch_mult = ch_mult, + ch=ch, + ch_mult=ch_mult, + block_out_channels=block_out_channels, use_gc_blocks=use_gc_blocks, mid_block_type=mid_block_type, mid_block_use_attention=mid_block_use_attention, @@ -196,81 +201,10 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): if isinstance(module, (omnigen_Mag_Encoder, omnigen_Mag_Decoder)): module.gradient_checkpointing = value - @property - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors - def attn_processors(self) -> Dict[str, AttentionProcessor]: - r""" - Returns: - `dict` of attention processors: A dictionary containing all attention processors used in the model with - indexed by its weight name. - """ - # set recursively - processors = {} - - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): - if hasattr(module, "get_processor"): - processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) - - for sub_name, child in module.named_children(): - fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) - - return processors - - for name, module in self.named_children(): - fn_recursive_add_processors(name, module, processors) - - return processors - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): - r""" - Sets the attention processor to use to compute attention. - - Parameters: - processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): - The instantiated processor class or a dictionary of processor classes that will be set as the processor - for **all** `Attention` layers. - - If `processor` is a dict, the key needs to define the path to the corresponding cross attention - processor. This is strongly recommended when setting trainable attention processors. - - """ - count = len(self.attn_processors.keys()) - - if isinstance(processor, dict) and len(processor) != count: - raise ValueError( - f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" - f" number of attention layers: {count}. Please make sure to pass {count} processor classes." - ) - - def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): - if hasattr(module, "set_processor"): - if not isinstance(processor, dict): - module.set_processor(processor) - else: - module.set_processor(processor.pop(f"{name}.processor")) - - for sub_name, child in module.named_children(): - fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) - - for name, module in self.named_children(): - fn_recursive_attn_processor(name, module, processor) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor - def set_default_attn_processor(self): - """ - Disables custom attention processors and sets the default attention implementation. - """ - if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnAddedKVProcessor() - elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): - processor = AttnProcessor() - else: - raise ValueError( - f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" - ) - - self.set_attn_processor(processor) + def _clear_conv_cache(self): + for name, module in self.named_modules(): + if isinstance(module, CausalConv3d): + module._clear_conv_cache() @apply_forward_hook def encode( @@ -308,6 +242,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) + self._clear_conv_cache() if not return_dict: return (posterior,) @@ -355,6 +290,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): else: decoded = self._decode(z).sample + self._clear_conv_cache() if not return_dict: return (decoded,) @@ -519,44 +455,6 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): return DecoderOutput(sample=dec) - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections - def fuse_qkv_projections(self): - """ - Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, - key, value) are fused. For cross-attention modules, key and value projection matrices are fused. - - - - This API is 🧪 experimental. - - - """ - self.original_attn_processors = None - - for _, attn_processor in self.attn_processors.items(): - if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") - - self.original_attn_processors = self.attn_processors - - for module in self.modules(): - if isinstance(module, Attention): - module.fuse_projections(fuse=True) - - # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections - def unfuse_qkv_projections(self): - """Disables the fused QKV projection if enabled. - - - - This API is 🧪 experimental. - - - - """ - if self.original_attn_processors is not None: - self.set_attn_processor(self.original_attn_processors) - @classmethod def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs): import json diff --git a/easyanimate/models/embeddings.py b/easyanimate/models/embeddings.py index dddef56856bc11dd7ea4803b264224f9585d473f..8e7531ebcc435b547eebea63769a6e9ebdd91477 100644 --- a/easyanimate/models/embeddings.py +++ b/easyanimate/models/embeddings.py @@ -4,8 +4,9 @@ from typing import Optional import numpy as np import torch import torch.nn.functional as F -from diffusers.models.embeddings import (PixArtAlphaTextProjection, get_timestep_embedding, - TimestepEmbedding, Timesteps) +from diffusers.models.embeddings import (PixArtAlphaTextProjection, + TimestepEmbedding, Timesteps, + get_timestep_embedding) from einops import rearrange from torch import nn diff --git a/easyanimate/models/norm.py b/easyanimate/models/norm.py index 9bb6dc0a149286a7170220fead19619cbc4cac6f..cb9a814618c7465dcd3b9b660d6d0f8d28aae3e3 100644 --- a/easyanimate/models/norm.py +++ b/easyanimate/models/norm.py @@ -25,6 +25,22 @@ class FP32LayerNorm(nn.LayerNorm): inputs.float(), self.normalized_shape, None, None, self.eps ).to(origin_dtype) +class EasyAnimateRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): """ For PixArt-Alpha. diff --git a/easyanimate/models/processor.py b/easyanimate/models/processor.py index 3f9224085976b068aa96f194d02363024532c097..0cea72eb0a1209ee93b5c5f32147d9fe0cbe6075 100644 --- a/easyanimate/models/processor.py +++ b/easyanimate/models/processor.py @@ -310,3 +310,149 @@ class EasyAnimateAttnProcessor2_0: hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn2.to_out[1](encoder_hidden_states) return hidden_states, encoder_hidden_states + +try: + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import pad_input, unpad_input +except: + print("Flash Attention is not installed. Please install with `pip install flash-attn`, if you want to use SWA.") + +class EasyAnimateSWAttnProcessor2_0: + def __init__(self, window_size=1024): + self.window_size = window_size + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + num_frames: int = None, + height: int = None, + width: int = None, + attn2: Attention = None, + ) -> torch.Tensor: + text_seq_length = encoder_hidden_states.size(1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attn2 is None: + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + if attn2 is not None: + query_txt = attn2.to_q(encoder_hidden_states) + key_txt = attn2.to_k(encoder_hidden_states) + value_txt = attn2.to_v(encoder_hidden_states) + + inner_dim = key_txt.shape[-1] + head_dim = inner_dim // attn.heads + + query_txt = query_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key_txt = key_txt.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value_txt = value_txt.view(batch_size, -1, attn.heads, head_dim) + + if attn2.norm_q is not None: + query_txt = attn2.norm_q(query_txt) + if attn2.norm_k is not None: + key_txt = attn2.norm_k(key_txt) + + query = torch.cat([query_txt, query], dim=2) + key = torch.cat([key_txt, key], dim=2) + value = torch.cat([value_txt, value], dim=1) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) + if not attn.is_cross_attention: + key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) + + query = query.transpose(1, 2).to(value) + key = key.transpose(1, 2).to(value) + interval = max((query.size(1) - text_seq_length) // (self.window_size - text_seq_length), 1) + + cross_key = torch.cat([key[:, :text_seq_length], key[:, text_seq_length::interval]], dim=1) + cross_val = torch.cat([value[:, :text_seq_length], value[:, text_seq_length::interval]], dim=1) + cross_hidden_states = flash_attn_func(query, cross_key, cross_val, dropout_p=0.0, causal=False) + + # Split and rearrange to six directions + querys = torch.tensor_split(query[:, text_seq_length:], 6, 2) + keys = torch.tensor_split(key[:, text_seq_length:], 6, 2) + values = torch.tensor_split(value[:, text_seq_length:], 6, 2) + + new_querys = [querys[0]] + new_keys = [keys[0]] + new_values = [values[0]] + for index, mode in enumerate( + [ + "bs (f h w) hn hd -> bs (f w h) hn hd", + "bs (f h w) hn hd -> bs (h f w) hn hd", + "bs (f h w) hn hd -> bs (h w f) hn hd", + "bs (f h w) hn hd -> bs (w f h) hn hd", + "bs (f h w) hn hd -> bs (w h f) hn hd" + ] + ): + new_querys.append(rearrange(querys[index + 1], mode, f=num_frames, h=height, w=width)) + new_keys.append(rearrange(keys[index + 1], mode, f=num_frames, h=height, w=width)) + new_values.append(rearrange(values[index + 1], mode, f=num_frames, h=height, w=width)) + query = torch.cat(new_querys, dim=2) + key = torch.cat(new_keys, dim=2) + value = torch.cat(new_values, dim=2) + + # apply attention + hidden_states = flash_attn_func(query, key, value, dropout_p=0.0, causal=False, window_size=(self.window_size, self.window_size)) + + hidden_states = torch.tensor_split(hidden_states, 6, 2) + new_hidden_states = [hidden_states[0]] + for index, mode in enumerate( + [ + "bs (f w h) hn hd -> bs (f h w) hn hd", + "bs (h f w) hn hd -> bs (f h w) hn hd", + "bs (h w f) hn hd -> bs (f h w) hn hd", + "bs (w f h) hn hd -> bs (f h w) hn hd", + "bs (w h f) hn hd -> bs (f h w) hn hd" + ] + ): + new_hidden_states.append(rearrange(hidden_states[index + 1], mode, f=num_frames, h=height, w=width)) + hidden_states = torch.cat([cross_hidden_states[:, :text_seq_length], torch.cat(new_hidden_states, dim=2)], dim=1) + cross_hidden_states + + hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim) + + if attn2 is None: + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + else: + encoder_hidden_states, hidden_states = hidden_states.split( + [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1 + ) + # linear proj + hidden_states = attn.to_out[0](hidden_states) + encoder_hidden_states = attn2.to_out[0](encoder_hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn2.to_out[1](encoder_hidden_states) + return hidden_states, encoder_hidden_states \ No newline at end of file diff --git a/easyanimate/models/transformer3d.py b/easyanimate/models/transformer3d.py index eff2760667b2e8b672a25d6891bdf20428e342e5..2fa518e504b33b99dc88fb5453dc0aa8e36bfa96 100644 --- a/easyanimate/models/transformer3d.py +++ b/easyanimate/models/transformer3d.py @@ -39,8 +39,9 @@ from torch import nn from .attention import (EasyAnimateDiTBlock, HunyuanDiTBlock, SelfAttentionTemporalTransformerBlock, TemporalTransformerBlock, zero_module) -from .embeddings import HunyuanCombinedTimestepTextSizeStyleEmbedding, TimePositionalEncoding -from .norm import AdaLayerNormSingle +from .embeddings import (HunyuanCombinedTimestepTextSizeStyleEmbedding, + TimePositionalEncoding) +from .norm import AdaLayerNormSingle, EasyAnimateRMSNorm from .patch import (CasualPatchEmbed3D, PatchEmbed3D, PatchEmbedF3D, TemporalUpsampler3D, UnPatch1D) from .resampler import Resampler @@ -142,6 +143,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin): norm_eps: float = 1e-5, attention_type: str = "default", caption_channels: int = None, + n_query=8, # block type basic_block_type: str = "motionmodule", # enable_uvit @@ -168,6 +170,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin): after_norm = False, resize_inpaint_mask_directly: bool = False, enable_clip_in_inpaint: bool = True, + position_of_clip_embedding: str = "head", + enable_zero_in_inpaint: bool = False, enable_text_attention_mask: bool = True, add_noise_in_inpaint_model: bool = False, ): @@ -192,6 +196,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin): self.time_patch_size = self.patch_size if time_patch_size is None else time_patch_size interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 interpolation_scale = max(interpolation_scale, 1) + self.n_query = n_query if self.casual_3d: self.pos_embed = CasualPatchEmbed3D( @@ -397,16 +402,22 @@ class Transformer3DModel(ModelMixin, ConfigMixin): def forward( self, hidden_states: torch.Tensor, + timestep: Optional[torch.LongTensor] = None, + timestep_cond = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + text_embedding_mask: Optional[torch.Tensor] = None, + encoder_hidden_states_t5: Optional[torch.Tensor] = None, + text_embedding_mask_t5: Optional[torch.Tensor] = None, + image_meta_size = None, + style = None, + image_rotary_emb: Optional[torch.Tensor] = None, inpaint_latents: torch.Tensor = None, control_latents: torch.Tensor = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - clip_encoder_hidden_states: Optional[torch.Tensor] = None, - timestep: Optional[torch.LongTensor] = None, added_cond_kwargs: Dict[str, torch.Tensor] = None, class_labels: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, attention_mask: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, + clip_encoder_hidden_states: Optional[torch.Tensor] = None, clip_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ): @@ -432,7 +443,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin): An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. - encoder_attention_mask ( `torch.Tensor`, *optional*): + text_embedding_mask ( `torch.Tensor`, *optional*): Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: * Mask `(batch, sequence_length)` True = keep, False = discard. @@ -466,11 +477,12 @@ class Transformer3DModel(ModelMixin, ConfigMixin): attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) + text_embedding_mask = text_embedding_mask.squeeze(1) if clip_attention_mask is not None: - encoder_attention_mask = torch.cat([encoder_attention_mask, clip_attention_mask], dim=1) + text_embedding_mask = torch.cat([text_embedding_mask, clip_attention_mask], dim=1) # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: - encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0 + if text_embedding_mask is not None and text_embedding_mask.ndim == 2: + encoder_attention_mask = (1 - text_embedding_mask.to(encoder_hidden_states.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) if inpaint_latents is not None: @@ -637,7 +649,10 @@ class Transformer3DModel(ModelMixin, ConfigMixin): return Transformer3DModelOutput(sample=output) @classmethod - def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}): + def from_pretrained_2d( + cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}, + low_cpu_mem_usage=False, torch_dtype=torch.bfloat16 + ): if subfolder is not None: pretrained_model_path = os.path.join(pretrained_model_path, subfolder) print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...") @@ -649,16 +664,73 @@ class Transformer3DModel(ModelMixin, ConfigMixin): config = json.load(f) from diffusers.utils import WEIGHTS_NAME - model = cls.from_config(config, **transformer_additional_kwargs) model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) model_file_safetensors = model_file.replace(".bin", ".safetensors") - if os.path.exists(model_file_safetensors): + + if low_cpu_mem_usage: + try: + import re + + from diffusers.models.modeling_utils import \ + load_model_dict_into_meta + from diffusers.utils import is_accelerate_available + if is_accelerate_available(): + import accelerate + + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **transformer_additional_kwargs) + + param_device = "cpu" + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + model._convert_deprecated_attention_blocks(state_dict) + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" + " those weights or else make sure your checkpoint file is correct." + ) + + unexpected_keys = load_model_dict_into_meta( + model, + state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + print( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + return model + except Exception as e: + print( + f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead." + ) + + model = cls.from_config(config, **transformer_additional_kwargs) + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): from safetensors.torch import load_file, safe_open state_dict = load_file(model_file_safetensors) else: - if not os.path.isfile(model_file): - raise RuntimeError(f"{model_file} does not exist") - state_dict = torch.load(model_file, map_location="cpu") + from safetensors.torch import load_file, safe_open + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + for model_file_safetensors in model_files_safetensors: + _state_dict = load_file(model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size(): new_shape = model.state_dict()['pos_embed.proj.weight'].size() @@ -692,6 +764,7 @@ class Transformer3DModel(ModelMixin, ConfigMixin): params = [p.numel() if "attn_temporal." in n else 0 for n, p in model.named_parameters()] print(f"### Attn temporal Parameters: {sum(params) / 1e6} M") + model = model.to(torch_dtype) return model class HunyuanTransformer3DModel(ModelMixin, ConfigMixin): @@ -769,6 +842,7 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin): after_norm = False, resize_inpaint_mask_directly: bool = False, enable_clip_in_inpaint: bool = True, + position_of_clip_embedding: str = "full", enable_text_attention_mask: bool = True, add_noise_in_inpaint_model: bool = False, ): @@ -909,6 +983,7 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin): control_latents: torch.Tensor = None, clip_encoder_hidden_states: Optional[torch.Tensor]=None, clip_attention_mask: Optional[torch.Tensor]=None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, return_dict=True, ): """ @@ -1085,7 +1160,10 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin): return Transformer2DModelOutput(sample=output) @classmethod - def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}): + def from_pretrained_2d( + cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}, + low_cpu_mem_usage=False, torch_dtype=torch.bfloat16 + ): if subfolder is not None: pretrained_model_path = os.path.join(pretrained_model_path, subfolder) print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...") @@ -1097,16 +1175,73 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin): config = json.load(f) from diffusers.utils import WEIGHTS_NAME - model = cls.from_config(config, **transformer_additional_kwargs) model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) model_file_safetensors = model_file.replace(".bin", ".safetensors") - if os.path.exists(model_file_safetensors): + + if low_cpu_mem_usage: + try: + import re + + from diffusers.models.modeling_utils import \ + load_model_dict_into_meta + from diffusers.utils import is_accelerate_available + if is_accelerate_available(): + import accelerate + + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **transformer_additional_kwargs) + + param_device = "cpu" + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + model._convert_deprecated_attention_blocks(state_dict) + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" + " those weights or else make sure your checkpoint file is correct." + ) + + unexpected_keys = load_model_dict_into_meta( + model, + state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + print( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + return model + except Exception as e: + print( + f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead." + ) + + model = cls.from_config(config, **transformer_additional_kwargs) + if os.path.exists(model_file): + state_dict = torch.load(model_file, map_location="cpu") + elif os.path.exists(model_file_safetensors): from safetensors.torch import load_file, safe_open state_dict = load_file(model_file_safetensors) else: - if not os.path.isfile(model_file): - raise RuntimeError(f"{model_file} does not exist") - state_dict = torch.load(model_file, map_location="cpu") + from safetensors.torch import load_file, safe_open + model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors")) + state_dict = {} + for model_file_safetensors in model_files_safetensors: + _state_dict = load_file(model_file_safetensors) + for key in _state_dict: + state_dict[key] = _state_dict[key] if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size(): new_shape = model.state_dict()['pos_embed.proj.weight'].size() @@ -1156,6 +1291,7 @@ class HunyuanTransformer3DModel(ModelMixin, ConfigMixin): params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] print(f"### attn1 Parameters: {sum(params) / 1e6} M") + model = model.to(torch_dtype) return model class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): @@ -1178,8 +1314,11 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): timestep_activation_fn: str = "silu", freq_shift: int = 0, num_layers: int = 30, + mmdit_layers: int = 10000, + swa_layers: list = None, dropout: float = 0.0, time_embed_dim: int = 512, + add_norm_text_encoder: bool = False, text_embed_dim: int = 4096, text_embed_dim_t5: int = 4096, norm_eps: float = 1e-5, @@ -1191,8 +1330,10 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): after_norm = False, resize_inpaint_mask_directly: bool = False, enable_clip_in_inpaint: bool = True, + position_of_clip_embedding: str = "full", enable_text_attention_mask: bool = True, add_noise_in_inpaint_model: bool = False, + add_ref_latent_in_control_model: bool = False, ): super().__init__() self.num_heads = num_attention_heads @@ -1211,8 +1352,20 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): self.proj = nn.Conv2d( in_channels, self.inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True ) - self.text_proj = nn.Linear(text_embed_dim, self.inner_dim) - self.text_proj_t5 = nn.Linear(text_embed_dim_t5, self.inner_dim) + if not add_norm_text_encoder: + self.text_proj = nn.Linear(text_embed_dim, self.inner_dim) + if text_embed_dim_t5 is not None: + self.text_proj_t5 = nn.Linear(text_embed_dim_t5, self.inner_dim) + else: + self.text_proj = nn.Sequential( + EasyAnimateRMSNorm(text_embed_dim), + nn.Linear(text_embed_dim, self.inner_dim) + ) + if text_embed_dim_t5 is not None: + self.text_proj_t5 = nn.Sequential( + EasyAnimateRMSNorm(text_embed_dim), + nn.Linear(text_embed_dim_t5, self.inner_dim) + ) if ref_channels is not None: self.ref_proj = nn.Conv2d( @@ -1224,23 +1377,45 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): if clip_channels is not None: self.clip_proj = nn.Linear(clip_channels, self.inner_dim) - - self.transformer_blocks = nn.ModuleList( - [ - EasyAnimateDiTBlock( - dim=self.inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - time_embed_dim=time_embed_dim, - dropout=dropout, - activation_fn=activation_fn, - norm_elementwise_affine=norm_elementwise_affine, - norm_eps=norm_eps, - after_norm=after_norm - ) - for _ in range(num_layers) - ] - ) + + self.swa_layers = swa_layers + if swa_layers is not None: + self.transformer_blocks = nn.ModuleList( + [ + EasyAnimateDiTBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + dropout=dropout, + activation_fn=activation_fn, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + after_norm=after_norm, + is_mmdit_block=True if index < mmdit_layers else False, + is_swa=True if index in swa_layers else False, + ) + for index in range(num_layers) + ] + ) + else: + self.transformer_blocks = nn.ModuleList( + [ + EasyAnimateDiTBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + time_embed_dim=time_embed_dim, + dropout=dropout, + activation_fn=activation_fn, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + after_norm=after_norm, + is_mmdit_block=True if _ < mmdit_layers else False, + ) + for _ in range(num_layers) + ] + ) self.norm_final = nn.LayerNorm(self.inner_dim, norm_eps, norm_elementwise_affine) # 5. Output blocks @@ -1275,6 +1450,7 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): ref_latents: Optional[torch.Tensor] = None, clip_encoder_hidden_states: Optional[torch.Tensor] = None, clip_attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, return_dict=True, ): batch_size, channels, video_length, height, width = hidden_states.size() @@ -1343,6 +1519,9 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): encoder_hidden_states, temb, image_rotary_emb, + video_length, + height // self.patch_size, + width // self.patch_size, **ckpt_kwargs, ) else: @@ -1351,6 +1530,9 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, + num_frames=video_length, + height=height // self.patch_size, + width=width // self.patch_size ) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) @@ -1371,7 +1553,10 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): return Transformer2DModelOutput(sample=output) @classmethod - def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}): + def from_pretrained_2d( + cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={}, + low_cpu_mem_usage=False, torch_dtype=torch.bfloat16 + ): if subfolder is not None: pretrained_model_path = os.path.join(pretrained_model_path, subfolder) print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...") @@ -1383,9 +1568,60 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): config = json.load(f) from diffusers.utils import WEIGHTS_NAME - model = cls.from_config(config, **transformer_additional_kwargs) model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) model_file_safetensors = model_file.replace(".bin", ".safetensors") + + if low_cpu_mem_usage: + try: + import re + + from diffusers.models.modeling_utils import \ + load_model_dict_into_meta + from diffusers.utils import is_accelerate_available + if is_accelerate_available(): + import accelerate + + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **transformer_additional_kwargs) + + param_device = "cpu" + from safetensors.torch import load_file, safe_open + state_dict = load_file(model_file_safetensors) + model._convert_deprecated_attention_blocks(state_dict) + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) + if len(missing_keys) > 0: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" + " those weights or else make sure your checkpoint file is correct." + ) + + unexpected_keys = load_model_dict_into_meta( + model, + state_dict, + device=param_device, + dtype=torch_dtype, + model_name_or_path=pretrained_model_path, + ) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + print( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + return model + except Exception as e: + print( + f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead." + ) + + model = cls.from_config(config, **transformer_additional_kwargs) if os.path.exists(model_file): state_dict = torch.load(model_file, map_location="cpu") elif os.path.exists(model_file_safetensors): @@ -1433,4 +1669,5 @@ class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin): params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()] print(f"### attn1 Parameters: {sum(params) / 1e6} M") + model = model.to(torch_dtype) return model \ No newline at end of file diff --git a/easyanimate/pipeline/pipeline_easyanimate.py b/easyanimate/pipeline/pipeline_easyanimate.py index 9d0ff0091a436b96f116c073176bfc0d15b5c987..79b84f616474b069003484bdf16cdf41036585eb 100644 --- a/easyanimate/pipeline/pipeline_easyanimate.py +++ b/easyanimate/pipeline/pipeline_easyanimate.py @@ -1,4 +1,4 @@ -# Copyright 2023 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved. +# Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,61 +12,113 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy -import html import inspect -import re -import urllib.parse as ul from dataclasses import dataclass -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch -from diffusers import DiffusionPipeline, ImagePipelineOutput +import torch.nn.functional as F +from diffusers import DiffusionPipeline +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.image_processor import VaeImageProcessor -from diffusers.models import AutoencoderKL -from diffusers.schedulers import DPMSolverMultistepScheduler +from diffusers.models import AutoencoderKL, HunyuanDiT2DModel +from diffusers.models.embeddings import (get_2d_rotary_pos_embed, + get_3d_rotary_pos_embed) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import \ + StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, FlowMatchEulerDiscreteScheduler from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate, - is_bs4_available, is_ftfy_available, logging, + is_bs4_available, is_ftfy_available, + is_torch_xla_available, logging, replace_example_docstring) from diffusers.utils.torch_utils import randn_tensor from einops import rearrange +from PIL import Image from tqdm import tqdm -from transformers import T5EncoderModel, T5Tokenizer +from transformers import (BertModel, BertTokenizer, CLIPImageProcessor, + Qwen2Tokenizer, Qwen2VLForConditionalGeneration, + T5EncoderModel, T5Tokenizer) -from ..models.transformer3d import Transformer3DModel +from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel +from .pipeline_easyanimate_inpaint import EasyAnimatePipelineOutput -logger = logging.get_logger(__name__) # pylint: disable=invalid-name +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm -if is_bs4_available(): - from bs4 import BeautifulSoup + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False -if is_ftfy_available(): - import ftfy +logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: - ```py + ```python >>> import torch >>> from diffusers import EasyAnimatePipeline - - >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too. - >>> pipe = EasyAnimatePipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) - >>> # Enable memory optimizations. - >>> pipe.enable_model_cpu_offload() - - >>> prompt = "A small cactus with a happy face in the Sahara desert." - >>> image = pipe(prompt).images[0] + >>> from diffusers.utils import export_to_video + + >>> # Models: "alibaba-pai/EasyAnimateV5.1-12b-zh" or "alibaba-pai/EasyAnimateV5.1-7b-zh" + >>> pipe = EasyAnimatePipeline.from_pretrained("alibaba-pai/EasyAnimateV5.1-7b-zh", torch_dtype=torch.float16).to("cuda") + >>> prompt = ( + ... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. " + ... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other " + ... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, " + ... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. " + ... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical " + ... "atmosphere of this unique musical performance." + ... ) + >>> video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).sample[0] + >>> export_to_video(video, "output.mp4", fps=8) ``` """ + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, **kwargs, ): """ @@ -77,19 +129,23 @@ def retrieve_timesteps( scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: @@ -100,86 +156,113 @@ def retrieve_timesteps( scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps -@dataclass -class EasyAnimatePipelineOutput(BaseOutput): - videos: Union[torch.Tensor, np.ndarray] class EasyAnimatePipeline(DiffusionPipeline): r""" - Pipeline for text-to-image generation using PixArt-Alpha. + Pipeline for text-to-video generation using EasyAnimate. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. + EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + HunyuanDiT team) in V5. + Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`T5EncoderModel`]): - Frozen text-encoder. PixArt-Alpha uses - [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the - [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. - tokenizer (`T5Tokenizer`): - Tokenizer of class - [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). - transformer ([`Transformer3DModel`]): - A text conditioned `Transformer3DModel` to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLMagvit`]): + Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. + text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]): + EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. + EasyAnimate uses [bilingual CLIP](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) in V5. + tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]): + A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text. + transformer ([`EasyAnimateTransformer3DModel`]): + The EasyAnimate model designed by EasyAnimate Team. + text_encoder_2 (`T5EncoderModel`): + EasyAnimate does not use text_encoder_2 in V5.1. + EasyAnimate uses [mT5](https://huggingface.co/google/mt5-base) embedder in V5. + tokenizer_2 (`T5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents. """ - bad_punct_regex = re.compile( - r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" - ) # noqa - _optional_components = ["tokenizer", "text_encoder"] - model_cpu_offload_seq = "text_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _optional_components = [ + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] def __init__( self, - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - vae: AutoencoderKL, - transformer: Transformer3DModel, - scheduler: DPMSolverMultistepScheduler, + vae: AutoencoderKLMagvit, + text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel], + tokenizer: Union[Qwen2Tokenizer, BertTokenizer], + text_encoder_2: Optional[Union[T5EncoderModel, Qwen2VLForConditionalGeneration]], + tokenizer_2: Optional[Union[T5Tokenizer, Qwen2Tokenizer]], + transformer: EasyAnimateTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() self.register_modules( - tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.enable_autocast_float8_transformer_flag = False - - # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py - def mask_text_embeddings(self, emb, mask): - if emb.shape[0] == 1: - keep_index = mask.sum().item() - return emb[:, :, :keep_index, :], keep_index - else: - masked_feature = emb * mask[:, None, :, None] - return masked_feature, emb.shape[2] - # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def enable_sequential_cpu_offload(self, *args, **kwargs): + super().enable_sequential_cpu_offload(*args, **kwargs) + if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None: + import accelerate + accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True) + self.transformer.clip_projection = self.transformer.clip_projection.to("cuda") + def encode_prompt( self, - prompt: Union[str, List[str]], - do_classifier_free_guidance: bool = True, - negative_prompt: str = "", + prompt: str, + device: torch.device, + dtype: torch.dtype, num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - prompt_attention_mask: Optional[torch.FloatTensor] = None, - negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, - clean_caption: bool = False, - max_sequence_length: int = 120, - **kwargs, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + text_encoder_index: int = 0, + actual_max_sequence_length: int = 256 ): r""" Encodes the prompt into text encoder hidden states. @@ -187,33 +270,46 @@ class EasyAnimatePipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` - instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For - PixArt-Alpha, this should be "". - do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): - whether to use classifier free guidance or not - num_images_per_prompt (`int`, *optional*, defaults to 1): + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): number of images that should be generated per prompt - device: (`torch.device`, *optional*): - torch device to place the resulting embeddings on - prompt_embeds (`torch.FloatTensor`, *optional*): + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" - string. - clean_caption (`bool`, defaults to `False`): - If `True`, the function will preprocess and clean the provided caption before encoding. - max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. """ + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] - if "mask_feature" in kwargs: - deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." - deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] - if device is None: - device = self._execution_device + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length) + if text_encoder_index == 1: + max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length) + else: + max_length = max_sequence_length if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -222,74 +318,199 @@ class EasyAnimatePipeline(DiffusionPipeline): else: batch_size = prompt_embeds.shape[0] - # See Section 3.1. of the paper. - max_length = max_sequence_length - if prompt_embeds is None: - prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {max_length} tokens: {removed_text}" + if type(tokenizer) in [BertTokenizer, T5Tokenizer]: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + if text_input_ids.shape[-1] > actual_max_sequence_length: + reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) + text_inputs = tokenizer( + reprompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length) + removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {_actual_max_sequence_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + + if self.transformer.config.enable_text_attention_mask: + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + else: + prompt_embeds = text_encoder( + text_input_ids.to(device) + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + if prompt is not None and isinstance(prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _prompt}], + } for _prompt in prompt + ] + text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True ) - prompt_attention_mask = text_inputs.attention_mask - prompt_attention_mask = prompt_attention_mask.to(device) - - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) - prompt_embeds = prompt_embeds[0] - - if self.text_encoder is not None: - dtype = self.text_encoder.dtype - elif self.transformer is not None: - dtype = self.transformer.dtype - else: - dtype = None - + text_inputs = tokenizer( + text=[text], + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(text_encoder.device) + + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + if self.transformer.config.enable_text_attention_mask: + # Inference: Generation of the output + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) - prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.to(device=device) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens = [negative_prompt] * batch_size - uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_attention_mask=True, - add_special_tokens=True, - return_tensors="pt", - ) - negative_prompt_attention_mask = uncond_input.attention_mask - negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + if type(tokenizer) in [BertTokenizer, T5Tokenizer]: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_input_ids = uncond_input.input_ids + if uncond_input_ids.shape[-1] > actual_max_sequence_length: + reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) + uncond_input = tokenizer( + reuncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + uncond_input_ids = uncond_input.input_ids + + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + if self.transformer.config.enable_text_attention_mask: + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + else: + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device) + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + if negative_prompt is not None and isinstance(negative_prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": negative_prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _negative_prompt}], + } for _negative_prompt in negative_prompt + ] + text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask - ) - negative_prompt_embeds = negative_prompt_embeds[0] + text_inputs = tokenizer( + text=[text], + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(text_encoder.device) + + text_input_ids = text_inputs.input_ids + negative_prompt_attention_mask = text_inputs.attention_mask + if self.transformer.config.enable_text_attention_mask: + # Inference: Generation of the output + negative_prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=negative_prompt_attention_mask, + output_hidden_states=True).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method @@ -299,14 +520,9 @@ class EasyAnimatePipeline(DiffusionPipeline): negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) - negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) - else: - negative_prompt_embeds = None - negative_prompt_attention_mask = None - - return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -331,20 +547,25 @@ class EasyAnimatePipeline(DiffusionPipeline): prompt, height, width, - negative_prompt, - callback_steps, + negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, ): - if height % 8 != 0 or width % 8 != 0: + if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: @@ -356,14 +577,18 @@ class EasyAnimatePipeline(DiffusionPipeline): raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -371,6 +596,13 @@ class EasyAnimatePipeline(DiffusionPipeline): f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( @@ -378,153 +610,25 @@ class EasyAnimatePipeline(DiffusionPipeline): f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) - - # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing - def _text_preprocessing(self, text, clean_caption=False): - if clean_caption and not is_bs4_available(): - logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) - logger.warn("Setting `clean_caption` to False...") - clean_caption = False - - if clean_caption and not is_ftfy_available(): - logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) - logger.warn("Setting `clean_caption` to False...") - clean_caption = False - - if not isinstance(text, (tuple, list)): - text = [text] - - def process(text: str): - if clean_caption: - text = self._clean_caption(text) - text = self._clean_caption(text) - else: - text = text.lower().strip() - return text - - return [process(t) for t in text] - - # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption - def _clean_caption(self, caption): - caption = str(caption) - caption = ul.unquote_plus(caption) - caption = caption.strip().lower() - caption = re.sub("", "person", caption) - # urls: - caption = re.sub( - r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa - "", - caption, - ) # regex for urls - caption = re.sub( - r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa - "", - caption, - ) # regex for urls - # html: - caption = BeautifulSoup(caption, features="html.parser").text - - # @ - caption = re.sub(r"@[\w\d]+\b", "", caption) - - # 31C0—31EF CJK Strokes - # 31F0—31FF Katakana Phonetic Extensions - # 3200—32FF Enclosed CJK Letters and Months - # 3300—33FF CJK Compatibility - # 3400—4DBF CJK Unified Ideographs Extension A - # 4DC0—4DFF Yijing Hexagram Symbols - # 4E00—9FFF CJK Unified Ideographs - caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) - caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) - caption = re.sub(r"[\u3200-\u32ff]+", "", caption) - caption = re.sub(r"[\u3300-\u33ff]+", "", caption) - caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) - caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) - caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) - ####################################################### - - # все виды тире / all types of dash --> "-" - caption = re.sub( - r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa - "-", - caption, - ) - - # кавычки к одному стандарту - caption = re.sub(r"[`´«»“”¨]", '"', caption) - caption = re.sub(r"[‘’]", "'", caption) - - # " - caption = re.sub(r""?", "", caption) - # & - caption = re.sub(r"&", "", caption) - - # ip adresses: - caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) - - # article ids: - caption = re.sub(r"\d:\d\d\s+$", "", caption) - - # \n - caption = re.sub(r"\\n", " ", caption) - - # "#123" - caption = re.sub(r"#\d{1,3}\b", "", caption) - # "#12345.." - caption = re.sub(r"#\d{5,}\b", "", caption) - # "123456.." - caption = re.sub(r"\b\d{6,}\b", "", caption) - # filenames: - caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) - - # - caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" - caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" - - caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT - caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " - - # this-is-my-cute-cat / this_is_my_cute_cat - regex2 = re.compile(r"(?:\-|\_)") - if len(re.findall(regex2, caption)) > 3: - caption = re.sub(regex2, " ", caption) - - caption = ftfy.fix_text(caption) - caption = html.unescape(html.unescape(caption)) - - caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 - caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc - caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 - - caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) - caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) - caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) - caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) - caption = re.sub(r"\bpage\s+\d+\b", "", caption) - - caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... - - caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) - - caption = re.sub(r"\b\s+\:\s+", r": ", caption) - caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) - caption = re.sub(r"\s+", " ", caption) - - caption.strip() - - caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) - caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) - caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) - caption = re.sub(r"^\.\S+$", "", caption) - - return caption.strip() + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): - if self.vae.quant_conv.weight.ndim==5: - mini_batch_encoder = self.vae.mini_batch_encoder - mini_batch_decoder = self.vae.mini_batch_decoder - shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor) + if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: + if self.vae.cache_mag_vae: + mini_batch_encoder = self.vae.mini_batch_encoder + mini_batch_decoder = self.vae.mini_batch_decoder + shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor) + else: + mini_batch_encoder = self.vae.mini_batch_encoder + mini_batch_decoder = self.vae.mini_batch_decoder + shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor) else: shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) @@ -538,11 +642,12 @@ class EasyAnimatePipeline(DiffusionPipeline): latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) - + # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma return latents - + def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder): if video.size()[2] <= mini_batch_encoder: return video @@ -558,16 +663,17 @@ class EasyAnimatePipeline(DiffusionPipeline): video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2 return video - + def decode_latents(self, latents): video_length = latents.shape[2] latents = 1 / self.vae.config.scaling_factor * latents - if self.vae.quant_conv.weight.ndim==5: + if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: mini_batch_encoder = self.vae.mini_batch_encoder mini_batch_decoder = self.vae.mini_batch_decoder video = self.vae.decode(latents)[0] video = video.clamp(-1, 1) - video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1) + if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae: + video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1) else: latents = rearrange(latents, "b c f h w -> (b f) c h w") video = [] @@ -580,8 +686,28 @@ class EasyAnimatePipeline(DiffusionPipeline): video = video.cpu().float().numpy() return video - def enable_autocast_float8_transformer(self): - self.enable_autocast_float8_transformer_flag = True + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -589,103 +715,131 @@ class EasyAnimatePipeline(DiffusionPipeline): self, prompt: Union[str, List[str]] = None, video_length: Optional[int] = None, - negative_prompt: str = "", - num_inference_steps: int = 20, - timesteps: List[int] = None, - guidance_scale: float = 4.5, - num_images_per_prompt: Optional[int] = 1, height: Optional[int] = None, width: Optional[int] = None, - eta: float = 0.0, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - prompt_attention_mask: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + prompt_attention_mask_2: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, output_type: Optional[str] = "latent", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - clean_caption: bool = True, - max_sequence_length: int = 120, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = (1024, 1024), + target_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), comfyui_progressbar: bool = False, - **kwargs, - ) -> Union[EasyAnimatePipelineOutput, Tuple]: - """ - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - num_inference_steps (`int`, *optional*, defaults to 100): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` - timesteps are used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 7.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - height (`int`, *optional*, defaults to self.unet.config.sample_size): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size): - The width in pixels of the generated image. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not - provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - clean_caption (`bool`, *optional*, defaults to `True`): - Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to - be installed. If the dependencies are not installed, the embeddings will be created from the raw - prompt. - mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked. + timesteps: Optional[List[int]] = None, + ): + r""" + Generates images or video using the EasyAnimate pipeline based on the provided prompts. Examples: + prompt (`str` or `List[str]`, *optional*): + Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead. + video_length (`int`, *optional*): + Length of the generated video (in frames). + height (`int`, *optional*): + Height of the generated image in pixels. + width (`int`, *optional*): + Width of the generated image in pixels. + num_inference_steps (`int`, *optional*, defaults to 50): + Number of denoising steps during generation. More steps generally yield higher quality images but slow down inference. + guidance_scale (`float`, *optional*, defaults to 5.0): + Encourages the model to align outputs with prompts. A higher value may decrease image quality. + negative_prompt (`str` or `List[str]`, *optional*): + Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + Number of images to generate for each prompt. + eta (`float`, *optional*, defaults to 0.0): + Applies to DDIM scheduling. Controlled by the eta parameter from the related literature. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A generator to ensure reproducibility in image generation. + latents (`torch.Tensor`, *optional*): + Predefined latent tensors to condition generation. + prompt_embeds (`torch.Tensor`, *optional*): + Text embeddings for the prompts. Overrides prompt string inputs for more flexibility. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Secondary text embeddings to supplement or replace the initial prompt embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Embeddings for negative prompts. Overrides string inputs if defined. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Secondary embeddings for negative prompts, similar to `negative_prompt_embeds`. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the primary prompt embeddings. + prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the secondary prompt embeddings. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for negative prompt embeddings. + negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for secondary negative prompt embeddings. + output_type (`str`, *optional*, defaults to "latent"): + Format of the generated output, either as a PIL image or as a NumPy array. + return_dict (`bool`, *optional*, defaults to `True`): + If `True`, returns a structured output. Otherwise returns a simple tuple. + callback_on_step_end (`Callable`, *optional*): + Functions called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + Tensor names to be included in callback function calls. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Adjusts noise levels based on guidance scale. + original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + Original dimensions of the output. + target_size (`Tuple[int, int]`, *optional*): + Desired output dimensions for calculations. + crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): + Coordinates for cropping. Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is - returned where the first element is a list with the generated images + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width + height = int((height // 16) * 16) + width = int((width // 16) * 16) + # 1. Check inputs. Raise error if not correct - height = height or self.transformer.config.sample_size * self.vae_scale_factor - width = width or self.transformer.config.sample_size * self.vae_scale_factor + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False - # 2. Default height and width to transformer + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -694,136 +848,223 @@ class EasyAnimatePipeline(DiffusionPipeline): batch_size = prompt_embeds.shape[0] device = self._execution_device - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.text_encoder_2 is not None: + dtype = self.text_encoder_2.dtype + else: + dtype = self.transformer.dtype # 3. Encode input prompt ( prompt_embeds, - prompt_attention_mask, negative_prompt_embeds, + prompt_attention_mask, negative_prompt_attention_mask, ) = self.encode_prompt( - prompt, - do_classifier_free_guidance, - negative_prompt=negative_prompt, - num_images_per_prompt=num_images_per_prompt, + prompt=prompt, device=device, + dtype=dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, - clean_caption=clean_caption, - max_sequence_length=max_sequence_length, + text_encoder_index=0, ) - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + if self.tokenizer_2 is not None: + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + text_encoder_index=1, + ) + else: + prompt_embeds_2 = None + negative_prompt_embeds_2 = None + prompt_attention_mask_2 = None + negative_prompt_attention_mask_2 = None # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 1) - # 5. Prepare latents. - latent_channels = self.transformer.config.in_channels + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, - latent_channels, + num_channels_latents, video_length, height, width, - prompt_embeds.dtype, + dtype, device, generator, latents, ) + if comfyui_progressbar: + pbar.update(1) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 6.1 Prepare micro-conditions. + # 7 create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope": + base_size_width = 720 // 8 // self.transformer.config.patch_size + base_size_height = 480 // 8 // self.transformer.config.patch_size + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + image_rotary_emb = get_3d_rotary_pos_embed( + self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width), + temporal_size=latents.size(2), use_real=True, + ) + else: + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size, base_size + ) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width) + ) + + # Get other hunyuan params + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + style = torch.tensor([0], device=device) + + if self.do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + + # To latents.device + add_time_ids = add_time_ids.to(dtype=dtype, device=device).repeat( + batch_size * num_images_per_prompt, 1 + ) + style = style.to(device=device).repeat(batch_size * num_images_per_prompt) + + # Get other pixart params added_cond_kwargs = {"resolution": None, "aspect_ratio": None} - if self.transformer.config.sample_size == 128: + if self.transformer.config.get("sample_size", 64) == 128: resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1) aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) - resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) - aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) + resolution = resolution.to(dtype=dtype, device=device) + aspect_ratio = aspect_ratio.to(dtype=dtype, device=device) - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: resolution = torch.cat([resolution, resolution], dim=0) aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} - torch.cuda.empty_cache() - if self.enable_autocast_float8_transformer_flag: - origin_weight_dtype = self.transformer.dtype - self.transformer = self.transformer.to(torch.float8_e4m3fn) - - # 7. Denoising loop - num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) - if comfyui_progressbar: - from comfy.utils import ProgressBar - pbar = ProgressBar(num_inference_steps) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + if prompt_embeds_2 is not None: + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + + # To latents.device + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + if prompt_embeds_2 is not None: + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - current_timestep = t - if not torch.is_tensor(current_timestep): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = latent_model_input.device.type == "mps" - if isinstance(current_timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) - elif len(current_timestep.shape) == 0: - current_timestep = current_timestep[None].to(latent_model_input.device) - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - current_timestep = current_timestep.expand(latent_model_input.shape[0]) - - # predict noise model_output + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual noise_pred = self.transformer( latent_model_input, + t_expand, encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - timestep=current_timestep, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] + + if noise_pred.size()[1] != self.vae.config.latent_channels: + noise_pred, _ = noise_pred.chunk(2, dim=1) # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # learned sigma - if self.transformer.config.out_channels // 2 == latent_channels: - noise_pred = noise_pred.chunk(2, dim=1)[0] - else: - noise_pred = noise_pred + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) - # compute previous image: x_t -> x_t-1 + # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() if comfyui_progressbar: pbar.update(1) - if self.enable_autocast_float8_transformer_flag: - self.transformer = self.transformer.to("cpu", origin_weight_dtype) - # Post-processing video = self.decode_latents(latents) @@ -831,7 +1072,10 @@ class EasyAnimatePipeline(DiffusionPipeline): if output_type == "latent": video = torch.from_numpy(video) + # Offload all models + self.maybe_free_model_hooks() + if not return_dict: return video - return EasyAnimatePipelineOutput(videos=video) \ No newline at end of file + return EasyAnimatePipelineOutput(frames=video) \ No newline at end of file diff --git a/easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_control.py b/easyanimate/pipeline/pipeline_easyanimate_control.py similarity index 62% rename from easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_control.py rename to easyanimate/pipeline/pipeline_easyanimate_control.py index b23502e8db80260d6824ba7d98e606cb62d252c0..f5ab704657966a745edeaa80276508dbe34df81f 100644 --- a/easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_control.py +++ b/easyanimate/pipeline/pipeline_easyanimate_control.py @@ -31,7 +31,8 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import \ StableDiffusionSafetyChecker -from diffusers.schedulers import DDIMScheduler, DPMSolverMultistepScheduler +from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler, + FlowMatchEulerDiscreteScheduler) from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate, is_bs4_available, is_ftfy_available, is_torch_xla_available, logging, @@ -41,11 +42,12 @@ from einops import rearrange from PIL import Image from tqdm import tqdm from transformers import (BertModel, BertTokenizer, CLIPImageProcessor, - CLIPVisionModelWithProjection, - T5EncoderModel, T5Tokenizer) + CLIPVisionModelWithProjection, Qwen2Tokenizer, + Qwen2VLForConditionalGeneration, T5EncoderModel, + T5Tokenizer) from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel -from .pipeline_easyanimate import EasyAnimatePipelineOutput +from .pipeline_easyanimate_inpaint import EasyAnimatePipelineOutput if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -64,6 +66,7 @@ EXAMPLE_DOC_STRING = """ ``` """ +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): tw = tgt_width th = tgt_height @@ -97,44 +100,140 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg -class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): +# Resize mask information in magvit +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class EasyAnimateControlPipeline(DiffusionPipeline): r""" Pipeline for text-to-video generation using EasyAnimate. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by - HunyuanDiT team) + HunyuanDiT team) in V5. Args: vae ([`AutoencoderKLMagvit`]): Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. - text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): - Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). - EasyAnimate uses a fine-tuned [bilingual CLIP]. - tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): - A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]): + EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. + EasyAnimate uses [bilingual CLIP](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) in V5. + tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]): + A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text. transformer ([`EasyAnimateTransformer3DModel`]): - The EasyAnimate model designed by Tencent Hunyuan. + The EasyAnimate model designed by EasyAnimate Team. text_encoder_2 (`T5EncoderModel`): - The mT5 embedder. + EasyAnimate does not use text_encoder_2 in V5.1. + EasyAnimate uses [mT5](https://huggingface.co/google/mt5-base) embedder in V5. tokenizer_2 (`T5Tokenizer`): The tokenizer for the mT5 embedder. - scheduler ([`DDIMScheduler`]): + scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" _optional_components = [ - "safety_checker", - "feature_extractor", "text_encoder_2", "tokenizer_2", "text_encoder", "tokenizer", ] - _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = [ "latents", "prompt_embeds", @@ -146,53 +245,30 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): def __init__( self, vae: AutoencoderKLMagvit, - text_encoder: BertModel, - tokenizer: BertTokenizer, - text_encoder_2: T5EncoderModel, - tokenizer_2: T5Tokenizer, + text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel], + tokenizer: Union[Qwen2Tokenizer, BertTokenizer], + text_encoder_2: Optional[Union[T5EncoderModel, Qwen2VLForConditionalGeneration]], + tokenizer_2: Optional[Union[T5Tokenizer, Qwen2Tokenizer]], transformer: EasyAnimateTransformer3DModel, - scheduler: DDIMScheduler, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, - requires_safety_checker: bool = True + scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, + text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, transformer=transformer, scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - text_encoder_2=text_encoder_2 ) - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True ) - self.enable_autocast_float8_transformer_flag = False - self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, *args, **kwargs): super().enable_sequential_cpu_offload(*args, **kwargs) @@ -272,19 +348,9 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): batch_size = prompt_embeds.shape[0] if prompt_embeds is None: - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_attention_mask=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - if text_input_ids.shape[-1] > actual_max_sequence_length: - reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) + if type(tokenizer) in [BertTokenizer, T5Tokenizer]: text_inputs = tokenizer( - reprompt, + prompt, padding="max_length", max_length=max_length, truncation=True, @@ -292,91 +358,188 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length) - removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {_actual_max_sequence_length} tokens: {removed_text}" - ) - prompt_attention_mask = text_inputs.attention_mask.to(device) - if self.transformer.config.enable_text_attention_mask: - prompt_embeds = text_encoder( - text_input_ids.to(device), - attention_mask=prompt_attention_mask, - ) + if text_input_ids.shape[-1] > actual_max_sequence_length: + reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) + text_inputs = tokenizer( + reprompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length) + removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {_actual_max_sequence_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + + if self.transformer.config.enable_text_attention_mask: + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + else: + prompt_embeds = text_encoder( + text_input_ids.to(device) + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) else: - prompt_embeds = text_encoder( - text_input_ids.to(device) + if prompt is not None and isinstance(prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _prompt}], + } for _prompt in prompt + ] + text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True ) - prompt_embeds = prompt_embeds[0] - prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + text_inputs = tokenizer( + text=[text], + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(text_encoder.device) + + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + if self.transformer.config.enable_text_attention_mask: + # Inference: Generation of the output + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.to(device=device) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - uncond_input_ids = uncond_input.input_ids - if uncond_input_ids.shape[-1] > actual_max_sequence_length: - reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) + if type(tokenizer) in [BertTokenizer, T5Tokenizer]: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] uncond_input = tokenizer( - reuncond_tokens, + uncond_tokens, padding="max_length", max_length=max_length, truncation=True, - return_attention_mask=True, return_tensors="pt", ) uncond_input_ids = uncond_input.input_ids + if uncond_input_ids.shape[-1] > actual_max_sequence_length: + reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) + uncond_input = tokenizer( + reuncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + uncond_input_ids = uncond_input.input_ids - negative_prompt_attention_mask = uncond_input.attention_mask.to(device) - if self.transformer.config.enable_text_attention_mask: - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), - attention_mask=negative_prompt_attention_mask, - ) + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + if self.transformer.config.enable_text_attention_mask: + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + else: + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device) + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) else: - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device) + if negative_prompt is not None and isinstance(negative_prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": negative_prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _negative_prompt}], + } for _negative_prompt in negative_prompt + ] + text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True ) - negative_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + + text_inputs = tokenizer( + text=[text], + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(text_encoder.device) + + text_input_ids = text_inputs.input_ids + negative_prompt_attention_mask = text_inputs.attention_mask + if self.transformer.config.enable_text_attention_mask: + # Inference: Generation of the output + negative_prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=negative_prompt_attention_mask, + output_hidden_states=True).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method @@ -386,24 +549,10 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device) + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - return image, has_nsfw_concept - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -438,8 +587,8 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): negative_prompt_attention_mask_2=None, callback_on_step_end_tensor_inputs=None, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs @@ -524,43 +673,44 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma return latents def prepare_control_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): - # resize the mask to latents shape as we concatenate the mask to the latents + # resize the control to latents shape as we concatenate the control to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision - if mask is not None: - mask = mask.to(device=device, dtype=self.vae.dtype) + if control is not None: + control = control.to(device=device, dtype=dtype) bs = 1 - new_mask = [] - for i in range(0, mask.shape[0], bs): - mask_bs = mask[i : i + bs] - mask_bs = self.vae.encode(mask_bs)[0] - mask_bs = mask_bs.mode() - new_mask.append(mask_bs) - mask = torch.cat(new_mask, dim = 0) - mask = mask * self.vae.config.scaling_factor - - if masked_image is not None: - masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim = 0) + control = control * self.vae.config.scaling_factor + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) bs = 1 - new_mask_pixel_values = [] - for i in range(0, masked_image.shape[0], bs): - mask_pixel_values_bs = masked_image[i : i + bs] - mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] - mask_pixel_values_bs = mask_pixel_values_bs.mode() - new_mask_pixel_values.append(mask_pixel_values_bs) - masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) - masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim = 0) + control_image_latents = control_image_latents * self.vae.config.scaling_factor else: - masked_image_latents = None + control_image_latents = None - return mask, masked_image_latents + return control, control_image_latents def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder): if video.size()[2] <= mini_batch_encoder: @@ -623,9 +773,6 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): def interrupt(self): return self._interrupt - def enable_autocast_float8_transformer(self): - self.enable_autocast_float8_transformer_flag = True - @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -635,6 +782,8 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): height: Optional[int] = None, width: Optional[int] = None, control_video: Union[torch.FloatTensor] = None, + control_camera_video: Union[torch.FloatTensor] = None, + ref_image: Union[torch.FloatTensor] = None, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -661,6 +810,7 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): target_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), comfyui_progressbar: bool = False, + timesteps: Optional[List[int]] = None, ): r""" Generates images or video using the EasyAnimate pipeline based on the provided prompts. @@ -765,6 +915,12 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): batch_size = prompt_embeds.shape[0] device = self._execution_device + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.text_encoder_2 is not None: + dtype = self.text_encoder_2.dtype + else: + dtype = self.transformer.dtype # 3. Encode input prompt ( @@ -775,7 +931,7 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): ) = self.encode_prompt( prompt=prompt, device=device, - dtype=self.transformer.dtype, + dtype=dtype, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, @@ -785,28 +941,36 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): negative_prompt_attention_mask=negative_prompt_attention_mask, text_encoder_index=0, ) - ( - prompt_embeds_2, - negative_prompt_embeds_2, - prompt_attention_mask_2, - negative_prompt_attention_mask_2, - ) = self.encode_prompt( - prompt=prompt, - device=device, - dtype=self.transformer.dtype, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds_2, - negative_prompt_embeds=negative_prompt_embeds_2, - prompt_attention_mask=prompt_attention_mask_2, - negative_prompt_attention_mask=negative_prompt_attention_mask_2, - text_encoder_index=1, - ) - torch.cuda.empty_cache() + if self.tokenizer_2 is not None: + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + text_encoder_index=1, + ) + else: + prompt_embeds_2 = None + negative_prompt_embeds_2 = None + prompt_attention_mask_2 = None + negative_prompt_attention_mask_2 = None # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps = self.scheduler.timesteps if comfyui_progressbar: from comfy.utils import ProgressBar @@ -820,7 +984,7 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): video_length, height, width, - prompt_embeds.dtype, + dtype, device, generator, latents, @@ -828,27 +992,69 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): if comfyui_progressbar: pbar.update(1) - if control_video is not None: + if control_camera_video is not None: + control_video_latents = resize_mask(control_camera_video, latents, process_first_frame_only=True) + control_video_latents = control_video_latents * 6 + control_latents = ( + torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents + ).to(device, dtype) + elif control_video is not None: video_length = control_video.shape[2] control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) control_video = control_video.to(dtype=torch.float32) control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) + control_video_latents = self.prepare_control_latents( + None, + control_video, + batch_size, + height, + width, + dtype, + device, + generator, + self.do_classifier_free_guidance + )[1] + control_latents = ( + torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents + ).to(device, dtype) else: - control_video = None - control_video_latents = self.prepare_control_latents( - None, - control_video, - batch_size, - height, - width, - prompt_embeds.dtype, - device, - generator, - self.do_classifier_free_guidance - )[1] - control_latents = ( - torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents - ) + control_video_latents = torch.zeros_like(latents).to(device, dtype) + control_latents = ( + torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents + ).to(device, dtype) + + if ref_image is not None: + video_length = ref_image.shape[2] + ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width) + ref_image = ref_image.to(dtype=torch.float32) + ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length) + + ref_image_latentes = self.prepare_control_latents( + None, + ref_image, + batch_size, + height, + width, + prompt_embeds.dtype, + device, + generator, + self.do_classifier_free_guidance + )[1] + + ref_image_latentes_conv_in = torch.zeros_like(latents) + if latents.size()[2] != 1: + ref_image_latentes_conv_in[:, :, :1] = ref_image_latentes + ref_image_latentes_conv_in = ( + torch.cat([ref_image_latentes_conv_in] * 2) if self.do_classifier_free_guidance else ref_image_latentes_conv_in + ).to(device, dtype) + control_latents = torch.cat([control_latents, ref_image_latentes_conv_in], dim = 1) + else: + if self.transformer.config.get("add_ref_latent_in_control_model", False): + ref_image_latentes_conv_in = torch.zeros_like(latents) + ref_image_latentes_conv_in = ( + torch.cat([ref_image_latentes_conv_in] * 2) if self.do_classifier_free_guidance else ref_image_latentes_conv_in + ).to(device, dtype) + control_latents = torch.cat([control_latents, ref_image_latentes_conv_in], dim = 1) if comfyui_progressbar: pbar.update(1) @@ -880,34 +1086,49 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): ) # Get other hunyuan params - style = torch.tensor([0], device=device) - target_size = target_size or (height, width) add_time_ids = list(original_size + target_size + crops_coords_top_left) - add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + style = torch.tensor([0], device=device) if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) - prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) - prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) add_time_ids = torch.cat([add_time_ids] * 2, dim=0) style = torch.cat([style] * 2, dim=0) # To latents.device - prompt_embeds = prompt_embeds.to(device=device) - prompt_attention_mask = prompt_attention_mask.to(device=device) - prompt_embeds_2 = prompt_embeds_2.to(device=device) - prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) - add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( + add_time_ids = add_time_ids.to(dtype=dtype, device=device).repeat( batch_size * num_images_per_prompt, 1 ) style = style.to(device=device).repeat(batch_size * num_images_per_prompt) - torch.cuda.empty_cache() - if self.enable_autocast_float8_transformer_flag: - origin_weight_dtype = self.transformer.dtype - self.transformer = self.transformer.to(torch.float8_e4m3fn) + # Get other pixart params + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + if self.transformer.config.get("sample_size", 64) == 128: + resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1) + aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) + resolution = resolution.to(dtype=dtype, device=device) + aspect_ratio = aspect_ratio.to(dtype=dtype, device=device) + + if self.do_classifier_free_guidance: + resolution = torch.cat([resolution, resolution], dim=0) + aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) + + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + if prompt_embeds_2 is not None: + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + + # To latents.device + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + if prompt_embeds_2 is not None: + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) + # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) @@ -918,7 +1139,8 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( @@ -935,8 +1157,9 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): image_meta_size=add_time_ids, style=style, image_rotary_emb=image_rotary_emb, - return_dict=False, + added_cond_kwargs=added_cond_kwargs, control_latents=control_latents, + return_dict=False, )[0] if noise_pred.size()[1] != self.vae.config.latent_channels: noise_pred, _ = noise_pred.chunk(2, dim=1) @@ -976,10 +1199,6 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): if comfyui_progressbar: pbar.update(1) - if self.enable_autocast_float8_transformer_flag: - self.transformer = self.transformer.to("cpu", origin_weight_dtype) - - torch.cuda.empty_cache() # Post-processing video = self.decode_latents(latents) @@ -993,4 +1212,4 @@ class EasyAnimatePipeline_Multi_Text_Encoder_Control(DiffusionPipeline): if not return_dict: return video - return EasyAnimatePipelineOutput(videos=video) \ No newline at end of file + return EasyAnimatePipelineOutput(frames=video) \ No newline at end of file diff --git a/easyanimate/pipeline/pipeline_easyanimate_inpaint.py b/easyanimate/pipeline/pipeline_easyanimate_inpaint.py index 340b977479d0b42f1b3980188b2ae7dbb8208c02..ffc5c45bfef6ecb9530ebd13b4c3333ffad7ea41 100644 --- a/easyanimate/pipeline/pipeline_easyanimate_inpaint.py +++ b/easyanimate/pipeline/pipeline_easyanimate_inpaint.py @@ -1,4 +1,4 @@ -# Copyright 2023 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved. +# Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,149 +12,336 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy -import gc -import html import inspect -import re -import urllib.parse as ul -from dataclasses import dataclass -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F -from diffusers import DiffusionPipeline, ImagePipelineOutput +from dataclasses import dataclass +from diffusers import DiffusionPipeline +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.image_processor import VaeImageProcessor -from diffusers.models import AutoencoderKL -from diffusers.schedulers import DPMSolverMultistepScheduler +from diffusers.models import AutoencoderKL, HunyuanDiT2DModel +from diffusers.models.embeddings import (get_2d_rotary_pos_embed, + get_3d_rotary_pos_embed) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion.safety_checker import \ + StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, FlowMatchEulerDiscreteScheduler from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate, - is_bs4_available, is_ftfy_available, logging, + is_bs4_available, is_ftfy_available, + is_torch_xla_available, logging, replace_example_docstring) from diffusers.utils.torch_utils import randn_tensor from einops import rearrange from PIL import Image from tqdm import tqdm -from transformers import (CLIPImageProcessor, CLIPVisionModelWithProjection, +from transformers import (BertModel, BertTokenizer, CLIPImageProcessor, + Qwen2Tokenizer, Qwen2VLForConditionalGeneration, CLIPVisionModelWithProjection, T5EncoderModel, T5Tokenizer) -from ..models.transformer3d import Transformer3DModel +from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel -logger = logging.get_logger(__name__) # pylint: disable=invalid-name +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm -if is_bs4_available(): - from bs4 import BeautifulSoup + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False -if is_ftfy_available(): - import ftfy +logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch - >>> from diffusers import EasyAnimatePipeline - - >>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too. - >>> pipe = EasyAnimatePipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16) - >>> # Enable memory optimizations. - >>> pipe.enable_model_cpu_offload() - - >>> prompt = "A small cactus with a happy face in the Sahara desert." - >>> image = pipe(prompt).images[0] + >>> from diffusers import EasyAnimateInpaintPipeline + >>> from diffusers.easyanimate.pipeline_easyanimate_inpaint import get_image_to_video_latent + >>> from diffusers.utils import export_to_video, load_image + + >>> pipe = EasyAnimateInpaintPipeline.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh-InP", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot." + >>> validation_image_start = load_image( + ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" + ... ) + >>> validation_image_end = None + >>> sample_size = (576, 448) + >>> video_length = 49 + >>> input_video, input_video_mask, _ = get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size) + >>> video = pipe(prompt, video_length=video_length, negative_prompt="bad detailed", height=sample_size[0], width=sample_size[1], input_video=input_video, mask_video=input_video_mask) + >>> export_to_video(video.sample[0], "output.mp4", fps=8) ``` """ -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents -def retrieve_latents(encoder_output, generator): - if hasattr(encoder_output, "latent_dist"): - return encoder_output.latent_dist.sample(generator) - elif hasattr(encoder_output, "latents"): - return encoder_output.latents + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Resize mask information in magvit +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized else: - raise AttributeError("Could not access latents of provided encoder_output") + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +## Add noise to reference video +def add_noise_to_reference_video(image, ratio=None): + if ratio is None: + sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device) + sigma = torch.exp(sigma).to(image.dtype) + else: + sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio + image_noise = torch.randn_like(image) * sigma[:, None, None, None, None] + image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise) + image = image + image_noise + return image + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + @dataclass class EasyAnimatePipelineOutput(BaseOutput): - videos: Union[torch.Tensor, np.ndarray] + r""" + Output class for EasyAnimate pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor + class EasyAnimateInpaintPipeline(DiffusionPipeline): r""" - Pipeline for text-to-image generation using PixArt-Alpha. + Pipeline for text-to-video generation using EasyAnimate. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. + EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by + HunyuanDiT team) in V5. + Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`T5EncoderModel`]): - Frozen text-encoder. PixArt-Alpha uses - [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the - [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. - tokenizer (`T5Tokenizer`): - Tokenizer of class - [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). - transformer ([`Transformer3DModel`]): - A text conditioned `Transformer3DModel` to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLMagvit`]): + Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. + text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]): + EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1. + EasyAnimate uses [bilingual CLIP](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) in V5. + tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]): + A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text. + transformer ([`EasyAnimateTransformer3DModel`]): + The EasyAnimate model designed by EasyAnimate Team. + text_encoder_2 (`T5EncoderModel`): + EasyAnimate does not use text_encoder_2 in V5.1. + EasyAnimate uses [mT5](https://huggingface.co/google/mt5-base) embedder in V5. + tokenizer_2 (`T5Tokenizer`): + The tokenizer for the mT5 embedder. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents. + clip_image_processor (`CLIPImageProcessor`): + The CLIP image embedder. + clip_image_encoder (`CLIPVisionModelWithProjection`): + The image processor for the CLIP image embedder. """ - bad_punct_regex = re.compile( - r"[" + "#®•©™&@·º½¾¿¡§~" + "\)" + "\(" + "\]" + "\[" + "\}" + "\{" + "\|" + "\\" + "\/" + "\*" + r"]{1,}" - ) # noqa - _optional_components = ["tokenizer", "text_encoder"] - model_cpu_offload_seq = "text_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->text_encoder_2->clip_image_encoder->transformer->vae" + _optional_components = [ + "text_encoder_2", + "tokenizer_2", + "text_encoder", + "tokenizer", + "clip_image_encoder", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "prompt_embeds_2", + "negative_prompt_embeds_2", + ] def __init__( self, - tokenizer: T5Tokenizer, - text_encoder: T5EncoderModel, - vae: AutoencoderKL, - transformer: Transformer3DModel, - scheduler: DPMSolverMultistepScheduler, - clip_image_processor:CLIPImageProcessor = None, - clip_image_encoder:CLIPVisionModelWithProjection = None, + vae: AutoencoderKLMagvit, + text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel], + tokenizer: Union[Qwen2Tokenizer, BertTokenizer], + text_encoder_2: Optional[Union[T5EncoderModel, Qwen2VLForConditionalGeneration]], + tokenizer_2: Optional[Union[T5Tokenizer, Qwen2Tokenizer]], + transformer: EasyAnimateTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + clip_image_processor: CLIPImageProcessor = None, + clip_image_encoder: CLIPVisionModelWithProjection = None, ): super().__init__() self.register_modules( - tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, scheduler=scheduler, - clip_image_processor=clip_image_processor, clip_image_encoder=clip_image_encoder, + text_encoder_2=text_encoder_2, + clip_image_processor=clip_image_processor, + clip_image_encoder=clip_image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=True) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.mask_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True ) - self.enable_autocast_float8_transformer_flag = False - # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py - def mask_text_embeddings(self, emb, mask): - if emb.shape[0] == 1: - keep_index = mask.sum().item() - return emb[:, :, :keep_index, :], keep_index - else: - masked_feature = emb * mask[:, None, :, None] - return masked_feature, emb.shape[2] + def enable_sequential_cpu_offload(self, *args, **kwargs): + super().enable_sequential_cpu_offload(*args, **kwargs) + if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None: + import accelerate + accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True) + self.transformer.clip_projection = self.transformer.clip_projection.to("cuda") - # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt def encode_prompt( self, - prompt: Union[str, List[str]], - do_classifier_free_guidance: bool = True, - negative_prompt: str = "", + prompt: str, + device: torch.device, + dtype: torch.dtype, num_images_per_prompt: int = 1, - device: Optional[torch.device] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - prompt_attention_mask: Optional[torch.FloatTensor] = None, - negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, - clean_caption: bool = False, - max_sequence_length: int = 120, - **kwargs, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + max_sequence_length: Optional[int] = None, + text_encoder_index: int = 0, + actual_max_sequence_length: int = 256 ): r""" Encodes the prompt into text encoder hidden states. @@ -162,33 +349,46 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` - instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For - PixArt-Alpha, this should be "". - do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): - whether to use classifier free guidance or not - num_images_per_prompt (`int`, *optional*, defaults to 1): + device: (`torch.device`): + torch device + dtype (`torch.dtype`): + torch dtype + num_images_per_prompt (`int`): number of images that should be generated per prompt - device: (`torch.device`, *optional*): - torch device to place the resulting embeddings on - prompt_embeds (`torch.FloatTensor`, *optional*): + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" - string. - clean_caption (`bool`, defaults to `False`): - If `True`, the function will preprocess and clean the provided caption before encoding. - max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the prompt. Required when `prompt_embeds` is passed directly. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. + max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. + text_encoder_index (`int`, *optional*): + Index of the text encoder to use. `0` for clip and `1` for T5. """ + tokenizers = [self.tokenizer, self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] - if "mask_feature" in kwargs: - deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." - deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + tokenizer = tokenizers[text_encoder_index] + text_encoder = text_encoders[text_encoder_index] - if device is None: - device = self._execution_device + if max_sequence_length is None: + if text_encoder_index == 0: + max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length) + if text_encoder_index == 1: + max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length) + else: + max_length = max_sequence_length if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -197,74 +397,199 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): else: batch_size = prompt_embeds.shape[0] - # See Section 3.1. of the paper. - max_length = max_sequence_length - if prompt_embeds is None: - prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {max_length} tokens: {removed_text}" + if type(tokenizer) in [BertTokenizer, T5Tokenizer]: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + if text_input_ids.shape[-1] > actual_max_sequence_length: + reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) + text_inputs = tokenizer( + reprompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length) + removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {_actual_max_sequence_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + + if self.transformer.config.enable_text_attention_mask: + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=prompt_attention_mask, + ) + else: + prompt_embeds = text_encoder( + text_input_ids.to(device) + ) + prompt_embeds = prompt_embeds[0] + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + if prompt is not None and isinstance(prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _prompt}], + } for _prompt in prompt + ] + text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True ) - prompt_attention_mask = text_inputs.attention_mask - prompt_attention_mask = prompt_attention_mask.to(device) - - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) - prompt_embeds = prompt_embeds[0] - - if self.text_encoder is not None: - dtype = self.text_encoder.dtype - elif self.transformer is not None: - dtype = self.transformer.dtype - else: - dtype = None - + text_inputs = tokenizer( + text=[text], + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(text_encoder.device) + + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + if self.transformer.config.enable_text_attention_mask: + # Inference: Generation of the output + prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) - prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.to(device=device) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens = [negative_prompt] * batch_size - uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_attention_mask=True, - add_special_tokens=True, - return_tensors="pt", - ) - negative_prompt_attention_mask = uncond_input.attention_mask - negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + if type(tokenizer) in [BertTokenizer, T5Tokenizer]: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + uncond_input_ids = uncond_input.input_ids + if uncond_input_ids.shape[-1] > actual_max_sequence_length: + reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) + uncond_input = tokenizer( + reuncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + uncond_input_ids = uncond_input.input_ids + + negative_prompt_attention_mask = uncond_input.attention_mask.to(device) + if self.transformer.config.enable_text_attention_mask: + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=negative_prompt_attention_mask, + ) + else: + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device) + ) + negative_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + else: + if negative_prompt is not None and isinstance(negative_prompt, str): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": negative_prompt}], + } + ] + else: + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": _negative_prompt}], + } for _negative_prompt in negative_prompt + ] + text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask - ) - negative_prompt_embeds = negative_prompt_embeds[0] + text_inputs = tokenizer( + text=[text], + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + padding_side="right", + return_tensors="pt", + ) + text_inputs = text_inputs.to(text_encoder.device) + + text_input_ids = text_inputs.input_ids + negative_prompt_attention_mask = text_inputs.attention_mask + if self.transformer.config.enable_text_attention_mask: + # Inference: Generation of the output + negative_prompt_embeds = text_encoder( + input_ids=text_input_ids, + attention_mask=negative_prompt_attention_mask, + output_hidden_states=True).hidden_states[-2] + else: + raise ValueError("LLM needs attention_mask") + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method @@ -274,14 +599,9 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) - negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) - else: - negative_prompt_embeds = None - negative_prompt_attention_mask = None - - return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device) + + return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): @@ -306,20 +626,25 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): prompt, height, width, - negative_prompt, - callback_steps, + negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + prompt_embeds_2=None, + negative_prompt_embeds_2=None, + prompt_attention_mask_2=None, + negative_prompt_attention_mask_2=None, + callback_on_step_end_tensor_inputs=None, ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: @@ -331,14 +656,18 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) + elif prompt is None and prompt_embeds_2 is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." + ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - if prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: + raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( @@ -346,6 +675,13 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: + raise ValueError( + "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." + ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( @@ -353,201 +689,83 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) + if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: + if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: + raise ValueError( + "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" + f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" + f" {negative_prompt_embeds_2.shape}." + ) - # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing - def _text_preprocessing(self, text, clean_caption=False): - if clean_caption and not is_bs4_available(): - logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) - logger.warn("Setting `clean_caption` to False...") - clean_caption = False - - if clean_caption and not is_ftfy_available(): - logger.warn(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) - logger.warn("Setting `clean_caption` to False...") - clean_caption = False - - if not isinstance(text, (tuple, list)): - text = [text] - - def process(text: str): - if clean_caption: - text = self._clean_caption(text) - text = self._clean_caption(text) - else: - text = text.lower().strip() - return text - - return [process(t) for t in text] - - # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption - def _clean_caption(self, caption): - caption = str(caption) - caption = ul.unquote_plus(caption) - caption = caption.strip().lower() - caption = re.sub("", "person", caption) - # urls: - caption = re.sub( - r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa - "", - caption, - ) # regex for urls - caption = re.sub( - r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa - "", - caption, - ) # regex for urls - # html: - caption = BeautifulSoup(caption, features="html.parser").text - - # @ - caption = re.sub(r"@[\w\d]+\b", "", caption) - - # 31C0—31EF CJK Strokes - # 31F0—31FF Katakana Phonetic Extensions - # 3200—32FF Enclosed CJK Letters and Months - # 3300—33FF CJK Compatibility - # 3400—4DBF CJK Unified Ideographs Extension A - # 4DC0—4DFF Yijing Hexagram Symbols - # 4E00—9FFF CJK Unified Ideographs - caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) - caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) - caption = re.sub(r"[\u3200-\u32ff]+", "", caption) - caption = re.sub(r"[\u3300-\u33ff]+", "", caption) - caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) - caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) - caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) - ####################################################### - - # все виды тире / all types of dash --> "-" - caption = re.sub( - r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa - "-", - caption, - ) - - # кавычки к одному стандарту - caption = re.sub(r"[`´«»“”¨]", '"', caption) - caption = re.sub(r"[‘’]", "'", caption) - - # " - caption = re.sub(r""?", "", caption) - # & - caption = re.sub(r"&", "", caption) - - # ip adresses: - caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) - - # article ids: - caption = re.sub(r"\d:\d\d\s+$", "", caption) - - # \n - caption = re.sub(r"\\n", " ", caption) - - # "#123" - caption = re.sub(r"#\d{1,3}\b", "", caption) - # "#12345.." - caption = re.sub(r"#\d{5,}\b", "", caption) - # "123456.." - caption = re.sub(r"\b\d{6,}\b", "", caption) - # filenames: - caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) - - # - caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" - caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" - - caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT - caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " - - # this-is-my-cute-cat / this_is_my_cute_cat - regex2 = re.compile(r"(?:\-|\_)") - if len(re.findall(regex2, caption)) > 3: - caption = re.sub(regex2, " ", caption) - - caption = ftfy.fix_text(caption) - caption = html.unescape(html.unescape(caption)) - - caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 - caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc - caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 - - caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) - caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) - caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) - caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) - caption = re.sub(r"\bpage\s+\d+\b", "", caption) - - caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... - - caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) - - caption = re.sub(r"\b\s+\:\s+", r": ", caption) - caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) - caption = re.sub(r"\s+", " ", caption) - - caption.strip() + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) - caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) - caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) - caption = re.sub(r"^\.\S+$", "", caption) + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - return caption.strip() + return timesteps, num_inference_steps - t_start def prepare_mask_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength ): # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision - video_length = mask.shape[2] - - mask = mask.to(device=device, dtype=self.vae.dtype) - if self.vae.quant_conv.weight.ndim==5: - bs = 1 - new_mask = [] - for i in range(0, mask.shape[0], bs): - mask_bs = mask[i : i + bs] - mask_bs = self.vae.encode(mask_bs)[0] - mask_bs = mask_bs.sample() - new_mask.append(mask_bs) - mask = torch.cat(new_mask, dim = 0) - mask = mask * self.vae.config.scaling_factor + if mask is not None: + mask = mask.to(device=device, dtype=dtype) + if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + mask = mask * self.vae.config.scaling_factor - else: - if mask.shape[1] == 4: - mask = mask else: - video_length = mask.shape[2] - mask = rearrange(mask, "b c f h w -> (b f) c h w") - mask = self._encode_vae_image(mask, generator=generator) - mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length) - - masked_image = masked_image.to(device=device, dtype=self.vae.dtype) - if self.vae.quant_conv.weight.ndim==5: - bs = 1 - new_mask_pixel_values = [] - for i in range(0, masked_image.shape[0], bs): - mask_pixel_values_bs = masked_image[i : i + bs] - mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] - mask_pixel_values_bs = mask_pixel_values_bs.sample() - new_mask_pixel_values.append(mask_pixel_values_bs) - masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) - masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + if mask.shape[1] == 4: + mask = mask + else: + video_length = mask.shape[2] + mask = rearrange(mask, "b c f h w -> (b f) c h w") + mask = self._encode_vae_image(mask, generator=generator) + mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length) + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=dtype) + if self.transformer.config.add_noise_in_inpaint_model: + masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength) + if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + masked_image_latents = masked_image_latents * self.vae.config.scaling_factor - else: - if masked_image.shape[1] == 4: - masked_image_latents = masked_image else: - video_length = mask.shape[2] - masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w") - masked_image_latents = self._encode_vae_image(masked_image, generator=generator) - masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length) + if masked_image.shape[1] == 4: + masked_image_latents = masked_image + else: + video_length = masked_image.shape[2] + masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w") + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + else: + masked_image_latents = None - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents - + def prepare_latents( self, batch_size, @@ -565,10 +783,15 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): return_noise=False, return_video_latents=False, ): - if self.vae.quant_conv.weight.ndim==5: - mini_batch_encoder = self.vae.mini_batch_encoder - mini_batch_decoder = self.vae.mini_batch_decoder - shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor) + if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: + if self.vae.cache_mag_vae: + mini_batch_encoder = self.vae.mini_batch_encoder + mini_batch_decoder = self.vae.mini_batch_decoder + shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor) + else: + mini_batch_encoder = self.vae.mini_batch_encoder + mini_batch_decoder = self.vae.mini_batch_decoder + shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor) else: shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) @@ -579,10 +802,9 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): ) if return_video_latents or (latents is None and not is_strength_max): - video = video.to(device=device, dtype=self.vae.dtype) - if self.vae.quant_conv.weight.ndim==5: + video = video.to(device=device, dtype=dtype) + if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: bs = 1 - mini_batch_encoder = self.vae.mini_batch_encoder new_video = [] for i in range(0, video.shape[0], bs): video_bs = video[i : i + bs] @@ -601,16 +823,24 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): video = self._encode_vae_image(video, generator=generator) video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1) + video_latents = video_latents.to(device=device, dtype=dtype) if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep) + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + latents = noise if is_strength_max else self.scheduler.scale_noise(video_latents, timestep, noise) + else: + latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep) # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents else: - noise = latents.to(device) - latents = noise * self.scheduler.init_noise_sigma + if hasattr(self.scheduler, "init_noise_sigma"): + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + else: + latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler outputs = (latents,) @@ -632,22 +862,23 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): # Encode middle videos latents = self.vae.encode(pixel_values)[0] - latents = latents.sample() + latents = latents.mode() # Decode middle videos middle_video = self.vae.decode(latents)[0] video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2 return video - + def decode_latents(self, latents): video_length = latents.shape[2] latents = 1 / self.vae.config.scaling_factor * latents - if self.vae.quant_conv.weight.ndim==5: + if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: mini_batch_encoder = self.vae.mini_batch_encoder mini_batch_decoder = self.vae.mini_batch_decoder video = self.vae.decode(latents)[0] video = video.clamp(-1, 1) - video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1) + if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae: + video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1) else: latents = rearrange(latents, "b c f h w -> (b f) c h w") video = [] @@ -660,32 +891,28 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): video = video.cpu().float().numpy() return video - def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): - if isinstance(generator, list): - image_latents = [ - retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) - for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(self.vae.encode(image), generator=generator) + @property + def guidance_scale(self): + return self._guidance_scale - image_latents = self.vae.config.scaling_factor * image_latents + @property + def guidance_rescale(self): + return self._guidance_rescale - return image_latents + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + @property + def num_timesteps(self): + return self._num_timesteps - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - - return timesteps, num_inference_steps - t_start - - def enable_autocast_float8_transformer(self): - self.enable_autocast_float8_transformer_flag = True + @property + def interrupt(self): + return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -696,109 +923,167 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): video: Union[torch.FloatTensor] = None, mask_video: Union[torch.FloatTensor] = None, masked_video_latents: Union[torch.FloatTensor] = None, - negative_prompt: str = "", - num_inference_steps: int = 20, - timesteps: List[int] = None, - guidance_scale: float = 4.5, - num_images_per_prompt: Optional[int] = 1, height: Optional[int] = None, width: Optional[int] = None, - strength: float = 1.0, - eta: float = 0.0, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - prompt_attention_mask: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_2: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_2: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + prompt_attention_mask_2: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, output_type: Optional[str] = "latent", return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - clean_caption: bool = True, - mask_feature: bool = True, - max_sequence_length: int = 120, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = (1024, 1024), + target_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), clip_image: Image = None, - clip_apply_ratio: float = 0.50, + clip_apply_ratio: float = 0.40, + strength: float = 1.0, + noise_aug_strength: float = 0.0563, comfyui_progressbar: bool = False, - **kwargs, - ) -> Union[EasyAnimatePipelineOutput, Tuple]: - """ - Function invoked when calling the pipeline for generation. + timesteps: Optional[List[int]] = None, + ): + r""" + The call function to the pipeline for generation with HunyuanDiT. - Args: + Examples: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + video_length (`int`, *optional*): + Length of the video to be generated in seconds. This parameter influences the number of frames and + continuity of generated content. + video (`torch.FloatTensor`, *optional*): + A tensor representing an input video, which can be modified depending on the prompts provided. + mask_video (`torch.FloatTensor`, *optional*): + A tensor to specify areas of the video to be masked (omitted from generation). + masked_video_latents (`torch.FloatTensor`, *optional*): + Latents from masked portions of the video, utilized during image generation. + height (`int`, *optional*): + The height in pixels of the generated image or video frames. + width (`int`, *optional*): + The width in pixels of the generated image or video frames. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image but slower + inference time. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is effective when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - num_inference_steps (`int`, *optional*, defaults to 100): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` - timesteps are used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 7.0): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. + The prompt or prompts to guide what to exclude in image generation. If not defined, you need to + provide `negative_prompt_embeds`. This parameter is ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. - height (`int`, *optional*, defaults to self.unet.config.sample_size): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size): - The width in pixels of the generated image. eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. + A parameter defined in the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the + [`~schedulers.DDIMScheduler`] and is ignored in other schedulers. It adjusts noise level during the + inference process. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not - provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) for setting + random seeds which helps in making generation deterministic. + latents (`torch.Tensor`, *optional*): + A pre-computed latent representation which can be used to guide the generation process. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, embeddings are generated from the `prompt` input argument. + prompt_embeds_2 (`torch.Tensor`, *optional*): + Secondary set of pre-generated text embeddings, useful for advanced prompt weighting. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings, aiding in fine-tuning what should not be represented in the outputs. + If not provided, embeddings are generated from the `negative_prompt` argument. + negative_prompt_embeds_2 (`torch.Tensor`, *optional*): + Secondary set of pre-generated negative text embeddings for further control. + prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask guiding the focus of the model on specific parts of the prompt text. Required when using + `prompt_embeds`. + prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the secondary prompt embedding. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Attention mask for the negative prompt, needed when `negative_prompt_embeds` are used. + negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): + Attention mask for the secondary negative prompt embedding. + output_type (`str`, *optional*, defaults to `"latent"`): + The output format of the generated image. Choose between `PIL.Image` and `np.array` to define + how you want the results to be formatted. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - clean_caption (`bool`, *optional*, defaults to `True`): - Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to - be installed. If the dependencies are not installed, the embeddings will be created from the raw - prompt. - mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked. + If set to `True`, a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] will be returned; + otherwise, a tuple containing the generated images and safety flags will be returned. + callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A callback function (or a list of them) that will be executed at the end of each denoising step, + allowing for custom processing during generation. + callback_on_step_end_tensor_inputs (`List[str]`, *optional*): + Specifies which tensor inputs should be included in the callback function. If not defined, all tensor + inputs will be passed, facilitating enhanced logging or monitoring of the generation process. + guidance_rescale (`float`, *optional*, defaults to 0.0): + Rescale parameter for adjusting noise configuration based on guidance rescale. Based on findings from + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): + The original dimensions of the image. Used to compute time ids during the generation process. + target_size (`Tuple[int, int]`, *optional*): + The targeted dimensions of the generated image, also utilized in the time id calculations. + crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): + Coordinates defining the top left corner of any cropping, utilized while calculating the time ids. + clip_image (`Image`, *optional*): + An optional image to assist in the generation process. It may be used as an additional visual cue. + clip_apply_ratio (`float`, *optional*, defaults to 0.40): + Ratio indicating how much influence the clip image should exert over the generated content. + strength (`float`, *optional*, defaults to 1.0): + Affects the overall styling or quality of the generated output. Values closer to 1 usually provide direct + adherence to prompts. + comfyui_progressbar (`bool`, *optional*, defaults to `False`): + Enables a progress bar in ComfyUI, providing visual feedback during the generation process. Examples: - + # Example usage of the function for generating images based on prompts. + Returns: - [`~pipelines.ImagePipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is - returned where the first element is a list with the generated images + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + Returns either a structured output containing generated images and their metadata when `return_dict` is + `True`, or a simpler tuple, where the first element is a list of generated images and the second + element indicates if any of them contain "not-safe-for-work" (NSFW) content. """ - # 1. Check inputs. Raise error if not correct - height = height or self.transformer.config.sample_size * self.vae_scale_factor - width = width or self.transformer.config.sample_size * self.vae_scale_factor + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. default height and width height = int(height // 16 * 16) width = int(width // 16 * 16) - # 2. Default height and width to transformer + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + callback_on_step_end_tensor_inputs, + ) + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._interrupt = False + + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -807,40 +1092,68 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): batch_size = prompt_embeds.shape[0] device = self._execution_device - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.text_encoder_2 is not None: + dtype = self.text_encoder_2.dtype + else: + dtype = self.transformer.dtype + # 3. Encode input prompt ( prompt_embeds, - prompt_attention_mask, negative_prompt_embeds, + prompt_attention_mask, negative_prompt_attention_mask, ) = self.encode_prompt( - prompt, - do_classifier_free_guidance, - negative_prompt=negative_prompt, - num_images_per_prompt=num_images_per_prompt, + prompt=prompt, device=device, + dtype=dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, prompt_attention_mask=prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask, - clean_caption=clean_caption, - max_sequence_length=max_sequence_length, + text_encoder_index=0, ) - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + if self.tokenizer_2 is not None: + ( + prompt_embeds_2, + negative_prompt_embeds_2, + prompt_attention_mask_2, + negative_prompt_attention_mask_2, + ) = self.encode_prompt( + prompt=prompt, + device=device, + dtype=dtype, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds_2, + negative_prompt_embeds=negative_prompt_embeds_2, + prompt_attention_mask=prompt_attention_mask_2, + negative_prompt_attention_mask=negative_prompt_attention_mask_2, + text_encoder_index=1, + ) + else: + prompt_embeds_2 = None + negative_prompt_embeds_2 = None + prompt_attention_mask_2 = None + negative_prompt_attention_mask_2 = None # 4. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps=num_inference_steps, strength=strength, device=device ) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 3) # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise @@ -857,7 +1170,7 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): # Prepare latent variables num_channels_latents = self.vae.config.latent_channels num_channels_transformer = self.transformer.config.in_channels - return_image_latents = True # num_channels_transformer == 4 + return_image_latents = num_channels_transformer == num_channels_latents # 5. Prepare latents. latents_outputs = self.prepare_latents( @@ -866,7 +1179,7 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): height, width, video_length, - prompt_embeds.dtype, + dtype, device, generator, latents, @@ -880,91 +1193,153 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): latents, noise, image_latents = latents_outputs else: latents, noise = latents_outputs - latents_dtype = latents.dtype + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare clip latents if it needs. + if clip_image is not None and self.transformer.enable_clip_in_inpaint: + inputs = self.clip_image_processor(images=clip_image, return_tensors="pt") + inputs["pixel_values"] = inputs["pixel_values"].to(device, dtype=dtype) + if self.transformer.config.get("position_of_clip_embedding", "full") == "full": + clip_encoder_hidden_states = self.clip_image_encoder(**inputs).last_hidden_state[:, 1:] + clip_encoder_hidden_states_neg = torch.zeros( + [ + batch_size, + int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2, + int(self.clip_image_encoder.config.hidden_size) + ] + ).to(device, dtype=dtype) + + else: + clip_encoder_hidden_states = self.clip_image_encoder(**inputs).image_embeds + clip_encoder_hidden_states_neg = torch.zeros([batch_size, 768]).to(device, dtype=dtype) + + clip_attention_mask = torch.ones([batch_size, self.transformer.n_query]).to(device, dtype=dtype) + clip_attention_mask_neg = torch.zeros([batch_size, self.transformer.n_query]).to(device, dtype=dtype) + + clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states_neg, clip_encoder_hidden_states]) if self.do_classifier_free_guidance else clip_encoder_hidden_states + clip_attention_mask_input = torch.cat([clip_attention_mask_neg, clip_attention_mask]) if self.do_classifier_free_guidance else clip_attention_mask + + elif clip_image is None and num_channels_transformer != num_channels_latents and self.transformer.enable_clip_in_inpaint: + if self.transformer.config.get("position_of_clip_embedding", "full") == "full": + clip_encoder_hidden_states = torch.zeros( + [ + batch_size, + int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2, + int(self.clip_image_encoder.config.hidden_size) + ] + ).to(device, dtype=dtype) + else: + clip_encoder_hidden_states = torch.zeros([batch_size, 768]).to(device, dtype=dtype) + + clip_attention_mask = torch.zeros([batch_size, self.transformer.n_query]) + clip_attention_mask = clip_attention_mask.to(device, dtype=dtype) + + clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states] * 2) if self.do_classifier_free_guidance else clip_encoder_hidden_states + clip_attention_mask_input = torch.cat([clip_attention_mask] * 2) if self.do_classifier_free_guidance else clip_attention_mask + + else: + clip_encoder_hidden_states_input = None + clip_attention_mask_input = None + if comfyui_progressbar: + pbar.update(1) + + # 7. Prepare inpaint latents if it needs. if mask_video is not None: - # Prepare mask latent variables - video_length = video.shape[2] - mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) - mask_condition = mask_condition.to(dtype=torch.float32) - mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) - - if num_channels_transformer == 12: - mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1]) - if masked_video_latents is None: - masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1 + if self.transformer.config.get("enable_zero_in_inpaint", True) and (mask_video == 255).all(): + # Use zero latents if we want to t2v. + mask = torch.zeros_like(latents).to(device, dtype) + if self.transformer.resize_inpaint_mask_directly: + mask_latents = torch.zeros_like(latents)[:, :1].to(device, dtype) else: - masked_video = masked_video_latents - - mask_latents, masked_video_latents = self.prepare_mask_latents( - mask_condition_tile, - masked_video, - batch_size, - height, - width, - prompt_embeds.dtype, - device, - generator, - do_classifier_free_guidance, - ) - mask = torch.tile(mask_condition, [1, num_channels_transformer // 3, 1, 1, 1]) - mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype) - - mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents + mask_latents = torch.zeros_like(latents).to(device, dtype) + masked_video_latents = torch.zeros_like(latents).to(device, dtype) + + mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents masked_video_latents_input = ( - torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents ) - inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype) + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype) else: - mask = torch.tile(mask_condition, [1, num_channels_transformer, 1, 1, 1]) - mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype) - - inpaint_latents = None + # Prepare mask latent variables + video_length = video.shape[2] + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + + if num_channels_transformer != num_channels_latents: + mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1]) + if masked_video_latents is None: + masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1 + else: + masked_video = masked_video_latents + + if self.transformer.resize_inpaint_mask_directly: + _, masked_video_latents = self.prepare_mask_latents( + None, + masked_video, + batch_size, + height, + width, + dtype, + device, + generator, + self.do_classifier_free_guidance, + noise_aug_strength=noise_aug_strength, + ) + mask_latents = resize_mask(1 - mask_condition, masked_video_latents, self.vae.cache_mag_vae) + mask_latents = mask_latents.to(device, dtype) * self.vae.config.scaling_factor + else: + mask_latents, masked_video_latents = self.prepare_mask_latents( + mask_condition_tile, + masked_video, + batch_size, + height, + width, + dtype, + device, + generator, + self.do_classifier_free_guidance, + noise_aug_strength=noise_aug_strength, + ) + + mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents + ) + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype) + else: + inpaint_latents = None + + mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1]) + mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(device, dtype) else: - if num_channels_transformer == 12: - mask = torch.zeros_like(latents).to(latents.device, latents.dtype) - masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype) + if num_channels_transformer != num_channels_latents: + mask = torch.zeros_like(latents).to(device, dtype) + if self.transformer.resize_inpaint_mask_directly: + mask_latents = torch.zeros_like(latents)[:, :1].to(device, dtype) + else: + mask_latents = torch.zeros_like(latents).to(device, dtype) + masked_video_latents = torch.zeros_like(latents).to(device, dtype) - mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents masked_video_latents_input = ( - torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents ) - inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype) + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(dtype) else: mask = torch.zeros_like(init_video[:, :1]) - mask = torch.tile(mask, [1, num_channels_transformer, 1, 1, 1]) - mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype) + mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1]) + mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(device, dtype) inpaint_latents = None - - if clip_image is not None: - inputs = self.clip_image_processor(images=clip_image, return_tensors="pt") - inputs["pixel_values"] = inputs["pixel_values"].to(latents.device, dtype=latents.dtype) - clip_encoder_hidden_states = self.clip_image_encoder(**inputs).image_embeds - clip_encoder_hidden_states_neg = torch.zeros([batch_size, 768]).to(latents.device, dtype=latents.dtype) - - clip_attention_mask = torch.ones([batch_size, 8]).to(latents.device, dtype=latents.dtype) - clip_attention_mask_neg = torch.zeros([batch_size, 8]).to(latents.device, dtype=latents.dtype) - clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states_neg, clip_encoder_hidden_states]) if do_classifier_free_guidance else clip_encoder_hidden_states - clip_attention_mask_input = torch.cat([clip_attention_mask_neg, clip_attention_mask]) if do_classifier_free_guidance else clip_attention_mask - - elif clip_image is None and num_channels_transformer == 12: - clip_encoder_hidden_states = torch.zeros([batch_size, 768]).to(latents.device, dtype=latents.dtype) - - clip_attention_mask = torch.zeros([batch_size, 8]) - clip_attention_mask = clip_attention_mask.to(latents.device, dtype=latents.dtype) - - clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states] * 2) if do_classifier_free_guidance else clip_encoder_hidden_states - clip_attention_mask_input = torch.cat([clip_attention_mask] * 2) if do_classifier_free_guidance else clip_attention_mask - - else: - clip_encoder_hidden_states_input = None - clip_attention_mask_input = None + if comfyui_progressbar: + pbar.update(1) # Check that sizes of mask, masked image and latents match - if num_channels_transformer == 12: - # default case for runwayml/stable-diffusion-inpainting + if num_channels_transformer != num_channels_latents: num_channels_mask = mask_latents.shape[1] num_channels_masked_image = masked_video_latents.shape[1] if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels: @@ -975,45 +1350,89 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" " `pipeline.transformer` or your `mask_image` or `image` input." ) - elif num_channels_transformer != 4: - raise ValueError( - f"The transformer {self.transformer.__class__} should have 9 input channels, not {self.transformer.config.in_channels}." - ) - - # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 6.1 Prepare micro-conditions. + # 9 create image_rotary_emb, style embedding & time ids + grid_height = height // 8 // self.transformer.config.patch_size + grid_width = width // 8 // self.transformer.config.patch_size + if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope": + base_size_width = 720 // 8 // self.transformer.config.patch_size + base_size_height = 480 // 8 // self.transformer.config.patch_size + + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + image_rotary_emb = get_3d_rotary_pos_embed( + self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width), + temporal_size=latents.size(2), use_real=True, + ) + else: + base_size = 512 // 8 // self.transformer.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size, base_size + ) + image_rotary_emb = get_2d_rotary_pos_embed( + self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width) + ) + + # Get other hunyuan params + target_size = target_size or (height, width) + add_time_ids = list(original_size + target_size + crops_coords_top_left) + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + style = torch.tensor([0], device=device) + + if self.do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + + # To latents.device + add_time_ids = add_time_ids.to(dtype=dtype, device=device).repeat( + batch_size * num_images_per_prompt, 1 + ) + style = style.to(device=device).repeat(batch_size * num_images_per_prompt) + + # Get other pixart params added_cond_kwargs = {"resolution": None, "aspect_ratio": None} - if self.transformer.config.sample_size == 128: + if self.transformer.config.get("sample_size", 64) == 128: resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1) aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1) - resolution = resolution.to(dtype=prompt_embeds.dtype, device=device) - aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device) + resolution = resolution.to(dtype=dtype, device=device) + aspect_ratio = aspect_ratio.to(dtype=dtype, device=device) - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: resolution = torch.cat([resolution, resolution], dim=0) aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0) - + added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio} - gc.collect() - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - if self.enable_autocast_float8_transformer_flag: - origin_weight_dtype = self.transformer.dtype - self.transformer = self.transformer.to(torch.float8_e4m3fn) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + if prompt_embeds_2 is not None: + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + + # To latents.device + prompt_embeds = prompt_embeds.to(device=device) + prompt_attention_mask = prompt_attention_mask.to(device=device) + if prompt_embeds_2 is not None: + prompt_embeds_2 = prompt_embeds_2.to(device=device) + prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) # 10. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) - if comfyui_progressbar: - from comfy.utils import ProgressBar - pbar = ProgressBar(num_inference_steps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) if i < len(timesteps) * (1 - clip_apply_ratio) and clip_encoder_hidden_states_input is not None: clip_encoder_hidden_states_actual_input = torch.zeros_like(clip_encoder_hidden_states_input) @@ -1021,74 +1440,83 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): else: clip_encoder_hidden_states_actual_input = clip_encoder_hidden_states_input clip_attention_mask_actual_input = clip_attention_mask_input + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( + dtype=latent_model_input.dtype + ) - current_timestep = t - if not torch.is_tensor(current_timestep): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = latent_model_input.device.type == "mps" - if isinstance(current_timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device) - elif len(current_timestep.shape) == 0: - current_timestep = current_timestep[None].to(latent_model_input.device) - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - current_timestep = current_timestep.expand(latent_model_input.shape[0]) - - # predict noise model_output + # predict the noise residual noise_pred = self.transformer( latent_model_input, + t_expand, encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - timestep=current_timestep, - added_cond_kwargs=added_cond_kwargs, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, inpaint_latents=inpaint_latents, clip_encoder_hidden_states=clip_encoder_hidden_states_actual_input, clip_attention_mask=clip_attention_mask_actual_input, + added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] + if noise_pred.size()[1] != self.vae.config.latent_channels: + noise_pred, _ = noise_pred.chunk(2, dim=1) # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - # learned sigma - noise_pred = noise_pred.chunk(2, dim=1)[0] + if self.do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) - # compute previous image: x_t -> x_t-1 + # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - if num_channels_transformer == 4: + if num_channels_transformer == num_channels_latents: init_latents_proper = image_latents init_mask = mask if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.add_noise( - init_latents_proper, noise, torch.tensor([noise_timestep]) - ) - + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep], noise) + ) + else: + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + latents = (1 - init_mask) * init_latents_proper + init_mask * latents - # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) + negative_prompt_embeds_2 = callback_outputs.pop( + "negative_prompt_embeds_2", negative_prompt_embeds_2 + ) + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() if comfyui_progressbar: pbar.update(1) - if self.enable_autocast_float8_transformer_flag: - self.transformer = self.transformer.to("cpu", origin_weight_dtype) - - gc.collect() - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - # Post-processing video = self.decode_latents(latents) @@ -1096,7 +1524,10 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): if output_type == "latent": video = torch.from_numpy(video) + # Offload all models + self.maybe_free_model_hooks() + if not return_dict: return video - return EasyAnimatePipelineOutput(videos=video) \ No newline at end of file + return EasyAnimatePipelineOutput(frames=video) \ No newline at end of file diff --git a/easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder.py b/easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder.py deleted file mode 100644 index 8adde44bfb5b01b12b9e08c0ef567253e89ce5d5..0000000000000000000000000000000000000000 --- a/easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder.py +++ /dev/null @@ -1,925 +0,0 @@ -# Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback -from diffusers.image_processor import VaeImageProcessor -from diffusers.models.embeddings import (get_2d_rotary_pos_embed, - get_3d_rotary_pos_embed) -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from diffusers.pipelines.stable_diffusion.safety_checker import \ - StableDiffusionSafetyChecker -from diffusers.schedulers import DDIMScheduler -from diffusers.utils import (is_torch_xla_available, logging, - replace_example_docstring) -from diffusers.utils.torch_utils import randn_tensor -from einops import rearrange -from tqdm import tqdm -from transformers import (BertModel, BertTokenizer, CLIPImageProcessor, - T5Tokenizer, T5EncoderModel) - -from .pipeline_easyanimate import EasyAnimatePipelineOutput -from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> pass - ``` -""" - - -def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): - tw = tgt_width - th = tgt_height - h, w = src - r = h / w - if r > (th / tw): - resize_height = th - resize_width = int(round(th / h * w)) - else: - resize_width = tw - resize_height = int(round(tw / w * h)) - - crop_top = int(round((th - resize_height) / 2.0)) - crop_left = int(round((tw - resize_width) / 2.0)) - - return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg -def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 - """ - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) - std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) - # rescale the results from guidance (fixes overexposure) - noise_pred_rescaled = noise_cfg * (std_text / std_cfg) - # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - return noise_cfg - - -class EasyAnimatePipeline_Multi_Text_Encoder(DiffusionPipeline): - r""" - Pipeline for text-to-video generation using EasyAnimate. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by - HunyuanDiT team) - - Args: - vae ([`AutoencoderKLMagvit`]): - Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. - text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): - Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). - EasyAnimate uses a fine-tuned [bilingual CLIP]. - tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): - A `BertTokenizer` or `CLIPTokenizer` to tokenize text. - transformer ([`EasyAnimateTransformer3DModel`]): - The EasyAnimate model designed by Tencent Hunyuan. - text_encoder_2 (`T5EncoderModel`): - The mT5 embedder. - tokenizer_2 (`T5Tokenizer`): - The tokenizer for the mT5 embedder. - scheduler ([`DDIMScheduler`]): - A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents. - """ - - model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" - _optional_components = [ - "safety_checker", - "feature_extractor", - "text_encoder_2", - "tokenizer_2", - "text_encoder", - "tokenizer", - ] - _exclude_from_cpu_offload = ["safety_checker"] - _callback_tensor_inputs = [ - "latents", - "prompt_embeds", - "negative_prompt_embeds", - "prompt_embeds_2", - "negative_prompt_embeds_2", - ] - - def __init__( - self, - vae: AutoencoderKLMagvit, - text_encoder: BertModel, - tokenizer: BertTokenizer, - text_encoder_2: T5EncoderModel, - tokenizer_2: T5Tokenizer, - transformer: EasyAnimateTransformer3DModel, - scheduler: DDIMScheduler, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, - requires_safety_checker: bool = True, - ): - super().__init__() - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - transformer=transformer, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - text_encoder_2=text_encoder_2, - ) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.enable_autocast_float8_transformer_flag = False - self.register_to_config(requires_safety_checker=requires_safety_checker) - - def enable_sequential_cpu_offload(self, *args, **kwargs): - super().enable_sequential_cpu_offload(*args, **kwargs) - if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None: - import accelerate - accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True) - self.transformer.clip_projection = self.transformer.clip_projection.to("cuda") - - def encode_prompt( - self, - prompt: str, - device: torch.device, - dtype: torch.dtype, - num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, - negative_prompt: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - prompt_attention_mask: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, - max_sequence_length: Optional[int] = None, - text_encoder_index: int = 0, - actual_max_sequence_length: int = 256 - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - dtype (`torch.dtype`): - torch dtype - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - prompt_attention_mask (`torch.Tensor`, *optional*): - Attention mask for the prompt. Required when `prompt_embeds` is passed directly. - negative_prompt_attention_mask (`torch.Tensor`, *optional*): - Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. - max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. - text_encoder_index (`int`, *optional*): - Index of the text encoder to use. `0` for clip and `1` for T5. - """ - tokenizers = [self.tokenizer, self.tokenizer_2] - text_encoders = [self.text_encoder, self.text_encoder_2] - - tokenizer = tokenizers[text_encoder_index] - text_encoder = text_encoders[text_encoder_index] - - if max_sequence_length is None: - if text_encoder_index == 0: - max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length) - if text_encoder_index == 1: - max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length) - else: - max_length = max_sequence_length - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_attention_mask=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - if text_input_ids.shape[-1] > actual_max_sequence_length: - reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) - text_inputs = tokenizer( - reprompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_attention_mask=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length) - removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {_actual_max_sequence_length} tokens: {removed_text}" - ) - prompt_attention_mask = text_inputs.attention_mask.to(device) - - if self.transformer.config.enable_text_attention_mask: - prompt_embeds = text_encoder( - text_input_ids.to(device), - attention_mask=prompt_attention_mask, - ) - else: - prompt_embeds = text_encoder( - text_input_ids.to(device) - ) - prompt_embeds = prompt_embeds[0] - prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) - - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - uncond_input_ids = uncond_input.input_ids - if uncond_input_ids.shape[-1] > actual_max_sequence_length: - reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) - uncond_input = tokenizer( - reuncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_attention_mask=True, - return_tensors="pt", - ) - uncond_input_ids = uncond_input.input_ids - - negative_prompt_attention_mask = uncond_input.attention_mask.to(device) - if self.transformer.config.enable_text_attention_mask: - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), - attention_mask=negative_prompt_attention_mask, - ) - else: - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device) - ) - negative_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - height, - width, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - prompt_attention_mask=None, - negative_prompt_attention_mask=None, - prompt_embeds_2=None, - negative_prompt_embeds_2=None, - prompt_attention_mask_2=None, - negative_prompt_attention_mask_2=None, - callback_on_step_end_tensor_inputs=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is None and prompt_embeds_2 is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if prompt_embeds is not None and prompt_attention_mask is None: - raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") - - if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: - raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: - raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") - - if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: - raise ValueError( - "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." - ) - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: - if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: - raise ValueError( - "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" - f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" - f" {negative_prompt_embeds_2.shape}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): - if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: - if self.vae.cache_mag_vae: - mini_batch_encoder = self.vae.mini_batch_encoder - mini_batch_decoder = self.vae.mini_batch_decoder - shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor) - else: - mini_batch_encoder = self.vae.mini_batch_encoder - mini_batch_decoder = self.vae.mini_batch_decoder - shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor) - else: - shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder): - if video.size()[2] <= mini_batch_encoder: - return video - prefix_index_before = mini_batch_encoder // 2 - prefix_index_after = mini_batch_encoder - prefix_index_before - pixel_values = video[:, :, prefix_index_before:-prefix_index_after] - - # Encode middle videos - latents = self.vae.encode(pixel_values)[0] - latents = latents.mode() - # Decode middle videos - middle_video = self.vae.decode(latents)[0] - - video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2 - return video - - def decode_latents(self, latents): - video_length = latents.shape[2] - latents = 1 / self.vae.config.scaling_factor * latents - if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: - mini_batch_encoder = self.vae.mini_batch_encoder - mini_batch_decoder = self.vae.mini_batch_decoder - video = self.vae.decode(latents)[0] - video = video.clamp(-1, 1) - if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae: - video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1) - else: - latents = rearrange(latents, "b c f h w -> (b f) c h w") - video = [] - for frame_idx in tqdm(range(latents.shape[0])): - video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) - video = torch.cat(video) - video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) - video = (video / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - video = video.cpu().float().numpy() - return video - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def guidance_rescale(self): - return self._guidance_rescale - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - - def enable_autocast_float8_transformer(self): - self.enable_autocast_float8_transformer_flag = True - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - video_length: Optional[int] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 5.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_2: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds_2: Optional[torch.Tensor] = None, - prompt_attention_mask: Optional[torch.Tensor] = None, - prompt_attention_mask_2: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, - negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, - output_type: Optional[str] = "latent", - return_dict: bool = True, - callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] - ] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - guidance_rescale: float = 0.0, - original_size: Optional[Tuple[int, int]] = (1024, 1024), - target_size: Optional[Tuple[int, int]] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - comfyui_progressbar: bool = False, - ): - r""" - Generates images or video using the EasyAnimate pipeline based on the provided prompts. - - Examples: - prompt (`str` or `List[str]`, *optional*): - Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead. - video_length (`int`, *optional*): - Length of the generated video (in frames). - height (`int`, *optional*): - Height of the generated image in pixels. - width (`int`, *optional*): - Width of the generated image in pixels. - num_inference_steps (`int`, *optional*, defaults to 50): - Number of denoising steps during generation. More steps generally yield higher quality images but slow down inference. - guidance_scale (`float`, *optional*, defaults to 5.0): - Encourages the model to align outputs with prompts. A higher value may decrease image quality. - negative_prompt (`str` or `List[str]`, *optional*): - Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`. - num_images_per_prompt (`int`, *optional*, defaults to 1): - Number of images to generate for each prompt. - eta (`float`, *optional*, defaults to 0.0): - Applies to DDIM scheduling. Controlled by the eta parameter from the related literature. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - A generator to ensure reproducibility in image generation. - latents (`torch.Tensor`, *optional*): - Predefined latent tensors to condition generation. - prompt_embeds (`torch.Tensor`, *optional*): - Text embeddings for the prompts. Overrides prompt string inputs for more flexibility. - prompt_embeds_2 (`torch.Tensor`, *optional*): - Secondary text embeddings to supplement or replace the initial prompt embeddings. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Embeddings for negative prompts. Overrides string inputs if defined. - negative_prompt_embeds_2 (`torch.Tensor`, *optional*): - Secondary embeddings for negative prompts, similar to `negative_prompt_embeds`. - prompt_attention_mask (`torch.Tensor`, *optional*): - Attention mask for the primary prompt embeddings. - prompt_attention_mask_2 (`torch.Tensor`, *optional*): - Attention mask for the secondary prompt embeddings. - negative_prompt_attention_mask (`torch.Tensor`, *optional*): - Attention mask for negative prompt embeddings. - negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): - Attention mask for secondary negative prompt embeddings. - output_type (`str`, *optional*, defaults to "latent"): - Format of the generated output, either as a PIL image or as a NumPy array. - return_dict (`bool`, *optional*, defaults to `True`): - If `True`, returns a structured output. Otherwise returns a simple tuple. - callback_on_step_end (`Callable`, *optional*): - Functions called at the end of each denoising step. - callback_on_step_end_tensor_inputs (`List[str]`, *optional*): - Tensor names to be included in callback function calls. - guidance_rescale (`float`, *optional*, defaults to 0.0): - Adjusts noise levels based on guidance scale. - original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): - Original dimensions of the output. - target_size (`Tuple[int, int]`, *optional*): - Desired output dimensions for calculations. - crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): - Coordinates for cropping. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, - otherwise a `tuple` is returned where the first element is a list with the generated images and the - second element is a list of `bool`s indicating whether the corresponding generated image contains - "not-safe-for-work" (nsfw) content. - """ - - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): - callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - - # 0. default height and width - height = int((height // 16) * 16) - width = int((width // 16) * 16) - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - height, - width, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - prompt_attention_mask, - negative_prompt_attention_mask, - prompt_embeds_2, - negative_prompt_embeds_2, - prompt_attention_mask_2, - negative_prompt_attention_mask_2, - callback_on_step_end_tensor_inputs, - ) - self._guidance_scale = guidance_scale - self._guidance_rescale = guidance_rescale - self._interrupt = False - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - - # 3. Encode input prompt - ( - prompt_embeds, - negative_prompt_embeds, - prompt_attention_mask, - negative_prompt_attention_mask, - ) = self.encode_prompt( - prompt=prompt, - device=device, - dtype=self.transformer.dtype, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, - text_encoder_index=0, - ) - ( - prompt_embeds_2, - negative_prompt_embeds_2, - prompt_attention_mask_2, - negative_prompt_attention_mask_2, - ) = self.encode_prompt( - prompt=prompt, - device=device, - dtype=self.transformer.dtype, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds_2, - negative_prompt_embeds=negative_prompt_embeds_2, - prompt_attention_mask=prompt_attention_mask_2, - negative_prompt_attention_mask=negative_prompt_attention_mask_2, - text_encoder_index=1, - ) - torch.cuda.empty_cache() - - # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - if comfyui_progressbar: - from comfy.utils import ProgressBar - pbar = ProgressBar(num_inference_steps + 1) - - # 5. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - video_length, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - if comfyui_progressbar: - pbar.update(1) - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 7 create image_rotary_emb, style embedding & time ids - grid_height = height // 8 // self.transformer.config.patch_size - grid_width = width // 8 // self.transformer.config.patch_size - if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope": - base_size_width = 720 // 8 // self.transformer.config.patch_size - base_size_height = 480 // 8 // self.transformer.config.patch_size - - grid_crops_coords = get_resize_crop_region_for_grid( - (grid_height, grid_width), base_size_width, base_size_height - ) - image_rotary_emb = get_3d_rotary_pos_embed( - self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width), - temporal_size=latents.size(2), use_real=True, - ) - else: - base_size = 512 // 8 // self.transformer.config.patch_size - grid_crops_coords = get_resize_crop_region_for_grid( - (grid_height, grid_width), base_size, base_size - ) - image_rotary_emb = get_2d_rotary_pos_embed( - self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width) - ) - - # Get other hunyuan params - style = torch.tensor([0], device=device) - - target_size = target_size or (height, width) - add_time_ids = list(original_size + target_size + crops_coords_top_left) - add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) - - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) - prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) - prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) - add_time_ids = torch.cat([add_time_ids] * 2, dim=0) - style = torch.cat([style] * 2, dim=0) - - # To latents.device - prompt_embeds = prompt_embeds.to(device=device) - prompt_attention_mask = prompt_attention_mask.to(device=device) - prompt_embeds_2 = prompt_embeds_2.to(device=device) - prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) - add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( - batch_size * num_images_per_prompt, 1 - ) - style = style.to(device=device).repeat(batch_size * num_images_per_prompt) - - torch.cuda.empty_cache() - if self.enable_autocast_float8_transformer_flag: - origin_weight_dtype = self.transformer.dtype - self.transformer = self.transformer.to(torch.float8_e4m3fn) - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input - t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( - dtype=latent_model_input.dtype - ) - - # predict the noise residual - noise_pred = self.transformer( - latent_model_input, - t_expand, - encoder_hidden_states=prompt_embeds, - text_embedding_mask=prompt_attention_mask, - encoder_hidden_states_t5=prompt_embeds_2, - text_embedding_mask_t5=prompt_attention_mask_2, - image_meta_size=add_time_ids, - style=style, - image_rotary_emb=image_rotary_emb, - return_dict=False, - )[0] - - if noise_pred.size()[1] != self.vae.config.latent_channels: - noise_pred, _ = noise_pred.chunk(2, dim=1) - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - if self.do_classifier_free_guidance and guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) - negative_prompt_embeds_2 = callback_outputs.pop( - "negative_prompt_embeds_2", negative_prompt_embeds_2 - ) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - if comfyui_progressbar: - pbar.update(1) - - if self.enable_autocast_float8_transformer_flag: - self.transformer = self.transformer.to("cpu", origin_weight_dtype) - - torch.cuda.empty_cache() - # Post-processing - video = self.decode_latents(latents) - - # Convert to tensor - if output_type == "latent": - video = torch.from_numpy(video) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return video - - return EasyAnimatePipelineOutput(videos=video) \ No newline at end of file diff --git a/easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_inpaint.py b/easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_inpaint.py deleted file mode 100644 index 2c45241bb5be766164ff5460500480d12d5c399c..0000000000000000000000000000000000000000 --- a/easyanimate/pipeline/pipeline_easyanimate_multi_text_encoder_inpaint.py +++ /dev/null @@ -1,1334 +0,0 @@ -# Copyright 2024 EasyAnimate Authors and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Callable, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -from diffusers import DiffusionPipeline -from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback -from diffusers.image_processor import VaeImageProcessor -from diffusers.models import AutoencoderKL, HunyuanDiT2DModel -from diffusers.models.embeddings import (get_2d_rotary_pos_embed, - get_3d_rotary_pos_embed) -from diffusers.pipelines.pipeline_utils import DiffusionPipeline -from diffusers.pipelines.stable_diffusion.safety_checker import \ - StableDiffusionSafetyChecker -from diffusers.schedulers import DDIMScheduler -from diffusers.utils import (is_torch_xla_available, logging, - replace_example_docstring) -from diffusers.utils.torch_utils import randn_tensor -from einops import rearrange -from PIL import Image -from tqdm import tqdm -from transformers import (BertModel, BertTokenizer, CLIPImageProcessor, - CLIPVisionModelWithProjection, T5Tokenizer, - T5EncoderModel) - -from .pipeline_easyanimate import EasyAnimatePipelineOutput -from ..models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel - -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> pass - ``` -""" - - -def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): - tw = tgt_width - th = tgt_height - h, w = src - r = h / w - if r > (th / tw): - resize_height = th - resize_width = int(round(th / h * w)) - else: - resize_width = tw - resize_height = int(round(tw / w * h)) - - crop_top = int(round((th - resize_height) / 2.0)) - crop_left = int(round((tw - resize_width) / 2.0)) - - return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) - - -# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg -def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 - """ - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) - std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) - # rescale the results from guidance (fixes overexposure) - noise_pred_rescaled = noise_cfg * (std_text / std_cfg) - # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - return noise_cfg - - -def resize_mask(mask, latent, process_first_frame_only=True): - latent_size = latent.size() - - if process_first_frame_only: - target_size = list(latent_size[2:]) - target_size[0] = 1 - first_frame_resized = F.interpolate( - mask[:, :, 0:1, :, :], - size=target_size, - mode='trilinear', - align_corners=False - ) - - target_size = list(latent_size[2:]) - target_size[0] = target_size[0] - 1 - if target_size[0] != 0: - remaining_frames_resized = F.interpolate( - mask[:, :, 1:, :, :], - size=target_size, - mode='trilinear', - align_corners=False - ) - resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) - else: - resized_mask = first_frame_resized - else: - target_size = list(latent_size[2:]) - resized_mask = F.interpolate( - mask, - size=target_size, - mode='trilinear', - align_corners=False - ) - return resized_mask - - -def add_noise_to_reference_video(image, ratio=None): - if ratio is None: - sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device) - sigma = torch.exp(sigma).to(image.dtype) - else: - sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio - - image_noise = torch.randn_like(image) * sigma[:, None, None, None, None] - image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise) - image = image + image_noise - return image - - -class EasyAnimatePipeline_Multi_Text_Encoder_Inpaint(DiffusionPipeline): - r""" - Pipeline for text-to-video generation using EasyAnimate. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - EasyAnimate uses two text encoders: [mT5](https://huggingface.co/google/mt5-base) and [bilingual CLIP](fine-tuned by - HunyuanDiT team) - - Args: - vae ([`AutoencoderKLMagvit`]): - Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. - text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): - Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). - EasyAnimate uses a fine-tuned [bilingual CLIP]. - tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): - A `BertTokenizer` or `CLIPTokenizer` to tokenize text. - transformer ([`EasyAnimateTransformer3DModel`]): - The EasyAnimate model designed by Tencent Hunyuan. - text_encoder_2 (`T5EncoderModel`): - The mT5 embedder. - tokenizer_2 (`T5Tokenizer`): - The tokenizer for the mT5 embedder. - scheduler ([`DDIMScheduler`]): - A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents. - clip_image_processor (`CLIPImageProcessor`): - The CLIP image embedder. - clip_image_encoder (`CLIPVisionModelWithProjection`): - The image processor for the CLIP image embedder. - """ - - model_cpu_offload_seq = "text_encoder->text_encoder_2->clip_image_encoder->transformer->vae" - _optional_components = [ - "safety_checker", - "feature_extractor", - "text_encoder_2", - "tokenizer_2", - "text_encoder", - "tokenizer", - "clip_image_encoder", - ] - _exclude_from_cpu_offload = ["safety_checker"] - _callback_tensor_inputs = [ - "latents", - "prompt_embeds", - "negative_prompt_embeds", - "prompt_embeds_2", - "negative_prompt_embeds_2", - ] - - def __init__( - self, - vae: AutoencoderKLMagvit, - text_encoder: BertModel, - tokenizer: BertTokenizer, - text_encoder_2: T5EncoderModel, - tokenizer_2: T5Tokenizer, - transformer: EasyAnimateTransformer3DModel, - scheduler: DDIMScheduler, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, - requires_safety_checker: bool = True, - clip_image_processor: CLIPImageProcessor = None, - clip_image_encoder: CLIPVisionModelWithProjection = None, - ): - super().__init__() - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - transformer=transformer, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - text_encoder_2=text_encoder_2, - clip_image_processor=clip_image_processor, - clip_image_encoder=clip_image_encoder, - ) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.mask_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True - ) - self.enable_autocast_float8_transformer_flag = False - self.register_to_config(requires_safety_checker=requires_safety_checker) - - def enable_sequential_cpu_offload(self, *args, **kwargs): - super().enable_sequential_cpu_offload(*args, **kwargs) - if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None: - import accelerate - accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True) - self.transformer.clip_projection = self.transformer.clip_projection.to("cuda") - - def encode_prompt( - self, - prompt: str, - device: torch.device, - dtype: torch.dtype, - num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, - negative_prompt: Optional[str] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - prompt_attention_mask: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, - max_sequence_length: Optional[int] = None, - text_encoder_index: int = 0, - actual_max_sequence_length: int = 256 - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - dtype (`torch.dtype`): - torch dtype - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - prompt_attention_mask (`torch.Tensor`, *optional*): - Attention mask for the prompt. Required when `prompt_embeds` is passed directly. - negative_prompt_attention_mask (`torch.Tensor`, *optional*): - Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly. - max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt. - text_encoder_index (`int`, *optional*): - Index of the text encoder to use. `0` for clip and `1` for T5. - """ - tokenizers = [self.tokenizer, self.tokenizer_2] - text_encoders = [self.text_encoder, self.text_encoder_2] - - tokenizer = tokenizers[text_encoder_index] - text_encoder = text_encoders[text_encoder_index] - - if max_sequence_length is None: - if text_encoder_index == 0: - max_length = min(self.tokenizer.model_max_length, actual_max_sequence_length) - if text_encoder_index == 1: - max_length = min(self.tokenizer_2.model_max_length, actual_max_sequence_length) - else: - max_length = max_sequence_length - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_attention_mask=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - if text_input_ids.shape[-1] > actual_max_sequence_length: - reprompt = tokenizer.batch_decode(text_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) - text_inputs = tokenizer( - reprompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_attention_mask=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - _actual_max_sequence_length = min(tokenizer.model_max_length, actual_max_sequence_length) - removed_text = tokenizer.batch_decode(untruncated_ids[:, _actual_max_sequence_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {_actual_max_sequence_length} tokens: {removed_text}" - ) - prompt_attention_mask = text_inputs.attention_mask.to(device) - if self.transformer.config.enable_text_attention_mask: - prompt_embeds = text_encoder( - text_input_ids.to(device), - attention_mask=prompt_attention_mask, - ) - else: - prompt_embeds = text_encoder( - text_input_ids.to(device) - ) - prompt_embeds = prompt_embeds[0] - prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) - - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - uncond_input_ids = uncond_input.input_ids - if uncond_input_ids.shape[-1] > actual_max_sequence_length: - reuncond_tokens = tokenizer.batch_decode(uncond_input_ids[:, :actual_max_sequence_length], skip_special_tokens=True) - uncond_input = tokenizer( - reuncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_attention_mask=True, - return_tensors="pt", - ) - uncond_input_ids = uncond_input.input_ids - - negative_prompt_attention_mask = uncond_input.attention_mask.to(device) - if self.transformer.config.enable_text_attention_mask: - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device), - attention_mask=negative_prompt_attention_mask, - ) - else: - negative_prompt_embeds = text_encoder( - uncond_input.input_ids.to(device) - ) - negative_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - height, - width, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - prompt_attention_mask=None, - negative_prompt_attention_mask=None, - prompt_embeds_2=None, - negative_prompt_embeds_2=None, - prompt_attention_mask_2=None, - negative_prompt_attention_mask_2=None, - callback_on_step_end_tensor_inputs=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is None and prompt_embeds_2 is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if prompt_embeds is not None and prompt_attention_mask is None: - raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") - - if prompt_embeds_2 is not None and prompt_attention_mask_2 is None: - raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: - raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") - - if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None: - raise ValueError( - "Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`." - ) - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None: - if prompt_embeds_2.shape != negative_prompt_embeds_2.shape: - raise ValueError( - "`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but" - f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`" - f" {negative_prompt_embeds_2.shape}." - ) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - - return timesteps, num_inference_steps - t_start - - def prepare_mask_latents( - self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength - ): - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - if mask is not None: - mask = mask.to(device=device, dtype=self.vae.dtype) - if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: - bs = 1 - new_mask = [] - for i in range(0, mask.shape[0], bs): - mask_bs = mask[i : i + bs] - mask_bs = self.vae.encode(mask_bs)[0] - mask_bs = mask_bs.mode() - new_mask.append(mask_bs) - mask = torch.cat(new_mask, dim = 0) - mask = mask * self.vae.config.scaling_factor - - else: - if mask.shape[1] == 4: - mask = mask - else: - video_length = mask.shape[2] - mask = rearrange(mask, "b c f h w -> (b f) c h w") - mask = self._encode_vae_image(mask, generator=generator) - mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length) - - if masked_image is not None: - masked_image = masked_image.to(device=device, dtype=self.vae.dtype) - if self.transformer.config.add_noise_in_inpaint_model: - masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength) - if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: - bs = 1 - new_mask_pixel_values = [] - for i in range(0, masked_image.shape[0], bs): - mask_pixel_values_bs = masked_image[i : i + bs] - mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] - mask_pixel_values_bs = mask_pixel_values_bs.mode() - new_mask_pixel_values.append(mask_pixel_values_bs) - masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) - masked_image_latents = masked_image_latents * self.vae.config.scaling_factor - - else: - if masked_image.shape[1] == 4: - masked_image_latents = masked_image - else: - video_length = masked_image.shape[2] - masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w") - masked_image_latents = self._encode_vae_image(masked_image, generator=generator) - masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length) - - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - else: - masked_image_latents = None - - return mask, masked_image_latents - - def prepare_latents( - self, - batch_size, - num_channels_latents, - height, - width, - video_length, - dtype, - device, - generator, - latents=None, - video=None, - timestep=None, - is_strength_max=True, - return_noise=False, - return_video_latents=False, - ): - if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: - if self.vae.cache_mag_vae: - mini_batch_encoder = self.vae.mini_batch_encoder - mini_batch_decoder = self.vae.mini_batch_decoder - shape = (batch_size, num_channels_latents, int((video_length - 1) // mini_batch_encoder * mini_batch_decoder + 1) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor) - else: - mini_batch_encoder = self.vae.mini_batch_encoder - mini_batch_decoder = self.vae.mini_batch_decoder - shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor) - else: - shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) - - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if return_video_latents or (latents is None and not is_strength_max): - video = video.to(device=device, dtype=self.vae.dtype) - if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: - bs = 1 - new_video = [] - for i in range(0, video.shape[0], bs): - video_bs = video[i : i + bs] - video_bs = self.vae.encode(video_bs)[0] - video_bs = video_bs.sample() - new_video.append(video_bs) - video = torch.cat(new_video, dim = 0) - video = video * self.vae.config.scaling_factor - - else: - if video.shape[1] == 4: - video = video - else: - video_length = video.shape[2] - video = rearrange(video, "b c f h w -> (b f) c h w") - video = self._encode_vae_image(video, generator=generator) - video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) - video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1) - video_latents = video_latents.to(device=device, dtype=dtype) - - if latents is None: - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep) - # if pure noise then scale the initial latents by the Scheduler's init sigma - latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents - else: - noise = latents.to(device) - latents = noise * self.scheduler.init_noise_sigma - - # scale the initial noise by the standard deviation required by the scheduler - outputs = (latents,) - - if return_noise: - outputs += (noise,) - - if return_video_latents: - outputs += (video_latents,) - - return outputs - - def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder): - if video.size()[2] <= mini_batch_encoder: - return video - prefix_index_before = mini_batch_encoder // 2 - prefix_index_after = mini_batch_encoder - prefix_index_before - pixel_values = video[:, :, prefix_index_before:-prefix_index_after] - - # Encode middle videos - latents = self.vae.encode(pixel_values)[0] - latents = latents.mode() - # Decode middle videos - middle_video = self.vae.decode(latents)[0] - - video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2 - return video - - def decode_latents(self, latents): - video_length = latents.shape[2] - latents = 1 / self.vae.config.scaling_factor * latents - if self.vae.quant_conv is None or self.vae.quant_conv.weight.ndim==5: - mini_batch_encoder = self.vae.mini_batch_encoder - mini_batch_decoder = self.vae.mini_batch_decoder - video = self.vae.decode(latents)[0] - video = video.clamp(-1, 1) - if not self.vae.cache_compression_vae and not self.vae.cache_mag_vae: - video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1) - else: - latents = rearrange(latents, "b c f h w -> (b f) c h w") - video = [] - for frame_idx in tqdm(range(latents.shape[0])): - video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) - video = torch.cat(video) - video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) - video = (video / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - video = video.cpu().float().numpy() - return video - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def guidance_rescale(self): - return self._guidance_rescale - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - - def enable_autocast_float8_transformer(self): - self.enable_autocast_float8_transformer_flag = True - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - video_length: Optional[int] = None, - video: Union[torch.FloatTensor] = None, - mask_video: Union[torch.FloatTensor] = None, - masked_video_latents: Union[torch.FloatTensor] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: Optional[int] = 50, - guidance_scale: Optional[float] = 5.0, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: Optional[float] = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - prompt_embeds_2: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds_2: Optional[torch.Tensor] = None, - prompt_attention_mask: Optional[torch.Tensor] = None, - prompt_attention_mask_2: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, - negative_prompt_attention_mask_2: Optional[torch.Tensor] = None, - output_type: Optional[str] = "latent", - return_dict: bool = True, - callback_on_step_end: Optional[ - Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] - ] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - guidance_rescale: float = 0.0, - original_size: Optional[Tuple[int, int]] = (1024, 1024), - target_size: Optional[Tuple[int, int]] = None, - crops_coords_top_left: Tuple[int, int] = (0, 0), - clip_image: Image = None, - clip_apply_ratio: float = 0.40, - strength: float = 1.0, - noise_aug_strength: float = 0.0563, - comfyui_progressbar: bool = False, - ): - r""" - The call function to the pipeline for generation with HunyuanDiT. - - Examples: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - video_length (`int`, *optional*): - Length of the video to be generated in seconds. This parameter influences the number of frames and - continuity of generated content. - video (`torch.FloatTensor`, *optional*): - A tensor representing an input video, which can be modified depending on the prompts provided. - mask_video (`torch.FloatTensor`, *optional*): - A tensor to specify areas of the video to be masked (omitted from generation). - masked_video_latents (`torch.FloatTensor`, *optional*): - Latents from masked portions of the video, utilized during image generation. - height (`int`, *optional*): - The height in pixels of the generated image or video frames. - width (`int`, *optional*): - The width in pixels of the generated image or video frames. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image but slower - inference time. This parameter is modulated by `strength`. - guidance_scale (`float`, *optional*, defaults to 5.0): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is effective when `guidance_scale > 1`. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to exclude in image generation. If not defined, you need to - provide `negative_prompt_embeds`. This parameter is ignored when not using guidance (`guidance_scale < 1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - A parameter defined in the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the - [`~schedulers.DDIMScheduler`] and is ignored in other schedulers. It adjusts noise level during the - inference process. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) for setting - random seeds which helps in making generation deterministic. - latents (`torch.Tensor`, *optional*): - A pre-computed latent representation which can be used to guide the generation process. - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, embeddings are generated from the `prompt` input argument. - prompt_embeds_2 (`torch.Tensor`, *optional*): - Secondary set of pre-generated text embeddings, useful for advanced prompt weighting. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings, aiding in fine-tuning what should not be represented in the outputs. - If not provided, embeddings are generated from the `negative_prompt` argument. - negative_prompt_embeds_2 (`torch.Tensor`, *optional*): - Secondary set of pre-generated negative text embeddings for further control. - prompt_attention_mask (`torch.Tensor`, *optional*): - Attention mask guiding the focus of the model on specific parts of the prompt text. Required when using - `prompt_embeds`. - prompt_attention_mask_2 (`torch.Tensor`, *optional*): - Attention mask for the secondary prompt embedding. - negative_prompt_attention_mask (`torch.Tensor`, *optional*): - Attention mask for the negative prompt, needed when `negative_prompt_embeds` are used. - negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*): - Attention mask for the secondary negative prompt embedding. - output_type (`str`, *optional*, defaults to `"latent"`): - The output format of the generated image. Choose between `PIL.Image` and `np.array` to define - how you want the results to be formatted. - return_dict (`bool`, *optional*, defaults to `True`): - If set to `True`, a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] will be returned; - otherwise, a tuple containing the generated images and safety flags will be returned. - callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): - A callback function (or a list of them) that will be executed at the end of each denoising step, - allowing for custom processing during generation. - callback_on_step_end_tensor_inputs (`List[str]`, *optional*): - Specifies which tensor inputs should be included in the callback function. If not defined, all tensor - inputs will be passed, facilitating enhanced logging or monitoring of the generation process. - guidance_rescale (`float`, *optional*, defaults to 0.0): - Rescale parameter for adjusting noise configuration based on guidance rescale. Based on findings from - [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). - original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`): - The original dimensions of the image. Used to compute time ids during the generation process. - target_size (`Tuple[int, int]`, *optional*): - The targeted dimensions of the generated image, also utilized in the time id calculations. - crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`): - Coordinates defining the top left corner of any cropping, utilized while calculating the time ids. - clip_image (`Image`, *optional*): - An optional image to assist in the generation process. It may be used as an additional visual cue. - clip_apply_ratio (`float`, *optional*, defaults to 0.40): - Ratio indicating how much influence the clip image should exert over the generated content. - strength (`float`, *optional*, defaults to 1.0): - Affects the overall styling or quality of the generated output. Values closer to 1 usually provide direct - adherence to prompts. - comfyui_progressbar (`bool`, *optional*, defaults to `False`): - Enables a progress bar in ComfyUI, providing visual feedback during the generation process. - - Examples: - # Example usage of the function for generating images based on prompts. - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - Returns either a structured output containing generated images and their metadata when `return_dict` is - `True`, or a simpler tuple, where the first element is a list of generated images and the second - element indicates if any of them contain "not-safe-for-work" (NSFW) content. - """ - - if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): - callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - - # 0. default height and width - height = int(height // 16 * 16) - width = int(width // 16 * 16) - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - height, - width, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - prompt_attention_mask, - negative_prompt_attention_mask, - prompt_embeds_2, - negative_prompt_embeds_2, - prompt_attention_mask_2, - negative_prompt_attention_mask_2, - callback_on_step_end_tensor_inputs, - ) - self._guidance_scale = guidance_scale - self._guidance_rescale = guidance_rescale - self._interrupt = False - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - - # 3. Encode input prompt - ( - prompt_embeds, - negative_prompt_embeds, - prompt_attention_mask, - negative_prompt_attention_mask, - ) = self.encode_prompt( - prompt=prompt, - device=device, - dtype=self.transformer.dtype, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - prompt_attention_mask=prompt_attention_mask, - negative_prompt_attention_mask=negative_prompt_attention_mask, - text_encoder_index=0, - ) - ( - prompt_embeds_2, - negative_prompt_embeds_2, - prompt_attention_mask_2, - negative_prompt_attention_mask_2, - ) = self.encode_prompt( - prompt=prompt, - device=device, - dtype=self.transformer.dtype, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds_2, - negative_prompt_embeds=negative_prompt_embeds_2, - prompt_attention_mask=prompt_attention_mask_2, - negative_prompt_attention_mask=negative_prompt_attention_mask_2, - text_encoder_index=1, - ) - torch.cuda.empty_cache() - - # 4. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps( - num_inference_steps=num_inference_steps, strength=strength, device=device - ) - if comfyui_progressbar: - from comfy.utils import ProgressBar - pbar = ProgressBar(num_inference_steps + 3) - # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise - is_strength_max = strength == 1.0 - - if video is not None: - video_length = video.shape[2] - init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) - init_video = init_video.to(dtype=torch.float32) - init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length) - else: - init_video = None - - # Prepare latent variables - num_channels_latents = self.vae.config.latent_channels - num_channels_transformer = self.transformer.config.in_channels - return_image_latents = num_channels_transformer == num_channels_latents - - # 5. Prepare latents. - latents_outputs = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - video_length, - prompt_embeds.dtype, - device, - generator, - latents, - video=init_video, - timestep=latent_timestep, - is_strength_max=is_strength_max, - return_noise=True, - return_video_latents=return_image_latents, - ) - if return_image_latents: - latents, noise, image_latents = latents_outputs - else: - latents, noise = latents_outputs - - if comfyui_progressbar: - pbar.update(1) - - # 6. Prepare clip latents if it needs. - if clip_image is not None and self.transformer.enable_clip_in_inpaint: - inputs = self.clip_image_processor(images=clip_image, return_tensors="pt") - inputs["pixel_values"] = inputs["pixel_values"].to(latents.device, dtype=latents.dtype) - clip_encoder_hidden_states = self.clip_image_encoder(**inputs).last_hidden_state[:, 1:] - clip_encoder_hidden_states_neg = torch.zeros( - [ - batch_size, - int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2, - int(self.clip_image_encoder.config.hidden_size) - ] - ).to(latents.device, dtype=latents.dtype) - - clip_attention_mask = torch.ones([batch_size, self.transformer.n_query]).to(latents.device, dtype=latents.dtype) - clip_attention_mask_neg = torch.zeros([batch_size, self.transformer.n_query]).to(latents.device, dtype=latents.dtype) - - clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states_neg, clip_encoder_hidden_states]) if self.do_classifier_free_guidance else clip_encoder_hidden_states - clip_attention_mask_input = torch.cat([clip_attention_mask_neg, clip_attention_mask]) if self.do_classifier_free_guidance else clip_attention_mask - - elif clip_image is None and num_channels_transformer != num_channels_latents and self.transformer.enable_clip_in_inpaint: - clip_encoder_hidden_states = torch.zeros( - [ - batch_size, - int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2, - int(self.clip_image_encoder.config.hidden_size) - ] - ).to(latents.device, dtype=latents.dtype) - - clip_attention_mask = torch.zeros([batch_size, self.transformer.n_query]) - clip_attention_mask = clip_attention_mask.to(latents.device, dtype=latents.dtype) - - clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states] * 2) if self.do_classifier_free_guidance else clip_encoder_hidden_states - clip_attention_mask_input = torch.cat([clip_attention_mask] * 2) if self.do_classifier_free_guidance else clip_attention_mask - - else: - clip_encoder_hidden_states_input = None - clip_attention_mask_input = None - if comfyui_progressbar: - pbar.update(1) - - # 7. Prepare inpaint latents if it needs. - if mask_video is not None: - if (mask_video == 255).all(): - # Use zero latents if we want to t2v. - if self.transformer.resize_inpaint_mask_directly: - mask_latents = torch.zeros_like(latents)[:, :1].to(latents.device, latents.dtype) - else: - mask_latents = torch.zeros_like(latents).to(latents.device, latents.dtype) - masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype) - - mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents - masked_video_latents_input = ( - torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents - ) - inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype) - else: - # Prepare mask latent variables - video_length = video.shape[2] - mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) - mask_condition = mask_condition.to(dtype=torch.float32) - mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) - - if num_channels_transformer != num_channels_latents: - mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1]) - if masked_video_latents is None: - masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1 - else: - masked_video = masked_video_latents - - if self.transformer.resize_inpaint_mask_directly: - _, masked_video_latents = self.prepare_mask_latents( - None, - masked_video, - batch_size, - height, - width, - prompt_embeds.dtype, - device, - generator, - self.do_classifier_free_guidance, - noise_aug_strength=noise_aug_strength, - ) - mask_latents = resize_mask(1 - mask_condition, masked_video_latents, self.vae.cache_mag_vae) - mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor - else: - mask_latents, masked_video_latents = self.prepare_mask_latents( - mask_condition_tile, - masked_video, - batch_size, - height, - width, - prompt_embeds.dtype, - device, - generator, - self.do_classifier_free_guidance, - noise_aug_strength=noise_aug_strength, - ) - - mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents - masked_video_latents_input = ( - torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents - ) - inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype) - else: - inpaint_latents = None - - mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1]) - mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype) - else: - if num_channels_transformer != num_channels_latents: - mask = torch.zeros_like(latents).to(latents.device, latents.dtype) - masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype) - - mask_input = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask - masked_video_latents_input = ( - torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents - ) - inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype) - else: - mask = torch.zeros_like(init_video[:, :1]) - mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1]) - mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype) - - inpaint_latents = None - if comfyui_progressbar: - pbar.update(1) - - # Check that sizes of mask, masked image and latents match - if num_channels_transformer != num_channels_latents: - num_channels_mask = mask_latents.shape[1] - num_channels_masked_image = masked_video_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels: - raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects" - f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.transformer` or your `mask_image` or `image` input." - ) - - # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 9 create image_rotary_emb, style embedding & time ids - grid_height = height // 8 // self.transformer.config.patch_size - grid_width = width // 8 // self.transformer.config.patch_size - if self.transformer.config.get("time_position_encoding_type", "2d_rope") == "3d_rope": - base_size_width = 720 // 8 // self.transformer.config.patch_size - base_size_height = 480 // 8 // self.transformer.config.patch_size - - grid_crops_coords = get_resize_crop_region_for_grid( - (grid_height, grid_width), base_size_width, base_size_height - ) - image_rotary_emb = get_3d_rotary_pos_embed( - self.transformer.config.attention_head_dim, grid_crops_coords, grid_size=(grid_height, grid_width), - temporal_size=latents.size(2), use_real=True, - ) - else: - base_size = 512 // 8 // self.transformer.config.patch_size - grid_crops_coords = get_resize_crop_region_for_grid( - (grid_height, grid_width), base_size, base_size - ) - image_rotary_emb = get_2d_rotary_pos_embed( - self.transformer.config.attention_head_dim, grid_crops_coords, (grid_height, grid_width) - ) - - # Get other hunyuan params - style = torch.tensor([0], device=device) - - target_size = target_size or (height, width) - add_time_ids = list(original_size + target_size + crops_coords_top_left) - add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) - - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) - prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) - prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) - add_time_ids = torch.cat([add_time_ids] * 2, dim=0) - style = torch.cat([style] * 2, dim=0) - - prompt_embeds = prompt_embeds.to(device=device) - prompt_attention_mask = prompt_attention_mask.to(device=device) - prompt_embeds_2 = prompt_embeds_2.to(device=device) - prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device) - add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat( - batch_size * num_images_per_prompt, 1 - ) - style = style.to(device=device).repeat(batch_size * num_images_per_prompt) - - torch.cuda.empty_cache() - if self.enable_autocast_float8_transformer_flag: - origin_weight_dtype = self.transformer.dtype - self.transformer = self.transformer.to(torch.float8_e4m3fn) - # 10. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - if i < len(timesteps) * (1 - clip_apply_ratio) and clip_encoder_hidden_states_input is not None: - clip_encoder_hidden_states_actual_input = torch.zeros_like(clip_encoder_hidden_states_input) - clip_attention_mask_actual_input = torch.zeros_like(clip_attention_mask_input) - else: - clip_encoder_hidden_states_actual_input = clip_encoder_hidden_states_input - clip_attention_mask_actual_input = clip_attention_mask_input - - # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input - t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to( - dtype=latent_model_input.dtype - ) - - # predict the noise residual - noise_pred = self.transformer( - latent_model_input, - t_expand, - encoder_hidden_states=prompt_embeds, - text_embedding_mask=prompt_attention_mask, - encoder_hidden_states_t5=prompt_embeds_2, - text_embedding_mask_t5=prompt_attention_mask_2, - image_meta_size=add_time_ids, - style=style, - image_rotary_emb=image_rotary_emb, - inpaint_latents=inpaint_latents, - clip_encoder_hidden_states=clip_encoder_hidden_states_actual_input, - clip_attention_mask=clip_attention_mask_actual_input, - return_dict=False, - )[0] - if noise_pred.size()[1] != self.vae.config.latent_channels: - noise_pred, _ = noise_pred.chunk(2, dim=1) - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - if self.do_classifier_free_guidance and guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - if num_channels_transformer == 4: - init_latents_proper = image_latents - init_mask = mask - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.add_noise( - init_latents_proper, noise, torch.tensor([noise_timestep]) - ) - - latents = (1 - init_mask) * init_latents_proper + init_mask * latents - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2) - negative_prompt_embeds_2 = callback_outputs.pop( - "negative_prompt_embeds_2", negative_prompt_embeds_2 - ) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - - if XLA_AVAILABLE: - xm.mark_step() - - if comfyui_progressbar: - pbar.update(1) - - if self.enable_autocast_float8_transformer_flag: - self.transformer = self.transformer.to("cpu", origin_weight_dtype) - - torch.cuda.empty_cache() - # Post-processing - video = self.decode_latents(latents) - - # Convert to tensor - if output_type == "latent": - video = torch.from_numpy(video) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return video - - return EasyAnimatePipelineOutput(videos=video) \ No newline at end of file diff --git a/easyanimate/ui/ui.py b/easyanimate/ui/ui.py index 50a424834fc414f7784a03b40b93cc0a3c4cc9ea..62daaaf78b14573f1f9524b9a90893f217de9c0f 100644 --- a/easyanimate/ui/ui.py +++ b/easyanimate/ui/ui.py @@ -17,41 +17,42 @@ import torch from diffusers import (AutoencoderKL, DDIMScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, - PNDMScheduler) + FlowMatchEulerDiscreteScheduler, PNDMScheduler) from diffusers.utils.import_utils import is_xformers_available from omegaconf import OmegaConf from PIL import Image from safetensors import safe_open from transformers import (BertModel, BertTokenizer, CLIPImageProcessor, - CLIPVisionModelWithProjection, T5Tokenizer, - T5EncoderModel, T5Tokenizer) + CLIPVisionModelWithProjection, Qwen2Tokenizer, + Qwen2VLForConditionalGeneration, T5EncoderModel, + T5Tokenizer) -from easyanimate.data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio -from easyanimate.models import (name_to_autoencoder_magvit, +from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio +from ..models import (name_to_autoencoder_magvit, name_to_transformer3d) -from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit -from easyanimate.models.transformer3d import (HunyuanTransformer3DModel, - Transformer3DModel) -from easyanimate.pipeline.pipeline_easyanimate import EasyAnimatePipeline -from easyanimate.pipeline.pipeline_easyanimate_inpaint import \ +from ..pipeline.pipeline_easyanimate import \ + EasyAnimatePipeline +from ..pipeline.pipeline_easyanimate_control import \ + EasyAnimateControlPipeline +from ..pipeline.pipeline_easyanimate_inpaint import \ EasyAnimateInpaintPipeline -from easyanimate.pipeline.pipeline_easyanimate_multi_text_encoder import \ - EasyAnimatePipeline_Multi_Text_Encoder -from easyanimate.pipeline.pipeline_easyanimate_multi_text_encoder_inpaint import \ - EasyAnimatePipeline_Multi_Text_Encoder_Inpaint -from easyanimate.utils.lora_utils import merge_lora, unmerge_lora -from easyanimate.utils.utils import ( +from ..utils.fp8_optimization import convert_weight_dtype_wrapper +from ..utils.lora_utils import merge_lora, unmerge_lora +from ..utils.utils import ( get_image_to_video_latent, get_video_to_video_latent, get_width_and_height_from_image_and_base_resolution, save_videos_grid) -from easyanimate.utils.fp8_optimization import convert_weight_dtype_wrapper -scheduler_dict = { +ddpm_scheduler_dict = { "Euler": EulerDiscreteScheduler, "Euler A": EulerAncestralDiscreteScheduler, "DPM++": DPMSolverMultistepScheduler, "PNDM": PNDMScheduler, "DDIM": DDIMScheduler, } +flow_scheduler_dict = { + "Flow": FlowMatchEulerDiscreteScheduler, +} +all_cheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict} gradio_version = pkg_resources.get_distribution("gradio").version gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False @@ -98,8 +99,8 @@ class EasyAnimateController: self.GPU_memory_mode = GPU_memory_mode self.weight_dtype = weight_dtype - self.edition = "v5" - self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5_magvit_multi_text_encoder.yaml")) + self.edition = "v5.1" + self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5.1_magvit_qwen.yaml")) def refresh_diffusion_transformer(self): self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/"))) @@ -121,26 +122,37 @@ class EasyAnimateController: if edition == "v1": self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v1_motion_module.yaml")) return gr.update(), gr.update(value="none"), gr.update(visible=True), gr.update(visible=True), \ + gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \ gr.update(value=512, minimum=384, maximum=704, step=32), \ gr.update(value=512, minimum=384, maximum=704, step=32), gr.update(value=80, minimum=40, maximum=80, step=1) elif edition == "v2": self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v2_magvit_motion_module.yaml")) return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \ + gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \ gr.update(value=672, minimum=128, maximum=1344, step=16), \ gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=9, maximum=144, step=9) elif edition == "v3": self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v3_slicevae_motion_module.yaml")) return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \ + gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \ gr.update(value=672, minimum=128, maximum=1344, step=16), \ gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=8, maximum=144, step=8) elif edition == "v4": self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v4_slicevae_multi_text_encoder.yaml")) return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \ + gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \ gr.update(value=672, minimum=128, maximum=1344, step=16), \ gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=144, minimum=8, maximum=144, step=8) elif edition == "v5": self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5_magvit_multi_text_encoder.yaml")) return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \ + gr.update(choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0]), \ + gr.update(value=672, minimum=128, maximum=1344, step=16), \ + gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=49, minimum=1, maximum=49, step=4) + elif edition == "v5.1": + self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_v5.1_magvit_qwen.yaml")) + return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \ + gr.update(choices=list(flow_scheduler_dict.keys()), value=list(flow_scheduler_dict.keys())[0]), \ gr.update(value=672, minimum=128, maximum=1344, step=16), \ gr.update(value=384, minimum=128, maximum=1344, step=16), gr.update(value=49, minimum=1, maximum=49, step=4) @@ -170,33 +182,55 @@ class EasyAnimateController: self.transformer = Choosen_Transformer3DModel.from_pretrained_2d( diffusion_transformer_dropdown, subfolder="transformer", - transformer_additional_kwargs=transformer_additional_kwargs - ).to(self.weight_dtype) + transformer_additional_kwargs=transformer_additional_kwargs, + torch_dtype=torch.float8_e4m3fn if self.GPU_memory_mode == "model_cpu_offload_and_qfloat8" else self.weight_dtype, + low_cpu_mem_usage=True, + ) if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): tokenizer = BertTokenizer.from_pretrained( diffusion_transformer_dropdown, subfolder="tokenizer" ) - tokenizer_2 = T5Tokenizer.from_pretrained( - diffusion_transformer_dropdown, subfolder="tokenizer_2" - ) + if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False): + tokenizer_2 = Qwen2Tokenizer.from_pretrained( + os.path.join(diffusion_transformer_dropdown, "tokenizer_2") + ) + else: + tokenizer_2 = T5Tokenizer.from_pretrained( + diffusion_transformer_dropdown, subfolder="tokenizer_2" + ) else: - tokenizer = T5Tokenizer.from_pretrained( - diffusion_transformer_dropdown, subfolder="tokenizer" - ) + if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False): + tokenizer = Qwen2Tokenizer.from_pretrained( + os.path.join(diffusion_transformer_dropdown, "tokenizer") + ) + else: + tokenizer = T5Tokenizer.from_pretrained( + diffusion_transformer_dropdown, subfolder="tokenizer" + ) tokenizer_2 = None if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): text_encoder = BertModel.from_pretrained( diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype ) - text_encoder_2 = T5EncoderModel.from_pretrained( - diffusion_transformer_dropdown, subfolder="text_encoder_2", torch_dtype=self.weight_dtype - ) - else: - text_encoder = T5EncoderModel.from_pretrained( - diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype - ) + if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False): + text_encoder_2 = Qwen2VLForConditionalGeneration.from_pretrained( + os.path.join(diffusion_transformer_dropdown, "text_encoder_2") + ) + else: + text_encoder_2 = T5EncoderModel.from_pretrained( + diffusion_transformer_dropdown, subfolder="text_encoder_2", torch_dtype=self.weight_dtype + ) + else: + if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False): + text_encoder = Qwen2VLForConditionalGeneration.from_pretrained( + os.path.join(diffusion_transformer_dropdown, "text_encoder") + ) + else: + text_encoder = T5EncoderModel.from_pretrained( + diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype + ) text_encoder_2 = None # Get pipeline @@ -212,23 +246,18 @@ class EasyAnimateController: clip_image_processor = None # Get Scheduler - Choosen_Scheduler = scheduler_dict = { - "Euler": EulerDiscreteScheduler, - "Euler A": EulerAncestralDiscreteScheduler, - "DPM++": DPMSolverMultistepScheduler, - "PNDM": PNDMScheduler, - "DDIM": DDIMScheduler, - }["Euler"] - + if self.edition in ["v5.1"]: + Choosen_Scheduler = all_cheduler_dict["Flow"] + else: + Choosen_Scheduler = all_cheduler_dict["Euler"] scheduler = Choosen_Scheduler.from_pretrained( diffusion_transformer_dropdown, subfolder="scheduler" ) - if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): + if self.model_type == "Inpaint": if self.transformer.config.in_channels != self.vae.config.latent_channels: - self.pipeline = EasyAnimatePipeline_Multi_Text_Encoder_Inpaint.from_pretrained( - diffusion_transformer_dropdown, + self.pipeline = EasyAnimateInpaintPipeline( text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, @@ -236,13 +265,11 @@ class EasyAnimateController: vae=self.vae, transformer=self.transformer, scheduler=scheduler, - torch_dtype=self.weight_dtype, clip_image_encoder=clip_image_encoder, clip_image_processor=clip_image_processor, - ) + ).to(self.weight_dtype) else: - self.pipeline = EasyAnimatePipeline_Multi_Text_Encoder.from_pretrained( - diffusion_transformer_dropdown, + self.pipeline = EasyAnimatePipeline( text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, @@ -250,40 +277,25 @@ class EasyAnimateController: vae=self.vae, transformer=self.transformer, scheduler=scheduler, - torch_dtype=self.weight_dtype - ) + ).to(self.weight_dtype) else: - if self.transformer.config.in_channels != self.vae.config.latent_channels: - self.pipeline = EasyAnimateInpaintPipeline( - diffusion_transformer_dropdown, - text_encoder=text_encoder, - tokenizer=tokenizer, - vae=self.vae, - transformer=self.transformer, - scheduler=scheduler, - torch_dtype=self.weight_dtype, - clip_image_encoder=clip_image_encoder, - clip_image_processor=clip_image_processor, - ) - else: - self.pipeline = EasyAnimatePipeline( - diffusion_transformer_dropdown, - text_encoder=text_encoder, - tokenizer=tokenizer, - vae=self.vae, - transformer=self.transformer, - scheduler=scheduler, - torch_dtype=self.weight_dtype - ) + self.pipeline = EasyAnimateControlPipeline( + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + vae=self.vae, + transformer=self.transformer, + scheduler=scheduler, + ).to(self.weight_dtype) if self.GPU_memory_mode == "sequential_cpu_offload": self.pipeline.enable_sequential_cpu_offload() elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8": self.pipeline.enable_model_cpu_offload() - self.pipeline.enable_autocast_float8_transformer() convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype) else: - self.GPU_memory_mode.enable_model_cpu_offload() + self.pipeline.enable_model_cpu_offload() print("Update diffusion transformer done") return gr.update() @@ -374,8 +386,10 @@ class EasyAnimateController: if self.base_model_path != base_model_dropdown: self.update_base_model(base_model_dropdown) + if self.motion_module_path != motion_module_dropdown: + self.update_motion_module(motion_module_dropdown) + if self.lora_model_path != lora_model_dropdown: - print("Update lora model") self.update_lora_model(lora_model_dropdown) if control_video is not None and self.model_type == "Inpaint": @@ -426,19 +440,21 @@ class EasyAnimateController: else: raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.") - fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8}[self.edition] + fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8, "v5.1": 8}[self.edition] is_image = True if generation_method == "Image Generation" else False - if is_xformers_available() and not self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): self.transformer.enable_xformers_memory_efficient_attention() + if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) + else: seed_textbox = np.random.randint(0, 1e10) + generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox)) - self.pipeline.scheduler = scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config) + if is_xformers_available() \ + and self.inference_config['transformer_additional_kwargs'].get('transformer_type', 'Transformer3DModel') == 'Transformer3DModel': + self.transformer.enable_xformers_memory_efficient_attention() + + self.pipeline.scheduler = all_cheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config) if self.lora_model_path != "none": # lora part self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) - - if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) - else: seed_textbox = np.random.randint(0, 1e10) - generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox)) try: if self.model_type == "Inpaint": @@ -480,7 +496,7 @@ class EasyAnimateController: video = input_video, mask_video = input_video_mask, strength = 1, - ).videos + ).frames if init_frames != 0: mix_ratio = torch.from_numpy( @@ -531,7 +547,7 @@ class EasyAnimateController: video = input_video, mask_video = input_video_mask, strength = strength, - ).videos + ).frames else: if self.vae.cache_mag_vae: length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1 @@ -547,7 +563,7 @@ class EasyAnimateController: height = height_slider, video_length = length_slider if not is_image else 1, generator = generator - ).videos + ).frames else: if self.vae.cache_mag_vae: length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1 @@ -566,7 +582,7 @@ class EasyAnimateController: generator = generator, control_video = input_video, - ).videos + ).frames except Exception as e: gc.collect() torch.cuda.empty_cache() @@ -676,8 +692,8 @@ def ui(GPU_memory_mode, weight_dtype): with gr.Row(): easyanimate_edition_dropdown = gr.Dropdown( label="The config of EasyAnimate Edition (EasyAnimate版本配置)", - choices=["v1", "v2", "v3", "v4", "v5"], - value="v5", + choices=["v1", "v2", "v3", "v4", "v5", "v5.1"], + value="v5.1", interactive=True, ) gr.Markdown( @@ -751,13 +767,22 @@ def ui(GPU_memory_mode, weight_dtype): """ ) - prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.") - negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="Blurring, mutation, deformation, distortion, dark and solid, comics." ) + prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful, clear eyes and blonde hair stands in the forest, wearing a white dress and a crown. Her expression is serene, reminiscent of a movie star, with fair and youthful skin. Her brown long hair flows in the wind. The video quality is very high, with a clear view. High quality, masterpiece, best quality, high resolution, ultra-fine, fantastical.") + gr.Markdown( + """ + Using longer neg prompt such as "Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art." can increase stability. Adding words such as "quiet, solid" to the neg prompt can increase dynamism. + 使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性。在neg prompt中添加"安静,固定"等词语可以增加动态性。 + """ + ) + negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="Twisted body, limb deformities, text captions, comic, static, ugly, error, messy code." ) with gr.Row(): with gr.Column(): with gr.Row(): - sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) + sampler_dropdown = gr.Dropdown( + label="Sampling method (采样器种类)", + choices=list(flow_scheduler_dict.keys()), value=list(flow_scheduler_dict.keys())[0] + ) sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=100, step=1) resize_method = gr.Radio( @@ -794,11 +819,11 @@ def ui(GPU_memory_mode, weight_dtype): template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"] def select_template(evt: gr.SelectData): text = { - "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", - "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", - "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", - "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", - "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", + "asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.", + "asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.", + "asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.", + "asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.", + "asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.", }[template_gallery_path[evt.index]] return template_gallery_path[evt.index], text @@ -838,6 +863,7 @@ def ui(GPU_memory_mode, weight_dtype): gr.Markdown( """ Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4). + Only normal controls are supported in app.py; trajectory control and camera control need ComfyUI, as shown in https://github.com/aigc-apps/EasyAnimate/tree/main/comfyui. """ ) control_video = gr.Video( @@ -927,6 +953,7 @@ def ui(GPU_memory_mode, weight_dtype): diffusion_transformer_dropdown, motion_module_dropdown, motion_module_refresh_button, + sampler_dropdown, width_slider, height_slider, length_slider, @@ -1003,33 +1030,55 @@ class EasyAnimateController_Modelscope: self.transformer = Choosen_Transformer3DModel.from_pretrained_2d( model_name, subfolder="transformer", - transformer_additional_kwargs=transformer_additional_kwargs - ).to(self.weight_dtype) + transformer_additional_kwargs=transformer_additional_kwargs, + torch_dtype=torch.float8_e4m3fn if GPU_memory_mode == "model_cpu_offload_and_qfloat8" else weight_dtype, + low_cpu_mem_usage=True, + ) if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): tokenizer = BertTokenizer.from_pretrained( model_name, subfolder="tokenizer" ) - tokenizer_2 = T5Tokenizer.from_pretrained( - model_name, subfolder="tokenizer_2" - ) + if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False): + tokenizer_2 = Qwen2Tokenizer.from_pretrained( + os.path.join(model_name, "tokenizer_2") + ) + else: + tokenizer_2 = T5Tokenizer.from_pretrained( + model_name, subfolder="tokenizer_2" + ) else: - tokenizer = T5Tokenizer.from_pretrained( - model_name, subfolder="tokenizer" - ) + if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False): + tokenizer = Qwen2Tokenizer.from_pretrained( + os.path.join(model_name, "tokenizer") + ) + else: + tokenizer = T5Tokenizer.from_pretrained( + model_name, subfolder="tokenizer" + ) tokenizer_2 = None if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): text_encoder = BertModel.from_pretrained( model_name, subfolder="text_encoder", torch_dtype=self.weight_dtype ) - text_encoder_2 = T5EncoderModel.from_pretrained( - model_name, subfolder="text_encoder_2", torch_dtype=self.weight_dtype - ) - else: - text_encoder = T5EncoderModel.from_pretrained( - model_name, subfolder="text_encoder", torch_dtype=self.weight_dtype - ) + if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False): + text_encoder_2 = Qwen2VLForConditionalGeneration.from_pretrained( + os.path.join(model_name, "text_encoder_2"), torch_dtype=self.weight_dtype + ) + else: + text_encoder_2 = T5EncoderModel.from_pretrained( + model_name, subfolder="text_encoder_2", torch_dtype=self.weight_dtype + ) + else: + if self.inference_config['text_encoder_kwargs'].get('replace_t5_to_llm', False): + text_encoder = Qwen2VLForConditionalGeneration.from_pretrained( + os.path.join(model_name, "text_encoder"), torch_dtype=self.weight_dtype + ) + else: + text_encoder = T5EncoderModel.from_pretrained( + model_name, subfolder="text_encoder", torch_dtype=self.weight_dtype + ) text_encoder_2 = None # Get pipeline @@ -1045,23 +1094,18 @@ class EasyAnimateController_Modelscope: clip_image_processor = None # Get Scheduler - Choosen_Scheduler = scheduler_dict = { - "Euler": EulerDiscreteScheduler, - "Euler A": EulerAncestralDiscreteScheduler, - "DPM++": DPMSolverMultistepScheduler, - "PNDM": PNDMScheduler, - "DDIM": DDIMScheduler, - }["Euler"] - + if self.edition in ["v5.1"]: + Choosen_Scheduler = all_cheduler_dict["Flow"] + else: + Choosen_Scheduler = all_cheduler_dict["Euler"] scheduler = Choosen_Scheduler.from_pretrained( model_name, subfolder="scheduler" ) - if self.inference_config['text_encoder_kwargs'].get('enable_multi_text_encoder', False): + if model_type == "Inpaint": if self.transformer.config.in_channels != self.vae.config.latent_channels: - self.pipeline = EasyAnimatePipeline_Multi_Text_Encoder_Inpaint.from_pretrained( - model_name, + self.pipeline = EasyAnimateInpaintPipeline( text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, @@ -1069,51 +1113,34 @@ class EasyAnimateController_Modelscope: vae=self.vae, transformer=self.transformer, scheduler=scheduler, - torch_dtype=self.weight_dtype, clip_image_encoder=clip_image_encoder, clip_image_processor=clip_image_processor, - ) + ).to(weight_dtype) else: - self.pipeline = EasyAnimatePipeline_Multi_Text_Encoder.from_pretrained( - model_name, + self.pipeline = EasyAnimatePipeline( text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, vae=self.vae, transformer=self.transformer, - scheduler=scheduler, - torch_dtype=self.weight_dtype - ) + scheduler=scheduler + ).to(weight_dtype) else: - if self.transformer.config.in_channels != self.vae.config.latent_channels: - self.pipeline = EasyAnimateInpaintPipeline( - model_name, - text_encoder=text_encoder, - tokenizer=tokenizer, - vae=self.vae, - transformer=self.transformer, - scheduler=scheduler, - torch_dtype=self.weight_dtype, - clip_image_encoder=clip_image_encoder, - clip_image_processor=clip_image_processor, - ) - else: - self.pipeline = EasyAnimatePipeline( - model_name, - text_encoder=text_encoder, - tokenizer=tokenizer, - vae=self.vae, - transformer=self.transformer, - scheduler=scheduler, - torch_dtype=self.weight_dtype - ) + self.pipeline = EasyAnimateControlPipeline( + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + vae=self.vae, + transformer=self.transformer, + scheduler=scheduler, + ).to(weight_dtype) if GPU_memory_mode == "sequential_cpu_offload": self.pipeline.enable_sequential_cpu_offload() elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": self.pipeline.enable_model_cpu_offload() - self.pipeline.enable_autocast_float8_transformer() convert_weight_dtype_wrapper(self.pipeline.transformer, weight_dtype) else: GPU_memory_mode.enable_model_cpu_offload() @@ -1214,17 +1241,17 @@ class EasyAnimateController_Modelscope: else: raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.") - fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8}[self.edition] + fps = {"v1": 12, "v2": 24, "v3": 24, "v4": 24, "v5": 8, "v5.1": 8}[self.edition] is_image = True if generation_method == "Image Generation" else False - self.pipeline.scheduler = scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config) - if self.lora_model_path != "none": - # lora part - self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) - if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) else: seed_textbox = np.random.randint(0, 1e10) generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox)) + + self.pipeline.scheduler = all_cheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config) + if self.lora_model_path != "none": + # lora part + self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) try: if self.model_type == "Inpaint": @@ -1254,7 +1281,7 @@ class EasyAnimateController_Modelscope: video = input_video, mask_video = input_video_mask, strength = strength, - ).videos + ).frames else: sample = self.pipeline( prompt_textbox, @@ -1265,7 +1292,7 @@ class EasyAnimateController_Modelscope: height = height_slider, video_length = length_slider if not is_image else 1, generator = generator - ).videos + ).frames else: if self.vae.cache_mag_vae: length_slider = int((length_slider - 1) // self.vae.mini_batch_encoder * self.vae.mini_batch_encoder) + 1 @@ -1285,7 +1312,7 @@ class EasyAnimateController_Modelscope: generator = generator, control_video = input_video, - ).videos + ).frames except Exception as e: gc.collect() torch.cuda.empty_cache() @@ -1406,13 +1433,28 @@ def ui_modelscope(model_type, edition, config_path, model_name, savedir_sample, """ ) - prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.") - negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="Blurring, mutation, deformation, distortion, dark and solid, comics." ) + prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful, clear eyes and blonde hair stands in the forest, wearing a white dress and a crown. Her expression is serene, reminiscent of a movie star, with fair and youthful skin. Her brown long hair flows in the wind. The video quality is very high, with a clear view. High quality, masterpiece, best quality, high resolution, ultra-fine, fantastical.") + gr.Markdown( + """ + Using longer neg prompt such as "Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art." can increase stability. Adding words such as "quiet, solid" to the neg prompt can increase dynamism. + 使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性。在neg prompt中添加"安静,固定"等词语可以增加动态性。 + """ + ) + negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="Twisted body, limb deformities, text captions, comic, static, ugly, error, messy code." ) with gr.Row(): with gr.Column(): with gr.Row(): - sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) + if edition in ["v5.1"]: + sampler_dropdown = gr.Dropdown( + label="Sampling method (采样器种类)", + choices=list(flow_scheduler_dict.keys()), value=list(flow_scheduler_dict.keys())[0] + ) + else: + sampler_dropdown = gr.Dropdown( + label="Sampling method (采样器种类)", + choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0] + ) sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=50, step=1, interactive=False) if edition == "v1": @@ -1466,11 +1508,11 @@ def ui_modelscope(model_type, edition, config_path, model_name, savedir_sample, template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"] def select_template(evt: gr.SelectData): text = { - "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", - "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", - "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", - "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", - "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", + "asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.", + "asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.", + "asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.", + "asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.", + "asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.", }[template_gallery_path[evt.index]] return template_gallery_path[evt.index], text @@ -1510,6 +1552,7 @@ def ui_modelscope(model_type, edition, config_path, model_name, savedir_sample, gr.Markdown( """ Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4). + Only normal controls are supported in app.py; trajectory control and camera control need ComfyUI, as shown in https://github.com/aigc-apps/EasyAnimate/tree/main/comfyui. """ ) control_video = gr.Video( @@ -1820,13 +1863,28 @@ def ui_eas(edition, config_path, model_name, savedir_sample): """ ) - prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.") - negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="Blurring, mutation, deformation, distortion, dark and solid, comics." ) + prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful, clear eyes and blonde hair stands in the forest, wearing a white dress and a crown. Her expression is serene, reminiscent of a movie star, with fair and youthful skin. Her brown long hair flows in the wind. The video quality is very high, with a clear view. High quality, masterpiece, best quality, high resolution, ultra-fine, fantastical.") + gr.Markdown( + """ + Using longer neg prompt such as "Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art." can increase stability. Adding words such as "quiet, solid" to the neg prompt can increase dynamism. + 使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性。在neg prompt中添加"安静,固定"等词语可以增加动态性。 + """ + ) + negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="Twisted body, limb deformities, text captions, comic, static, ugly, error, messy code." ) with gr.Row(): with gr.Column(): with gr.Row(): - sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) + if edition in ["v5.1"]: + sampler_dropdown = gr.Dropdown( + label="Sampling method (采样器种类)", + choices=list(flow_scheduler_dict.keys()), value=list(flow_scheduler_dict.keys())[0] + ) + else: + sampler_dropdown = gr.Dropdown( + label="Sampling method (采样器种类)", + choices=list(ddpm_scheduler_dict.keys()), value=list(ddpm_scheduler_dict.keys())[0] + ) sample_step_slider = gr.Slider(label="Sampling steps", value=40, minimum=10, maximum=40, step=1, interactive=False) if edition == "v1": @@ -1875,11 +1933,11 @@ def ui_eas(edition, config_path, model_name, savedir_sample): template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"] def select_template(evt: gr.SelectData): text = { - "asset/1.png": "The dog is shaking head. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", - "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", - "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", - "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", - "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", + "asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.", + "asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.", + "asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.", + "asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.", + "asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.", }[template_gallery_path[evt.index]] return template_gallery_path[evt.index], text diff --git a/easyanimate/utils/lora_utils.py b/easyanimate/utils/lora_utils.py index b50d11b33c60d40e1b73b9da4b7e4ff39cfb3f93..f6dab6f4f94022fd9174a4c584ed1f014984e704 100644 --- a/easyanimate/utils/lora_utils.py +++ b/easyanimate/utils/lora_utils.py @@ -369,7 +369,6 @@ def create_network( def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False): LORA_PREFIX_TRANSFORMER = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" - SPECIAL_LAYER_NAME = ["text_proj_t5"] if state_dict is None: state_dict = load_file(lora_path, device=device) else: @@ -410,20 +409,25 @@ def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float3 else: temp_name = layer_infos.pop(0) - weight_up = elems['lora_up.weight'].to(dtype) - weight_down = elems['lora_down.weight'].to(dtype) + origin_dtype = curr_layer.weight.data.dtype + origin_device = curr_layer.weight.data.device + + curr_layer = curr_layer.to(device, dtype) + weight_up = elems['lora_up.weight'].to(device, dtype) + weight_down = elems['lora_down.weight'].to(device, dtype) + if 'alpha' in elems.keys(): alpha = elems['alpha'].item() / weight_up.shape[1] else: alpha = 1.0 - curr_layer.weight.data = curr_layer.weight.data.to(device) if len(weight_up.shape) == 4: - curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), - weight_down.squeeze(3).squeeze(2)).unsqueeze( - 2).unsqueeze(3) + curr_layer.weight.data += multiplier * alpha * torch.mm( + weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2) + ).unsqueeze(2).unsqueeze(3) else: curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down) + curr_layer = curr_layer.to(origin_device, origin_dtype) return pipeline @@ -448,35 +452,43 @@ def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.fl layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") curr_layer = pipeline.transformer - temp_name = layer_infos.pop(0) - print(layer, curr_layer) - while len(layer_infos) > -1: - try: - curr_layer = curr_layer.__getattr__(temp_name) - if len(layer_infos) > 0: - temp_name = layer_infos.pop(0) - elif len(layer_infos) == 0: - break - except Exception: - if len(layer_infos) == 0: - print('Error loading layer') - if len(temp_name) > 0: - temp_name += "_" + layer_infos.pop(0) - else: - temp_name = layer_infos.pop(0) - - weight_up = elems['lora_up.weight'].to(dtype) - weight_down = elems['lora_down.weight'].to(dtype) + try: + curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:])) + except Exception: + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(layer_infos) == 0: + print('Error loading layer') + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + origin_dtype = curr_layer.weight.data.dtype + origin_device = curr_layer.weight.data.device + + curr_layer = curr_layer.to(device, dtype) + weight_up = elems['lora_up.weight'].to(device, dtype) + weight_down = elems['lora_down.weight'].to(device, dtype) + if 'alpha' in elems.keys(): alpha = elems['alpha'].item() / weight_up.shape[1] else: alpha = 1.0 - curr_layer.weight.data = curr_layer.weight.data.to(device) if len(weight_up.shape) == 4: - curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), - weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + curr_layer.weight.data -= multiplier * alpha * torch.mm( + weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2) + ).unsqueeze(2).unsqueeze(3) else: curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down) + curr_layer = curr_layer.to(origin_device, origin_dtype) - return pipeline \ No newline at end of file + return pipeline diff --git a/easyanimate/utils/utils.py b/easyanimate/utils/utils.py index c1e2083a456053fdd6a50590d1c378572183f94b..bcfedf05789dd7c0b7ad71bebf6bc3a9587c7c82 100644 --- a/easyanimate/utils/utils.py +++ b/easyanimate/utils/utils.py @@ -169,47 +169,67 @@ def get_image_to_video_latent(validation_image_start, validation_image_end, vide return input_video, input_video_mask, clip_image def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None, ref_image=None): - if isinstance(input_video_path, str): - cap = cv2.VideoCapture(input_video_path) - input_video = [] + if input_video_path is not None: + if isinstance(input_video_path, str): + cap = cv2.VideoCapture(input_video_path) + input_video = [] - original_fps = cap.get(cv2.CAP_PROP_FPS) - frame_skip = 1 if fps is None else int(original_fps // fps) + original_fps = cap.get(cv2.CAP_PROP_FPS) + frame_skip = 1 if fps is None else int(original_fps // fps) - frame_count = 0 + frame_count = 0 - while True: - ret, frame = cap.read() - if not ret: - break + while True: + ret, frame = cap.read() + if not ret: + break - if frame_count % frame_skip == 0: - frame = cv2.resize(frame, (sample_size[1], sample_size[0])) - input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + if frame_count % frame_skip == 0: + frame = cv2.resize(frame, (sample_size[1], sample_size[0])) + input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) - frame_count += 1 + frame_count += 1 - cap.release() - else: - input_video = input_video_path + cap.release() + else: + input_video = input_video_path + + input_video = torch.from_numpy(np.array(input_video))[:video_length] + input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255 - input_video = torch.from_numpy(np.array(input_video))[:video_length] - input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255 + if validation_video_mask is not None: + validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0])) + input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255) + + input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0) + input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1]) + input_video_mask = input_video_mask.to(input_video.device, input_video.dtype) + else: + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, :] = 255 + else: + input_video, input_video_mask = None, None if ref_image is not None: - ref_image = Image.open(ref_image) - ref_image = torch.from_numpy(np.array(ref_image)) - ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + if isinstance(ref_image, str): + ref_image = Image.open(ref_image).convert("RGB") + ref_image = ref_image.resize((sample_size[1], sample_size[0])) + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + else: + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + return input_video, input_video_mask, ref_image - if validation_video_mask is not None: - validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0])) - input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255) - - input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0) - input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1]) - input_video_mask = input_video_mask.to(input_video.device, input_video.dtype) - else: - input_video_mask = torch.zeros_like(input_video[:, :1]) - input_video_mask[:, :, :] = 255 +def get_image_latent(ref_image=None, sample_size=None): + if ref_image is not None: + if isinstance(ref_image, str): + ref_image = Image.open(ref_image).convert("RGB") + ref_image = ref_image.resize((sample_size[1], sample_size[0])) + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + else: + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 - return input_video, input_video_mask, ref_image \ No newline at end of file + return ref_image \ No newline at end of file diff --git a/easyanimate/vae/ldm/models/autoencoder.py b/easyanimate/vae/ldm/models/autoencoder.py index 7cc28aef70ea3489bea9b114b12f5c6233e0791e..7d7854fc8b1a142e32a7e5841b751661a2d16161 100644 --- a/easyanimate/vae/ldm/models/autoencoder.py +++ b/easyanimate/vae/ldm/models/autoencoder.py @@ -126,13 +126,13 @@ class AutoencoderKLMagvit(pl.LightningModule): def configure_optimizers(self): lr = self.learning_rate - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + opt_ae = torch.optim.AdamW(list(self.encoder.parameters())+ list(self.decoder.parameters())+ list(self.quant_conv.parameters())+ list(self.post_quant_conv.parameters()), - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) + lr=lr, betas=(0.9, 0.999), weight_decay=5e-2) + opt_disc = torch.optim.AdamW(self.loss.discriminator.parameters(), + lr=lr, betas=(0.9, 0.999), weight_decay=5e-2) return [opt_ae, opt_disc], [] def get_last_layer(self): diff --git a/easyanimate/vae/ldm/models/casual3dcnn.py b/easyanimate/vae/ldm/models/casual3dcnn.py index a1e4a60ef24d9dc7c91697ee73252dae04c4b8f9..0c99c9a2449c74cbd4a10f63c295a37cbbe6d22f 100644 --- a/easyanimate/vae/ldm/models/casual3dcnn.py +++ b/easyanimate/vae/ldm/models/casual3dcnn.py @@ -279,13 +279,13 @@ class AutoencoderKL(pl.LightningModule): def configure_optimizers(self): lr = self.learning_rate - opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + opt_ae = torch.optim.AdamW(list(self.encoder.parameters())+ list(self.decoder.parameters())+ list(self.quant_conv.parameters())+ - list(self.post_quant_conv.parameters()), - lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), - lr=lr, betas=(0.5, 0.9)) + list(self.post_quant_conv.parameters()), \ + lr=lr, betas=(0.9, 0.999), weight_decay=5e-2) + opt_disc = torch.optim.AdamW(self.loss.discriminator.parameters(), + lr=lr, betas=(0.9, 0.999), weight_decay=5e-2) return [opt_ae, opt_disc], [] def get_last_layer(self): diff --git a/easyanimate/vae/ldm/models/cogvideox_casual3dcnn.py b/easyanimate/vae/ldm/models/cogvideox_casual3dcnn.py index a64c6160f52d6db2a7b9f60bef2ffe1fda615266..5a4db1731d13babd58ca635763a588d8487e253e 100644 --- a/easyanimate/vae/ldm/models/cogvideox_casual3dcnn.py +++ b/easyanimate/vae/ldm/models/cogvideox_casual3dcnn.py @@ -277,23 +277,23 @@ class AutoencoderKLMagvit_CogVideoX(pl.LightningModule): training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters()) else: training_list = list(self.decoder.parameters()) - opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9)) + opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2) elif self.train_encoder_only: if self.quant_conv is not None: training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters()) else: training_list = list(self.encoder.parameters()) - opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9)) + opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2) else: training_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) if self.quant_conv is not None: training_list = training_list + list(self.quant_conv.parameters()) if self.post_quant_conv is not None: training_list = training_list + list(self.post_quant_conv.parameters()) - opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam( + opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2) + opt_disc = torch.optim.AdamW( list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()), - lr=lr, betas=(0.5, 0.9) + lr=lr, betas=(0.9, 0.999), weight_decay=5e-2 ) return [opt_ae, opt_disc], [] diff --git a/easyanimate/vae/ldm/models/omnigen_casual3dcnn.py b/easyanimate/vae/ldm/models/omnigen_casual3dcnn.py index 12692dadfc4a6aae9cc744f4cbba05096a8ac65d..37fee0e82adc1d7d356b7c8a1bcfd84865a64ca4 100644 --- a/easyanimate/vae/ldm/models/omnigen_casual3dcnn.py +++ b/easyanimate/vae/ldm/models/omnigen_casual3dcnn.py @@ -95,6 +95,7 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule): out_channels: int = 3, ch = 128, ch_mult = [ 1,2,4,4 ], + block_out_channels = [128, 256, 512, 512], use_gc_blocks = None, down_block_types: tuple = None, up_block_types: tuple = None, @@ -129,8 +130,9 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule): in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, - ch = ch, - ch_mult = ch_mult, + ch=ch, + ch_mult=ch_mult, + block_out_channels=block_out_channels, use_gc_blocks=use_gc_blocks, mid_block_type=mid_block_type, mid_block_use_attention=mid_block_use_attention, @@ -144,6 +146,7 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule): slice_mag_vae=slice_mag_vae, slice_compression_vae=slice_compression_vae, cache_compression_vae=cache_compression_vae, + cache_mag_vae=cache_mag_vae, spatial_group_norm=spatial_group_norm, mini_batch_encoder=mini_batch_encoder, ) @@ -152,8 +155,9 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule): in_channels=latent_channels, out_channels=out_channels, up_block_types=up_block_types, - ch = ch, - ch_mult = ch_mult, + ch=ch, + ch_mult=ch_mult, + block_out_channels=block_out_channels, use_gc_blocks=use_gc_blocks, mid_block_type=mid_block_type, mid_block_use_attention=mid_block_use_attention, @@ -292,23 +296,23 @@ class AutoencoderKLMagvit_fromOmnigen(pl.LightningModule): training_list = list(self.decoder.parameters()) + list(self.post_quant_conv.parameters()) else: training_list = list(self.decoder.parameters()) - opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9)) + opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2) elif self.train_encoder_only: if self.quant_conv is not None: training_list = list(self.encoder.parameters()) + list(self.quant_conv.parameters()) else: training_list = list(self.encoder.parameters()) - opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9)) + opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2) else: training_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) if self.quant_conv is not None: training_list = training_list + list(self.quant_conv.parameters()) if self.post_quant_conv is not None: training_list = training_list + list(self.post_quant_conv.parameters()) - opt_ae = torch.optim.Adam(training_list, lr=lr, betas=(0.5, 0.9)) - opt_disc = torch.optim.Adam( + opt_ae = torch.optim.AdamW(training_list, lr=lr, betas=(0.9, 0.999), weight_decay=5e-2) + opt_disc = torch.optim.AdamW( list(self.loss.discriminator3d.parameters()) + list(self.loss.discriminator.parameters()), - lr=lr, betas=(0.5, 0.9) + lr=lr, betas=(0.9, 0.999), weight_decay=5e-2 ) return [opt_ae, opt_disc], [] diff --git a/easyanimate/vae/ldm/models/omnigen_enc_dec.py b/easyanimate/vae/ldm/models/omnigen_enc_dec.py index ff501f14de342f0541d1023a870777bdd50703f5..3c1f35996449b85afc55e208dc78d24f93a59340 100644 --- a/easyanimate/vae/ldm/models/omnigen_enc_dec.py +++ b/easyanimate/vae/ldm/models/omnigen_enc_dec.py @@ -58,6 +58,7 @@ class Encoder(nn.Module): down_block_types = ("SpatialDownBlock3D",), ch = 128, ch_mult = [1,2,4,4,], + block_out_channels = [128, 256, 512, 512], use_gc_blocks = None, mid_block_type: str = "MidBlock3D", mid_block_use_attention: bool = True, @@ -77,7 +78,8 @@ class Encoder(nn.Module): verbose = False, ): super().__init__() - block_out_channels = [ch * i for i in ch_mult] + if block_out_channels is None: + block_out_channels = [ch * i for i in ch_mult] assert len(down_block_types) == len(block_out_channels), ( "Number of down block types must match number of block output channels." ) @@ -364,6 +366,7 @@ class Decoder(nn.Module): up_block_types = ("SpatialUpBlock3D",), ch = 128, ch_mult = [1,2,4,4,], + block_out_channels = [128, 256, 512, 512], use_gc_blocks = None, mid_block_type: str = "MidBlock3D", mid_block_use_attention: bool = True, @@ -382,7 +385,8 @@ class Decoder(nn.Module): verbose = False, ): super().__init__() - block_out_channels = [ch * i for i in ch_mult] + if block_out_channels is None: + block_out_channels = [ch * i for i in ch_mult] assert len(up_block_types) == len(block_out_channels), ( "Number of up block types must match number of block output channels." ) diff --git a/easyanimate/vae/ldm/modules/losses/contperceptual.py b/easyanimate/vae/ldm/modules/losses/contperceptual.py index c344005ad243e885030a5f8b25f6b8f8bee46ec2..205c5d5fcdef14625924e4cde8d1da0e77438363 100644 --- a/easyanimate/vae/ldm/modules/losses/contperceptual.py +++ b/easyanimate/vae/ldm/modules/losses/contperceptual.py @@ -9,7 +9,8 @@ from ..vaemodules.discriminator import Discriminator3D class LPIPSWithDiscriminator(nn.Module): def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, - perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + outlier_penalty_loss_r=3.0, outlier_penalty_loss_weight=1e5, disc_loss="hinge", l2_loss_weight=0.0, l1_loss_weight=1.0): super().__init__() @@ -34,6 +35,8 @@ class LPIPSWithDiscriminator(nn.Module): self.disc_factor = disc_factor self.discriminator_weight = disc_weight self.disc_conditional = disc_conditional + self.outlier_penalty_loss_r = outlier_penalty_loss_r + self.outlier_penalty_loss_weight = outlier_penalty_loss_weight self.l1_loss_weight = l1_loss_weight self.l2_loss_weight = l2_loss_weight @@ -50,6 +53,18 @@ class LPIPSWithDiscriminator(nn.Module): d_weight = d_weight * self.discriminator_weight return d_weight + def outlier_penalty_loss(self, posteriors, r): + batch_size, channels, frames, height, width = posteriors.shape + mean_X = posteriors.mean(dim=(3, 4), keepdim=True) + std_X = posteriors.std(dim=(3, 4), keepdim=True) + + diff = torch.abs(posteriors - mean_X) + penalty = torch.maximum(diff - r * std_X, torch.zeros_like(diff)) + + opl = penalty.sum(dim=(3, 4)) / (height * width) + opl_final = opl.mean(dim=(0, 1, 2)) + return opl_final + def forward(self, inputs, reconstructions, posteriors, optimizer_idx, global_step, last_layer=None, cond=None, split="train", weights=None): @@ -86,6 +101,8 @@ class LPIPSWithDiscriminator(nn.Module): kl_loss = posteriors.kl() kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + outlier_penalty_loss = self.outlier_penalty_loss(posteriors.mode(), self.outlier_penalty_loss_r) * self.outlier_penalty_loss_weight + # now the GAN part if optimizer_idx == 0: # generator update @@ -102,13 +119,13 @@ class LPIPSWithDiscriminator(nn.Module): try: d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) except RuntimeError: - assert not self.training + # assert not self.training d_weight = torch.tensor(0.0) else: d_weight = torch.tensor(0.0) disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) - loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + outlier_penalty_loss log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), diff --git a/easyanimate/vae/ldm/modules/vaemodules/__init__.py b/easyanimate/vae/ldm/modules/vaemodules/__init__.py old mode 100755 new mode 100644 diff --git a/easyanimate/vae/ldm/modules/vaemodules/activations.py b/easyanimate/vae/ldm/modules/vaemodules/activations.py old mode 100755 new mode 100644 diff --git a/easyanimate/vae/ldm/modules/vaemodules/common.py b/easyanimate/vae/ldm/modules/vaemodules/common.py old mode 100755 new mode 100644 index f85bf39652f0ff95585df31906eaef2bd215e479..f865527b2ccae8922adcb49ec7465a37d58fb4ca --- a/easyanimate/vae/ldm/modules/vaemodules/common.py +++ b/easyanimate/vae/ldm/modules/vaemodules/common.py @@ -8,6 +8,17 @@ from einops import rearrange, repeat from .activations import get_activation +try: + current_version = torch.__version__ + version_numbers = [int(x) for x in current_version.split('.')[:2]] + if version_numbers[0] < 2 or (version_numbers[0] == 2 and version_numbers[1] < 2): + need_to_float = True + else: + need_to_float = False +except Exception as e: + print("Encountered an error with Torch version. Set the data type to float in the VAE. ") + need_to_float = False + def cast_tuple(t, length = 1): return t if isinstance(t, tuple) else ((t,) * length) @@ -66,10 +77,15 @@ class CausalConv3d(nn.Conv3d): **kwargs, ) + def _clear_conv_cache(self): + del self.prev_features + self.prev_features = None + def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, C, T, H, W) dtype = x.dtype - x = x.float() + if need_to_float: + x = x.float() if self.padding_flag == 0: x = F.pad( x, @@ -85,7 +101,11 @@ class CausalConv3d(nn.Conv3d): mode="replicate", # TODO: check if this is necessary ) x = x.to(dtype=dtype) - self.prev_features = x[:, :, -self.temporal_padding:] + + # Clear cache before + self._clear_conv_cache() + # We could move these to the cpu for a lower VRAM + self.prev_features = x[:, :, -self.temporal_padding:].clone() b, c, f, h, w = x.size() outputs = [] @@ -105,7 +125,11 @@ class CausalConv3d(nn.Conv3d): [self.prev_features, x], dim = 2 ) x = x.to(dtype=dtype) - self.prev_features = x[:, :, -self.temporal_padding:] + + # Clear cache before + self._clear_conv_cache() + # We could move these to the cpu for a lower VRAM + self.prev_features = x[:, :, -self.temporal_padding:].clone() b, c, f, h, w = x.size() outputs = [] @@ -122,7 +146,12 @@ class CausalConv3d(nn.Conv3d): mode="replicate", # TODO: check if this is necessary ) x = x.to(dtype=dtype) - self.prev_features = x[:, :, -self.temporal_padding:] + + # Clear cache before + self._clear_conv_cache() + # We could move these to the cpu for a lower VRAM + self.prev_features = x[:, :, -self.temporal_padding:].clone() + return super().forward(x) elif self.padding_flag == 6: if self.t_stride == 2: @@ -133,7 +162,12 @@ class CausalConv3d(nn.Conv3d): x = torch.concat( [self.prev_features, x], dim = 2 ) - self.prev_features = x[:, :, -self.temporal_padding:] + + # Clear cache before + self._clear_conv_cache() + # We could move these to the cpu for a lower VRAM + self.prev_features = x[:, :, -self.temporal_padding:].clone() + x = x.to(dtype=dtype) return super().forward(x) else: diff --git a/easyanimate/vae/ldm/modules/vaemodules/down_blocks.py b/easyanimate/vae/ldm/modules/vaemodules/down_blocks.py old mode 100755 new mode 100644 diff --git a/easyanimate/vae/ldm/modules/vaemodules/mid_blocks.py b/easyanimate/vae/ldm/modules/vaemodules/mid_blocks.py old mode 100755 new mode 100644 diff --git a/easyanimate/vae/ldm/modules/vaemodules/up_blocks.py b/easyanimate/vae/ldm/modules/vaemodules/up_blocks.py old mode 100755 new mode 100644 diff --git a/requirements.txt b/requirements.txt index 4ee48530662b177c0877c7f281b01ba199f794f5..7cd6fb5195c4560639b037d620d2b0f61c7e24a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,6 @@ tomesd torch>=2.1.2 torchdiffeq torchsde -xformers decord datasets numpy @@ -21,8 +20,6 @@ tensorboard beautifulsoup4 ftfy func_timeout -deepspeed accelerate>=0.25.0 -gradio>=3.41.2 -diffusers>=0.30.1 -transformers>=4.37.2 \ No newline at end of file +diffusers==0.30.1 +transformers==4.46.2 \ No newline at end of file