VideoModelStudio
/
docs
/finetrainers-src-codebase
/finetrainers
/models
/cogview4
/control_specification.py
| import functools | |
| import os | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import safetensors.torch | |
| import torch | |
| from accelerate import init_empty_weights | |
| from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler | |
| from transformers import AutoTokenizer, GlmModel | |
| import finetrainers.functional as FF | |
| from finetrainers.data import ImageArtifact | |
| from finetrainers.models.modeling_utils import ControlModelSpecification | |
| from finetrainers.models.utils import DiagonalGaussianDistribution, _expand_linear_with_zeroed_weights | |
| from finetrainers.patches.dependencies.diffusers.control import control_channel_concat | |
| from finetrainers.processors import CogView4GLMProcessor, ProcessorMixin | |
| from finetrainers.typing import ArtifactType, SchedulerType | |
| from finetrainers.utils import _enable_vae_memory_optimizations, get_non_null_items, safetensors_torch_save_function | |
| from .base_specification import CogView4LatentEncodeProcessor | |
| class CogView4ControlModelSpecification(ControlModelSpecification): | |
| def __init__( | |
| self, | |
| pretrained_model_name_or_path: str = "THUDM/CogView4-6B", | |
| tokenizer_id: Optional[str] = None, | |
| text_encoder_id: Optional[str] = None, | |
| transformer_id: Optional[str] = None, | |
| vae_id: Optional[str] = None, | |
| text_encoder_dtype: torch.dtype = torch.bfloat16, | |
| transformer_dtype: torch.dtype = torch.bfloat16, | |
| vae_dtype: torch.dtype = torch.bfloat16, | |
| revision: Optional[str] = None, | |
| cache_dir: Optional[str] = None, | |
| condition_model_processors: List[ProcessorMixin] = None, | |
| latent_model_processors: List[ProcessorMixin] = None, | |
| control_model_processors: List[ProcessorMixin] = None, | |
| **kwargs, | |
| ) -> None: | |
| super().__init__( | |
| pretrained_model_name_or_path=pretrained_model_name_or_path, | |
| tokenizer_id=tokenizer_id, | |
| text_encoder_id=text_encoder_id, | |
| transformer_id=transformer_id, | |
| vae_id=vae_id, | |
| text_encoder_dtype=text_encoder_dtype, | |
| transformer_dtype=transformer_dtype, | |
| vae_dtype=vae_dtype, | |
| revision=revision, | |
| cache_dir=cache_dir, | |
| ) | |
| if condition_model_processors is None: | |
| condition_model_processors = [CogView4GLMProcessor(["encoder_hidden_states"])] | |
| if latent_model_processors is None: | |
| latent_model_processors = [ | |
| CogView4LatentEncodeProcessor(["latents", "original_size", "target_size", "crop_coords"]) | |
| ] | |
| if control_model_processors is None: | |
| control_model_processors = [ | |
| CogView4LatentEncodeProcessor(["control_latents", "original_size", "target_size", "crop_coords"]) | |
| ] | |
| self.condition_model_processors = condition_model_processors | |
| self.latent_model_processors = latent_model_processors | |
| self.control_model_processors = control_model_processors | |
| def control_injection_layer_name(self): | |
| return "patch_embed.proj" | |
| def _resolution_dim_keys(self): | |
| return {"latents": (2, 3)} | |
| def load_condition_models(self) -> Dict[str, torch.nn.Module]: | |
| common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} | |
| if self.tokenizer_id is not None: | |
| tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs) | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs | |
| ) | |
| if self.text_encoder_id is not None: | |
| text_encoder = GlmModel.from_pretrained( | |
| self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs | |
| ) | |
| else: | |
| text_encoder = GlmModel.from_pretrained( | |
| self.pretrained_model_name_or_path, | |
| subfolder="text_encoder", | |
| torch_dtype=self.text_encoder_dtype, | |
| **common_kwargs, | |
| ) | |
| return {"tokenizer": tokenizer, "text_encoder": text_encoder} | |
| def load_latent_models(self) -> Dict[str, torch.nn.Module]: | |
| common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} | |
| if self.vae_id is not None: | |
| vae = AutoencoderKL.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs) | |
| else: | |
| vae = AutoencoderKL.from_pretrained( | |
| self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs | |
| ) | |
| return {"vae": vae} | |
| def load_diffusion_models(self, new_in_features: int) -> Dict[str, torch.nn.Module]: | |
| common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} | |
| if self.transformer_id is not None: | |
| transformer = CogView4Transformer2DModel.from_pretrained( | |
| self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs | |
| ) | |
| else: | |
| transformer = CogView4Transformer2DModel.from_pretrained( | |
| self.pretrained_model_name_or_path, | |
| subfolder="transformer", | |
| torch_dtype=self.transformer_dtype, | |
| **common_kwargs, | |
| ) | |
| actual_new_in_features = new_in_features * transformer.config.patch_size**2 | |
| transformer.patch_embed.proj = _expand_linear_with_zeroed_weights( | |
| transformer.patch_embed.proj, new_in_features=actual_new_in_features | |
| ) | |
| transformer.register_to_config(in_channels=new_in_features) | |
| scheduler = FlowMatchEulerDiscreteScheduler() | |
| return {"transformer": transformer, "scheduler": scheduler} | |
| def load_pipeline( | |
| self, | |
| tokenizer: Optional[AutoTokenizer] = None, | |
| text_encoder: Optional[GlmModel] = None, | |
| transformer: Optional[CogView4Transformer2DModel] = None, | |
| vae: Optional[AutoencoderKL] = None, | |
| scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, | |
| enable_slicing: bool = False, | |
| enable_tiling: bool = False, | |
| enable_model_cpu_offload: bool = False, | |
| training: bool = False, | |
| **kwargs, | |
| ) -> CogView4Pipeline: | |
| components = { | |
| "tokenizer": tokenizer, | |
| "text_encoder": text_encoder, | |
| "transformer": transformer, | |
| "vae": vae, | |
| # Load the scheduler based on CogView4's config instead of using the default initialization being used for training | |
| # "scheduler": scheduler, | |
| } | |
| components = get_non_null_items(components) | |
| pipe = CogView4Pipeline.from_pretrained( | |
| self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir | |
| ) | |
| pipe.text_encoder.to(self.text_encoder_dtype) | |
| pipe.vae.to(self.vae_dtype) | |
| _enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling) | |
| if not training: | |
| pipe.transformer.to(self.transformer_dtype) | |
| if enable_model_cpu_offload: | |
| pipe.enable_model_cpu_offload() | |
| return pipe | |
| def prepare_conditions( | |
| self, | |
| tokenizer: AutoTokenizer, | |
| text_encoder: GlmModel, | |
| caption: str, | |
| max_sequence_length: int = 1024, | |
| **kwargs, | |
| ) -> Dict[str, Any]: | |
| conditions = { | |
| "tokenizer": tokenizer, | |
| "text_encoder": text_encoder, | |
| "caption": caption, | |
| "max_sequence_length": max_sequence_length, | |
| **kwargs, | |
| } | |
| input_keys = set(conditions.keys()) | |
| conditions = super().prepare_conditions(**conditions) | |
| conditions = {k: v for k, v in conditions.items() if k not in input_keys} | |
| return conditions | |
| def prepare_latents( | |
| self, | |
| vae: AutoencoderKL, | |
| image: Optional[torch.Tensor] = None, | |
| video: Optional[torch.Tensor] = None, | |
| control_image: Optional[torch.Tensor] = None, | |
| control_video: Optional[torch.Tensor] = None, | |
| generator: Optional[torch.Generator] = None, | |
| compute_posterior: bool = True, | |
| _original_height: Optional[int] = None, | |
| _original_width: Optional[int] = None, | |
| **kwargs, | |
| ) -> Dict[str, torch.Tensor]: | |
| common_kwargs = { | |
| "vae": vae, | |
| "generator": generator, | |
| "compute_posterior": compute_posterior, | |
| "_original_height": _original_height, | |
| "_original_width": _original_width, | |
| **kwargs, | |
| } | |
| conditions = {"image": image, "video": video, **common_kwargs} | |
| input_keys = set(conditions.keys()) | |
| conditions = ControlModelSpecification.prepare_latents(self, self.latent_model_processors, **conditions) | |
| conditions = {k: v for k, v in conditions.items() if k not in input_keys} | |
| control_conditions = {"image": control_image, "video": control_video, **common_kwargs} | |
| input_keys = set(control_conditions.keys()) | |
| control_conditions = ControlModelSpecification.prepare_latents( | |
| self, self.control_model_processors, **control_conditions | |
| ) | |
| control_conditions = {k: v for k, v in control_conditions.items() if k not in input_keys} | |
| return {**control_conditions, **conditions} | |
| def forward( | |
| self, | |
| transformer: CogView4Transformer2DModel, | |
| condition_model_conditions: Dict[str, torch.Tensor], | |
| latent_model_conditions: Dict[str, torch.Tensor], | |
| sigmas: torch.Tensor, | |
| generator: Optional[torch.Generator] = None, | |
| compute_posterior: bool = True, | |
| **kwargs, | |
| ) -> Tuple[torch.Tensor, ...]: | |
| base_image_sequence_length = 256 | |
| base_shift = 0.25 | |
| max_shift = 0.75 | |
| if compute_posterior: | |
| latents = latent_model_conditions.pop("latents") | |
| control_latents = latent_model_conditions.pop("control_latents") | |
| else: | |
| posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents")) | |
| latents = posterior.sample(generator=generator) | |
| del posterior | |
| control_posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("control_latents")) | |
| control_latents = control_posterior.sample(generator=generator) | |
| del control_posterior | |
| if getattr(self.vae_config, "shift_factor") is not None: | |
| latents = (latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor | |
| control_latents = (control_latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor | |
| noise = torch.zeros_like(latents).normal_(generator=generator) | |
| timesteps = (sigmas.flatten() * 1000.0).long() | |
| image_sequence_length = latents.size(2) * latents.size(3) // self.transformer_config.patch_size**2 | |
| mu = (image_sequence_length / base_image_sequence_length) ** 0.5 | |
| mu = mu * max_shift + base_shift | |
| shifted_sigmas = mu / (mu + (1 / sigmas - 1) ** 1.0) | |
| noisy_latents = FF.flow_match_xt(latents, noise, shifted_sigmas) | |
| noisy_latents = torch.cat([noisy_latents, control_latents], dim=1) | |
| latent_model_conditions["hidden_states"] = noisy_latents.to(latents) | |
| pred = transformer( | |
| **latent_model_conditions, | |
| **condition_model_conditions, | |
| timestep=timesteps, | |
| return_dict=False, | |
| )[0] | |
| target = FF.flow_match_target(noise, latents) | |
| # NOTE: shifted_sigmas loss weighting seems to work better than sigmas. Needs more investigation | |
| # but let's keep it this way for now. Longer training runs should reveal more insights. | |
| # return pred, target, sigmas | |
| return pred, target, shifted_sigmas | |
| def validation( | |
| self, | |
| pipeline: CogView4Pipeline, | |
| prompt: str, | |
| control_image: torch.Tensor, | |
| height: Optional[int] = None, | |
| width: Optional[int] = None, | |
| num_inference_steps: int = 50, | |
| generator: Optional[torch.Generator] = None, | |
| **kwargs, | |
| ) -> List[ArtifactType]: | |
| with torch.no_grad(): | |
| dtype = pipeline.vae.dtype | |
| device = pipeline._execution_device | |
| in_channels = self.transformer_config.in_channels # We need to use the original in_channels | |
| latents = pipeline.prepare_latents(1, in_channels, height, width, dtype, device, generator) | |
| control_image = pipeline.image_processor.preprocess(control_image, height=height, width=width) | |
| control_image = control_image.to(device=device, dtype=dtype) | |
| control_latents = pipeline.vae.encode(control_image).latent_dist.sample(generator=generator) | |
| if getattr(self.vae_config, "shift_factor") is not None: | |
| control_latents = (control_latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor | |
| generation_kwargs = { | |
| "latents": latents, | |
| "prompt": prompt, | |
| "height": height, | |
| "width": width, | |
| "num_inference_steps": num_inference_steps, | |
| "generator": generator, | |
| "return_dict": True, | |
| "output_type": "pil", | |
| } | |
| generation_kwargs = get_non_null_items(generation_kwargs) | |
| with control_channel_concat(pipeline.transformer, ["hidden_states"], [control_latents], dims=[1]): | |
| image = pipeline(**generation_kwargs).images[0] | |
| return [ImageArtifact(value=image)] | |
| def _save_lora_weights( | |
| self, | |
| directory: str, | |
| transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, | |
| norm_state_dict: Optional[Dict[str, torch.Tensor]] = None, | |
| scheduler: Optional[SchedulerType] = None, | |
| metadata: Optional[Dict[str, str]] = None, | |
| *args, | |
| **kwargs, | |
| ) -> None: | |
| # TODO(aryan): this needs refactoring | |
| if transformer_state_dict is not None: | |
| CogView4Pipeline.save_lora_weights( | |
| directory, | |
| transformer_state_dict, | |
| save_function=functools.partial(safetensors_torch_save_function, metadata=metadata), | |
| safe_serialization=True, | |
| ) | |
| if norm_state_dict is not None: | |
| safetensors.torch.save_file(norm_state_dict, os.path.join(directory, "norm_state_dict.safetensors")) | |
| if scheduler is not None: | |
| scheduler.save_pretrained(os.path.join(directory, "scheduler")) | |
| def _save_model( | |
| self, | |
| directory: str, | |
| transformer: CogView4Transformer2DModel, | |
| transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, | |
| scheduler: Optional[SchedulerType] = None, | |
| ) -> None: | |
| # TODO(aryan): this needs refactoring | |
| if transformer_state_dict is not None: | |
| with init_empty_weights(): | |
| transformer_copy = CogView4Transformer2DModel.from_config(transformer.config) | |
| transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) | |
| transformer_copy.save_pretrained(os.path.join(directory, "transformer")) | |
| if scheduler is not None: | |
| scheduler.save_pretrained(os.path.join(directory, "scheduler")) | |
| def _original_control_layer_in_features(self): | |
| return self.transformer_config.in_channels | |
| def _original_control_layer_out_features(self): | |
| return self.transformer_config.num_attention_heads * self.transformer_config.attention_head_dim | |
| def _qk_norm_identifiers(self): | |
| return ["attn1.norm_q", "attn1.norm_k"] | |