# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT # except for the third-party components listed below. # Hunyuan 3D does not impose any additional limitations beyond what is outlined # in the repsective licenses of these third-party components. # Users must comply with all terms and conditions of original licenses of these third-party # components and must ensure that the usage of the third party components adheres to # all relevant laws and regulations. # For avoidance of doubts, Hunyuan 3D means the large language models and # their software and algorithms, including trained model weights, parameters (including # optimizer states), machine-learning model code, inference-enabling code, training-enabling code, # fine-tuning enabling code and other elements of the foregoing made publicly available # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. import torch import torch.nn as nn import math import torch.nn.functional as F from diffusers.models.attention import FeedForward class AddAuxiliaryLoss(torch.autograd.Function): """ The trick function of adding auxiliary (aux) loss, which includes the gradient of the aux loss during backpropagation. """ @staticmethod def forward(ctx, x, loss): assert loss.numel() == 1 ctx.dtype = loss.dtype ctx.required_aux_loss = loss.requires_grad return x @staticmethod def backward(ctx, grad_output): grad_loss = None if ctx.required_aux_loss: grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) return grad_output, grad_loss class MoEGate(nn.Module): def __init__( self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01 ): super().__init__() self.top_k = num_experts_per_tok self.n_routed_experts = num_experts self.scoring_func = "softmax" self.alpha = aux_loss_alpha self.seq_aux = False # topk selection algorithm self.norm_topk_prob = False self.gating_dim = embed_dim self.weight = nn.Parameter( torch.empty((self.n_routed_experts, self.gating_dim)) ) self.reset_parameters() def reset_parameters(self) -> None: import torch.nn.init as init init.kaiming_uniform_(self.weight, a=math.sqrt(5)) def forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape # print(bsz, seq_len, h) ### compute gating score hidden_states = hidden_states.view(-1, h) logits = F.linear(hidden_states, self.weight, None) if self.scoring_func == "softmax": scores = logits.softmax(dim=-1) else: raise NotImplementedError( f"insupportable scoring function for MoE gating: {self.scoring_func}" ) ### select top-k experts topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False) ### norm gate to sum 1 if self.top_k > 1 and self.norm_topk_prob: denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 topk_weight = topk_weight / denominator ### expert-level computation auxiliary loss if self.training and self.alpha > 0.0: scores_for_aux = scores aux_topk = self.top_k # always compute aux loss based on the naive greedy topk method topk_idx_for_aux_loss = topk_idx.view(bsz, -1) if self.seq_aux: scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) ce = torch.zeros( bsz, self.n_routed_experts, device=hidden_states.device ) ce.scatter_add_( 1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device), ).div_(seq_len * aux_topk / self.n_routed_experts) aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() aux_loss = aux_loss * self.alpha else: mask_ce = F.one_hot( topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts ) ce = mask_ce.float().mean(0) Pi = scores_for_aux.mean(0) fi = ce * self.n_routed_experts aux_loss = (Pi * fi).sum() * self.alpha else: aux_loss = None return topk_idx, topk_weight, aux_loss class MoEBlock(nn.Module): def __init__( self, dim, num_experts=8, moe_top_k=2, activation_fn="gelu", dropout=0.0, final_dropout=False, ff_inner_dim=None, ff_bias=True, ): super().__init__() self.moe_top_k = moe_top_k self.experts = nn.ModuleList([ FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) for i in range(num_experts) ]) self.gate = MoEGate( embed_dim=dim, num_experts=num_experts, num_experts_per_tok=moe_top_k ) self.shared_experts = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) def initialize_weight(self): pass def forward(self, hidden_states): identity = hidden_states orig_shape = hidden_states.shape topk_idx, topk_weight, aux_loss = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) flat_topk_idx = topk_idx.view(-1) if self.training: hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim=0) y = torch.empty_like(hidden_states, dtype=hidden_states.dtype) for i, expert in enumerate(self.experts): tmp = expert(hidden_states[flat_topk_idx == i]) y[flat_topk_idx == i] = tmp.to(hidden_states.dtype) y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) y = y.view(*orig_shape) y = AddAuxiliaryLoss.apply(y, aux_loss) else: y = self.moe_infer( hidden_states, flat_topk_idx, topk_weight.view(-1, 1) ).view(*orig_shape) y = y + self.shared_experts(identity) return y @torch.no_grad() def moe_infer(self, x, flat_expert_indices, flat_expert_weights): expert_cache = torch.zeros_like(x) idxs = flat_expert_indices.argsort() tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0) token_idxs = idxs // self.moe_top_k for i, end_idx in enumerate(tokens_per_expert): start_idx = 0 if i == 0 else tokens_per_expert[i - 1] if start_idx == end_idx: continue expert = self.experts[i] exp_token_idx = token_idxs[start_idx:end_idx] expert_tokens = x[exp_token_idx] expert_out = expert(expert_tokens) expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) # for fp16 and other dtype expert_cache = expert_cache.to(expert_out.dtype) expert_cache.scatter_reduce_( 0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum", ) return expert_cache