Created
April 16, 2024 18:09
-
-
Save wconstab/365aa5615270645c11658e28f8051e54 to your computer and use it in GitHub Desktop.
test shows hang when using single microbatch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. | |
| import unittest | |
| import copy | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| from pippy.ManualPipelineStage import ManualPipelineStage | |
| from pippy.PipelineSchedule import ScheduleGPipe | |
| from torch.distributed._composable.fsdp.fully_shard import ( | |
| fully_shard, | |
| MixedPrecisionPolicy, | |
| ) | |
| from torch.distributed._tensor import DTensor | |
| from torch.distributed.device_mesh import DeviceMesh, init_device_mesh | |
| # torch.testing._internal.common_distributed requies "expecttest" | |
| from torch.testing._internal.common_distributed import MultiProcessTestCase | |
| from torch.testing._internal.common_utils import FILE_SCHEMA | |
| class DDPAROnce(torch.nn.Module): | |
| def __init__( | |
| self, module: torch.nn.Module, group: dist.ProcessGroup, dtype=None | |
| ): | |
| super().__init__() | |
| self.module = module | |
| self.group = group | |
| self.dtype = dtype | |
| # Broadcast the init state of the module from source rank (rank 0) | |
| global_rank = dist.get_global_rank(self.group, self.group.rank()) | |
| for param in self.module.parameters(): | |
| dist.broadcast( | |
| param.data, | |
| src=global_rank, | |
| group=self.group, | |
| ) | |
| # Create buffer as 1D tensor | |
| self.buffer = ( | |
| torch.zeros( | |
| sum([p.numel() for p in module.parameters()]), | |
| ).cuda(global_rank) | |
| if self.dtype is None | |
| else torch.zeros( | |
| sum([p.numel() for p in module.parameters()]), | |
| ) | |
| .to(self.dtype) | |
| .cuda(global_rank) | |
| ) | |
| def zero_grad(self): | |
| self.buffer.zero_() | |
| offset = 0 | |
| for p in self.module.parameters(): | |
| p.grad = self.buffer[offset : (offset + p.numel())].view(p.shape) | |
| offset = offset + p.numel() | |
| def all_reduce_async(self, norm_factor: int): | |
| self.buffer.div_(norm_factor * self.group.size()) | |
| work = dist.all_reduce(self.buffer, async_op=True, group=self.group) | |
| return work | |
| def all_reduce(self, norm_factor: int): | |
| self.buffer.div_(norm_factor * self.group.size()) | |
| work = dist.all_reduce(self.buffer, async_op=True, group=self.group) | |
| work.wait() | |
| def forward(self, *args, **kwargs): | |
| return self.module.forward(*args, **kwargs) | |
| # python -m unittest test_composability.TestPipelineComposability.<test> | |
| # or | |
| # pytest test_composability.py -vsk <test> | |
| class TestPipelineComposability(MultiProcessTestCase): | |
| @property | |
| def world_size(self) -> int: | |
| # covers first_stage, middle_stage, last_stage cases | |
| return 4 | |
| @property | |
| def init_method(self) -> str: | |
| return f"{FILE_SCHEMA}{self.file_name}" | |
| def setUp(self): | |
| super().setUp() | |
| # starts world_size processes | |
| self._spawn_processes() | |
| def _create_manual_pipeline_stage( | |
| self, | |
| model, | |
| stage_id, | |
| num_stages, | |
| device, | |
| group, | |
| inputs, | |
| num_microbatches, | |
| ): | |
| return ManualPipelineStage( | |
| module=model, | |
| stage_id=stage_id, | |
| num_stages=num_stages, | |
| device=device, | |
| group=group, | |
| num_microbatches=num_microbatches, | |
| input_args=inputs, | |
| ) | |
| def _init_device_mesh(self, mesh_shape, mesh_dim_names): | |
| device = f"cuda:{self.rank}" | |
| torch.cuda.set_device(device) | |
| dist.init_process_group( | |
| init_method=self.init_method, | |
| backend="nccl", | |
| rank=self.rank, | |
| world_size=self.world_size, | |
| ) | |
| # TODO(whc) there is a bug in our helpers (DeviceMesh: _get_device_handle) where passing `cuda:1` fails | |
| # File "/data/users/whc/pytorch/torch/distributed/_composable/fsdp/_fsdp_init.py", | |
| # line 69, in _get_device_from_mesh | |
| # return torch.device(mesh.device_type, device_handle.current_device()) | |
| # AttributeError: 'NoneType' object has no attribute 'current_device' | |
| device_mesh = init_device_mesh( | |
| "cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names | |
| ) | |
| return device_mesh, device | |
| def test_backward_single_mb(self): | |
| device_mesh, device = self._init_device_mesh( | |
| mesh_shape=(2, 2), mesh_dim_names=("dp", "pp") | |
| ) | |
| pp_group = device_mesh["pp"].get_group() | |
| dp_mesh = device_mesh["dp"] | |
| assert type(pp_group) == dist.ProcessGroup | |
| assert type(dp_mesh) == DeviceMesh | |
| # create "entire model" | |
| pp_group_size = pp_group.size() | |
| # 8 layers | |
| layers_per_model = 4 | |
| dim = 10 | |
| full_model = nn.ModuleList( | |
| [ | |
| nn.Linear(dim, dim) | |
| for _ in range(pp_group_size * layers_per_model) | |
| ] | |
| ) | |
| def weights_init(m): | |
| if isinstance(m, nn.Linear): | |
| torch.nn.init.xavier_uniform_(m.weight) | |
| full_model.apply(weights_init) | |
| ref_model = nn.Sequential(*copy.deepcopy(full_model)) | |
| ref_model.to(device) | |
| # divide the model (8 layers) by the number of ranks (2) | |
| partial_model = nn.Sequential( | |
| *full_model[ | |
| pp_group.rank() | |
| * layers_per_model : (pp_group.rank() + 1) | |
| * layers_per_model | |
| ] | |
| ) | |
| partial_model.to(device) | |
| # apply PP | |
| num_microbatches = 1 | |
| input = torch.rand((num_microbatches, dim), device=device) | |
| input_mb = [ | |
| [input[i].reshape((1, dim))] for i in range(num_microbatches) | |
| ] | |
| pipeline_stage = self._create_manual_pipeline_stage( | |
| partial_model, | |
| pp_group.rank(), | |
| pp_group.size(), | |
| device, | |
| pp_group, | |
| input_mb[0], | |
| num_microbatches, | |
| ) | |
| pipeline_schedule = ScheduleGPipe( | |
| pipeline_stage, | |
| n_microbatches=num_microbatches, | |
| # dummy loss needed just to force backwards to run in schedule step | |
| loss_fn=lambda y, t: y.sum()/1000000., | |
| ) | |
| pipeline_schedule.step_microbatches( | |
| arg_mbs=input_mb, target_mbs=input_mb | |
| ) | |
| print(f"{self.rank} finished pipeline step") | |
| (ref_model(input).sum()/1000000.).backward() | |
| ref_parameters = dict(ref_model.named_parameters()) | |
| for name, p in partial_model.named_parameters(): | |
| print(f"validating param {name}") | |
| ref_p = ref_parameters[name] | |
| # self.assertTrue(isinstance(p.grad, DTensor)) | |
| # self.assertEqual(ref_p.grad, p.grad.full_tensor()) | |
| self.assertEqual(ref_p.grad, p.grad) | |
| if __name__ == "__main__": | |
| unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment