drbh
commited on
Commit
Β·
3bdb4b8
1
Parent(s):
89e2950
feat: bump build for shared experts
Browse files- build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py +277 -1
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py +277 -1
- build/torch26-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
- build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py +277 -1
- build/torch26-cxx98-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
- build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py +277 -1
- build/torch26-cxx98-cu124-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
- build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py +277 -1
- build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_76c7de7.abi3.so +0 -3
- build/torch26-cxx98-cu126-x86_64-linux/megablocks/{_megablocks_9a1816c.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py +277 -1
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py +277 -1
- build/torch27-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
- build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py +277 -1
- build/torch27-cxx11-cu128-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so} +1 -1
- build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so +0 -3
- build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py +3 -3
- build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py +277 -1
build/torch26-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 10517576
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:070067fec0e735e865610caf4fc33b384fe8c9c47a002c365f740c82c5af1bab
|
| 3 |
size 10517576
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:4e4c48e189572141f6a140dd83f9eca19eaebbc20c5cd686aa0263aafec14533
|
| 3 |
-
size 10517576
|
|
|
|
|
|
|
|
|
|
|
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_89e2950
|
| 3 |
+
ops = torch.ops._megablocks_89e2950
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch26-cxx11-cu118-x86_64-linux/megablocks/layers.py
CHANGED
|
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# Global variable to store load balancing loss
|
| 156 |
_LOAD_BALANCING_LOSS = []
|
| 157 |
|
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
| 680 |
return x, expert_weights, router_scores
|
| 681 |
|
| 682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 684 |
|
| 685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 694 |
-
|
| 695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 698 |
|
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 722 |
hidden_size=self.experts.hidden_size,
|
| 723 |
mlp_impl=mlp_impl,
|
| 724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 725 |
return output, expert_weights_out
|
|
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
| 155 |
+
# Shared expert MLP forward pass
|
| 156 |
+
def shared_mlp_forward(
|
| 157 |
+
x: torch.Tensor,
|
| 158 |
+
up_proj_weight: torch.Tensor,
|
| 159 |
+
down_proj_weight: torch.Tensor,
|
| 160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 162 |
+
activation_fn: Optional[Any] = None,
|
| 163 |
+
gradient_scale: Optional[float] = None,
|
| 164 |
+
) -> torch.Tensor:
|
| 165 |
+
# Default activation function
|
| 166 |
+
if activation_fn is None:
|
| 167 |
+
activation_fn = torch.nn.functional.gelu
|
| 168 |
+
|
| 169 |
+
# Scale weights
|
| 170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
| 171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
| 172 |
+
if up_proj_bias is not None:
|
| 173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
| 174 |
+
if down_proj_bias is not None:
|
| 175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
| 176 |
+
|
| 177 |
+
# Resolve dtensors
|
| 178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
| 179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
| 180 |
+
if up_proj_bias is not None:
|
| 181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
| 182 |
+
if down_proj_bias is not None:
|
| 183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
| 184 |
+
|
| 185 |
+
# Up projection
|
| 186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
| 187 |
+
|
| 188 |
+
# Activation
|
| 189 |
+
x = activation_fn(x)
|
| 190 |
+
|
| 191 |
+
# Down projection
|
| 192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
| 193 |
+
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# Combine outputs from shared expert and regular experts
|
| 198 |
+
def combine_expert_shared_outputs(
|
| 199 |
+
shared_expert_out: torch.Tensor,
|
| 200 |
+
expert_out: torch.Tensor,
|
| 201 |
+
shared_expert_weighted_sum: bool = False,
|
| 202 |
+
moe_top_k: int = 1,
|
| 203 |
+
) -> torch.Tensor:
|
| 204 |
+
if shared_expert_weighted_sum:
|
| 205 |
+
# Weighted sum based on number of experts used
|
| 206 |
+
total_experts = moe_top_k + 1
|
| 207 |
+
shared_weight = 1.0 / total_experts
|
| 208 |
+
expert_weight = moe_top_k / total_experts
|
| 209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
| 210 |
+
else:
|
| 211 |
+
# Simple addition
|
| 212 |
+
return shared_expert_out + expert_out
|
| 213 |
+
|
| 214 |
+
|
| 215 |
# Global variable to store load balancing loss
|
| 216 |
_LOAD_BALANCING_LOSS = []
|
| 217 |
|
|
|
|
| 740 |
return x, expert_weights, router_scores
|
| 741 |
|
| 742 |
|
| 743 |
+
def moe_forward_with_shared_expert(
|
| 744 |
+
x: torch.Tensor,
|
| 745 |
+
router_weight: torch.Tensor,
|
| 746 |
+
moe_top_k: int,
|
| 747 |
+
moe_num_experts: int,
|
| 748 |
+
moe_jitter_eps: float = None,
|
| 749 |
+
moe_normalize_expert_weights: int = None,
|
| 750 |
+
uniform_expert_assignment: bool = False,
|
| 751 |
+
training: bool = False,
|
| 752 |
+
w1: torch.Tensor = None,
|
| 753 |
+
w2: torch.Tensor = None,
|
| 754 |
+
w1_bias: torch.Tensor = None,
|
| 755 |
+
w2_bias: torch.Tensor = None,
|
| 756 |
+
gradient_scale: Optional[float] = None,
|
| 757 |
+
alpha: float = 1.702,
|
| 758 |
+
sort_end_bit: int = 0,
|
| 759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
| 760 |
+
moe_capacity_factor: float = 1.0,
|
| 761 |
+
moe_expert_model_parallelism: bool = False,
|
| 762 |
+
forward_fn: Any = None,
|
| 763 |
+
hidden_size: int = None,
|
| 764 |
+
mlp_impl: str = "grouped",
|
| 765 |
+
# Shared expert parameters
|
| 766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
| 767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
| 768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
| 769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
| 770 |
+
shared_expert_weighted_sum: bool = False,
|
| 771 |
+
shared_activation_fn: Optional[Any] = None,
|
| 772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 773 |
+
|
| 774 |
+
# First, compute regular MoE forward pass
|
| 775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
| 776 |
+
x=x,
|
| 777 |
+
router_weight=router_weight,
|
| 778 |
+
moe_top_k=moe_top_k,
|
| 779 |
+
moe_num_experts=moe_num_experts,
|
| 780 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 783 |
+
training=training,
|
| 784 |
+
w1=w1,
|
| 785 |
+
w2=w2,
|
| 786 |
+
w1_bias=w1_bias,
|
| 787 |
+
w2_bias=w2_bias,
|
| 788 |
+
gradient_scale=gradient_scale,
|
| 789 |
+
alpha=alpha,
|
| 790 |
+
sort_end_bit=sort_end_bit,
|
| 791 |
+
expert_parallel_group=expert_parallel_group,
|
| 792 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
| 794 |
+
forward_fn=forward_fn,
|
| 795 |
+
hidden_size=hidden_size,
|
| 796 |
+
mlp_impl=mlp_impl,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
# If shared expert weights provided, compute shared expert output
|
| 800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
| 801 |
+
shared_expert_out = shared_mlp_forward(
|
| 802 |
+
x=x,
|
| 803 |
+
up_proj_weight=shared_up_proj_weight,
|
| 804 |
+
down_proj_weight=shared_down_proj_weight,
|
| 805 |
+
up_proj_bias=shared_up_proj_bias,
|
| 806 |
+
down_proj_bias=shared_down_proj_bias,
|
| 807 |
+
activation_fn=shared_activation_fn,
|
| 808 |
+
gradient_scale=gradient_scale,
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
# Combine expert outputs
|
| 812 |
+
combined_out = combine_expert_shared_outputs(
|
| 813 |
+
shared_expert_out=shared_expert_out,
|
| 814 |
+
expert_out=expert_out,
|
| 815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
| 816 |
+
moe_top_k=moe_top_k,
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
return combined_out, expert_weights, router_scores
|
| 820 |
+
|
| 821 |
+
# Return regular MoE output if no shared expert
|
| 822 |
+
return expert_out, expert_weights, router_scores
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def create_shared_expert_weights(
|
| 826 |
+
hidden_size: int,
|
| 827 |
+
shared_expert_hidden_size: int,
|
| 828 |
+
device: torch.device,
|
| 829 |
+
dtype: torch.dtype,
|
| 830 |
+
init_method: Any,
|
| 831 |
+
output_layer_init_method: Any = None,
|
| 832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 833 |
+
|
| 834 |
+
if output_layer_init_method is None:
|
| 835 |
+
output_layer_init_method = init_method
|
| 836 |
+
|
| 837 |
+
# Create weight tensors
|
| 838 |
+
up_proj_weight = torch.empty(
|
| 839 |
+
shared_expert_hidden_size,
|
| 840 |
+
hidden_size,
|
| 841 |
+
device=device,
|
| 842 |
+
dtype=dtype,
|
| 843 |
+
)
|
| 844 |
+
down_proj_weight = torch.empty(
|
| 845 |
+
hidden_size,
|
| 846 |
+
shared_expert_hidden_size,
|
| 847 |
+
device=device,
|
| 848 |
+
dtype=dtype,
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
# Initialize weights
|
| 852 |
+
init_method(up_proj_weight)
|
| 853 |
+
output_layer_init_method(down_proj_weight)
|
| 854 |
+
|
| 855 |
+
# No bias by default
|
| 856 |
+
return up_proj_weight, down_proj_weight, None, None
|
| 857 |
+
|
| 858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
| 859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
| 860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
| 861 |
+
# TODO: Replace with a more robust solution when available
|
| 862 |
+
def get_device_mesh(model):
|
| 863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
| 864 |
+
try:
|
| 865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
| 866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
| 867 |
+
# Extract the device_mesh from the closure
|
| 868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
| 869 |
+
except Exception:
|
| 870 |
+
return None
|
| 871 |
+
|
| 872 |
+
|
| 873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 874 |
|
| 875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 884 |
+
|
| 885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 886 |
+
if expert_parallel_group is None:
|
| 887 |
+
device_mesh = get_device_mesh(self)
|
| 888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 889 |
+
|
| 890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 892 |
|
|
|
|
| 916 |
hidden_size=self.experts.hidden_size,
|
| 917 |
mlp_impl=mlp_impl,
|
| 918 |
)
|
| 919 |
+
return output, expert_weights_out
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
| 923 |
+
|
| 924 |
+
def __init__(self):
|
| 925 |
+
super().__init__()
|
| 926 |
+
# Shared expert weights will be set by the user
|
| 927 |
+
self.shared_up_proj_weight = None
|
| 928 |
+
self.shared_down_proj_weight = None
|
| 929 |
+
self.shared_up_proj_bias = None
|
| 930 |
+
self.shared_down_proj_bias = None
|
| 931 |
+
self.shared_expert_weighted_sum = False
|
| 932 |
+
self.shared_activation_fn = None
|
| 933 |
+
|
| 934 |
+
def set_shared_expert_weights(
|
| 935 |
+
self,
|
| 936 |
+
up_proj_weight: torch.Tensor,
|
| 937 |
+
down_proj_weight: torch.Tensor,
|
| 938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 940 |
+
weighted_sum: bool = False,
|
| 941 |
+
activation_fn: Optional[Any] = None,
|
| 942 |
+
):
|
| 943 |
+
self.shared_up_proj_weight = up_proj_weight
|
| 944 |
+
self.shared_down_proj_weight = down_proj_weight
|
| 945 |
+
self.shared_up_proj_bias = up_proj_bias
|
| 946 |
+
self.shared_down_proj_bias = down_proj_bias
|
| 947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
| 948 |
+
self.shared_activation_fn = activation_fn
|
| 949 |
+
|
| 950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
| 952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
| 953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
| 954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
| 955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
| 956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 959 |
+
|
| 960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 961 |
+
if expert_parallel_group is None:
|
| 962 |
+
device_mesh = get_device_mesh(self)
|
| 963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 964 |
+
|
| 965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 967 |
+
|
| 968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
| 969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
| 970 |
+
|
| 971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
| 972 |
+
x=x,
|
| 973 |
+
router_weight=self.router.weight,
|
| 974 |
+
moe_top_k=moe_top_k,
|
| 975 |
+
moe_num_experts=moe_num_experts,
|
| 976 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 979 |
+
training=self.training,
|
| 980 |
+
w1=self.experts.gate_up_proj,
|
| 981 |
+
w2=self.experts.down_proj,
|
| 982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
| 983 |
+
w2_bias=self.experts.down_proj_bias,
|
| 984 |
+
gradient_scale=gradient_scale,
|
| 985 |
+
alpha=alpha,
|
| 986 |
+
sort_end_bit=sort_end_bit,
|
| 987 |
+
expert_parallel_group=expert_parallel_group,
|
| 988 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 989 |
+
moe_expert_model_parallelism=has_parallel,
|
| 990 |
+
forward_fn=forward_fn,
|
| 991 |
+
hidden_size=self.experts.hidden_size,
|
| 992 |
+
mlp_impl=mlp_impl,
|
| 993 |
+
# Shared expert parameters
|
| 994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
| 995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
| 996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
| 997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
| 998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
| 999 |
+
shared_activation_fn=self.shared_activation_fn,
|
| 1000 |
+
)
|
| 1001 |
return output, expert_weights_out
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 11869392
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:02dffd561ef226c1ec17c99e462c3c771879f078dde9b1e5cd8bd5992be5b3da
|
| 3 |
size 11869392
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:3d958a0c77589a5ede72336d1cab80ea9d6324ef6f8a9a187af2da4db74e1894
|
| 3 |
-
size 11869392
|
|
|
|
|
|
|
|
|
|
|
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_89e2950
|
| 3 |
+
ops = torch.ops._megablocks_89e2950
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch26-cxx11-cu124-x86_64-linux/megablocks/layers.py
CHANGED
|
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# Global variable to store load balancing loss
|
| 156 |
_LOAD_BALANCING_LOSS = []
|
| 157 |
|
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
| 680 |
return x, expert_weights, router_scores
|
| 681 |
|
| 682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 684 |
|
| 685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 694 |
-
|
| 695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 698 |
|
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 722 |
hidden_size=self.experts.hidden_size,
|
| 723 |
mlp_impl=mlp_impl,
|
| 724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 725 |
return output, expert_weights_out
|
|
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
| 155 |
+
# Shared expert MLP forward pass
|
| 156 |
+
def shared_mlp_forward(
|
| 157 |
+
x: torch.Tensor,
|
| 158 |
+
up_proj_weight: torch.Tensor,
|
| 159 |
+
down_proj_weight: torch.Tensor,
|
| 160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 162 |
+
activation_fn: Optional[Any] = None,
|
| 163 |
+
gradient_scale: Optional[float] = None,
|
| 164 |
+
) -> torch.Tensor:
|
| 165 |
+
# Default activation function
|
| 166 |
+
if activation_fn is None:
|
| 167 |
+
activation_fn = torch.nn.functional.gelu
|
| 168 |
+
|
| 169 |
+
# Scale weights
|
| 170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
| 171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
| 172 |
+
if up_proj_bias is not None:
|
| 173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
| 174 |
+
if down_proj_bias is not None:
|
| 175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
| 176 |
+
|
| 177 |
+
# Resolve dtensors
|
| 178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
| 179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
| 180 |
+
if up_proj_bias is not None:
|
| 181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
| 182 |
+
if down_proj_bias is not None:
|
| 183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
| 184 |
+
|
| 185 |
+
# Up projection
|
| 186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
| 187 |
+
|
| 188 |
+
# Activation
|
| 189 |
+
x = activation_fn(x)
|
| 190 |
+
|
| 191 |
+
# Down projection
|
| 192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
| 193 |
+
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# Combine outputs from shared expert and regular experts
|
| 198 |
+
def combine_expert_shared_outputs(
|
| 199 |
+
shared_expert_out: torch.Tensor,
|
| 200 |
+
expert_out: torch.Tensor,
|
| 201 |
+
shared_expert_weighted_sum: bool = False,
|
| 202 |
+
moe_top_k: int = 1,
|
| 203 |
+
) -> torch.Tensor:
|
| 204 |
+
if shared_expert_weighted_sum:
|
| 205 |
+
# Weighted sum based on number of experts used
|
| 206 |
+
total_experts = moe_top_k + 1
|
| 207 |
+
shared_weight = 1.0 / total_experts
|
| 208 |
+
expert_weight = moe_top_k / total_experts
|
| 209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
| 210 |
+
else:
|
| 211 |
+
# Simple addition
|
| 212 |
+
return shared_expert_out + expert_out
|
| 213 |
+
|
| 214 |
+
|
| 215 |
# Global variable to store load balancing loss
|
| 216 |
_LOAD_BALANCING_LOSS = []
|
| 217 |
|
|
|
|
| 740 |
return x, expert_weights, router_scores
|
| 741 |
|
| 742 |
|
| 743 |
+
def moe_forward_with_shared_expert(
|
| 744 |
+
x: torch.Tensor,
|
| 745 |
+
router_weight: torch.Tensor,
|
| 746 |
+
moe_top_k: int,
|
| 747 |
+
moe_num_experts: int,
|
| 748 |
+
moe_jitter_eps: float = None,
|
| 749 |
+
moe_normalize_expert_weights: int = None,
|
| 750 |
+
uniform_expert_assignment: bool = False,
|
| 751 |
+
training: bool = False,
|
| 752 |
+
w1: torch.Tensor = None,
|
| 753 |
+
w2: torch.Tensor = None,
|
| 754 |
+
w1_bias: torch.Tensor = None,
|
| 755 |
+
w2_bias: torch.Tensor = None,
|
| 756 |
+
gradient_scale: Optional[float] = None,
|
| 757 |
+
alpha: float = 1.702,
|
| 758 |
+
sort_end_bit: int = 0,
|
| 759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
| 760 |
+
moe_capacity_factor: float = 1.0,
|
| 761 |
+
moe_expert_model_parallelism: bool = False,
|
| 762 |
+
forward_fn: Any = None,
|
| 763 |
+
hidden_size: int = None,
|
| 764 |
+
mlp_impl: str = "grouped",
|
| 765 |
+
# Shared expert parameters
|
| 766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
| 767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
| 768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
| 769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
| 770 |
+
shared_expert_weighted_sum: bool = False,
|
| 771 |
+
shared_activation_fn: Optional[Any] = None,
|
| 772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 773 |
+
|
| 774 |
+
# First, compute regular MoE forward pass
|
| 775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
| 776 |
+
x=x,
|
| 777 |
+
router_weight=router_weight,
|
| 778 |
+
moe_top_k=moe_top_k,
|
| 779 |
+
moe_num_experts=moe_num_experts,
|
| 780 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 783 |
+
training=training,
|
| 784 |
+
w1=w1,
|
| 785 |
+
w2=w2,
|
| 786 |
+
w1_bias=w1_bias,
|
| 787 |
+
w2_bias=w2_bias,
|
| 788 |
+
gradient_scale=gradient_scale,
|
| 789 |
+
alpha=alpha,
|
| 790 |
+
sort_end_bit=sort_end_bit,
|
| 791 |
+
expert_parallel_group=expert_parallel_group,
|
| 792 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
| 794 |
+
forward_fn=forward_fn,
|
| 795 |
+
hidden_size=hidden_size,
|
| 796 |
+
mlp_impl=mlp_impl,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
# If shared expert weights provided, compute shared expert output
|
| 800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
| 801 |
+
shared_expert_out = shared_mlp_forward(
|
| 802 |
+
x=x,
|
| 803 |
+
up_proj_weight=shared_up_proj_weight,
|
| 804 |
+
down_proj_weight=shared_down_proj_weight,
|
| 805 |
+
up_proj_bias=shared_up_proj_bias,
|
| 806 |
+
down_proj_bias=shared_down_proj_bias,
|
| 807 |
+
activation_fn=shared_activation_fn,
|
| 808 |
+
gradient_scale=gradient_scale,
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
# Combine expert outputs
|
| 812 |
+
combined_out = combine_expert_shared_outputs(
|
| 813 |
+
shared_expert_out=shared_expert_out,
|
| 814 |
+
expert_out=expert_out,
|
| 815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
| 816 |
+
moe_top_k=moe_top_k,
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
return combined_out, expert_weights, router_scores
|
| 820 |
+
|
| 821 |
+
# Return regular MoE output if no shared expert
|
| 822 |
+
return expert_out, expert_weights, router_scores
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def create_shared_expert_weights(
|
| 826 |
+
hidden_size: int,
|
| 827 |
+
shared_expert_hidden_size: int,
|
| 828 |
+
device: torch.device,
|
| 829 |
+
dtype: torch.dtype,
|
| 830 |
+
init_method: Any,
|
| 831 |
+
output_layer_init_method: Any = None,
|
| 832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 833 |
+
|
| 834 |
+
if output_layer_init_method is None:
|
| 835 |
+
output_layer_init_method = init_method
|
| 836 |
+
|
| 837 |
+
# Create weight tensors
|
| 838 |
+
up_proj_weight = torch.empty(
|
| 839 |
+
shared_expert_hidden_size,
|
| 840 |
+
hidden_size,
|
| 841 |
+
device=device,
|
| 842 |
+
dtype=dtype,
|
| 843 |
+
)
|
| 844 |
+
down_proj_weight = torch.empty(
|
| 845 |
+
hidden_size,
|
| 846 |
+
shared_expert_hidden_size,
|
| 847 |
+
device=device,
|
| 848 |
+
dtype=dtype,
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
# Initialize weights
|
| 852 |
+
init_method(up_proj_weight)
|
| 853 |
+
output_layer_init_method(down_proj_weight)
|
| 854 |
+
|
| 855 |
+
# No bias by default
|
| 856 |
+
return up_proj_weight, down_proj_weight, None, None
|
| 857 |
+
|
| 858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
| 859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
| 860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
| 861 |
+
# TODO: Replace with a more robust solution when available
|
| 862 |
+
def get_device_mesh(model):
|
| 863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
| 864 |
+
try:
|
| 865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
| 866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
| 867 |
+
# Extract the device_mesh from the closure
|
| 868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
| 869 |
+
except Exception:
|
| 870 |
+
return None
|
| 871 |
+
|
| 872 |
+
|
| 873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 874 |
|
| 875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 884 |
+
|
| 885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 886 |
+
if expert_parallel_group is None:
|
| 887 |
+
device_mesh = get_device_mesh(self)
|
| 888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 889 |
+
|
| 890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 892 |
|
|
|
|
| 916 |
hidden_size=self.experts.hidden_size,
|
| 917 |
mlp_impl=mlp_impl,
|
| 918 |
)
|
| 919 |
+
return output, expert_weights_out
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
| 923 |
+
|
| 924 |
+
def __init__(self):
|
| 925 |
+
super().__init__()
|
| 926 |
+
# Shared expert weights will be set by the user
|
| 927 |
+
self.shared_up_proj_weight = None
|
| 928 |
+
self.shared_down_proj_weight = None
|
| 929 |
+
self.shared_up_proj_bias = None
|
| 930 |
+
self.shared_down_proj_bias = None
|
| 931 |
+
self.shared_expert_weighted_sum = False
|
| 932 |
+
self.shared_activation_fn = None
|
| 933 |
+
|
| 934 |
+
def set_shared_expert_weights(
|
| 935 |
+
self,
|
| 936 |
+
up_proj_weight: torch.Tensor,
|
| 937 |
+
down_proj_weight: torch.Tensor,
|
| 938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 940 |
+
weighted_sum: bool = False,
|
| 941 |
+
activation_fn: Optional[Any] = None,
|
| 942 |
+
):
|
| 943 |
+
self.shared_up_proj_weight = up_proj_weight
|
| 944 |
+
self.shared_down_proj_weight = down_proj_weight
|
| 945 |
+
self.shared_up_proj_bias = up_proj_bias
|
| 946 |
+
self.shared_down_proj_bias = down_proj_bias
|
| 947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
| 948 |
+
self.shared_activation_fn = activation_fn
|
| 949 |
+
|
| 950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
| 952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
| 953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
| 954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
| 955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
| 956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 959 |
+
|
| 960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 961 |
+
if expert_parallel_group is None:
|
| 962 |
+
device_mesh = get_device_mesh(self)
|
| 963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 964 |
+
|
| 965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 967 |
+
|
| 968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
| 969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
| 970 |
+
|
| 971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
| 972 |
+
x=x,
|
| 973 |
+
router_weight=self.router.weight,
|
| 974 |
+
moe_top_k=moe_top_k,
|
| 975 |
+
moe_num_experts=moe_num_experts,
|
| 976 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 979 |
+
training=self.training,
|
| 980 |
+
w1=self.experts.gate_up_proj,
|
| 981 |
+
w2=self.experts.down_proj,
|
| 982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
| 983 |
+
w2_bias=self.experts.down_proj_bias,
|
| 984 |
+
gradient_scale=gradient_scale,
|
| 985 |
+
alpha=alpha,
|
| 986 |
+
sort_end_bit=sort_end_bit,
|
| 987 |
+
expert_parallel_group=expert_parallel_group,
|
| 988 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 989 |
+
moe_expert_model_parallelism=has_parallel,
|
| 990 |
+
forward_fn=forward_fn,
|
| 991 |
+
hidden_size=self.experts.hidden_size,
|
| 992 |
+
mlp_impl=mlp_impl,
|
| 993 |
+
# Shared expert parameters
|
| 994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
| 995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
| 996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
| 997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
| 998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
| 999 |
+
shared_activation_fn=self.shared_activation_fn,
|
| 1000 |
+
)
|
| 1001 |
return output, expert_weights_out
|
build/torch26-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 11931048
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b5aa4e066ddbd863693ca8a5ec37fba34996226442dfa407e4a49b779497001d
|
| 3 |
size 11931048
|
build/torch26-cxx11-cu126-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:d41a4f5bbc160f51b058d3ba36e9087e9f15d35ae4782f36c984dd7199ee8ede
|
| 3 |
-
size 11931048
|
|
|
|
|
|
|
|
|
|
|
|
build/torch26-cxx11-cu126-x86_64-linux/megablocks/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_89e2950
|
| 3 |
+
ops = torch.ops._megablocks_89e2950
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch26-cxx11-cu126-x86_64-linux/megablocks/layers.py
CHANGED
|
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# Global variable to store load balancing loss
|
| 156 |
_LOAD_BALANCING_LOSS = []
|
| 157 |
|
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
| 680 |
return x, expert_weights, router_scores
|
| 681 |
|
| 682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 684 |
|
| 685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 694 |
-
|
| 695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 698 |
|
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 722 |
hidden_size=self.experts.hidden_size,
|
| 723 |
mlp_impl=mlp_impl,
|
| 724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 725 |
return output, expert_weights_out
|
|
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
| 155 |
+
# Shared expert MLP forward pass
|
| 156 |
+
def shared_mlp_forward(
|
| 157 |
+
x: torch.Tensor,
|
| 158 |
+
up_proj_weight: torch.Tensor,
|
| 159 |
+
down_proj_weight: torch.Tensor,
|
| 160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 162 |
+
activation_fn: Optional[Any] = None,
|
| 163 |
+
gradient_scale: Optional[float] = None,
|
| 164 |
+
) -> torch.Tensor:
|
| 165 |
+
# Default activation function
|
| 166 |
+
if activation_fn is None:
|
| 167 |
+
activation_fn = torch.nn.functional.gelu
|
| 168 |
+
|
| 169 |
+
# Scale weights
|
| 170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
| 171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
| 172 |
+
if up_proj_bias is not None:
|
| 173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
| 174 |
+
if down_proj_bias is not None:
|
| 175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
| 176 |
+
|
| 177 |
+
# Resolve dtensors
|
| 178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
| 179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
| 180 |
+
if up_proj_bias is not None:
|
| 181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
| 182 |
+
if down_proj_bias is not None:
|
| 183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
| 184 |
+
|
| 185 |
+
# Up projection
|
| 186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
| 187 |
+
|
| 188 |
+
# Activation
|
| 189 |
+
x = activation_fn(x)
|
| 190 |
+
|
| 191 |
+
# Down projection
|
| 192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
| 193 |
+
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# Combine outputs from shared expert and regular experts
|
| 198 |
+
def combine_expert_shared_outputs(
|
| 199 |
+
shared_expert_out: torch.Tensor,
|
| 200 |
+
expert_out: torch.Tensor,
|
| 201 |
+
shared_expert_weighted_sum: bool = False,
|
| 202 |
+
moe_top_k: int = 1,
|
| 203 |
+
) -> torch.Tensor:
|
| 204 |
+
if shared_expert_weighted_sum:
|
| 205 |
+
# Weighted sum based on number of experts used
|
| 206 |
+
total_experts = moe_top_k + 1
|
| 207 |
+
shared_weight = 1.0 / total_experts
|
| 208 |
+
expert_weight = moe_top_k / total_experts
|
| 209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
| 210 |
+
else:
|
| 211 |
+
# Simple addition
|
| 212 |
+
return shared_expert_out + expert_out
|
| 213 |
+
|
| 214 |
+
|
| 215 |
# Global variable to store load balancing loss
|
| 216 |
_LOAD_BALANCING_LOSS = []
|
| 217 |
|
|
|
|
| 740 |
return x, expert_weights, router_scores
|
| 741 |
|
| 742 |
|
| 743 |
+
def moe_forward_with_shared_expert(
|
| 744 |
+
x: torch.Tensor,
|
| 745 |
+
router_weight: torch.Tensor,
|
| 746 |
+
moe_top_k: int,
|
| 747 |
+
moe_num_experts: int,
|
| 748 |
+
moe_jitter_eps: float = None,
|
| 749 |
+
moe_normalize_expert_weights: int = None,
|
| 750 |
+
uniform_expert_assignment: bool = False,
|
| 751 |
+
training: bool = False,
|
| 752 |
+
w1: torch.Tensor = None,
|
| 753 |
+
w2: torch.Tensor = None,
|
| 754 |
+
w1_bias: torch.Tensor = None,
|
| 755 |
+
w2_bias: torch.Tensor = None,
|
| 756 |
+
gradient_scale: Optional[float] = None,
|
| 757 |
+
alpha: float = 1.702,
|
| 758 |
+
sort_end_bit: int = 0,
|
| 759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
| 760 |
+
moe_capacity_factor: float = 1.0,
|
| 761 |
+
moe_expert_model_parallelism: bool = False,
|
| 762 |
+
forward_fn: Any = None,
|
| 763 |
+
hidden_size: int = None,
|
| 764 |
+
mlp_impl: str = "grouped",
|
| 765 |
+
# Shared expert parameters
|
| 766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
| 767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
| 768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
| 769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
| 770 |
+
shared_expert_weighted_sum: bool = False,
|
| 771 |
+
shared_activation_fn: Optional[Any] = None,
|
| 772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 773 |
+
|
| 774 |
+
# First, compute regular MoE forward pass
|
| 775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
| 776 |
+
x=x,
|
| 777 |
+
router_weight=router_weight,
|
| 778 |
+
moe_top_k=moe_top_k,
|
| 779 |
+
moe_num_experts=moe_num_experts,
|
| 780 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 783 |
+
training=training,
|
| 784 |
+
w1=w1,
|
| 785 |
+
w2=w2,
|
| 786 |
+
w1_bias=w1_bias,
|
| 787 |
+
w2_bias=w2_bias,
|
| 788 |
+
gradient_scale=gradient_scale,
|
| 789 |
+
alpha=alpha,
|
| 790 |
+
sort_end_bit=sort_end_bit,
|
| 791 |
+
expert_parallel_group=expert_parallel_group,
|
| 792 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
| 794 |
+
forward_fn=forward_fn,
|
| 795 |
+
hidden_size=hidden_size,
|
| 796 |
+
mlp_impl=mlp_impl,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
# If shared expert weights provided, compute shared expert output
|
| 800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
| 801 |
+
shared_expert_out = shared_mlp_forward(
|
| 802 |
+
x=x,
|
| 803 |
+
up_proj_weight=shared_up_proj_weight,
|
| 804 |
+
down_proj_weight=shared_down_proj_weight,
|
| 805 |
+
up_proj_bias=shared_up_proj_bias,
|
| 806 |
+
down_proj_bias=shared_down_proj_bias,
|
| 807 |
+
activation_fn=shared_activation_fn,
|
| 808 |
+
gradient_scale=gradient_scale,
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
# Combine expert outputs
|
| 812 |
+
combined_out = combine_expert_shared_outputs(
|
| 813 |
+
shared_expert_out=shared_expert_out,
|
| 814 |
+
expert_out=expert_out,
|
| 815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
| 816 |
+
moe_top_k=moe_top_k,
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
return combined_out, expert_weights, router_scores
|
| 820 |
+
|
| 821 |
+
# Return regular MoE output if no shared expert
|
| 822 |
+
return expert_out, expert_weights, router_scores
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def create_shared_expert_weights(
|
| 826 |
+
hidden_size: int,
|
| 827 |
+
shared_expert_hidden_size: int,
|
| 828 |
+
device: torch.device,
|
| 829 |
+
dtype: torch.dtype,
|
| 830 |
+
init_method: Any,
|
| 831 |
+
output_layer_init_method: Any = None,
|
| 832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 833 |
+
|
| 834 |
+
if output_layer_init_method is None:
|
| 835 |
+
output_layer_init_method = init_method
|
| 836 |
+
|
| 837 |
+
# Create weight tensors
|
| 838 |
+
up_proj_weight = torch.empty(
|
| 839 |
+
shared_expert_hidden_size,
|
| 840 |
+
hidden_size,
|
| 841 |
+
device=device,
|
| 842 |
+
dtype=dtype,
|
| 843 |
+
)
|
| 844 |
+
down_proj_weight = torch.empty(
|
| 845 |
+
hidden_size,
|
| 846 |
+
shared_expert_hidden_size,
|
| 847 |
+
device=device,
|
| 848 |
+
dtype=dtype,
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
# Initialize weights
|
| 852 |
+
init_method(up_proj_weight)
|
| 853 |
+
output_layer_init_method(down_proj_weight)
|
| 854 |
+
|
| 855 |
+
# No bias by default
|
| 856 |
+
return up_proj_weight, down_proj_weight, None, None
|
| 857 |
+
|
| 858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
| 859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
| 860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
| 861 |
+
# TODO: Replace with a more robust solution when available
|
| 862 |
+
def get_device_mesh(model):
|
| 863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
| 864 |
+
try:
|
| 865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
| 866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
| 867 |
+
# Extract the device_mesh from the closure
|
| 868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
| 869 |
+
except Exception:
|
| 870 |
+
return None
|
| 871 |
+
|
| 872 |
+
|
| 873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 874 |
|
| 875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 884 |
+
|
| 885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 886 |
+
if expert_parallel_group is None:
|
| 887 |
+
device_mesh = get_device_mesh(self)
|
| 888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 889 |
+
|
| 890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 892 |
|
|
|
|
| 916 |
hidden_size=self.experts.hidden_size,
|
| 917 |
mlp_impl=mlp_impl,
|
| 918 |
)
|
| 919 |
+
return output, expert_weights_out
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
| 923 |
+
|
| 924 |
+
def __init__(self):
|
| 925 |
+
super().__init__()
|
| 926 |
+
# Shared expert weights will be set by the user
|
| 927 |
+
self.shared_up_proj_weight = None
|
| 928 |
+
self.shared_down_proj_weight = None
|
| 929 |
+
self.shared_up_proj_bias = None
|
| 930 |
+
self.shared_down_proj_bias = None
|
| 931 |
+
self.shared_expert_weighted_sum = False
|
| 932 |
+
self.shared_activation_fn = None
|
| 933 |
+
|
| 934 |
+
def set_shared_expert_weights(
|
| 935 |
+
self,
|
| 936 |
+
up_proj_weight: torch.Tensor,
|
| 937 |
+
down_proj_weight: torch.Tensor,
|
| 938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 940 |
+
weighted_sum: bool = False,
|
| 941 |
+
activation_fn: Optional[Any] = None,
|
| 942 |
+
):
|
| 943 |
+
self.shared_up_proj_weight = up_proj_weight
|
| 944 |
+
self.shared_down_proj_weight = down_proj_weight
|
| 945 |
+
self.shared_up_proj_bias = up_proj_bias
|
| 946 |
+
self.shared_down_proj_bias = down_proj_bias
|
| 947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
| 948 |
+
self.shared_activation_fn = activation_fn
|
| 949 |
+
|
| 950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
| 952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
| 953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
| 954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
| 955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
| 956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 959 |
+
|
| 960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 961 |
+
if expert_parallel_group is None:
|
| 962 |
+
device_mesh = get_device_mesh(self)
|
| 963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 964 |
+
|
| 965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 967 |
+
|
| 968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
| 969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
| 970 |
+
|
| 971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
| 972 |
+
x=x,
|
| 973 |
+
router_weight=self.router.weight,
|
| 974 |
+
moe_top_k=moe_top_k,
|
| 975 |
+
moe_num_experts=moe_num_experts,
|
| 976 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 979 |
+
training=self.training,
|
| 980 |
+
w1=self.experts.gate_up_proj,
|
| 981 |
+
w2=self.experts.down_proj,
|
| 982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
| 983 |
+
w2_bias=self.experts.down_proj_bias,
|
| 984 |
+
gradient_scale=gradient_scale,
|
| 985 |
+
alpha=alpha,
|
| 986 |
+
sort_end_bit=sort_end_bit,
|
| 987 |
+
expert_parallel_group=expert_parallel_group,
|
| 988 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 989 |
+
moe_expert_model_parallelism=has_parallel,
|
| 990 |
+
forward_fn=forward_fn,
|
| 991 |
+
hidden_size=self.experts.hidden_size,
|
| 992 |
+
mlp_impl=mlp_impl,
|
| 993 |
+
# Shared expert parameters
|
| 994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
| 995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
| 996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
| 997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
| 998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
| 999 |
+
shared_activation_fn=self.shared_activation_fn,
|
| 1000 |
+
)
|
| 1001 |
return output, expert_weights_out
|
build/torch26-cxx98-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 10510040
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fababa7e0d2c20c98afaebef6165a8145b33d80cdadba28f895c14dd2a7b2823
|
| 3 |
size 10510040
|
build/torch26-cxx98-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:01f0c774e900380d3c0721dfe15591c67be5d5eb5ad687af6c89a88ecdff4f2a
|
| 3 |
-
size 10510040
|
|
|
|
|
|
|
|
|
|
|
|
build/torch26-cxx98-cu118-x86_64-linux/megablocks/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_89e2950
|
| 3 |
+
ops = torch.ops._megablocks_89e2950
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch26-cxx98-cu118-x86_64-linux/megablocks/layers.py
CHANGED
|
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# Global variable to store load balancing loss
|
| 156 |
_LOAD_BALANCING_LOSS = []
|
| 157 |
|
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
| 680 |
return x, expert_weights, router_scores
|
| 681 |
|
| 682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 684 |
|
| 685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 694 |
-
|
| 695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 698 |
|
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 722 |
hidden_size=self.experts.hidden_size,
|
| 723 |
mlp_impl=mlp_impl,
|
| 724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 725 |
return output, expert_weights_out
|
|
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
| 155 |
+
# Shared expert MLP forward pass
|
| 156 |
+
def shared_mlp_forward(
|
| 157 |
+
x: torch.Tensor,
|
| 158 |
+
up_proj_weight: torch.Tensor,
|
| 159 |
+
down_proj_weight: torch.Tensor,
|
| 160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 162 |
+
activation_fn: Optional[Any] = None,
|
| 163 |
+
gradient_scale: Optional[float] = None,
|
| 164 |
+
) -> torch.Tensor:
|
| 165 |
+
# Default activation function
|
| 166 |
+
if activation_fn is None:
|
| 167 |
+
activation_fn = torch.nn.functional.gelu
|
| 168 |
+
|
| 169 |
+
# Scale weights
|
| 170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
| 171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
| 172 |
+
if up_proj_bias is not None:
|
| 173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
| 174 |
+
if down_proj_bias is not None:
|
| 175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
| 176 |
+
|
| 177 |
+
# Resolve dtensors
|
| 178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
| 179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
| 180 |
+
if up_proj_bias is not None:
|
| 181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
| 182 |
+
if down_proj_bias is not None:
|
| 183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
| 184 |
+
|
| 185 |
+
# Up projection
|
| 186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
| 187 |
+
|
| 188 |
+
# Activation
|
| 189 |
+
x = activation_fn(x)
|
| 190 |
+
|
| 191 |
+
# Down projection
|
| 192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
| 193 |
+
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# Combine outputs from shared expert and regular experts
|
| 198 |
+
def combine_expert_shared_outputs(
|
| 199 |
+
shared_expert_out: torch.Tensor,
|
| 200 |
+
expert_out: torch.Tensor,
|
| 201 |
+
shared_expert_weighted_sum: bool = False,
|
| 202 |
+
moe_top_k: int = 1,
|
| 203 |
+
) -> torch.Tensor:
|
| 204 |
+
if shared_expert_weighted_sum:
|
| 205 |
+
# Weighted sum based on number of experts used
|
| 206 |
+
total_experts = moe_top_k + 1
|
| 207 |
+
shared_weight = 1.0 / total_experts
|
| 208 |
+
expert_weight = moe_top_k / total_experts
|
| 209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
| 210 |
+
else:
|
| 211 |
+
# Simple addition
|
| 212 |
+
return shared_expert_out + expert_out
|
| 213 |
+
|
| 214 |
+
|
| 215 |
# Global variable to store load balancing loss
|
| 216 |
_LOAD_BALANCING_LOSS = []
|
| 217 |
|
|
|
|
| 740 |
return x, expert_weights, router_scores
|
| 741 |
|
| 742 |
|
| 743 |
+
def moe_forward_with_shared_expert(
|
| 744 |
+
x: torch.Tensor,
|
| 745 |
+
router_weight: torch.Tensor,
|
| 746 |
+
moe_top_k: int,
|
| 747 |
+
moe_num_experts: int,
|
| 748 |
+
moe_jitter_eps: float = None,
|
| 749 |
+
moe_normalize_expert_weights: int = None,
|
| 750 |
+
uniform_expert_assignment: bool = False,
|
| 751 |
+
training: bool = False,
|
| 752 |
+
w1: torch.Tensor = None,
|
| 753 |
+
w2: torch.Tensor = None,
|
| 754 |
+
w1_bias: torch.Tensor = None,
|
| 755 |
+
w2_bias: torch.Tensor = None,
|
| 756 |
+
gradient_scale: Optional[float] = None,
|
| 757 |
+
alpha: float = 1.702,
|
| 758 |
+
sort_end_bit: int = 0,
|
| 759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
| 760 |
+
moe_capacity_factor: float = 1.0,
|
| 761 |
+
moe_expert_model_parallelism: bool = False,
|
| 762 |
+
forward_fn: Any = None,
|
| 763 |
+
hidden_size: int = None,
|
| 764 |
+
mlp_impl: str = "grouped",
|
| 765 |
+
# Shared expert parameters
|
| 766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
| 767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
| 768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
| 769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
| 770 |
+
shared_expert_weighted_sum: bool = False,
|
| 771 |
+
shared_activation_fn: Optional[Any] = None,
|
| 772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 773 |
+
|
| 774 |
+
# First, compute regular MoE forward pass
|
| 775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
| 776 |
+
x=x,
|
| 777 |
+
router_weight=router_weight,
|
| 778 |
+
moe_top_k=moe_top_k,
|
| 779 |
+
moe_num_experts=moe_num_experts,
|
| 780 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 783 |
+
training=training,
|
| 784 |
+
w1=w1,
|
| 785 |
+
w2=w2,
|
| 786 |
+
w1_bias=w1_bias,
|
| 787 |
+
w2_bias=w2_bias,
|
| 788 |
+
gradient_scale=gradient_scale,
|
| 789 |
+
alpha=alpha,
|
| 790 |
+
sort_end_bit=sort_end_bit,
|
| 791 |
+
expert_parallel_group=expert_parallel_group,
|
| 792 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
| 794 |
+
forward_fn=forward_fn,
|
| 795 |
+
hidden_size=hidden_size,
|
| 796 |
+
mlp_impl=mlp_impl,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
# If shared expert weights provided, compute shared expert output
|
| 800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
| 801 |
+
shared_expert_out = shared_mlp_forward(
|
| 802 |
+
x=x,
|
| 803 |
+
up_proj_weight=shared_up_proj_weight,
|
| 804 |
+
down_proj_weight=shared_down_proj_weight,
|
| 805 |
+
up_proj_bias=shared_up_proj_bias,
|
| 806 |
+
down_proj_bias=shared_down_proj_bias,
|
| 807 |
+
activation_fn=shared_activation_fn,
|
| 808 |
+
gradient_scale=gradient_scale,
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
# Combine expert outputs
|
| 812 |
+
combined_out = combine_expert_shared_outputs(
|
| 813 |
+
shared_expert_out=shared_expert_out,
|
| 814 |
+
expert_out=expert_out,
|
| 815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
| 816 |
+
moe_top_k=moe_top_k,
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
return combined_out, expert_weights, router_scores
|
| 820 |
+
|
| 821 |
+
# Return regular MoE output if no shared expert
|
| 822 |
+
return expert_out, expert_weights, router_scores
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def create_shared_expert_weights(
|
| 826 |
+
hidden_size: int,
|
| 827 |
+
shared_expert_hidden_size: int,
|
| 828 |
+
device: torch.device,
|
| 829 |
+
dtype: torch.dtype,
|
| 830 |
+
init_method: Any,
|
| 831 |
+
output_layer_init_method: Any = None,
|
| 832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 833 |
+
|
| 834 |
+
if output_layer_init_method is None:
|
| 835 |
+
output_layer_init_method = init_method
|
| 836 |
+
|
| 837 |
+
# Create weight tensors
|
| 838 |
+
up_proj_weight = torch.empty(
|
| 839 |
+
shared_expert_hidden_size,
|
| 840 |
+
hidden_size,
|
| 841 |
+
device=device,
|
| 842 |
+
dtype=dtype,
|
| 843 |
+
)
|
| 844 |
+
down_proj_weight = torch.empty(
|
| 845 |
+
hidden_size,
|
| 846 |
+
shared_expert_hidden_size,
|
| 847 |
+
device=device,
|
| 848 |
+
dtype=dtype,
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
# Initialize weights
|
| 852 |
+
init_method(up_proj_weight)
|
| 853 |
+
output_layer_init_method(down_proj_weight)
|
| 854 |
+
|
| 855 |
+
# No bias by default
|
| 856 |
+
return up_proj_weight, down_proj_weight, None, None
|
| 857 |
+
|
| 858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
| 859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
| 860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
| 861 |
+
# TODO: Replace with a more robust solution when available
|
| 862 |
+
def get_device_mesh(model):
|
| 863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
| 864 |
+
try:
|
| 865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
| 866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
| 867 |
+
# Extract the device_mesh from the closure
|
| 868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
| 869 |
+
except Exception:
|
| 870 |
+
return None
|
| 871 |
+
|
| 872 |
+
|
| 873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 874 |
|
| 875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 884 |
+
|
| 885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 886 |
+
if expert_parallel_group is None:
|
| 887 |
+
device_mesh = get_device_mesh(self)
|
| 888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 889 |
+
|
| 890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 892 |
|
|
|
|
| 916 |
hidden_size=self.experts.hidden_size,
|
| 917 |
mlp_impl=mlp_impl,
|
| 918 |
)
|
| 919 |
+
return output, expert_weights_out
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
| 923 |
+
|
| 924 |
+
def __init__(self):
|
| 925 |
+
super().__init__()
|
| 926 |
+
# Shared expert weights will be set by the user
|
| 927 |
+
self.shared_up_proj_weight = None
|
| 928 |
+
self.shared_down_proj_weight = None
|
| 929 |
+
self.shared_up_proj_bias = None
|
| 930 |
+
self.shared_down_proj_bias = None
|
| 931 |
+
self.shared_expert_weighted_sum = False
|
| 932 |
+
self.shared_activation_fn = None
|
| 933 |
+
|
| 934 |
+
def set_shared_expert_weights(
|
| 935 |
+
self,
|
| 936 |
+
up_proj_weight: torch.Tensor,
|
| 937 |
+
down_proj_weight: torch.Tensor,
|
| 938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 940 |
+
weighted_sum: bool = False,
|
| 941 |
+
activation_fn: Optional[Any] = None,
|
| 942 |
+
):
|
| 943 |
+
self.shared_up_proj_weight = up_proj_weight
|
| 944 |
+
self.shared_down_proj_weight = down_proj_weight
|
| 945 |
+
self.shared_up_proj_bias = up_proj_bias
|
| 946 |
+
self.shared_down_proj_bias = down_proj_bias
|
| 947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
| 948 |
+
self.shared_activation_fn = activation_fn
|
| 949 |
+
|
| 950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
| 952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
| 953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
| 954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
| 955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
| 956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 959 |
+
|
| 960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 961 |
+
if expert_parallel_group is None:
|
| 962 |
+
device_mesh = get_device_mesh(self)
|
| 963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 964 |
+
|
| 965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 967 |
+
|
| 968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
| 969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
| 970 |
+
|
| 971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
| 972 |
+
x=x,
|
| 973 |
+
router_weight=self.router.weight,
|
| 974 |
+
moe_top_k=moe_top_k,
|
| 975 |
+
moe_num_experts=moe_num_experts,
|
| 976 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 979 |
+
training=self.training,
|
| 980 |
+
w1=self.experts.gate_up_proj,
|
| 981 |
+
w2=self.experts.down_proj,
|
| 982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
| 983 |
+
w2_bias=self.experts.down_proj_bias,
|
| 984 |
+
gradient_scale=gradient_scale,
|
| 985 |
+
alpha=alpha,
|
| 986 |
+
sort_end_bit=sort_end_bit,
|
| 987 |
+
expert_parallel_group=expert_parallel_group,
|
| 988 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 989 |
+
moe_expert_model_parallelism=has_parallel,
|
| 990 |
+
forward_fn=forward_fn,
|
| 991 |
+
hidden_size=self.experts.hidden_size,
|
| 992 |
+
mlp_impl=mlp_impl,
|
| 993 |
+
# Shared expert parameters
|
| 994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
| 995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
| 996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
| 997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
| 998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
| 999 |
+
shared_activation_fn=self.shared_activation_fn,
|
| 1000 |
+
)
|
| 1001 |
return output, expert_weights_out
|
build/torch26-cxx98-cu124-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 11857920
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e3663f46030f07e030efe94c26495d17b2703551a46c0ca3acf8b25ecb2a238
|
| 3 |
size 11857920
|
build/torch26-cxx98-cu124-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:09a5f57ae37af9f5b14c4a0f21d1679e32f5b7424973c36dac9bbbecbfbf7374
|
| 3 |
-
size 11857920
|
|
|
|
|
|
|
|
|
|
|
|
build/torch26-cxx98-cu124-x86_64-linux/megablocks/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_89e2950
|
| 3 |
+
ops = torch.ops._megablocks_89e2950
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch26-cxx98-cu124-x86_64-linux/megablocks/layers.py
CHANGED
|
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# Global variable to store load balancing loss
|
| 156 |
_LOAD_BALANCING_LOSS = []
|
| 157 |
|
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
| 680 |
return x, expert_weights, router_scores
|
| 681 |
|
| 682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 684 |
|
| 685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 694 |
-
|
| 695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 698 |
|
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 722 |
hidden_size=self.experts.hidden_size,
|
| 723 |
mlp_impl=mlp_impl,
|
| 724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 725 |
return output, expert_weights_out
|
|
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
| 155 |
+
# Shared expert MLP forward pass
|
| 156 |
+
def shared_mlp_forward(
|
| 157 |
+
x: torch.Tensor,
|
| 158 |
+
up_proj_weight: torch.Tensor,
|
| 159 |
+
down_proj_weight: torch.Tensor,
|
| 160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 162 |
+
activation_fn: Optional[Any] = None,
|
| 163 |
+
gradient_scale: Optional[float] = None,
|
| 164 |
+
) -> torch.Tensor:
|
| 165 |
+
# Default activation function
|
| 166 |
+
if activation_fn is None:
|
| 167 |
+
activation_fn = torch.nn.functional.gelu
|
| 168 |
+
|
| 169 |
+
# Scale weights
|
| 170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
| 171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
| 172 |
+
if up_proj_bias is not None:
|
| 173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
| 174 |
+
if down_proj_bias is not None:
|
| 175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
| 176 |
+
|
| 177 |
+
# Resolve dtensors
|
| 178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
| 179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
| 180 |
+
if up_proj_bias is not None:
|
| 181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
| 182 |
+
if down_proj_bias is not None:
|
| 183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
| 184 |
+
|
| 185 |
+
# Up projection
|
| 186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
| 187 |
+
|
| 188 |
+
# Activation
|
| 189 |
+
x = activation_fn(x)
|
| 190 |
+
|
| 191 |
+
# Down projection
|
| 192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
| 193 |
+
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# Combine outputs from shared expert and regular experts
|
| 198 |
+
def combine_expert_shared_outputs(
|
| 199 |
+
shared_expert_out: torch.Tensor,
|
| 200 |
+
expert_out: torch.Tensor,
|
| 201 |
+
shared_expert_weighted_sum: bool = False,
|
| 202 |
+
moe_top_k: int = 1,
|
| 203 |
+
) -> torch.Tensor:
|
| 204 |
+
if shared_expert_weighted_sum:
|
| 205 |
+
# Weighted sum based on number of experts used
|
| 206 |
+
total_experts = moe_top_k + 1
|
| 207 |
+
shared_weight = 1.0 / total_experts
|
| 208 |
+
expert_weight = moe_top_k / total_experts
|
| 209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
| 210 |
+
else:
|
| 211 |
+
# Simple addition
|
| 212 |
+
return shared_expert_out + expert_out
|
| 213 |
+
|
| 214 |
+
|
| 215 |
# Global variable to store load balancing loss
|
| 216 |
_LOAD_BALANCING_LOSS = []
|
| 217 |
|
|
|
|
| 740 |
return x, expert_weights, router_scores
|
| 741 |
|
| 742 |
|
| 743 |
+
def moe_forward_with_shared_expert(
|
| 744 |
+
x: torch.Tensor,
|
| 745 |
+
router_weight: torch.Tensor,
|
| 746 |
+
moe_top_k: int,
|
| 747 |
+
moe_num_experts: int,
|
| 748 |
+
moe_jitter_eps: float = None,
|
| 749 |
+
moe_normalize_expert_weights: int = None,
|
| 750 |
+
uniform_expert_assignment: bool = False,
|
| 751 |
+
training: bool = False,
|
| 752 |
+
w1: torch.Tensor = None,
|
| 753 |
+
w2: torch.Tensor = None,
|
| 754 |
+
w1_bias: torch.Tensor = None,
|
| 755 |
+
w2_bias: torch.Tensor = None,
|
| 756 |
+
gradient_scale: Optional[float] = None,
|
| 757 |
+
alpha: float = 1.702,
|
| 758 |
+
sort_end_bit: int = 0,
|
| 759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
| 760 |
+
moe_capacity_factor: float = 1.0,
|
| 761 |
+
moe_expert_model_parallelism: bool = False,
|
| 762 |
+
forward_fn: Any = None,
|
| 763 |
+
hidden_size: int = None,
|
| 764 |
+
mlp_impl: str = "grouped",
|
| 765 |
+
# Shared expert parameters
|
| 766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
| 767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
| 768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
| 769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
| 770 |
+
shared_expert_weighted_sum: bool = False,
|
| 771 |
+
shared_activation_fn: Optional[Any] = None,
|
| 772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 773 |
+
|
| 774 |
+
# First, compute regular MoE forward pass
|
| 775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
| 776 |
+
x=x,
|
| 777 |
+
router_weight=router_weight,
|
| 778 |
+
moe_top_k=moe_top_k,
|
| 779 |
+
moe_num_experts=moe_num_experts,
|
| 780 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 783 |
+
training=training,
|
| 784 |
+
w1=w1,
|
| 785 |
+
w2=w2,
|
| 786 |
+
w1_bias=w1_bias,
|
| 787 |
+
w2_bias=w2_bias,
|
| 788 |
+
gradient_scale=gradient_scale,
|
| 789 |
+
alpha=alpha,
|
| 790 |
+
sort_end_bit=sort_end_bit,
|
| 791 |
+
expert_parallel_group=expert_parallel_group,
|
| 792 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
| 794 |
+
forward_fn=forward_fn,
|
| 795 |
+
hidden_size=hidden_size,
|
| 796 |
+
mlp_impl=mlp_impl,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
# If shared expert weights provided, compute shared expert output
|
| 800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
| 801 |
+
shared_expert_out = shared_mlp_forward(
|
| 802 |
+
x=x,
|
| 803 |
+
up_proj_weight=shared_up_proj_weight,
|
| 804 |
+
down_proj_weight=shared_down_proj_weight,
|
| 805 |
+
up_proj_bias=shared_up_proj_bias,
|
| 806 |
+
down_proj_bias=shared_down_proj_bias,
|
| 807 |
+
activation_fn=shared_activation_fn,
|
| 808 |
+
gradient_scale=gradient_scale,
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
# Combine expert outputs
|
| 812 |
+
combined_out = combine_expert_shared_outputs(
|
| 813 |
+
shared_expert_out=shared_expert_out,
|
| 814 |
+
expert_out=expert_out,
|
| 815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
| 816 |
+
moe_top_k=moe_top_k,
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
return combined_out, expert_weights, router_scores
|
| 820 |
+
|
| 821 |
+
# Return regular MoE output if no shared expert
|
| 822 |
+
return expert_out, expert_weights, router_scores
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def create_shared_expert_weights(
|
| 826 |
+
hidden_size: int,
|
| 827 |
+
shared_expert_hidden_size: int,
|
| 828 |
+
device: torch.device,
|
| 829 |
+
dtype: torch.dtype,
|
| 830 |
+
init_method: Any,
|
| 831 |
+
output_layer_init_method: Any = None,
|
| 832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 833 |
+
|
| 834 |
+
if output_layer_init_method is None:
|
| 835 |
+
output_layer_init_method = init_method
|
| 836 |
+
|
| 837 |
+
# Create weight tensors
|
| 838 |
+
up_proj_weight = torch.empty(
|
| 839 |
+
shared_expert_hidden_size,
|
| 840 |
+
hidden_size,
|
| 841 |
+
device=device,
|
| 842 |
+
dtype=dtype,
|
| 843 |
+
)
|
| 844 |
+
down_proj_weight = torch.empty(
|
| 845 |
+
hidden_size,
|
| 846 |
+
shared_expert_hidden_size,
|
| 847 |
+
device=device,
|
| 848 |
+
dtype=dtype,
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
# Initialize weights
|
| 852 |
+
init_method(up_proj_weight)
|
| 853 |
+
output_layer_init_method(down_proj_weight)
|
| 854 |
+
|
| 855 |
+
# No bias by default
|
| 856 |
+
return up_proj_weight, down_proj_weight, None, None
|
| 857 |
+
|
| 858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
| 859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
| 860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
| 861 |
+
# TODO: Replace with a more robust solution when available
|
| 862 |
+
def get_device_mesh(model):
|
| 863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
| 864 |
+
try:
|
| 865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
| 866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
| 867 |
+
# Extract the device_mesh from the closure
|
| 868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
| 869 |
+
except Exception:
|
| 870 |
+
return None
|
| 871 |
+
|
| 872 |
+
|
| 873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 874 |
|
| 875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 884 |
+
|
| 885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 886 |
+
if expert_parallel_group is None:
|
| 887 |
+
device_mesh = get_device_mesh(self)
|
| 888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 889 |
+
|
| 890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 892 |
|
|
|
|
| 916 |
hidden_size=self.experts.hidden_size,
|
| 917 |
mlp_impl=mlp_impl,
|
| 918 |
)
|
| 919 |
+
return output, expert_weights_out
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
| 923 |
+
|
| 924 |
+
def __init__(self):
|
| 925 |
+
super().__init__()
|
| 926 |
+
# Shared expert weights will be set by the user
|
| 927 |
+
self.shared_up_proj_weight = None
|
| 928 |
+
self.shared_down_proj_weight = None
|
| 929 |
+
self.shared_up_proj_bias = None
|
| 930 |
+
self.shared_down_proj_bias = None
|
| 931 |
+
self.shared_expert_weighted_sum = False
|
| 932 |
+
self.shared_activation_fn = None
|
| 933 |
+
|
| 934 |
+
def set_shared_expert_weights(
|
| 935 |
+
self,
|
| 936 |
+
up_proj_weight: torch.Tensor,
|
| 937 |
+
down_proj_weight: torch.Tensor,
|
| 938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 940 |
+
weighted_sum: bool = False,
|
| 941 |
+
activation_fn: Optional[Any] = None,
|
| 942 |
+
):
|
| 943 |
+
self.shared_up_proj_weight = up_proj_weight
|
| 944 |
+
self.shared_down_proj_weight = down_proj_weight
|
| 945 |
+
self.shared_up_proj_bias = up_proj_bias
|
| 946 |
+
self.shared_down_proj_bias = down_proj_bias
|
| 947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
| 948 |
+
self.shared_activation_fn = activation_fn
|
| 949 |
+
|
| 950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
| 952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
| 953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
| 954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
| 955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
| 956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 959 |
+
|
| 960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 961 |
+
if expert_parallel_group is None:
|
| 962 |
+
device_mesh = get_device_mesh(self)
|
| 963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 964 |
+
|
| 965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 967 |
+
|
| 968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
| 969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
| 970 |
+
|
| 971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
| 972 |
+
x=x,
|
| 973 |
+
router_weight=self.router.weight,
|
| 974 |
+
moe_top_k=moe_top_k,
|
| 975 |
+
moe_num_experts=moe_num_experts,
|
| 976 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 979 |
+
training=self.training,
|
| 980 |
+
w1=self.experts.gate_up_proj,
|
| 981 |
+
w2=self.experts.down_proj,
|
| 982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
| 983 |
+
w2_bias=self.experts.down_proj_bias,
|
| 984 |
+
gradient_scale=gradient_scale,
|
| 985 |
+
alpha=alpha,
|
| 986 |
+
sort_end_bit=sort_end_bit,
|
| 987 |
+
expert_parallel_group=expert_parallel_group,
|
| 988 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 989 |
+
moe_expert_model_parallelism=has_parallel,
|
| 990 |
+
forward_fn=forward_fn,
|
| 991 |
+
hidden_size=self.experts.hidden_size,
|
| 992 |
+
mlp_impl=mlp_impl,
|
| 993 |
+
# Shared expert parameters
|
| 994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
| 995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
| 996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
| 997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
| 998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
| 999 |
+
shared_activation_fn=self.shared_activation_fn,
|
| 1000 |
+
)
|
| 1001 |
return output, expert_weights_out
|
build/torch26-cxx98-cu126-x86_64-linux/megablocks/_megablocks_76c7de7.abi3.so
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:a3f893773ec7b8157a4531a57821807f5f27ac48ceaa695c342cc7a39ad318dc
|
| 3 |
-
size 11927768
|
|
|
|
|
|
|
|
|
|
|
|
build/torch26-cxx98-cu126-x86_64-linux/megablocks/{_megablocks_9a1816c.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 11923672
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d1571732c5954914d5ddf0f12ebc4074d88d907130d71d898de43958e3b9a5d1
|
| 3 |
size 11923672
|
build/torch26-cxx98-cu126-x86_64-linux/megablocks/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_89e2950
|
| 3 |
+
ops = torch.ops._megablocks_89e2950
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch26-cxx98-cu126-x86_64-linux/megablocks/layers.py
CHANGED
|
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# Global variable to store load balancing loss
|
| 156 |
_LOAD_BALANCING_LOSS = []
|
| 157 |
|
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
| 680 |
return x, expert_weights, router_scores
|
| 681 |
|
| 682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 684 |
|
| 685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 694 |
-
|
| 695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 698 |
|
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 722 |
hidden_size=self.experts.hidden_size,
|
| 723 |
mlp_impl=mlp_impl,
|
| 724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 725 |
return output, expert_weights_out
|
|
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
| 155 |
+
# Shared expert MLP forward pass
|
| 156 |
+
def shared_mlp_forward(
|
| 157 |
+
x: torch.Tensor,
|
| 158 |
+
up_proj_weight: torch.Tensor,
|
| 159 |
+
down_proj_weight: torch.Tensor,
|
| 160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 162 |
+
activation_fn: Optional[Any] = None,
|
| 163 |
+
gradient_scale: Optional[float] = None,
|
| 164 |
+
) -> torch.Tensor:
|
| 165 |
+
# Default activation function
|
| 166 |
+
if activation_fn is None:
|
| 167 |
+
activation_fn = torch.nn.functional.gelu
|
| 168 |
+
|
| 169 |
+
# Scale weights
|
| 170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
| 171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
| 172 |
+
if up_proj_bias is not None:
|
| 173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
| 174 |
+
if down_proj_bias is not None:
|
| 175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
| 176 |
+
|
| 177 |
+
# Resolve dtensors
|
| 178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
| 179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
| 180 |
+
if up_proj_bias is not None:
|
| 181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
| 182 |
+
if down_proj_bias is not None:
|
| 183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
| 184 |
+
|
| 185 |
+
# Up projection
|
| 186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
| 187 |
+
|
| 188 |
+
# Activation
|
| 189 |
+
x = activation_fn(x)
|
| 190 |
+
|
| 191 |
+
# Down projection
|
| 192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
| 193 |
+
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# Combine outputs from shared expert and regular experts
|
| 198 |
+
def combine_expert_shared_outputs(
|
| 199 |
+
shared_expert_out: torch.Tensor,
|
| 200 |
+
expert_out: torch.Tensor,
|
| 201 |
+
shared_expert_weighted_sum: bool = False,
|
| 202 |
+
moe_top_k: int = 1,
|
| 203 |
+
) -> torch.Tensor:
|
| 204 |
+
if shared_expert_weighted_sum:
|
| 205 |
+
# Weighted sum based on number of experts used
|
| 206 |
+
total_experts = moe_top_k + 1
|
| 207 |
+
shared_weight = 1.0 / total_experts
|
| 208 |
+
expert_weight = moe_top_k / total_experts
|
| 209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
| 210 |
+
else:
|
| 211 |
+
# Simple addition
|
| 212 |
+
return shared_expert_out + expert_out
|
| 213 |
+
|
| 214 |
+
|
| 215 |
# Global variable to store load balancing loss
|
| 216 |
_LOAD_BALANCING_LOSS = []
|
| 217 |
|
|
|
|
| 740 |
return x, expert_weights, router_scores
|
| 741 |
|
| 742 |
|
| 743 |
+
def moe_forward_with_shared_expert(
|
| 744 |
+
x: torch.Tensor,
|
| 745 |
+
router_weight: torch.Tensor,
|
| 746 |
+
moe_top_k: int,
|
| 747 |
+
moe_num_experts: int,
|
| 748 |
+
moe_jitter_eps: float = None,
|
| 749 |
+
moe_normalize_expert_weights: int = None,
|
| 750 |
+
uniform_expert_assignment: bool = False,
|
| 751 |
+
training: bool = False,
|
| 752 |
+
w1: torch.Tensor = None,
|
| 753 |
+
w2: torch.Tensor = None,
|
| 754 |
+
w1_bias: torch.Tensor = None,
|
| 755 |
+
w2_bias: torch.Tensor = None,
|
| 756 |
+
gradient_scale: Optional[float] = None,
|
| 757 |
+
alpha: float = 1.702,
|
| 758 |
+
sort_end_bit: int = 0,
|
| 759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
| 760 |
+
moe_capacity_factor: float = 1.0,
|
| 761 |
+
moe_expert_model_parallelism: bool = False,
|
| 762 |
+
forward_fn: Any = None,
|
| 763 |
+
hidden_size: int = None,
|
| 764 |
+
mlp_impl: str = "grouped",
|
| 765 |
+
# Shared expert parameters
|
| 766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
| 767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
| 768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
| 769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
| 770 |
+
shared_expert_weighted_sum: bool = False,
|
| 771 |
+
shared_activation_fn: Optional[Any] = None,
|
| 772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 773 |
+
|
| 774 |
+
# First, compute regular MoE forward pass
|
| 775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
| 776 |
+
x=x,
|
| 777 |
+
router_weight=router_weight,
|
| 778 |
+
moe_top_k=moe_top_k,
|
| 779 |
+
moe_num_experts=moe_num_experts,
|
| 780 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 783 |
+
training=training,
|
| 784 |
+
w1=w1,
|
| 785 |
+
w2=w2,
|
| 786 |
+
w1_bias=w1_bias,
|
| 787 |
+
w2_bias=w2_bias,
|
| 788 |
+
gradient_scale=gradient_scale,
|
| 789 |
+
alpha=alpha,
|
| 790 |
+
sort_end_bit=sort_end_bit,
|
| 791 |
+
expert_parallel_group=expert_parallel_group,
|
| 792 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
| 794 |
+
forward_fn=forward_fn,
|
| 795 |
+
hidden_size=hidden_size,
|
| 796 |
+
mlp_impl=mlp_impl,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
# If shared expert weights provided, compute shared expert output
|
| 800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
| 801 |
+
shared_expert_out = shared_mlp_forward(
|
| 802 |
+
x=x,
|
| 803 |
+
up_proj_weight=shared_up_proj_weight,
|
| 804 |
+
down_proj_weight=shared_down_proj_weight,
|
| 805 |
+
up_proj_bias=shared_up_proj_bias,
|
| 806 |
+
down_proj_bias=shared_down_proj_bias,
|
| 807 |
+
activation_fn=shared_activation_fn,
|
| 808 |
+
gradient_scale=gradient_scale,
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
# Combine expert outputs
|
| 812 |
+
combined_out = combine_expert_shared_outputs(
|
| 813 |
+
shared_expert_out=shared_expert_out,
|
| 814 |
+
expert_out=expert_out,
|
| 815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
| 816 |
+
moe_top_k=moe_top_k,
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
return combined_out, expert_weights, router_scores
|
| 820 |
+
|
| 821 |
+
# Return regular MoE output if no shared expert
|
| 822 |
+
return expert_out, expert_weights, router_scores
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def create_shared_expert_weights(
|
| 826 |
+
hidden_size: int,
|
| 827 |
+
shared_expert_hidden_size: int,
|
| 828 |
+
device: torch.device,
|
| 829 |
+
dtype: torch.dtype,
|
| 830 |
+
init_method: Any,
|
| 831 |
+
output_layer_init_method: Any = None,
|
| 832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 833 |
+
|
| 834 |
+
if output_layer_init_method is None:
|
| 835 |
+
output_layer_init_method = init_method
|
| 836 |
+
|
| 837 |
+
# Create weight tensors
|
| 838 |
+
up_proj_weight = torch.empty(
|
| 839 |
+
shared_expert_hidden_size,
|
| 840 |
+
hidden_size,
|
| 841 |
+
device=device,
|
| 842 |
+
dtype=dtype,
|
| 843 |
+
)
|
| 844 |
+
down_proj_weight = torch.empty(
|
| 845 |
+
hidden_size,
|
| 846 |
+
shared_expert_hidden_size,
|
| 847 |
+
device=device,
|
| 848 |
+
dtype=dtype,
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
# Initialize weights
|
| 852 |
+
init_method(up_proj_weight)
|
| 853 |
+
output_layer_init_method(down_proj_weight)
|
| 854 |
+
|
| 855 |
+
# No bias by default
|
| 856 |
+
return up_proj_weight, down_proj_weight, None, None
|
| 857 |
+
|
| 858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
| 859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
| 860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
| 861 |
+
# TODO: Replace with a more robust solution when available
|
| 862 |
+
def get_device_mesh(model):
|
| 863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
| 864 |
+
try:
|
| 865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
| 866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
| 867 |
+
# Extract the device_mesh from the closure
|
| 868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
| 869 |
+
except Exception:
|
| 870 |
+
return None
|
| 871 |
+
|
| 872 |
+
|
| 873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 874 |
|
| 875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 884 |
+
|
| 885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 886 |
+
if expert_parallel_group is None:
|
| 887 |
+
device_mesh = get_device_mesh(self)
|
| 888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 889 |
+
|
| 890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 892 |
|
|
|
|
| 916 |
hidden_size=self.experts.hidden_size,
|
| 917 |
mlp_impl=mlp_impl,
|
| 918 |
)
|
| 919 |
+
return output, expert_weights_out
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
| 923 |
+
|
| 924 |
+
def __init__(self):
|
| 925 |
+
super().__init__()
|
| 926 |
+
# Shared expert weights will be set by the user
|
| 927 |
+
self.shared_up_proj_weight = None
|
| 928 |
+
self.shared_down_proj_weight = None
|
| 929 |
+
self.shared_up_proj_bias = None
|
| 930 |
+
self.shared_down_proj_bias = None
|
| 931 |
+
self.shared_expert_weighted_sum = False
|
| 932 |
+
self.shared_activation_fn = None
|
| 933 |
+
|
| 934 |
+
def set_shared_expert_weights(
|
| 935 |
+
self,
|
| 936 |
+
up_proj_weight: torch.Tensor,
|
| 937 |
+
down_proj_weight: torch.Tensor,
|
| 938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 940 |
+
weighted_sum: bool = False,
|
| 941 |
+
activation_fn: Optional[Any] = None,
|
| 942 |
+
):
|
| 943 |
+
self.shared_up_proj_weight = up_proj_weight
|
| 944 |
+
self.shared_down_proj_weight = down_proj_weight
|
| 945 |
+
self.shared_up_proj_bias = up_proj_bias
|
| 946 |
+
self.shared_down_proj_bias = down_proj_bias
|
| 947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
| 948 |
+
self.shared_activation_fn = activation_fn
|
| 949 |
+
|
| 950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
| 952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
| 953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
| 954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
| 955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
| 956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 959 |
+
|
| 960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 961 |
+
if expert_parallel_group is None:
|
| 962 |
+
device_mesh = get_device_mesh(self)
|
| 963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 964 |
+
|
| 965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 967 |
+
|
| 968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
| 969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
| 970 |
+
|
| 971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
| 972 |
+
x=x,
|
| 973 |
+
router_weight=self.router.weight,
|
| 974 |
+
moe_top_k=moe_top_k,
|
| 975 |
+
moe_num_experts=moe_num_experts,
|
| 976 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 979 |
+
training=self.training,
|
| 980 |
+
w1=self.experts.gate_up_proj,
|
| 981 |
+
w2=self.experts.down_proj,
|
| 982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
| 983 |
+
w2_bias=self.experts.down_proj_bias,
|
| 984 |
+
gradient_scale=gradient_scale,
|
| 985 |
+
alpha=alpha,
|
| 986 |
+
sort_end_bit=sort_end_bit,
|
| 987 |
+
expert_parallel_group=expert_parallel_group,
|
| 988 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 989 |
+
moe_expert_model_parallelism=has_parallel,
|
| 990 |
+
forward_fn=forward_fn,
|
| 991 |
+
hidden_size=self.experts.hidden_size,
|
| 992 |
+
mlp_impl=mlp_impl,
|
| 993 |
+
# Shared expert parameters
|
| 994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
| 995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
| 996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
| 997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
| 998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
| 999 |
+
shared_activation_fn=self.shared_activation_fn,
|
| 1000 |
+
)
|
| 1001 |
return output, expert_weights_out
|
build/torch27-cxx11-cu118-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 10517816
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a39b315c5359b79a67282160b5b344853aa06b5a5c9d8efafb903eb4f249b645
|
| 3 |
size 10517816
|
build/torch27-cxx11-cu118-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:002c2687dbc5693308fe32eaebe2f45ed3c85454fd45bc06d7b30e9c1a6d8949
|
| 3 |
-
size 10517816
|
|
|
|
|
|
|
|
|
|
|
|
build/torch27-cxx11-cu118-x86_64-linux/megablocks/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_89e2950
|
| 3 |
+
ops = torch.ops._megablocks_89e2950
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch27-cxx11-cu118-x86_64-linux/megablocks/layers.py
CHANGED
|
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# Global variable to store load balancing loss
|
| 156 |
_LOAD_BALANCING_LOSS = []
|
| 157 |
|
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
| 680 |
return x, expert_weights, router_scores
|
| 681 |
|
| 682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 684 |
|
| 685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 694 |
-
|
| 695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 698 |
|
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 722 |
hidden_size=self.experts.hidden_size,
|
| 723 |
mlp_impl=mlp_impl,
|
| 724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 725 |
return output, expert_weights_out
|
|
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
| 155 |
+
# Shared expert MLP forward pass
|
| 156 |
+
def shared_mlp_forward(
|
| 157 |
+
x: torch.Tensor,
|
| 158 |
+
up_proj_weight: torch.Tensor,
|
| 159 |
+
down_proj_weight: torch.Tensor,
|
| 160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 162 |
+
activation_fn: Optional[Any] = None,
|
| 163 |
+
gradient_scale: Optional[float] = None,
|
| 164 |
+
) -> torch.Tensor:
|
| 165 |
+
# Default activation function
|
| 166 |
+
if activation_fn is None:
|
| 167 |
+
activation_fn = torch.nn.functional.gelu
|
| 168 |
+
|
| 169 |
+
# Scale weights
|
| 170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
| 171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
| 172 |
+
if up_proj_bias is not None:
|
| 173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
| 174 |
+
if down_proj_bias is not None:
|
| 175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
| 176 |
+
|
| 177 |
+
# Resolve dtensors
|
| 178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
| 179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
| 180 |
+
if up_proj_bias is not None:
|
| 181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
| 182 |
+
if down_proj_bias is not None:
|
| 183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
| 184 |
+
|
| 185 |
+
# Up projection
|
| 186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
| 187 |
+
|
| 188 |
+
# Activation
|
| 189 |
+
x = activation_fn(x)
|
| 190 |
+
|
| 191 |
+
# Down projection
|
| 192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
| 193 |
+
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# Combine outputs from shared expert and regular experts
|
| 198 |
+
def combine_expert_shared_outputs(
|
| 199 |
+
shared_expert_out: torch.Tensor,
|
| 200 |
+
expert_out: torch.Tensor,
|
| 201 |
+
shared_expert_weighted_sum: bool = False,
|
| 202 |
+
moe_top_k: int = 1,
|
| 203 |
+
) -> torch.Tensor:
|
| 204 |
+
if shared_expert_weighted_sum:
|
| 205 |
+
# Weighted sum based on number of experts used
|
| 206 |
+
total_experts = moe_top_k + 1
|
| 207 |
+
shared_weight = 1.0 / total_experts
|
| 208 |
+
expert_weight = moe_top_k / total_experts
|
| 209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
| 210 |
+
else:
|
| 211 |
+
# Simple addition
|
| 212 |
+
return shared_expert_out + expert_out
|
| 213 |
+
|
| 214 |
+
|
| 215 |
# Global variable to store load balancing loss
|
| 216 |
_LOAD_BALANCING_LOSS = []
|
| 217 |
|
|
|
|
| 740 |
return x, expert_weights, router_scores
|
| 741 |
|
| 742 |
|
| 743 |
+
def moe_forward_with_shared_expert(
|
| 744 |
+
x: torch.Tensor,
|
| 745 |
+
router_weight: torch.Tensor,
|
| 746 |
+
moe_top_k: int,
|
| 747 |
+
moe_num_experts: int,
|
| 748 |
+
moe_jitter_eps: float = None,
|
| 749 |
+
moe_normalize_expert_weights: int = None,
|
| 750 |
+
uniform_expert_assignment: bool = False,
|
| 751 |
+
training: bool = False,
|
| 752 |
+
w1: torch.Tensor = None,
|
| 753 |
+
w2: torch.Tensor = None,
|
| 754 |
+
w1_bias: torch.Tensor = None,
|
| 755 |
+
w2_bias: torch.Tensor = None,
|
| 756 |
+
gradient_scale: Optional[float] = None,
|
| 757 |
+
alpha: float = 1.702,
|
| 758 |
+
sort_end_bit: int = 0,
|
| 759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
| 760 |
+
moe_capacity_factor: float = 1.0,
|
| 761 |
+
moe_expert_model_parallelism: bool = False,
|
| 762 |
+
forward_fn: Any = None,
|
| 763 |
+
hidden_size: int = None,
|
| 764 |
+
mlp_impl: str = "grouped",
|
| 765 |
+
# Shared expert parameters
|
| 766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
| 767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
| 768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
| 769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
| 770 |
+
shared_expert_weighted_sum: bool = False,
|
| 771 |
+
shared_activation_fn: Optional[Any] = None,
|
| 772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 773 |
+
|
| 774 |
+
# First, compute regular MoE forward pass
|
| 775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
| 776 |
+
x=x,
|
| 777 |
+
router_weight=router_weight,
|
| 778 |
+
moe_top_k=moe_top_k,
|
| 779 |
+
moe_num_experts=moe_num_experts,
|
| 780 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 783 |
+
training=training,
|
| 784 |
+
w1=w1,
|
| 785 |
+
w2=w2,
|
| 786 |
+
w1_bias=w1_bias,
|
| 787 |
+
w2_bias=w2_bias,
|
| 788 |
+
gradient_scale=gradient_scale,
|
| 789 |
+
alpha=alpha,
|
| 790 |
+
sort_end_bit=sort_end_bit,
|
| 791 |
+
expert_parallel_group=expert_parallel_group,
|
| 792 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
| 794 |
+
forward_fn=forward_fn,
|
| 795 |
+
hidden_size=hidden_size,
|
| 796 |
+
mlp_impl=mlp_impl,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
# If shared expert weights provided, compute shared expert output
|
| 800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
| 801 |
+
shared_expert_out = shared_mlp_forward(
|
| 802 |
+
x=x,
|
| 803 |
+
up_proj_weight=shared_up_proj_weight,
|
| 804 |
+
down_proj_weight=shared_down_proj_weight,
|
| 805 |
+
up_proj_bias=shared_up_proj_bias,
|
| 806 |
+
down_proj_bias=shared_down_proj_bias,
|
| 807 |
+
activation_fn=shared_activation_fn,
|
| 808 |
+
gradient_scale=gradient_scale,
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
# Combine expert outputs
|
| 812 |
+
combined_out = combine_expert_shared_outputs(
|
| 813 |
+
shared_expert_out=shared_expert_out,
|
| 814 |
+
expert_out=expert_out,
|
| 815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
| 816 |
+
moe_top_k=moe_top_k,
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
return combined_out, expert_weights, router_scores
|
| 820 |
+
|
| 821 |
+
# Return regular MoE output if no shared expert
|
| 822 |
+
return expert_out, expert_weights, router_scores
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def create_shared_expert_weights(
|
| 826 |
+
hidden_size: int,
|
| 827 |
+
shared_expert_hidden_size: int,
|
| 828 |
+
device: torch.device,
|
| 829 |
+
dtype: torch.dtype,
|
| 830 |
+
init_method: Any,
|
| 831 |
+
output_layer_init_method: Any = None,
|
| 832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 833 |
+
|
| 834 |
+
if output_layer_init_method is None:
|
| 835 |
+
output_layer_init_method = init_method
|
| 836 |
+
|
| 837 |
+
# Create weight tensors
|
| 838 |
+
up_proj_weight = torch.empty(
|
| 839 |
+
shared_expert_hidden_size,
|
| 840 |
+
hidden_size,
|
| 841 |
+
device=device,
|
| 842 |
+
dtype=dtype,
|
| 843 |
+
)
|
| 844 |
+
down_proj_weight = torch.empty(
|
| 845 |
+
hidden_size,
|
| 846 |
+
shared_expert_hidden_size,
|
| 847 |
+
device=device,
|
| 848 |
+
dtype=dtype,
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
# Initialize weights
|
| 852 |
+
init_method(up_proj_weight)
|
| 853 |
+
output_layer_init_method(down_proj_weight)
|
| 854 |
+
|
| 855 |
+
# No bias by default
|
| 856 |
+
return up_proj_weight, down_proj_weight, None, None
|
| 857 |
+
|
| 858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
| 859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
| 860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
| 861 |
+
# TODO: Replace with a more robust solution when available
|
| 862 |
+
def get_device_mesh(model):
|
| 863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
| 864 |
+
try:
|
| 865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
| 866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
| 867 |
+
# Extract the device_mesh from the closure
|
| 868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
| 869 |
+
except Exception:
|
| 870 |
+
return None
|
| 871 |
+
|
| 872 |
+
|
| 873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 874 |
|
| 875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 884 |
+
|
| 885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 886 |
+
if expert_parallel_group is None:
|
| 887 |
+
device_mesh = get_device_mesh(self)
|
| 888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 889 |
+
|
| 890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 892 |
|
|
|
|
| 916 |
hidden_size=self.experts.hidden_size,
|
| 917 |
mlp_impl=mlp_impl,
|
| 918 |
)
|
| 919 |
+
return output, expert_weights_out
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
| 923 |
+
|
| 924 |
+
def __init__(self):
|
| 925 |
+
super().__init__()
|
| 926 |
+
# Shared expert weights will be set by the user
|
| 927 |
+
self.shared_up_proj_weight = None
|
| 928 |
+
self.shared_down_proj_weight = None
|
| 929 |
+
self.shared_up_proj_bias = None
|
| 930 |
+
self.shared_down_proj_bias = None
|
| 931 |
+
self.shared_expert_weighted_sum = False
|
| 932 |
+
self.shared_activation_fn = None
|
| 933 |
+
|
| 934 |
+
def set_shared_expert_weights(
|
| 935 |
+
self,
|
| 936 |
+
up_proj_weight: torch.Tensor,
|
| 937 |
+
down_proj_weight: torch.Tensor,
|
| 938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 940 |
+
weighted_sum: bool = False,
|
| 941 |
+
activation_fn: Optional[Any] = None,
|
| 942 |
+
):
|
| 943 |
+
self.shared_up_proj_weight = up_proj_weight
|
| 944 |
+
self.shared_down_proj_weight = down_proj_weight
|
| 945 |
+
self.shared_up_proj_bias = up_proj_bias
|
| 946 |
+
self.shared_down_proj_bias = down_proj_bias
|
| 947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
| 948 |
+
self.shared_activation_fn = activation_fn
|
| 949 |
+
|
| 950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
| 952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
| 953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
| 954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
| 955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
| 956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 959 |
+
|
| 960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 961 |
+
if expert_parallel_group is None:
|
| 962 |
+
device_mesh = get_device_mesh(self)
|
| 963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 964 |
+
|
| 965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 967 |
+
|
| 968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
| 969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
| 970 |
+
|
| 971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
| 972 |
+
x=x,
|
| 973 |
+
router_weight=self.router.weight,
|
| 974 |
+
moe_top_k=moe_top_k,
|
| 975 |
+
moe_num_experts=moe_num_experts,
|
| 976 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 979 |
+
training=self.training,
|
| 980 |
+
w1=self.experts.gate_up_proj,
|
| 981 |
+
w2=self.experts.down_proj,
|
| 982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
| 983 |
+
w2_bias=self.experts.down_proj_bias,
|
| 984 |
+
gradient_scale=gradient_scale,
|
| 985 |
+
alpha=alpha,
|
| 986 |
+
sort_end_bit=sort_end_bit,
|
| 987 |
+
expert_parallel_group=expert_parallel_group,
|
| 988 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 989 |
+
moe_expert_model_parallelism=has_parallel,
|
| 990 |
+
forward_fn=forward_fn,
|
| 991 |
+
hidden_size=self.experts.hidden_size,
|
| 992 |
+
mlp_impl=mlp_impl,
|
| 993 |
+
# Shared expert parameters
|
| 994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
| 995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
| 996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
| 997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
| 998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
| 999 |
+
shared_activation_fn=self.shared_activation_fn,
|
| 1000 |
+
)
|
| 1001 |
return output, expert_weights_out
|
build/torch27-cxx11-cu126-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 11931080
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4870e4a9a831c30c7177b9b23b2b20d64f47242f16d818be1884b4e130e063c1
|
| 3 |
size 11931080
|
build/torch27-cxx11-cu126-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:ef9197ea269734d4e0528887ab3c353fa8ba10ccf9a82c9abe85b72bc0ea3553
|
| 3 |
-
size 11931080
|
|
|
|
|
|
|
|
|
|
|
|
build/torch27-cxx11-cu126-x86_64-linux/megablocks/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_89e2950
|
| 3 |
+
ops = torch.ops._megablocks_89e2950
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch27-cxx11-cu126-x86_64-linux/megablocks/layers.py
CHANGED
|
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# Global variable to store load balancing loss
|
| 156 |
_LOAD_BALANCING_LOSS = []
|
| 157 |
|
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
| 680 |
return x, expert_weights, router_scores
|
| 681 |
|
| 682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 684 |
|
| 685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 694 |
-
|
| 695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 698 |
|
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 722 |
hidden_size=self.experts.hidden_size,
|
| 723 |
mlp_impl=mlp_impl,
|
| 724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 725 |
return output, expert_weights_out
|
|
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
| 155 |
+
# Shared expert MLP forward pass
|
| 156 |
+
def shared_mlp_forward(
|
| 157 |
+
x: torch.Tensor,
|
| 158 |
+
up_proj_weight: torch.Tensor,
|
| 159 |
+
down_proj_weight: torch.Tensor,
|
| 160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 162 |
+
activation_fn: Optional[Any] = None,
|
| 163 |
+
gradient_scale: Optional[float] = None,
|
| 164 |
+
) -> torch.Tensor:
|
| 165 |
+
# Default activation function
|
| 166 |
+
if activation_fn is None:
|
| 167 |
+
activation_fn = torch.nn.functional.gelu
|
| 168 |
+
|
| 169 |
+
# Scale weights
|
| 170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
| 171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
| 172 |
+
if up_proj_bias is not None:
|
| 173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
| 174 |
+
if down_proj_bias is not None:
|
| 175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
| 176 |
+
|
| 177 |
+
# Resolve dtensors
|
| 178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
| 179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
| 180 |
+
if up_proj_bias is not None:
|
| 181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
| 182 |
+
if down_proj_bias is not None:
|
| 183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
| 184 |
+
|
| 185 |
+
# Up projection
|
| 186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
| 187 |
+
|
| 188 |
+
# Activation
|
| 189 |
+
x = activation_fn(x)
|
| 190 |
+
|
| 191 |
+
# Down projection
|
| 192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
| 193 |
+
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# Combine outputs from shared expert and regular experts
|
| 198 |
+
def combine_expert_shared_outputs(
|
| 199 |
+
shared_expert_out: torch.Tensor,
|
| 200 |
+
expert_out: torch.Tensor,
|
| 201 |
+
shared_expert_weighted_sum: bool = False,
|
| 202 |
+
moe_top_k: int = 1,
|
| 203 |
+
) -> torch.Tensor:
|
| 204 |
+
if shared_expert_weighted_sum:
|
| 205 |
+
# Weighted sum based on number of experts used
|
| 206 |
+
total_experts = moe_top_k + 1
|
| 207 |
+
shared_weight = 1.0 / total_experts
|
| 208 |
+
expert_weight = moe_top_k / total_experts
|
| 209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
| 210 |
+
else:
|
| 211 |
+
# Simple addition
|
| 212 |
+
return shared_expert_out + expert_out
|
| 213 |
+
|
| 214 |
+
|
| 215 |
# Global variable to store load balancing loss
|
| 216 |
_LOAD_BALANCING_LOSS = []
|
| 217 |
|
|
|
|
| 740 |
return x, expert_weights, router_scores
|
| 741 |
|
| 742 |
|
| 743 |
+
def moe_forward_with_shared_expert(
|
| 744 |
+
x: torch.Tensor,
|
| 745 |
+
router_weight: torch.Tensor,
|
| 746 |
+
moe_top_k: int,
|
| 747 |
+
moe_num_experts: int,
|
| 748 |
+
moe_jitter_eps: float = None,
|
| 749 |
+
moe_normalize_expert_weights: int = None,
|
| 750 |
+
uniform_expert_assignment: bool = False,
|
| 751 |
+
training: bool = False,
|
| 752 |
+
w1: torch.Tensor = None,
|
| 753 |
+
w2: torch.Tensor = None,
|
| 754 |
+
w1_bias: torch.Tensor = None,
|
| 755 |
+
w2_bias: torch.Tensor = None,
|
| 756 |
+
gradient_scale: Optional[float] = None,
|
| 757 |
+
alpha: float = 1.702,
|
| 758 |
+
sort_end_bit: int = 0,
|
| 759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
| 760 |
+
moe_capacity_factor: float = 1.0,
|
| 761 |
+
moe_expert_model_parallelism: bool = False,
|
| 762 |
+
forward_fn: Any = None,
|
| 763 |
+
hidden_size: int = None,
|
| 764 |
+
mlp_impl: str = "grouped",
|
| 765 |
+
# Shared expert parameters
|
| 766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
| 767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
| 768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
| 769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
| 770 |
+
shared_expert_weighted_sum: bool = False,
|
| 771 |
+
shared_activation_fn: Optional[Any] = None,
|
| 772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 773 |
+
|
| 774 |
+
# First, compute regular MoE forward pass
|
| 775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
| 776 |
+
x=x,
|
| 777 |
+
router_weight=router_weight,
|
| 778 |
+
moe_top_k=moe_top_k,
|
| 779 |
+
moe_num_experts=moe_num_experts,
|
| 780 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 783 |
+
training=training,
|
| 784 |
+
w1=w1,
|
| 785 |
+
w2=w2,
|
| 786 |
+
w1_bias=w1_bias,
|
| 787 |
+
w2_bias=w2_bias,
|
| 788 |
+
gradient_scale=gradient_scale,
|
| 789 |
+
alpha=alpha,
|
| 790 |
+
sort_end_bit=sort_end_bit,
|
| 791 |
+
expert_parallel_group=expert_parallel_group,
|
| 792 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
| 794 |
+
forward_fn=forward_fn,
|
| 795 |
+
hidden_size=hidden_size,
|
| 796 |
+
mlp_impl=mlp_impl,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
# If shared expert weights provided, compute shared expert output
|
| 800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
| 801 |
+
shared_expert_out = shared_mlp_forward(
|
| 802 |
+
x=x,
|
| 803 |
+
up_proj_weight=shared_up_proj_weight,
|
| 804 |
+
down_proj_weight=shared_down_proj_weight,
|
| 805 |
+
up_proj_bias=shared_up_proj_bias,
|
| 806 |
+
down_proj_bias=shared_down_proj_bias,
|
| 807 |
+
activation_fn=shared_activation_fn,
|
| 808 |
+
gradient_scale=gradient_scale,
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
# Combine expert outputs
|
| 812 |
+
combined_out = combine_expert_shared_outputs(
|
| 813 |
+
shared_expert_out=shared_expert_out,
|
| 814 |
+
expert_out=expert_out,
|
| 815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
| 816 |
+
moe_top_k=moe_top_k,
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
return combined_out, expert_weights, router_scores
|
| 820 |
+
|
| 821 |
+
# Return regular MoE output if no shared expert
|
| 822 |
+
return expert_out, expert_weights, router_scores
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def create_shared_expert_weights(
|
| 826 |
+
hidden_size: int,
|
| 827 |
+
shared_expert_hidden_size: int,
|
| 828 |
+
device: torch.device,
|
| 829 |
+
dtype: torch.dtype,
|
| 830 |
+
init_method: Any,
|
| 831 |
+
output_layer_init_method: Any = None,
|
| 832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 833 |
+
|
| 834 |
+
if output_layer_init_method is None:
|
| 835 |
+
output_layer_init_method = init_method
|
| 836 |
+
|
| 837 |
+
# Create weight tensors
|
| 838 |
+
up_proj_weight = torch.empty(
|
| 839 |
+
shared_expert_hidden_size,
|
| 840 |
+
hidden_size,
|
| 841 |
+
device=device,
|
| 842 |
+
dtype=dtype,
|
| 843 |
+
)
|
| 844 |
+
down_proj_weight = torch.empty(
|
| 845 |
+
hidden_size,
|
| 846 |
+
shared_expert_hidden_size,
|
| 847 |
+
device=device,
|
| 848 |
+
dtype=dtype,
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
# Initialize weights
|
| 852 |
+
init_method(up_proj_weight)
|
| 853 |
+
output_layer_init_method(down_proj_weight)
|
| 854 |
+
|
| 855 |
+
# No bias by default
|
| 856 |
+
return up_proj_weight, down_proj_weight, None, None
|
| 857 |
+
|
| 858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
| 859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
| 860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
| 861 |
+
# TODO: Replace with a more robust solution when available
|
| 862 |
+
def get_device_mesh(model):
|
| 863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
| 864 |
+
try:
|
| 865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
| 866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
| 867 |
+
# Extract the device_mesh from the closure
|
| 868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
| 869 |
+
except Exception:
|
| 870 |
+
return None
|
| 871 |
+
|
| 872 |
+
|
| 873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 874 |
|
| 875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 884 |
+
|
| 885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 886 |
+
if expert_parallel_group is None:
|
| 887 |
+
device_mesh = get_device_mesh(self)
|
| 888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 889 |
+
|
| 890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 892 |
|
|
|
|
| 916 |
hidden_size=self.experts.hidden_size,
|
| 917 |
mlp_impl=mlp_impl,
|
| 918 |
)
|
| 919 |
+
return output, expert_weights_out
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
| 923 |
+
|
| 924 |
+
def __init__(self):
|
| 925 |
+
super().__init__()
|
| 926 |
+
# Shared expert weights will be set by the user
|
| 927 |
+
self.shared_up_proj_weight = None
|
| 928 |
+
self.shared_down_proj_weight = None
|
| 929 |
+
self.shared_up_proj_bias = None
|
| 930 |
+
self.shared_down_proj_bias = None
|
| 931 |
+
self.shared_expert_weighted_sum = False
|
| 932 |
+
self.shared_activation_fn = None
|
| 933 |
+
|
| 934 |
+
def set_shared_expert_weights(
|
| 935 |
+
self,
|
| 936 |
+
up_proj_weight: torch.Tensor,
|
| 937 |
+
down_proj_weight: torch.Tensor,
|
| 938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 940 |
+
weighted_sum: bool = False,
|
| 941 |
+
activation_fn: Optional[Any] = None,
|
| 942 |
+
):
|
| 943 |
+
self.shared_up_proj_weight = up_proj_weight
|
| 944 |
+
self.shared_down_proj_weight = down_proj_weight
|
| 945 |
+
self.shared_up_proj_bias = up_proj_bias
|
| 946 |
+
self.shared_down_proj_bias = down_proj_bias
|
| 947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
| 948 |
+
self.shared_activation_fn = activation_fn
|
| 949 |
+
|
| 950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
| 952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
| 953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
| 954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
| 955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
| 956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 959 |
+
|
| 960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 961 |
+
if expert_parallel_group is None:
|
| 962 |
+
device_mesh = get_device_mesh(self)
|
| 963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 964 |
+
|
| 965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 967 |
+
|
| 968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
| 969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
| 970 |
+
|
| 971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
| 972 |
+
x=x,
|
| 973 |
+
router_weight=self.router.weight,
|
| 974 |
+
moe_top_k=moe_top_k,
|
| 975 |
+
moe_num_experts=moe_num_experts,
|
| 976 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 979 |
+
training=self.training,
|
| 980 |
+
w1=self.experts.gate_up_proj,
|
| 981 |
+
w2=self.experts.down_proj,
|
| 982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
| 983 |
+
w2_bias=self.experts.down_proj_bias,
|
| 984 |
+
gradient_scale=gradient_scale,
|
| 985 |
+
alpha=alpha,
|
| 986 |
+
sort_end_bit=sort_end_bit,
|
| 987 |
+
expert_parallel_group=expert_parallel_group,
|
| 988 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 989 |
+
moe_expert_model_parallelism=has_parallel,
|
| 990 |
+
forward_fn=forward_fn,
|
| 991 |
+
hidden_size=self.experts.hidden_size,
|
| 992 |
+
mlp_impl=mlp_impl,
|
| 993 |
+
# Shared expert parameters
|
| 994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
| 995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
| 996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
| 997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
| 998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
| 999 |
+
shared_activation_fn=self.shared_activation_fn,
|
| 1000 |
+
)
|
| 1001 |
return output, expert_weights_out
|
build/torch27-cxx11-cu128-x86_64-linux/megablocks/{_megablocks_76c7de7.abi3.so β _megablocks_89e2950.abi3.so}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 17892624
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:37844f7b2972aae75a1eeb8cda3b573a93ef27dd5a73b2cfb95fca1f41da07d9
|
| 3 |
size 17892624
|
build/torch27-cxx11-cu128-x86_64-linux/megablocks/_megablocks_9a1816c.abi3.so
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:b071dec56af72c9e6b8408106b97fb42355b08e94cc1200bb6f4d3f42ba0e97e
|
| 3 |
-
size 17892624
|
|
|
|
|
|
|
|
|
|
|
|
build/torch27-cxx11-cu128-x86_64-linux/megablocks/_ops.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
-
from . import
|
| 3 |
-
ops = torch.ops.
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
-
return f"
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from . import _megablocks_89e2950
|
| 3 |
+
ops = torch.ops._megablocks_89e2950
|
| 4 |
|
| 5 |
def add_op_namespace_prefix(op_name: str):
|
| 6 |
"""
|
| 7 |
Prefix op by namespace.
|
| 8 |
"""
|
| 9 |
+
return f"_megablocks_89e2950::{op_name}"
|
build/torch27-cxx11-cu128-x86_64-linux/megablocks/layers.py
CHANGED
|
@@ -152,6 +152,66 @@ def mlp_forward(
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# Global variable to store load balancing loss
|
| 156 |
_LOAD_BALANCING_LOSS = []
|
| 157 |
|
|
@@ -680,6 +740,136 @@ def moe_forward(
|
|
| 680 |
return x, expert_weights, router_scores
|
| 681 |
|
| 682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 684 |
|
| 685 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
@@ -691,8 +881,12 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 691 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 692 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 693 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 694 |
-
|
| 695 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 697 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 698 |
|
|
@@ -722,4 +916,86 @@ class MegaBlocksMoeMLP(torch.nn.Module):
|
|
| 722 |
hidden_size=self.experts.hidden_size,
|
| 723 |
mlp_impl=mlp_impl,
|
| 724 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 725 |
return output, expert_weights_out
|
|
|
|
| 152 |
return torch.bmm(x, w2) + w2_bias[..., None, :]
|
| 153 |
|
| 154 |
|
| 155 |
+
# Shared expert MLP forward pass
|
| 156 |
+
def shared_mlp_forward(
|
| 157 |
+
x: torch.Tensor,
|
| 158 |
+
up_proj_weight: torch.Tensor,
|
| 159 |
+
down_proj_weight: torch.Tensor,
|
| 160 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 161 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 162 |
+
activation_fn: Optional[Any] = None,
|
| 163 |
+
gradient_scale: Optional[float] = None,
|
| 164 |
+
) -> torch.Tensor:
|
| 165 |
+
# Default activation function
|
| 166 |
+
if activation_fn is None:
|
| 167 |
+
activation_fn = torch.nn.functional.gelu
|
| 168 |
+
|
| 169 |
+
# Scale weights
|
| 170 |
+
up_proj_weight = scale_grad(up_proj_weight, gradient_scale)
|
| 171 |
+
down_proj_weight = scale_grad(down_proj_weight, gradient_scale)
|
| 172 |
+
if up_proj_bias is not None:
|
| 173 |
+
up_proj_bias = scale_grad(up_proj_bias, gradient_scale)
|
| 174 |
+
if down_proj_bias is not None:
|
| 175 |
+
down_proj_bias = scale_grad(down_proj_bias, gradient_scale)
|
| 176 |
+
|
| 177 |
+
# Resolve dtensors
|
| 178 |
+
up_proj_weight = _layers.mlp.resolve_dtensor(up_proj_weight)
|
| 179 |
+
down_proj_weight = _layers.mlp.resolve_dtensor(down_proj_weight)
|
| 180 |
+
if up_proj_bias is not None:
|
| 181 |
+
up_proj_bias = _layers.mlp.resolve_dtensor(up_proj_bias)
|
| 182 |
+
if down_proj_bias is not None:
|
| 183 |
+
down_proj_bias = _layers.mlp.resolve_dtensor(down_proj_bias)
|
| 184 |
+
|
| 185 |
+
# Up projection
|
| 186 |
+
x = torch.nn.functional.linear(x, up_proj_weight, up_proj_bias)
|
| 187 |
+
|
| 188 |
+
# Activation
|
| 189 |
+
x = activation_fn(x)
|
| 190 |
+
|
| 191 |
+
# Down projection
|
| 192 |
+
x = torch.nn.functional.linear(x, down_proj_weight, down_proj_bias)
|
| 193 |
+
|
| 194 |
+
return x
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# Combine outputs from shared expert and regular experts
|
| 198 |
+
def combine_expert_shared_outputs(
|
| 199 |
+
shared_expert_out: torch.Tensor,
|
| 200 |
+
expert_out: torch.Tensor,
|
| 201 |
+
shared_expert_weighted_sum: bool = False,
|
| 202 |
+
moe_top_k: int = 1,
|
| 203 |
+
) -> torch.Tensor:
|
| 204 |
+
if shared_expert_weighted_sum:
|
| 205 |
+
# Weighted sum based on number of experts used
|
| 206 |
+
total_experts = moe_top_k + 1
|
| 207 |
+
shared_weight = 1.0 / total_experts
|
| 208 |
+
expert_weight = moe_top_k / total_experts
|
| 209 |
+
return shared_expert_out * shared_weight + expert_out * expert_weight
|
| 210 |
+
else:
|
| 211 |
+
# Simple addition
|
| 212 |
+
return shared_expert_out + expert_out
|
| 213 |
+
|
| 214 |
+
|
| 215 |
# Global variable to store load balancing loss
|
| 216 |
_LOAD_BALANCING_LOSS = []
|
| 217 |
|
|
|
|
| 740 |
return x, expert_weights, router_scores
|
| 741 |
|
| 742 |
|
| 743 |
+
def moe_forward_with_shared_expert(
|
| 744 |
+
x: torch.Tensor,
|
| 745 |
+
router_weight: torch.Tensor,
|
| 746 |
+
moe_top_k: int,
|
| 747 |
+
moe_num_experts: int,
|
| 748 |
+
moe_jitter_eps: float = None,
|
| 749 |
+
moe_normalize_expert_weights: int = None,
|
| 750 |
+
uniform_expert_assignment: bool = False,
|
| 751 |
+
training: bool = False,
|
| 752 |
+
w1: torch.Tensor = None,
|
| 753 |
+
w2: torch.Tensor = None,
|
| 754 |
+
w1_bias: torch.Tensor = None,
|
| 755 |
+
w2_bias: torch.Tensor = None,
|
| 756 |
+
gradient_scale: Optional[float] = None,
|
| 757 |
+
alpha: float = 1.702,
|
| 758 |
+
sort_end_bit: int = 0,
|
| 759 |
+
expert_parallel_group: torch.distributed.ProcessGroup = None,
|
| 760 |
+
moe_capacity_factor: float = 1.0,
|
| 761 |
+
moe_expert_model_parallelism: bool = False,
|
| 762 |
+
forward_fn: Any = None,
|
| 763 |
+
hidden_size: int = None,
|
| 764 |
+
mlp_impl: str = "grouped",
|
| 765 |
+
# Shared expert parameters
|
| 766 |
+
shared_up_proj_weight: Optional[torch.Tensor] = None,
|
| 767 |
+
shared_down_proj_weight: Optional[torch.Tensor] = None,
|
| 768 |
+
shared_up_proj_bias: Optional[torch.Tensor] = None,
|
| 769 |
+
shared_down_proj_bias: Optional[torch.Tensor] = None,
|
| 770 |
+
shared_expert_weighted_sum: bool = False,
|
| 771 |
+
shared_activation_fn: Optional[Any] = None,
|
| 772 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 773 |
+
|
| 774 |
+
# First, compute regular MoE forward pass
|
| 775 |
+
expert_out, expert_weights, router_scores = moe_forward(
|
| 776 |
+
x=x,
|
| 777 |
+
router_weight=router_weight,
|
| 778 |
+
moe_top_k=moe_top_k,
|
| 779 |
+
moe_num_experts=moe_num_experts,
|
| 780 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 781 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 782 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 783 |
+
training=training,
|
| 784 |
+
w1=w1,
|
| 785 |
+
w2=w2,
|
| 786 |
+
w1_bias=w1_bias,
|
| 787 |
+
w2_bias=w2_bias,
|
| 788 |
+
gradient_scale=gradient_scale,
|
| 789 |
+
alpha=alpha,
|
| 790 |
+
sort_end_bit=sort_end_bit,
|
| 791 |
+
expert_parallel_group=expert_parallel_group,
|
| 792 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 793 |
+
moe_expert_model_parallelism=moe_expert_model_parallelism,
|
| 794 |
+
forward_fn=forward_fn,
|
| 795 |
+
hidden_size=hidden_size,
|
| 796 |
+
mlp_impl=mlp_impl,
|
| 797 |
+
)
|
| 798 |
+
|
| 799 |
+
# If shared expert weights provided, compute shared expert output
|
| 800 |
+
if shared_up_proj_weight is not None and shared_down_proj_weight is not None:
|
| 801 |
+
shared_expert_out = shared_mlp_forward(
|
| 802 |
+
x=x,
|
| 803 |
+
up_proj_weight=shared_up_proj_weight,
|
| 804 |
+
down_proj_weight=shared_down_proj_weight,
|
| 805 |
+
up_proj_bias=shared_up_proj_bias,
|
| 806 |
+
down_proj_bias=shared_down_proj_bias,
|
| 807 |
+
activation_fn=shared_activation_fn,
|
| 808 |
+
gradient_scale=gradient_scale,
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
# Combine expert outputs
|
| 812 |
+
combined_out = combine_expert_shared_outputs(
|
| 813 |
+
shared_expert_out=shared_expert_out,
|
| 814 |
+
expert_out=expert_out,
|
| 815 |
+
shared_expert_weighted_sum=shared_expert_weighted_sum,
|
| 816 |
+
moe_top_k=moe_top_k,
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
return combined_out, expert_weights, router_scores
|
| 820 |
+
|
| 821 |
+
# Return regular MoE output if no shared expert
|
| 822 |
+
return expert_out, expert_weights, router_scores
|
| 823 |
+
|
| 824 |
+
|
| 825 |
+
def create_shared_expert_weights(
|
| 826 |
+
hidden_size: int,
|
| 827 |
+
shared_expert_hidden_size: int,
|
| 828 |
+
device: torch.device,
|
| 829 |
+
dtype: torch.dtype,
|
| 830 |
+
init_method: Any,
|
| 831 |
+
output_layer_init_method: Any = None,
|
| 832 |
+
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 833 |
+
|
| 834 |
+
if output_layer_init_method is None:
|
| 835 |
+
output_layer_init_method = init_method
|
| 836 |
+
|
| 837 |
+
# Create weight tensors
|
| 838 |
+
up_proj_weight = torch.empty(
|
| 839 |
+
shared_expert_hidden_size,
|
| 840 |
+
hidden_size,
|
| 841 |
+
device=device,
|
| 842 |
+
dtype=dtype,
|
| 843 |
+
)
|
| 844 |
+
down_proj_weight = torch.empty(
|
| 845 |
+
hidden_size,
|
| 846 |
+
shared_expert_hidden_size,
|
| 847 |
+
device=device,
|
| 848 |
+
dtype=dtype,
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
# Initialize weights
|
| 852 |
+
init_method(up_proj_weight)
|
| 853 |
+
output_layer_init_method(down_proj_weight)
|
| 854 |
+
|
| 855 |
+
# No bias by default
|
| 856 |
+
return up_proj_weight, down_proj_weight, None, None
|
| 857 |
+
|
| 858 |
+
# HACK: Extract device_mesh from pre-hook closure - required for transformers integration
|
| 859 |
+
# This exists because device_mesh is trapped in hook closures with no model attribute
|
| 860 |
+
# Fragile - breaks if hook structure changes or Python internals change
|
| 861 |
+
# TODO: Replace with a more robust solution when available
|
| 862 |
+
def get_device_mesh(model):
|
| 863 |
+
# Extract device_mesh from child's unused pre_hook closure
|
| 864 |
+
try:
|
| 865 |
+
# Find the pre-hook that contains 'device_mesh' in its closure
|
| 866 |
+
hook = next(h for h in model.experts._forward_pre_hooks.values() if 'device_mesh' in h.__code__.co_freevars)
|
| 867 |
+
# Extract the device_mesh from the closure
|
| 868 |
+
return hook.__closure__[hook.__code__.co_freevars.index('device_mesh')].cell_contents
|
| 869 |
+
except Exception:
|
| 870 |
+
return None
|
| 871 |
+
|
| 872 |
+
|
| 873 |
class MegaBlocksMoeMLP(torch.nn.Module):
|
| 874 |
|
| 875 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 881 |
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 882 |
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 883 |
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 884 |
+
|
| 885 |
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 886 |
+
if expert_parallel_group is None:
|
| 887 |
+
device_mesh = get_device_mesh(self)
|
| 888 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 889 |
+
|
| 890 |
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 891 |
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 892 |
|
|
|
|
| 916 |
hidden_size=self.experts.hidden_size,
|
| 917 |
mlp_impl=mlp_impl,
|
| 918 |
)
|
| 919 |
+
return output, expert_weights_out
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
class MegaBlocksMoeMLPWithSharedExpert(MegaBlocksMoeMLP):
|
| 923 |
+
|
| 924 |
+
def __init__(self):
|
| 925 |
+
super().__init__()
|
| 926 |
+
# Shared expert weights will be set by the user
|
| 927 |
+
self.shared_up_proj_weight = None
|
| 928 |
+
self.shared_down_proj_weight = None
|
| 929 |
+
self.shared_up_proj_bias = None
|
| 930 |
+
self.shared_down_proj_bias = None
|
| 931 |
+
self.shared_expert_weighted_sum = False
|
| 932 |
+
self.shared_activation_fn = None
|
| 933 |
+
|
| 934 |
+
def set_shared_expert_weights(
|
| 935 |
+
self,
|
| 936 |
+
up_proj_weight: torch.Tensor,
|
| 937 |
+
down_proj_weight: torch.Tensor,
|
| 938 |
+
up_proj_bias: Optional[torch.Tensor] = None,
|
| 939 |
+
down_proj_bias: Optional[torch.Tensor] = None,
|
| 940 |
+
weighted_sum: bool = False,
|
| 941 |
+
activation_fn: Optional[Any] = None,
|
| 942 |
+
):
|
| 943 |
+
self.shared_up_proj_weight = up_proj_weight
|
| 944 |
+
self.shared_down_proj_weight = down_proj_weight
|
| 945 |
+
self.shared_up_proj_bias = up_proj_bias
|
| 946 |
+
self.shared_down_proj_bias = down_proj_bias
|
| 947 |
+
self.shared_expert_weighted_sum = weighted_sum
|
| 948 |
+
self.shared_activation_fn = activation_fn
|
| 949 |
+
|
| 950 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 951 |
+
moe_top_k = getattr(self.router, "top_k", 4)
|
| 952 |
+
moe_num_experts = getattr(self.experts, "num_experts", 128)
|
| 953 |
+
gradient_scale = getattr(self.experts, "gradient_scale", None)
|
| 954 |
+
alpha = getattr(self.experts, "alpha", 1.0)
|
| 955 |
+
moe_capacity_factor = getattr(self.experts, "capacity_factor", 1.0)
|
| 956 |
+
moe_jitter_eps = getattr(self.experts, "jitter_eps", None)
|
| 957 |
+
moe_normalize_expert_weights = getattr(self.experts, "normalize_expert_weights", None)
|
| 958 |
+
uniform_expert_assignment = getattr(self, "uniform_expert_assignment", False)
|
| 959 |
+
|
| 960 |
+
expert_parallel_group = getattr(self, "expert_parallel_group", None)
|
| 961 |
+
if expert_parallel_group is None:
|
| 962 |
+
device_mesh = get_device_mesh(self)
|
| 963 |
+
expert_parallel_group = device_mesh.get_group() if device_mesh else None
|
| 964 |
+
|
| 965 |
+
has_parallel = expert_parallel_group is not None and dist.is_initialized() and dist.get_world_size(expert_parallel_group) > 1
|
| 966 |
+
forward_fn = parallel_forward_once if has_parallel else forward_once
|
| 967 |
+
|
| 968 |
+
sort_end_bit = max(int(torch.ceil(torch.log2(torch.tensor(moe_num_experts)))), 1)
|
| 969 |
+
mlp_impl = getattr(self, "mlp_impl", "grouped")
|
| 970 |
+
|
| 971 |
+
output, expert_weights_out, *_ = moe_forward_with_shared_expert(
|
| 972 |
+
x=x,
|
| 973 |
+
router_weight=self.router.weight,
|
| 974 |
+
moe_top_k=moe_top_k,
|
| 975 |
+
moe_num_experts=moe_num_experts,
|
| 976 |
+
moe_jitter_eps=moe_jitter_eps,
|
| 977 |
+
moe_normalize_expert_weights=moe_normalize_expert_weights,
|
| 978 |
+
uniform_expert_assignment=uniform_expert_assignment,
|
| 979 |
+
training=self.training,
|
| 980 |
+
w1=self.experts.gate_up_proj,
|
| 981 |
+
w2=self.experts.down_proj,
|
| 982 |
+
w1_bias=self.experts.gate_up_proj_bias,
|
| 983 |
+
w2_bias=self.experts.down_proj_bias,
|
| 984 |
+
gradient_scale=gradient_scale,
|
| 985 |
+
alpha=alpha,
|
| 986 |
+
sort_end_bit=sort_end_bit,
|
| 987 |
+
expert_parallel_group=expert_parallel_group,
|
| 988 |
+
moe_capacity_factor=moe_capacity_factor,
|
| 989 |
+
moe_expert_model_parallelism=has_parallel,
|
| 990 |
+
forward_fn=forward_fn,
|
| 991 |
+
hidden_size=self.experts.hidden_size,
|
| 992 |
+
mlp_impl=mlp_impl,
|
| 993 |
+
# Shared expert parameters
|
| 994 |
+
shared_up_proj_weight=self.shared_up_proj_weight,
|
| 995 |
+
shared_down_proj_weight=self.shared_down_proj_weight,
|
| 996 |
+
shared_up_proj_bias=self.shared_up_proj_bias,
|
| 997 |
+
shared_down_proj_bias=self.shared_down_proj_bias,
|
| 998 |
+
shared_expert_weighted_sum=self.shared_expert_weighted_sum,
|
| 999 |
+
shared_activation_fn=self.shared_activation_fn,
|
| 1000 |
+
)
|
| 1001 |
return output, expert_weights_out
|