Skip to content

Instantly share code, notes, and snippets.

@cloneofsimo
Created September 2, 2025 08:32
Show Gist options
  • Select an option

  • Save cloneofsimo/7b02df1426841f1ccbb62b37e8e9d2d5 to your computer and use it in GitHub Desktop.

Select an option

Save cloneofsimo/7b02df1426841f1ccbb62b37e8e9d2d5 to your computer and use it in GitHub Desktop.

Revisions

  1. cloneofsimo created this gist Sep 2, 2025.
    112 changes: 112 additions & 0 deletions impl.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,112 @@
    import os
    from typing import List

    import torch
    import torch.distributed.distributed_c10d as c10d
    import torch.distributed as dist

    from torch._C._distributed_c10d import _SymmetricMemory # type: ignore
    from torch.distributed._symmetric_memory import ( # type: ignore
    rendezvous as _sm_rendezvous,
    get_symm_mem_workspace as _sm_get_workspace,
    )
    from torch.distributed import _symmetric_memory as _sm_mod # type: ignore

    lib = torch.library.Library("symm_mem2", "DEF") # LOL
    lib.define("_low_contention_all2all(Tensor tensor, str group_name) -> Tensor")

    @torch.library.impl(lib, "_low_contention_all2all", "Meta")
    def _low_contention_all2all_meta(
    tensor: torch.Tensor,
    group_name: str,
    ) -> torch.Tensor:
    try:
    group_size = c10d._get_group_size_by_name(group_name)
    except:
    if dist.is_initialized():
    group_size = dist.get_world_size()
    else:
    group_size = 2

    if tensor.shape[0] % group_size != 0:
    raise ValueError(
    f"_low_contention_all2all: the leading dim ({tensor.shape[0]}) "
    f"must be divisible by group_size ({group_size})"
    )
    return tensor.new_empty(tensor.shape)

    def _all2all_pull_with_symm_mem_input(
    tensor: torch.Tensor,
    symm_mem: _SymmetricMemory,
    ) -> torch.Tensor:
    rank = symm_mem.rank
    world_size = symm_mem.world_size
    assert tensor.shape[0] % world_size == 0
    chunk_numel = tensor.numel() // world_size

    output = torch.empty_like(tensor)
    out_chunks = output.chunk(world_size, dim=0)

    _sm_mod._get_backend_stream().wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(_sm_mod._get_backend_stream()):
    symm_mem.barrier()
    for step in range(world_size):
    remote_rank = (rank - step) % world_size
    src_buf = symm_mem.get_buffer(
    remote_rank,
    out_chunks[0].shape,
    out_chunks[0].dtype,
    storage_offset=chunk_numel * rank,
    )
    out_chunks[remote_rank].copy_(src_buf)
    symm_mem.barrier()
    torch._C._distributed_c10d._register_work(output, _sm_mod.Work())
    return output

    def _all2all_push_with_workspace(
    tensor: torch.Tensor,
    workspace: _SymmetricMemory,
    ) -> torch.Tensor:
    rank = workspace.rank
    world_size = workspace.world_size
    assert tensor.shape[0] % world_size == 0
    chunk_size = tensor.shape[0] // world_size
    chunk_numel = tensor.numel() // world_size

    chunks = tensor.chunk(world_size, dim=0)

    _sm_mod._get_backend_stream().wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(_sm_mod._get_backend_stream()):
    workspace.barrier()
    for dest_rank in range(world_size):
    dst_buf = workspace.get_buffer(
    dest_rank,
    chunks[0].shape,
    chunks[0].dtype,
    storage_offset=chunk_numel * rank,
    )
    dst_buf.copy_(chunks[dest_rank])
    workspace.barrier()
    buf = workspace.get_buffer(rank, tensor.shape, tensor.dtype)
    output = torch.empty_like(tensor)
    output.copy_(buf)
    torch._C._distributed_c10d._register_work(output, _sm_mod.Work())
    return output

    @torch.library.impl(lib, "_low_contention_all2all", "CUDA")
    def _low_contention_all2all_cuda(
    tensor: torch.Tensor,
    group_name: str,
    ) -> torch.Tensor:
    symm_mem = _sm_rendezvous(tensor, group_name)
    if symm_mem is not None:
    return _all2all_pull_with_symm_mem_input(tensor, symm_mem)
    else:
    workspace = _sm_get_workspace(group_name, tensor.numel() * tensor.element_size())
    return _all2all_push_with_workspace(tensor, workspace)

    def _low_contention_all2all_workspace_cuda(
    tensor: torch.Tensor,
    workspace: _SymmetricMemory,
    ) -> torch.Tensor:
    return _all2all_push_with_workspace(tensor, workspace)