Skip to content

Instantly share code, notes, and snippets.

@wconstab
Created April 16, 2024 18:09
Show Gist options
  • Save wconstab/365aa5615270645c11658e28f8051e54 to your computer and use it in GitHub Desktop.
Save wconstab/365aa5615270645c11658e28f8051e54 to your computer and use it in GitHub Desktop.
test shows hang when using single microbatch
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
import unittest
import copy
import torch
import torch.distributed as dist
import torch.nn as nn
from pippy.ManualPipelineStage import ManualPipelineStage
from pippy.PipelineSchedule import ScheduleGPipe
from torch.distributed._composable.fsdp.fully_shard import (
fully_shard,
MixedPrecisionPolicy,
)
from torch.distributed._tensor import DTensor
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
# torch.testing._internal.common_distributed requies "expecttest"
from torch.testing._internal.common_distributed import MultiProcessTestCase
from torch.testing._internal.common_utils import FILE_SCHEMA
class DDPAROnce(torch.nn.Module):
def __init__(
self, module: torch.nn.Module, group: dist.ProcessGroup, dtype=None
):
super().__init__()
self.module = module
self.group = group
self.dtype = dtype
# Broadcast the init state of the module from source rank (rank 0)
global_rank = dist.get_global_rank(self.group, self.group.rank())
for param in self.module.parameters():
dist.broadcast(
param.data,
src=global_rank,
group=self.group,
)
# Create buffer as 1D tensor
self.buffer = (
torch.zeros(
sum([p.numel() for p in module.parameters()]),
).cuda(global_rank)
if self.dtype is None
else torch.zeros(
sum([p.numel() for p in module.parameters()]),
)
.to(self.dtype)
.cuda(global_rank)
)
def zero_grad(self):
self.buffer.zero_()
offset = 0
for p in self.module.parameters():
p.grad = self.buffer[offset : (offset + p.numel())].view(p.shape)
offset = offset + p.numel()
def all_reduce_async(self, norm_factor: int):
self.buffer.div_(norm_factor * self.group.size())
work = dist.all_reduce(self.buffer, async_op=True, group=self.group)
return work
def all_reduce(self, norm_factor: int):
self.buffer.div_(norm_factor * self.group.size())
work = dist.all_reduce(self.buffer, async_op=True, group=self.group)
work.wait()
def forward(self, *args, **kwargs):
return self.module.forward(*args, **kwargs)
# python -m unittest test_composability.TestPipelineComposability.<test>
# or
# pytest test_composability.py -vsk <test>
class TestPipelineComposability(MultiProcessTestCase):
@property
def world_size(self) -> int:
# covers first_stage, middle_stage, last_stage cases
return 4
@property
def init_method(self) -> str:
return f"{FILE_SCHEMA}{self.file_name}"
def setUp(self):
super().setUp()
# starts world_size processes
self._spawn_processes()
def _create_manual_pipeline_stage(
self,
model,
stage_id,
num_stages,
device,
group,
inputs,
num_microbatches,
):
return ManualPipelineStage(
module=model,
stage_id=stage_id,
num_stages=num_stages,
device=device,
group=group,
num_microbatches=num_microbatches,
input_args=inputs,
)
def _init_device_mesh(self, mesh_shape, mesh_dim_names):
device = f"cuda:{self.rank}"
torch.cuda.set_device(device)
dist.init_process_group(
init_method=self.init_method,
backend="nccl",
rank=self.rank,
world_size=self.world_size,
)
# TODO(whc) there is a bug in our helpers (DeviceMesh: _get_device_handle) where passing `cuda:1` fails
# File "/data/users/whc/pytorch/torch/distributed/_composable/fsdp/_fsdp_init.py",
# line 69, in _get_device_from_mesh
# return torch.device(mesh.device_type, device_handle.current_device())
# AttributeError: 'NoneType' object has no attribute 'current_device'
device_mesh = init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
return device_mesh, device
def test_backward_single_mb(self):
device_mesh, device = self._init_device_mesh(
mesh_shape=(2, 2), mesh_dim_names=("dp", "pp")
)
pp_group = device_mesh["pp"].get_group()
dp_mesh = device_mesh["dp"]
assert type(pp_group) == dist.ProcessGroup
assert type(dp_mesh) == DeviceMesh
# create "entire model"
pp_group_size = pp_group.size()
# 8 layers
layers_per_model = 4
dim = 10
full_model = nn.ModuleList(
[
nn.Linear(dim, dim)
for _ in range(pp_group_size * layers_per_model)
]
)
def weights_init(m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight)
full_model.apply(weights_init)
ref_model = nn.Sequential(*copy.deepcopy(full_model))
ref_model.to(device)
# divide the model (8 layers) by the number of ranks (2)
partial_model = nn.Sequential(
*full_model[
pp_group.rank()
* layers_per_model : (pp_group.rank() + 1)
* layers_per_model
]
)
partial_model.to(device)
# apply PP
num_microbatches = 1
input = torch.rand((num_microbatches, dim), device=device)
input_mb = [
[input[i].reshape((1, dim))] for i in range(num_microbatches)
]
pipeline_stage = self._create_manual_pipeline_stage(
partial_model,
pp_group.rank(),
pp_group.size(),
device,
pp_group,
input_mb[0],
num_microbatches,
)
pipeline_schedule = ScheduleGPipe(
pipeline_stage,
n_microbatches=num_microbatches,
# dummy loss needed just to force backwards to run in schedule step
loss_fn=lambda y, t: y.sum()/1000000.,
)
pipeline_schedule.step_microbatches(
arg_mbs=input_mb, target_mbs=input_mb
)
print(f"{self.rank} finished pipeline step")
(ref_model(input).sum()/1000000.).backward()
ref_parameters = dict(ref_model.named_parameters())
for name, p in partial_model.named_parameters():
print(f"validating param {name}")
ref_p = ref_parameters[name]
# self.assertTrue(isinstance(p.grad, DTensor))
# self.assertEqual(ref_p.grad, p.grad.full_tensor())
self.assertEqual(ref_p.grad, p.grad)
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment