import os from typing import List import torch import torch.distributed.distributed_c10d as c10d import torch.distributed as dist from torch._C._distributed_c10d import _SymmetricMemory # type: ignore from torch.distributed._symmetric_memory import ( # type: ignore rendezvous as _sm_rendezvous, get_symm_mem_workspace as _sm_get_workspace, ) from torch.distributed import _symmetric_memory as _sm_mod # type: ignore lib = torch.library.Library("symm_mem2", "DEF") # LOL lib.define("_low_contention_all2all(Tensor tensor, str group_name) -> Tensor") @torch.library.impl(lib, "_low_contention_all2all", "Meta") def _low_contention_all2all_meta( tensor: torch.Tensor, group_name: str, ) -> torch.Tensor: try: group_size = c10d._get_group_size_by_name(group_name) except: if dist.is_initialized(): group_size = dist.get_world_size() else: group_size = 2 if tensor.shape[0] % group_size != 0: raise ValueError( f"_low_contention_all2all: the leading dim ({tensor.shape[0]}) " f"must be divisible by group_size ({group_size})" ) return tensor.new_empty(tensor.shape) def _all2all_pull_with_symm_mem_input( tensor: torch.Tensor, symm_mem: _SymmetricMemory, ) -> torch.Tensor: rank = symm_mem.rank world_size = symm_mem.world_size assert tensor.shape[0] % world_size == 0 chunk_numel = tensor.numel() // world_size output = torch.empty_like(tensor) out_chunks = output.chunk(world_size, dim=0) _sm_mod._get_backend_stream().wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(_sm_mod._get_backend_stream()): symm_mem.barrier() for step in range(world_size): remote_rank = (rank - step) % world_size src_buf = symm_mem.get_buffer( remote_rank, out_chunks[0].shape, out_chunks[0].dtype, storage_offset=chunk_numel * rank, ) out_chunks[remote_rank].copy_(src_buf) symm_mem.barrier() torch._C._distributed_c10d._register_work(output, _sm_mod.Work()) return output def _all2all_push_with_workspace( tensor: torch.Tensor, workspace: _SymmetricMemory, ) -> torch.Tensor: rank = workspace.rank world_size = workspace.world_size assert tensor.shape[0] % world_size == 0 chunk_size = tensor.shape[0] // world_size chunk_numel = tensor.numel() // world_size chunks = tensor.chunk(world_size, dim=0) _sm_mod._get_backend_stream().wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(_sm_mod._get_backend_stream()): workspace.barrier() for dest_rank in range(world_size): dst_buf = workspace.get_buffer( dest_rank, chunks[0].shape, chunks[0].dtype, storage_offset=chunk_numel * rank, ) dst_buf.copy_(chunks[dest_rank]) workspace.barrier() buf = workspace.get_buffer(rank, tensor.shape, tensor.dtype) output = torch.empty_like(tensor) output.copy_(buf) torch._C._distributed_c10d._register_work(output, _sm_mod.Work()) return output @torch.library.impl(lib, "_low_contention_all2all", "CUDA") def _low_contention_all2all_cuda( tensor: torch.Tensor, group_name: str, ) -> torch.Tensor: symm_mem = _sm_rendezvous(tensor, group_name) if symm_mem is not None: return _all2all_pull_with_symm_mem_input(tensor, symm_mem) else: workspace = _sm_get_workspace(group_name, tensor.numel() * tensor.element_size()) return _all2all_push_with_workspace(tensor, workspace) def _low_contention_all2all_workspace_cuda( tensor: torch.Tensor, workspace: _SymmetricMemory, ) -> torch.Tensor: return _all2all_push_with_workspace(tensor, workspace)