Spaces:
Runtime error
Runtime error
| # Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/t5.py | |
| # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
| import math | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusers.configuration_utils import ConfigMixin | |
| from diffusers.loaders.single_file_model import FromOriginalModelMixin | |
| from diffusers.models.modeling_utils import ModelMixin | |
| def fp16_clamp(x): | |
| if x.dtype == torch.float16 and torch.isinf(x).any(): | |
| clamp = torch.finfo(x.dtype).max - 1000 | |
| x = torch.clamp(x, min=-clamp, max=clamp) | |
| return x | |
| def init_weights(m): | |
| if isinstance(m, T5LayerNorm): | |
| nn.init.ones_(m.weight) | |
| elif isinstance(m, T5FeedForward): | |
| nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) | |
| nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) | |
| nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) | |
| elif isinstance(m, T5Attention): | |
| nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) | |
| nn.init.normal_(m.k.weight, std=m.dim**-0.5) | |
| nn.init.normal_(m.v.weight, std=m.dim**-0.5) | |
| nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) | |
| elif isinstance(m, T5RelativeEmbedding): | |
| nn.init.normal_( | |
| m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) | |
| class GELU(nn.Module): | |
| def forward(self, x): | |
| return 0.5 * x * (1.0 + torch.tanh( | |
| math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) | |
| class T5LayerNorm(nn.Module): | |
| def __init__(self, dim, eps=1e-6): | |
| super(T5LayerNorm, self).__init__() | |
| self.dim = dim | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x): | |
| x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + | |
| self.eps) | |
| if self.weight.dtype in [torch.float16, torch.bfloat16]: | |
| x = x.type_as(self.weight) | |
| return self.weight * x | |
| class T5Attention(nn.Module): | |
| def __init__(self, dim, dim_attn, num_heads, dropout=0.1): | |
| assert dim_attn % num_heads == 0 | |
| super(T5Attention, self).__init__() | |
| self.dim = dim | |
| self.dim_attn = dim_attn | |
| self.num_heads = num_heads | |
| self.head_dim = dim_attn // num_heads | |
| # layers | |
| self.q = nn.Linear(dim, dim_attn, bias=False) | |
| self.k = nn.Linear(dim, dim_attn, bias=False) | |
| self.v = nn.Linear(dim, dim_attn, bias=False) | |
| self.o = nn.Linear(dim_attn, dim, bias=False) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x, context=None, mask=None, pos_bias=None): | |
| """ | |
| x: [B, L1, C]. | |
| context: [B, L2, C] or None. | |
| mask: [B, L2] or [B, L1, L2] or None. | |
| """ | |
| # check inputs | |
| context = x if context is None else context | |
| b, n, c = x.size(0), self.num_heads, self.head_dim | |
| # compute query, key, value | |
| q = self.q(x).view(b, -1, n, c) | |
| k = self.k(context).view(b, -1, n, c) | |
| v = self.v(context).view(b, -1, n, c) | |
| # attention bias | |
| attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) | |
| if pos_bias is not None: | |
| attn_bias += pos_bias | |
| if mask is not None: | |
| assert mask.ndim in [2, 3] | |
| mask = mask.view(b, 1, 1, | |
| -1) if mask.ndim == 2 else mask.unsqueeze(1) | |
| attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) | |
| # compute attention (T5 does not use scaling) | |
| attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias | |
| attn = F.softmax(attn.float(), dim=-1).type_as(attn) | |
| x = torch.einsum('bnij,bjnc->binc', attn, v) | |
| # output | |
| x = x.reshape(b, -1, n * c) | |
| x = self.o(x) | |
| x = self.dropout(x) | |
| return x | |
| class T5FeedForward(nn.Module): | |
| def __init__(self, dim, dim_ffn, dropout=0.1): | |
| super(T5FeedForward, self).__init__() | |
| self.dim = dim | |
| self.dim_ffn = dim_ffn | |
| # layers | |
| self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) | |
| self.fc1 = nn.Linear(dim, dim_ffn, bias=False) | |
| self.fc2 = nn.Linear(dim_ffn, dim, bias=False) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x): | |
| x = self.fc1(x) * self.gate(x) | |
| x = self.dropout(x) | |
| x = self.fc2(x) | |
| x = self.dropout(x) | |
| return x | |
| class T5SelfAttention(nn.Module): | |
| def __init__(self, | |
| dim, | |
| dim_attn, | |
| dim_ffn, | |
| num_heads, | |
| num_buckets, | |
| shared_pos=True, | |
| dropout=0.1): | |
| super(T5SelfAttention, self).__init__() | |
| self.dim = dim | |
| self.dim_attn = dim_attn | |
| self.dim_ffn = dim_ffn | |
| self.num_heads = num_heads | |
| self.num_buckets = num_buckets | |
| self.shared_pos = shared_pos | |
| # layers | |
| self.norm1 = T5LayerNorm(dim) | |
| self.attn = T5Attention(dim, dim_attn, num_heads, dropout) | |
| self.norm2 = T5LayerNorm(dim) | |
| self.ffn = T5FeedForward(dim, dim_ffn, dropout) | |
| self.pos_embedding = None if shared_pos else T5RelativeEmbedding( | |
| num_buckets, num_heads, bidirectional=True) | |
| def forward(self, x, mask=None, pos_bias=None): | |
| e = pos_bias if self.shared_pos else self.pos_embedding( | |
| x.size(1), x.size(1)) | |
| x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) | |
| x = fp16_clamp(x + self.ffn(self.norm2(x))) | |
| return x | |
| class T5CrossAttention(nn.Module): | |
| def __init__(self, | |
| dim, | |
| dim_attn, | |
| dim_ffn, | |
| num_heads, | |
| num_buckets, | |
| shared_pos=True, | |
| dropout=0.1): | |
| super(T5CrossAttention, self).__init__() | |
| self.dim = dim | |
| self.dim_attn = dim_attn | |
| self.dim_ffn = dim_ffn | |
| self.num_heads = num_heads | |
| self.num_buckets = num_buckets | |
| self.shared_pos = shared_pos | |
| # layers | |
| self.norm1 = T5LayerNorm(dim) | |
| self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) | |
| self.norm2 = T5LayerNorm(dim) | |
| self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) | |
| self.norm3 = T5LayerNorm(dim) | |
| self.ffn = T5FeedForward(dim, dim_ffn, dropout) | |
| self.pos_embedding = None if shared_pos else T5RelativeEmbedding( | |
| num_buckets, num_heads, bidirectional=False) | |
| def forward(self, | |
| x, | |
| mask=None, | |
| encoder_states=None, | |
| encoder_mask=None, | |
| pos_bias=None): | |
| e = pos_bias if self.shared_pos else self.pos_embedding( | |
| x.size(1), x.size(1)) | |
| x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) | |
| x = fp16_clamp(x + self.cross_attn( | |
| self.norm2(x), context=encoder_states, mask=encoder_mask)) | |
| x = fp16_clamp(x + self.ffn(self.norm3(x))) | |
| return x | |
| class T5RelativeEmbedding(nn.Module): | |
| def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): | |
| super(T5RelativeEmbedding, self).__init__() | |
| self.num_buckets = num_buckets | |
| self.num_heads = num_heads | |
| self.bidirectional = bidirectional | |
| self.max_dist = max_dist | |
| # layers | |
| self.embedding = nn.Embedding(num_buckets, num_heads) | |
| def forward(self, lq, lk): | |
| device = self.embedding.weight.device | |
| # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ | |
| # torch.arange(lq).unsqueeze(1).to(device) | |
| if torch.device(type="meta") != device: | |
| rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ | |
| torch.arange(lq, device=device).unsqueeze(1) | |
| else: | |
| rel_pos = torch.arange(lk).unsqueeze(0) - \ | |
| torch.arange(lq).unsqueeze(1) | |
| rel_pos = self._relative_position_bucket(rel_pos) | |
| rel_pos_embeds = self.embedding(rel_pos) | |
| rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( | |
| 0) # [1, N, Lq, Lk] | |
| return rel_pos_embeds.contiguous() | |
| def _relative_position_bucket(self, rel_pos): | |
| # preprocess | |
| if self.bidirectional: | |
| num_buckets = self.num_buckets // 2 | |
| rel_buckets = (rel_pos > 0).long() * num_buckets | |
| rel_pos = torch.abs(rel_pos) | |
| else: | |
| num_buckets = self.num_buckets | |
| rel_buckets = 0 | |
| rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) | |
| # embeddings for small and large positions | |
| max_exact = num_buckets // 2 | |
| rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / | |
| math.log(self.max_dist / max_exact) * | |
| (num_buckets - max_exact)).long() | |
| rel_pos_large = torch.min( | |
| rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) | |
| rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) | |
| return rel_buckets | |
| class WanT5EncoderModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): | |
| def __init__(self, | |
| vocab, | |
| dim, | |
| dim_attn, | |
| dim_ffn, | |
| num_heads, | |
| num_layers, | |
| num_buckets, | |
| shared_pos=True, | |
| dropout=0.1): | |
| super(WanT5EncoderModel, self).__init__() | |
| self.dim = dim | |
| self.dim_attn = dim_attn | |
| self.dim_ffn = dim_ffn | |
| self.num_heads = num_heads | |
| self.num_layers = num_layers | |
| self.num_buckets = num_buckets | |
| self.shared_pos = shared_pos | |
| # layers | |
| self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ | |
| else nn.Embedding(vocab, dim) | |
| self.pos_embedding = T5RelativeEmbedding( | |
| num_buckets, num_heads, bidirectional=True) if shared_pos else None | |
| self.dropout = nn.Dropout(dropout) | |
| self.blocks = nn.ModuleList([ | |
| T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, | |
| shared_pos, dropout) for _ in range(num_layers) | |
| ]) | |
| self.norm = T5LayerNorm(dim) | |
| # initialize weights | |
| self.apply(init_weights) | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| ): | |
| x = self.token_embedding(input_ids) | |
| x = self.dropout(x) | |
| e = self.pos_embedding(x.size(1), | |
| x.size(1)) if self.shared_pos else None | |
| for block in self.blocks: | |
| x = block(x, attention_mask, pos_bias=e) | |
| x = self.norm(x) | |
| x = self.dropout(x) | |
| return (x, ) | |
| def from_pretrained(cls, pretrained_model_path, additional_kwargs={}, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16): | |
| def filter_kwargs(cls, kwargs): | |
| import inspect | |
| sig = inspect.signature(cls.__init__) | |
| valid_params = set(sig.parameters.keys()) - {'self', 'cls'} | |
| filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} | |
| return filtered_kwargs | |
| if low_cpu_mem_usage: | |
| try: | |
| import re | |
| from diffusers.models.modeling_utils import \ | |
| load_model_dict_into_meta | |
| from diffusers.utils import is_accelerate_available | |
| if is_accelerate_available(): | |
| import accelerate | |
| # Instantiate model with empty weights | |
| with accelerate.init_empty_weights(): | |
| model = cls(**filter_kwargs(cls, additional_kwargs)) | |
| param_device = "cpu" | |
| if pretrained_model_path.endswith(".safetensors"): | |
| from safetensors.torch import load_file | |
| state_dict = load_file(pretrained_model_path) | |
| else: | |
| state_dict = torch.load(pretrained_model_path, map_location="cpu") | |
| # move the params from meta device to cpu | |
| missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) | |
| if len(missing_keys) > 0: | |
| raise ValueError( | |
| f"Cannot load {cls} from {pretrained_model_path} because the following keys are" | |
| f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" | |
| " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" | |
| " those weights or else make sure your checkpoint file is correct." | |
| ) | |
| unexpected_keys = load_model_dict_into_meta( | |
| model, | |
| state_dict, | |
| device=param_device, | |
| dtype=torch_dtype, | |
| model_name_or_path=pretrained_model_path, | |
| ) | |
| if cls._keys_to_ignore_on_load_unexpected is not None: | |
| for pat in cls._keys_to_ignore_on_load_unexpected: | |
| unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] | |
| if len(unexpected_keys) > 0: | |
| print( | |
| f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" | |
| ) | |
| return model | |
| except Exception as e: | |
| print( | |
| f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead." | |
| ) |