Julian Bilcke
we are going to hack into finetrainers
9fd1204
raw
history blame
2.69 kB
import torch
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanTransformer3DModel
from transformers import AutoTokenizer, T5EncoderModel
from finetrainers.models.utils import _expand_conv3d_with_zeroed_weights
from finetrainers.models.wan import WanControlModelSpecification
class DummyWanControlModelSpecification(WanControlModelSpecification):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# This needs to be updated for the test to work correctly.
# TODO(aryan): it will not be needed if we hosted the dummy model so that the correct config could be loaded
# with ModelSpecification::_load_configs
self.transformer_config.in_channels = 16
def load_condition_models(self):
text_encoder = T5EncoderModel.from_pretrained(
"hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
return {"text_encoder": text_encoder, "tokenizer": tokenizer}
def load_latent_models(self):
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
# Doing so overrides things like _keep_in_fp32_modules
vae.to(self.vae_dtype)
self.vae_config = vae.config
return {"vae": vae}
def load_diffusion_models(self, new_in_features: int):
torch.manual_seed(0)
transformer = WanTransformer3DModel(
patch_size=(1, 2, 2),
num_attention_heads=2,
attention_head_dim=12,
in_channels=16,
out_channels=16,
text_dim=32,
freq_dim=256,
ffn_dim=32,
num_layers=2,
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
).to(self.transformer_dtype)
transformer.patch_embedding = _expand_conv3d_with_zeroed_weights(
transformer.patch_embedding, new_in_channels=new_in_features
)
transformer.register_to_config(in_channels=new_in_features)
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
# Doing so overrides things like _keep_in_fp32_modules
transformer.to(self.transformer_dtype)
scheduler = FlowMatchEulerDiscreteScheduler()
return {"transformer": transformer, "scheduler": scheduler}