Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import math | |
| import torch.distributed as dist | |
| def _all_to_all( | |
| input_: torch.Tensor, | |
| world_size: int, | |
| group: dist.ProcessGroup, | |
| scatter_dim: int, | |
| gather_dim: int, | |
| ): | |
| if world_size == 1: | |
| return input_ | |
| input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] | |
| output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] | |
| dist.all_to_all(output_list, input_list, group=group) | |
| return torch.cat(output_list, dim=gather_dim).contiguous() | |
| class _AllToAll(torch.autograd.Function): | |
| def forward(ctx, input_, process_group, world_size, scatter_dim, gather_dim): | |
| ctx.process_group = process_group | |
| ctx.scatter_dim = scatter_dim | |
| ctx.gather_dim = gather_dim | |
| ctx.world_size = world_size | |
| output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim) | |
| return output | |
| def backward(ctx, grad_output): | |
| grad_output = _all_to_all( | |
| grad_output, | |
| ctx.world_size, | |
| ctx.process_group, | |
| ctx.gather_dim, | |
| ctx.scatter_dim, | |
| ) | |
| return ( | |
| grad_output, | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| def all_to_all( | |
| input_: torch.Tensor, | |
| process_group: dist.ProcessGroup, | |
| world_size: int = 1, | |
| scatter_dim: int = 2, | |
| gather_dim: int = 1, | |
| ): | |
| return _AllToAll.apply(input_, process_group, world_size, scatter_dim, gather_dim) |