Created
September 2, 2025 08:32
-
-
Save cloneofsimo/7b02df1426841f1ccbb62b37e8e9d2d5 to your computer and use it in GitHub Desktop.
Revisions
-
cloneofsimo created this gist
Sep 2, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,112 @@ import os from typing import List import torch import torch.distributed.distributed_c10d as c10d import torch.distributed as dist from torch._C._distributed_c10d import _SymmetricMemory # type: ignore from torch.distributed._symmetric_memory import ( # type: ignore rendezvous as _sm_rendezvous, get_symm_mem_workspace as _sm_get_workspace, ) from torch.distributed import _symmetric_memory as _sm_mod # type: ignore lib = torch.library.Library("symm_mem2", "DEF") # LOL lib.define("_low_contention_all2all(Tensor tensor, str group_name) -> Tensor") @torch.library.impl(lib, "_low_contention_all2all", "Meta") def _low_contention_all2all_meta( tensor: torch.Tensor, group_name: str, ) -> torch.Tensor: try: group_size = c10d._get_group_size_by_name(group_name) except: if dist.is_initialized(): group_size = dist.get_world_size() else: group_size = 2 if tensor.shape[0] % group_size != 0: raise ValueError( f"_low_contention_all2all: the leading dim ({tensor.shape[0]}) " f"must be divisible by group_size ({group_size})" ) return tensor.new_empty(tensor.shape) def _all2all_pull_with_symm_mem_input( tensor: torch.Tensor, symm_mem: _SymmetricMemory, ) -> torch.Tensor: rank = symm_mem.rank world_size = symm_mem.world_size assert tensor.shape[0] % world_size == 0 chunk_numel = tensor.numel() // world_size output = torch.empty_like(tensor) out_chunks = output.chunk(world_size, dim=0) _sm_mod._get_backend_stream().wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(_sm_mod._get_backend_stream()): symm_mem.barrier() for step in range(world_size): remote_rank = (rank - step) % world_size src_buf = symm_mem.get_buffer( remote_rank, out_chunks[0].shape, out_chunks[0].dtype, storage_offset=chunk_numel * rank, ) out_chunks[remote_rank].copy_(src_buf) symm_mem.barrier() torch._C._distributed_c10d._register_work(output, _sm_mod.Work()) return output def _all2all_push_with_workspace( tensor: torch.Tensor, workspace: _SymmetricMemory, ) -> torch.Tensor: rank = workspace.rank world_size = workspace.world_size assert tensor.shape[0] % world_size == 0 chunk_size = tensor.shape[0] // world_size chunk_numel = tensor.numel() // world_size chunks = tensor.chunk(world_size, dim=0) _sm_mod._get_backend_stream().wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(_sm_mod._get_backend_stream()): workspace.barrier() for dest_rank in range(world_size): dst_buf = workspace.get_buffer( dest_rank, chunks[0].shape, chunks[0].dtype, storage_offset=chunk_numel * rank, ) dst_buf.copy_(chunks[dest_rank]) workspace.barrier() buf = workspace.get_buffer(rank, tensor.shape, tensor.dtype) output = torch.empty_like(tensor) output.copy_(buf) torch._C._distributed_c10d._register_work(output, _sm_mod.Work()) return output @torch.library.impl(lib, "_low_contention_all2all", "CUDA") def _low_contention_all2all_cuda( tensor: torch.Tensor, group_name: str, ) -> torch.Tensor: symm_mem = _sm_rendezvous(tensor, group_name) if symm_mem is not None: return _all2all_pull_with_symm_mem_input(tensor, symm_mem) else: workspace = _sm_get_workspace(group_name, tensor.numel() * tensor.element_size()) return _all2all_push_with_workspace(tensor, workspace) def _low_contention_all2all_workspace_cuda( tensor: torch.Tensor, workspace: _SymmetricMemory, ) -> torch.Tensor: return _all2all_push_with_workspace(tensor, workspace)