Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created September 2, 2025 08:32
Show Gist options
  • Save cloneofsimo/7b02df1426841f1ccbb62b37e8e9d2d5 to your computer and use it in GitHub Desktop.
Save cloneofsimo/7b02df1426841f1ccbb62b37e8e9d2d5 to your computer and use it in GitHub Desktop.
symmem-all2all
import os
from typing import List
import torch
import torch.distributed.distributed_c10d as c10d
import torch.distributed as dist
from torch._C._distributed_c10d import _SymmetricMemory # type: ignore
from torch.distributed._symmetric_memory import ( # type: ignore
rendezvous as _sm_rendezvous,
get_symm_mem_workspace as _sm_get_workspace,
)
from torch.distributed import _symmetric_memory as _sm_mod # type: ignore
lib = torch.library.Library("symm_mem2", "DEF") # LOL
lib.define("_low_contention_all2all(Tensor tensor, str group_name) -> Tensor")
@torch.library.impl(lib, "_low_contention_all2all", "Meta")
def _low_contention_all2all_meta(
tensor: torch.Tensor,
group_name: str,
) -> torch.Tensor:
try:
group_size = c10d._get_group_size_by_name(group_name)
except:
if dist.is_initialized():
group_size = dist.get_world_size()
else:
group_size = 2
if tensor.shape[0] % group_size != 0:
raise ValueError(
f"_low_contention_all2all: the leading dim ({tensor.shape[0]}) "
f"must be divisible by group_size ({group_size})"
)
return tensor.new_empty(tensor.shape)
def _all2all_pull_with_symm_mem_input(
tensor: torch.Tensor,
symm_mem: _SymmetricMemory,
) -> torch.Tensor:
rank = symm_mem.rank
world_size = symm_mem.world_size
assert tensor.shape[0] % world_size == 0
chunk_numel = tensor.numel() // world_size
output = torch.empty_like(tensor)
out_chunks = output.chunk(world_size, dim=0)
_sm_mod._get_backend_stream().wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(_sm_mod._get_backend_stream()):
symm_mem.barrier()
for step in range(world_size):
remote_rank = (rank - step) % world_size
src_buf = symm_mem.get_buffer(
remote_rank,
out_chunks[0].shape,
out_chunks[0].dtype,
storage_offset=chunk_numel * rank,
)
out_chunks[remote_rank].copy_(src_buf)
symm_mem.barrier()
torch._C._distributed_c10d._register_work(output, _sm_mod.Work())
return output
def _all2all_push_with_workspace(
tensor: torch.Tensor,
workspace: _SymmetricMemory,
) -> torch.Tensor:
rank = workspace.rank
world_size = workspace.world_size
assert tensor.shape[0] % world_size == 0
chunk_size = tensor.shape[0] // world_size
chunk_numel = tensor.numel() // world_size
chunks = tensor.chunk(world_size, dim=0)
_sm_mod._get_backend_stream().wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(_sm_mod._get_backend_stream()):
workspace.barrier()
for dest_rank in range(world_size):
dst_buf = workspace.get_buffer(
dest_rank,
chunks[0].shape,
chunks[0].dtype,
storage_offset=chunk_numel * rank,
)
dst_buf.copy_(chunks[dest_rank])
workspace.barrier()
buf = workspace.get_buffer(rank, tensor.shape, tensor.dtype)
output = torch.empty_like(tensor)
output.copy_(buf)
torch._C._distributed_c10d._register_work(output, _sm_mod.Work())
return output
@torch.library.impl(lib, "_low_contention_all2all", "CUDA")
def _low_contention_all2all_cuda(
tensor: torch.Tensor,
group_name: str,
) -> torch.Tensor:
symm_mem = _sm_rendezvous(tensor, group_name)
if symm_mem is not None:
return _all2all_pull_with_symm_mem_input(tensor, symm_mem)
else:
workspace = _sm_get_workspace(group_name, tensor.numel() * tensor.element_size())
return _all2all_push_with_workspace(tensor, workspace)
def _low_contention_all2all_workspace_cuda(
tensor: torch.Tensor,
workspace: _SymmetricMemory,
) -> torch.Tensor:
return _all2all_push_with_workspace(tensor, workspace)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment