Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.distributed as dist | |
| try: | |
| import xfuser | |
| from xfuser.core.distributed import (get_sequence_parallel_rank, | |
| get_sequence_parallel_world_size, | |
| get_sp_group, get_world_group, | |
| init_distributed_environment, | |
| initialize_model_parallel) | |
| from xfuser.core.long_ctx_attention import xFuserLongContextAttention | |
| except Exception as ex: | |
| get_sequence_parallel_world_size = None | |
| get_sequence_parallel_rank = None | |
| xFuserLongContextAttention = None | |
| get_sp_group = None | |
| get_world_group = None | |
| init_distributed_environment = None | |
| initialize_model_parallel = None | |
| def set_multi_gpus_devices(ulysses_degree, ring_degree): | |
| if ulysses_degree > 1 or ring_degree > 1: | |
| if get_sp_group is None: | |
| raise RuntimeError("xfuser is not installed.") | |
| dist.init_process_group("nccl") | |
| print('parallel inference enabled: ulysses_degree=%d ring_degree=%d rank=%d world_size=%d' % ( | |
| ulysses_degree, ring_degree, dist.get_rank(), | |
| dist.get_world_size())) | |
| assert dist.get_world_size() == ring_degree * ulysses_degree, \ | |
| "number of GPUs(%d) should be equal to ring_degree * ulysses_degree." % dist.get_world_size() | |
| init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size()) | |
| initialize_model_parallel(sequence_parallel_degree=dist.get_world_size(), | |
| ring_degree=ring_degree, | |
| ulysses_degree=ulysses_degree) | |
| # device = torch.device("cuda:%d" % dist.get_rank()) | |
| device = torch.device(f"cuda:{get_world_group().local_rank}") | |
| print('rank=%d device=%s' % (get_world_group().rank, str(device))) | |
| else: | |
| device = "cuda" | |
| return device |