Created
September 2, 2025 08:32
-
-
Save cloneofsimo/7b02df1426841f1ccbb62b37e8e9d2d5 to your computer and use it in GitHub Desktop.
symmem-all2all
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| from typing import List | |
| import torch | |
| import torch.distributed.distributed_c10d as c10d | |
| import torch.distributed as dist | |
| from torch._C._distributed_c10d import _SymmetricMemory # type: ignore | |
| from torch.distributed._symmetric_memory import ( # type: ignore | |
| rendezvous as _sm_rendezvous, | |
| get_symm_mem_workspace as _sm_get_workspace, | |
| ) | |
| from torch.distributed import _symmetric_memory as _sm_mod # type: ignore | |
| lib = torch.library.Library("symm_mem2", "DEF") # LOL | |
| lib.define("_low_contention_all2all(Tensor tensor, str group_name) -> Tensor") | |
| @torch.library.impl(lib, "_low_contention_all2all", "Meta") | |
| def _low_contention_all2all_meta( | |
| tensor: torch.Tensor, | |
| group_name: str, | |
| ) -> torch.Tensor: | |
| try: | |
| group_size = c10d._get_group_size_by_name(group_name) | |
| except: | |
| if dist.is_initialized(): | |
| group_size = dist.get_world_size() | |
| else: | |
| group_size = 2 | |
| if tensor.shape[0] % group_size != 0: | |
| raise ValueError( | |
| f"_low_contention_all2all: the leading dim ({tensor.shape[0]}) " | |
| f"must be divisible by group_size ({group_size})" | |
| ) | |
| return tensor.new_empty(tensor.shape) | |
| def _all2all_pull_with_symm_mem_input( | |
| tensor: torch.Tensor, | |
| symm_mem: _SymmetricMemory, | |
| ) -> torch.Tensor: | |
| rank = symm_mem.rank | |
| world_size = symm_mem.world_size | |
| assert tensor.shape[0] % world_size == 0 | |
| chunk_numel = tensor.numel() // world_size | |
| output = torch.empty_like(tensor) | |
| out_chunks = output.chunk(world_size, dim=0) | |
| _sm_mod._get_backend_stream().wait_stream(torch.cuda.current_stream()) | |
| with torch.cuda.stream(_sm_mod._get_backend_stream()): | |
| symm_mem.barrier() | |
| for step in range(world_size): | |
| remote_rank = (rank - step) % world_size | |
| src_buf = symm_mem.get_buffer( | |
| remote_rank, | |
| out_chunks[0].shape, | |
| out_chunks[0].dtype, | |
| storage_offset=chunk_numel * rank, | |
| ) | |
| out_chunks[remote_rank].copy_(src_buf) | |
| symm_mem.barrier() | |
| torch._C._distributed_c10d._register_work(output, _sm_mod.Work()) | |
| return output | |
| def _all2all_push_with_workspace( | |
| tensor: torch.Tensor, | |
| workspace: _SymmetricMemory, | |
| ) -> torch.Tensor: | |
| rank = workspace.rank | |
| world_size = workspace.world_size | |
| assert tensor.shape[0] % world_size == 0 | |
| chunk_size = tensor.shape[0] // world_size | |
| chunk_numel = tensor.numel() // world_size | |
| chunks = tensor.chunk(world_size, dim=0) | |
| _sm_mod._get_backend_stream().wait_stream(torch.cuda.current_stream()) | |
| with torch.cuda.stream(_sm_mod._get_backend_stream()): | |
| workspace.barrier() | |
| for dest_rank in range(world_size): | |
| dst_buf = workspace.get_buffer( | |
| dest_rank, | |
| chunks[0].shape, | |
| chunks[0].dtype, | |
| storage_offset=chunk_numel * rank, | |
| ) | |
| dst_buf.copy_(chunks[dest_rank]) | |
| workspace.barrier() | |
| buf = workspace.get_buffer(rank, tensor.shape, tensor.dtype) | |
| output = torch.empty_like(tensor) | |
| output.copy_(buf) | |
| torch._C._distributed_c10d._register_work(output, _sm_mod.Work()) | |
| return output | |
| @torch.library.impl(lib, "_low_contention_all2all", "CUDA") | |
| def _low_contention_all2all_cuda( | |
| tensor: torch.Tensor, | |
| group_name: str, | |
| ) -> torch.Tensor: | |
| symm_mem = _sm_rendezvous(tensor, group_name) | |
| if symm_mem is not None: | |
| return _all2all_pull_with_symm_mem_input(tensor, symm_mem) | |
| else: | |
| workspace = _sm_get_workspace(group_name, tensor.numel() * tensor.element_size()) | |
| return _all2all_push_with_workspace(tensor, workspace) | |
| def _low_contention_all2all_workspace_cuda( | |
| tensor: torch.Tensor, | |
| workspace: _SymmetricMemory, | |
| ) -> torch.Tensor: | |
| return _all2all_push_with_workspace(tensor, workspace) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment