Last active
March 27, 2024 02:21
-
-
Save yf225/b1063249168de00d0a1f9e4a563607f1 to your computer and use it in GitHub Desktop.
Revisions
-
yf225 revised this gist
Mar 27, 2024 . 1 changed file with 275 additions and 334 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,389 +1,330 @@ TRACED GRAPH ===== AFTER POST GRAD ===== /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[2048]", primals_3: "f32[64]", primals_4: "f32[2048]", primals_5: "f32[16]", primals_6: "f32[128, 32]", primals_7: "f32[128]", primals_8: "f32[32, 128]", primals_9: "f32[32]", primals_10, primals_11: "f32[2048]", primals_12: "f32[64]", primals_13: "f32[2048]", primals_14: "f32[16]", primals_15: "f32[128, 32]", primals_16: "f32[128]", primals_17: "f32[32, 128]", primals_18: "f32[32]", primals_19: "f32[2048]", primals_20: "f32[64]", primals_21: "f32[2048]", primals_22: "f32[16]", primals_23: "f32[128, 32]", primals_24: "f32[128]", primals_25: "f32[32, 128]", primals_26: "f32[32]"): # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( empty: "f32[8352]" = torch.ops.aten.empty.memory_format([8352], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( slice_1: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [2048, 64, 2048, 16]); slice_1 = None getitem: "f32[2048]" = split_with_sizes[0] getitem_1: "f32[64]" = split_with_sizes[1] getitem_2: "f32[2048]" = split_with_sizes[2] getitem_3: "f32[16]" = split_with_sizes[3]; split_with_sizes = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) _foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_2, primals_3, primals_4, primals_5]); primals_2 = primals_3 = primals_4 = primals_5 = None getitem_4: "f32[2048]" = _foreach_copy[0] getitem_5: "f32[64]" = _foreach_copy[1] getitem_6: "f32[2048]" = _foreach_copy[2] getitem_7: "f32[16]" = _foreach_copy[3]; _foreach_copy = None # No stacktrace found for following nodes slice_tensor: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352) slice_scatter_default: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor, getitem_4, 0, 0, 2048); slice_tensor = getitem_4 = None slice_scatter_default_1: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default, 0, 4176, 8352); slice_scatter_default = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_3: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_1: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 4176, 8352) slice_scatter_default_2: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_1, getitem_5, 0, 2048, 2112); slice_tensor_1 = getitem_5 = None slice_scatter_default_3: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_1, slice_scatter_default_2, 0, 4176, 8352); slice_scatter_default_1 = slice_scatter_default_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_4: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_2: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 4176, 8352) slice_scatter_default_4: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_2, getitem_6, 0, 2112, 4160); slice_tensor_2 = getitem_6 = None slice_scatter_default_5: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_3, slice_scatter_default_4, 0, 4176, 8352); slice_scatter_default_3 = slice_scatter_default_4 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_5: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_3: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4176, 8352) slice_scatter_default_6: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_3, getitem_7, 0, 4160, 4176); slice_tensor_3 = getitem_7 = None slice_scatter_default_7: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_5, slice_scatter_default_6, 0, 4176, 8352); slice_scatter_default_5 = slice_scatter_default_6 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( slice_10: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_7, 0, 4176, 8352); slice_scatter_default_7 = None all_gather_into_tensor: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_10, 2, '0'); slice_10 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] wait_tensor: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), view_1: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]); wait_tensor = None split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_1, [2048, 64, 2048, 16], 1); view_1 = None getitem_28: "f32[2, 2048]" = split_with_sizes_6[0] clone: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_28, memory_format = torch.contiguous_format); getitem_28 = None view_2: "f32[4096]" = torch.ops.aten.reshape.default(clone, [4096]); clone = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_2, [128, 32], [32, 1], 0); view_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_33: "f32[2, 64]" = split_with_sizes_6[1] clone_1: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_33, memory_format = torch.contiguous_format); getitem_33 = None view_4: "f32[128]" = torch.ops.aten.reshape.default(clone_1, [128]); clone_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_1: "f32[128]" = torch.ops.aten.as_strided.default(view_4, [128], [1], 0); view_4 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_38: "f32[2, 2048]" = split_with_sizes_6[2] clone_2: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_38, memory_format = torch.contiguous_format); getitem_38 = None view_6: "f32[4096]" = torch.ops.aten.reshape.default(clone_2, [4096]); clone_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_2: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_6, [32, 128], [128, 1], 0); view_6 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_43: "f32[2, 16]" = split_with_sizes_6[3]; split_with_sizes_6 = None clone_3: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_43, memory_format = torch.contiguous_format); getitem_43 = None view_8: "f32[32]" = torch.ops.aten.reshape.default(clone_3, [32]); clone_3 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_3: "f32[32]" = torch.ops.aten.as_strided.default(view_8, [32], [1], 0); view_8 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) _foreach_copy_1 = torch.ops.aten._foreach_copy.default([primals_6, primals_7, primals_8, primals_9], [as_strided, as_strided_1, as_strided_2, as_strided_3]); primals_6 = primals_7 = primals_9 = as_strided = as_strided_1 = as_strided_2 = as_strided_3 = None getitem_44: "f32[128, 32]" = _foreach_copy_1[0] getitem_45: "f32[128]" = _foreach_copy_1[1] getitem_46: "f32[32, 128]" = _foreach_copy_1[2] getitem_47: "f32[32]" = _foreach_copy_1[3]; _foreach_copy_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_1: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_44, [1, 0]); getitem_44 = None # No stacktrace found for following nodes mm_default_5: "f32[8, 128]" = torch.ops.aten.mm.default(primals_1, permute_1); permute_1 = None add_tensor_5: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_5, getitem_45); mm_default_5 = getitem_45 = None # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z) relu: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_5); add_tensor_5 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_3: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None # No stacktrace found for following nodes mm_default_4: "f32[8, 32]" = torch.ops.aten.mm.default(relu, permute_3); permute_3 = None add_tensor_4: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default_4, getitem_47); mm_default_4 = getitem_47 = None # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z) relu_1: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor_4); add_tensor_4 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) _foreach_copy_2 = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_11, primals_12, primals_13, primals_14]); primals_11 = primals_12 = primals_13 = primals_14 = None getitem_52: "f32[2048]" = _foreach_copy_2[0] getitem_53: "f32[64]" = _foreach_copy_2[1] getitem_54: "f32[2048]" = _foreach_copy_2[2] getitem_55: "f32[16]" = _foreach_copy_2[3]; _foreach_copy_2 = None # No stacktrace found for following nodes slice_tensor_4: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352) slice_scatter_default_8: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_4, getitem_52, 0, 0, 2048); slice_tensor_4 = getitem_52 = None slice_scatter_default_9: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default_8, 0, 4176, 8352); slice_scatter_default_8 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_13: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_5: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 4176, 8352) slice_scatter_default_10: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_5, getitem_53, 0, 2048, 2112); slice_tensor_5 = getitem_53 = None slice_scatter_default_11: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_9, slice_scatter_default_10, 0, 4176, 8352); slice_scatter_default_9 = slice_scatter_default_10 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_14: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_6: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 4176, 8352) slice_scatter_default_12: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_6, getitem_54, 0, 2112, 4160); slice_tensor_6 = getitem_54 = None slice_scatter_default_13: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_11, slice_scatter_default_12, 0, 4176, 8352); slice_scatter_default_11 = slice_scatter_default_12 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_15: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_13, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_7: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_13, 0, 4176, 8352) slice_scatter_default_14: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_7, getitem_55, 0, 4160, 4176); slice_tensor_7 = getitem_55 = None slice_scatter_default_15: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_13, slice_scatter_default_14, 0, 4176, 8352); slice_scatter_default_13 = slice_scatter_default_14 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( slice_20: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_15, 0, 4176, 8352); slice_scatter_default_15 = None all_gather_into_tensor_1: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_20, 2, '0'); slice_20 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] wait_tensor_1: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), view_10: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]); wait_tensor_1 = None split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_10, [2048, 64, 2048, 16], 1); view_10 = None getitem_76: "f32[2, 2048]" = split_with_sizes_16[0] clone_4: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_76, memory_format = torch.contiguous_format); getitem_76 = None view_11: "f32[4096]" = torch.ops.aten.reshape.default(clone_4, [4096]); clone_4 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_4: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_11, [128, 32], [32, 1], 0); view_11 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_81: "f32[2, 64]" = split_with_sizes_16[1] clone_5: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_81, memory_format = torch.contiguous_format); getitem_81 = None view_13: "f32[128]" = torch.ops.aten.reshape.default(clone_5, [128]); clone_5 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_5: "f32[128]" = torch.ops.aten.as_strided.default(view_13, [128], [1], 0); view_13 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_86: "f32[2, 2048]" = split_with_sizes_16[2] clone_6: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_86, memory_format = torch.contiguous_format); getitem_86 = None view_15: "f32[4096]" = torch.ops.aten.reshape.default(clone_6, [4096]); clone_6 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_6: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_15, [32, 128], [128, 1], 0); view_15 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_91: "f32[2, 16]" = split_with_sizes_16[3]; split_with_sizes_16 = None clone_7: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_91, memory_format = torch.contiguous_format); getitem_91 = None view_17: "f32[32]" = torch.ops.aten.reshape.default(clone_7, [32]); clone_7 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_7: "f32[32]" = torch.ops.aten.as_strided.default(view_17, [32], [1], 0); view_17 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) _foreach_copy_3 = torch.ops.aten._foreach_copy.default([primals_15, primals_16, primals_17, primals_18], [as_strided_4, as_strided_5, as_strided_6, as_strided_7]); primals_16 = primals_18 = as_strided_4 = as_strided_5 = as_strided_6 = as_strided_7 = None getitem_92: "f32[128, 32]" = _foreach_copy_3[0] getitem_93: "f32[128]" = _foreach_copy_3[1] getitem_94: "f32[32, 128]" = _foreach_copy_3[2] getitem_95: "f32[32]" = _foreach_copy_3[3]; _foreach_copy_3 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_5: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_92, [1, 0]); getitem_92 = None # No stacktrace found for following nodes mm_default_3: "f32[8, 128]" = torch.ops.aten.mm.default(relu_1, permute_5); permute_5 = None add_tensor_3: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_3, getitem_93); mm_default_3 = getitem_93 = None # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z) relu_2: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_3); add_tensor_3 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_7: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_94, [1, 0]); getitem_94 = None # No stacktrace found for following nodes mm_default_2: "f32[8, 32]" = torch.ops.aten.mm.default(relu_2, permute_7); permute_7 = None add_tensor_2: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default_2, getitem_95); mm_default_2 = getitem_95 = None # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z) relu_3: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor_2); add_tensor_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) _foreach_copy_4 = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_19, primals_20, primals_21, primals_22]); getitem = getitem_1 = getitem_2 = getitem_3 = primals_19 = primals_20 = primals_21 = primals_22 = None getitem_100: "f32[2048]" = _foreach_copy_4[0] getitem_101: "f32[64]" = _foreach_copy_4[1] getitem_102: "f32[2048]" = _foreach_copy_4[2] getitem_103: "f32[16]" = _foreach_copy_4[3]; _foreach_copy_4 = None # No stacktrace found for following nodes slice_tensor_8: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352) slice_scatter_default_16: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_8, getitem_100, 0, 0, 2048); slice_tensor_8 = getitem_100 = None slice_scatter_default_17: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default_16, 0, 4176, 8352); empty = slice_scatter_default_16 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_23: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_17, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_9: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_17, 0, 4176, 8352) slice_scatter_default_18: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_9, getitem_101, 0, 2048, 2112); slice_tensor_9 = getitem_101 = None slice_scatter_default_19: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_17, slice_scatter_default_18, 0, 4176, 8352); slice_scatter_default_17 = slice_scatter_default_18 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_24: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_19, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_10: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_19, 0, 4176, 8352) slice_scatter_default_20: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_10, getitem_102, 0, 2112, 4160); slice_tensor_10 = getitem_102 = None slice_scatter_default_21: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_19, slice_scatter_default_20, 0, 4176, 8352); slice_scatter_default_19 = slice_scatter_default_20 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_25: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_21, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_11: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_21, 0, 4176, 8352) slice_scatter_default_22: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_11, getitem_103, 0, 4160, 4176); slice_tensor_11 = getitem_103 = None slice_scatter_default_23: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_21, slice_scatter_default_22, 0, 4176, 8352); slice_scatter_default_21 = slice_scatter_default_22 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( slice_30: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_23, 0, 4176, 8352); slice_scatter_default_23 = None all_gather_into_tensor_2: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_30, 2, '0'); slice_30 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] wait_tensor_2: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), view_19: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]); wait_tensor_2 = None split_with_sizes_26 = torch.ops.aten.split_with_sizes.default(view_19, [2048, 64, 2048, 16], 1); view_19 = None getitem_124: "f32[2, 2048]" = split_with_sizes_26[0] clone_8: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_124, memory_format = torch.contiguous_format); getitem_124 = None view_20: "f32[4096]" = torch.ops.aten.reshape.default(clone_8, [4096]); clone_8 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_8: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_20, [128, 32], [32, 1], 0); view_20 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_129: "f32[2, 64]" = split_with_sizes_26[1] clone_9: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_129, memory_format = torch.contiguous_format); getitem_129 = None view_22: "f32[128]" = torch.ops.aten.reshape.default(clone_9, [128]); clone_9 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_9: "f32[128]" = torch.ops.aten.as_strided.default(view_22, [128], [1], 0); view_22 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_134: "f32[2, 2048]" = split_with_sizes_26[2] clone_10: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_134, memory_format = torch.contiguous_format); getitem_134 = None view_24: "f32[4096]" = torch.ops.aten.reshape.default(clone_10, [4096]); clone_10 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_10: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_24, [32, 128], [128, 1], 0); view_24 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_139: "f32[2, 16]" = split_with_sizes_26[3]; split_with_sizes_26 = None clone_11: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_139, memory_format = torch.contiguous_format); getitem_139 = None view_26: "f32[32]" = torch.ops.aten.reshape.default(clone_11, [32]); clone_11 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_11: "f32[32]" = torch.ops.aten.as_strided.default(view_26, [32], [1], 0); view_26 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) _foreach_copy_5 = torch.ops.aten._foreach_copy.default([primals_23, primals_24, primals_25, primals_26], [as_strided_8, as_strided_9, as_strided_10, as_strided_11]); primals_24 = primals_26 = as_strided_8 = as_strided_9 = as_strided_10 = as_strided_11 = None getitem_140: "f32[128, 32]" = _foreach_copy_5[0] getitem_141: "f32[128]" = _foreach_copy_5[1] getitem_142: "f32[32, 128]" = _foreach_copy_5[2] getitem_143: "f32[32]" = _foreach_copy_5[3]; _foreach_copy_5 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_9: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_140, [1, 0]); getitem_140 = None # No stacktrace found for following nodes mm_default_1: "f32[8, 128]" = torch.ops.aten.mm.default(relu_3, permute_9); permute_9 = None add_tensor_1: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_1, getitem_141); mm_default_1 = getitem_141 = None # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z) relu_4: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_1); add_tensor_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_11: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_142, [1, 0]); getitem_142 = None # No stacktrace found for following nodes mm_default: "f32[8, 32]" = torch.ops.aten.mm.default(relu_4, permute_11); permute_11 = None add_tensor: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default, getitem_143); mm_default = getitem_143 = None # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z) relu_5: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor); add_tensor = None le: "b8[8, 32]" = torch.ops.aten.le.Scalar(relu_5, 0) return [relu_5, primals_1, primals_8, primals_15, primals_17, primals_23, primals_25, relu, relu_1, relu_2, relu_3, relu_4, le] -
yf225 revised this gist
Mar 27, 2024 . 1 changed file with 277 additions and 218 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,330 +1,389 @@ TRACED GRAPH ===== AFTER POST GRAD ===== /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[1024]", primals_3: "f32[32]", primals_4: "f32[64, 32]", primals_5: "f32[64]", primals_6, primals_7: "f32[4096]", primals_8: "f32[64]", primals_9: "f32[128, 64]", primals_10: "f32[128]", primals_11: "f32[16384]", primals_12: "f32[128]", primals_13: "f32[256, 128]", primals_14: "f32[256]"): # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( empty: "f32[2112]" = torch.ops.aten.empty.memory_format([2112], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( slice_1: "f32[1056]" = torch.ops.aten.slice.Tensor(empty, 0, 0, 1056) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [1024, 32]); slice_1 = None getitem: "f32[1024]" = split_with_sizes[0] getitem_1: "f32[32]" = split_with_sizes[1]; split_with_sizes = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) _foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1], [primals_2, primals_3]); getitem = getitem_1 = primals_2 = primals_3 = None getitem_2: "f32[1024]" = _foreach_copy[0] getitem_3: "f32[32]" = _foreach_copy[1]; _foreach_copy = None # No stacktrace found for following nodes slice_tensor: "f32[1056]" = torch.ops.aten.slice.Tensor(empty, 0, 0, 1056) slice_scatter_default: "f32[1056]" = torch.ops.aten.slice_scatter.default(slice_tensor, getitem_2, 0, 0, 1024); slice_tensor = getitem_2 = None slice_scatter_default_1: "f32[2112]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default, 0, 0, 1056); empty = slice_scatter_default = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_3: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 0, 1056) # No stacktrace found for following nodes slice_tensor_1: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 0, 1056) slice_scatter_default_2: "f32[1056]" = torch.ops.aten.slice_scatter.default(slice_tensor_1, getitem_3, 0, 1024, 1056); slice_tensor_1 = getitem_3 = None slice_scatter_default_3: "f32[2112]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_1, slice_scatter_default_2, 0, 0, 1056); slice_scatter_default_1 = slice_scatter_default_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( slice_6: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 0, 1056); slice_scatter_default_3 = None all_gather_into_tensor: "f32[2112]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_6, 2, '0'); slice_6 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] wait_tensor: "f32[2112]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), view_1: "f32[2, 1056]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]); wait_tensor = None split_with_sizes_4 = torch.ops.aten.split_with_sizes.default(view_1, [1024, 32], 1); view_1 = None getitem_10: "f32[2, 1024]" = split_with_sizes_4[0] clone: "f32[2, 1024]" = torch.ops.aten.clone.default(getitem_10, memory_format = torch.contiguous_format); getitem_10 = None view_2: "f32[2048]" = torch.ops.aten.reshape.default(clone, [2048]); clone = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided: "f32[64, 32]" = torch.ops.aten.as_strided.default(view_2, [64, 32], [32, 1], 0); view_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_13: "f32[2, 32]" = split_with_sizes_4[1]; split_with_sizes_4 = None clone_1: "f32[2, 32]" = torch.ops.aten.clone.default(getitem_13, memory_format = torch.contiguous_format); getitem_13 = None view_4: "f32[64]" = torch.ops.aten.reshape.default(clone_1, [64]); clone_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_1: "f32[64]" = torch.ops.aten.as_strided.default(view_4, [64], [1], 0); view_4 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) _foreach_copy_1 = torch.ops.aten._foreach_copy.default([primals_4, primals_5], [as_strided, as_strided_1]); primals_4 = primals_5 = as_strided = as_strided_1 = None getitem_14: "f32[64, 32]" = _foreach_copy_1[0] getitem_15: "f32[64]" = _foreach_copy_1[1]; _foreach_copy_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_1: "f32[32, 64]" = torch.ops.aten.permute.default(getitem_14, [1, 0]); getitem_14 = None addmm: "f32[8, 64]" = torch.ops.aten.addmm.default(getitem_15, primals_1, permute_1); getitem_15 = permute_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( empty_1: "f32[8320]" = torch.ops.aten.empty.memory_format([8320], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( slice_7: "f32[4160]" = torch.ops.aten.slice.Tensor(empty_1, 0, 0, 4160) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(slice_7, [4096, 64]); slice_7 = None getitem_16: "f32[4096]" = split_with_sizes_6[0] getitem_17: "f32[64]" = split_with_sizes_6[1]; split_with_sizes_6 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) _foreach_copy_2 = torch.ops.aten._foreach_copy.default([getitem_16, getitem_17], [primals_7, primals_8]); getitem_16 = getitem_17 = primals_7 = primals_8 = None getitem_18: "f32[4096]" = _foreach_copy_2[0] getitem_19: "f32[64]" = _foreach_copy_2[1]; _foreach_copy_2 = None # No stacktrace found for following nodes slice_tensor_2: "f32[4160]" = torch.ops.aten.slice.Tensor(empty_1, 0, 0, 4160) slice_scatter_default_4: "f32[4160]" = torch.ops.aten.slice_scatter.default(slice_tensor_2, getitem_18, 0, 0, 4096); slice_tensor_2 = getitem_18 = None slice_scatter_default_5: "f32[8320]" = torch.ops.aten.slice_scatter.default(empty_1, slice_scatter_default_4, 0, 0, 4160); empty_1 = slice_scatter_default_4 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_9: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 0, 4160) # No stacktrace found for following nodes slice_tensor_3: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 0, 4160) slice_scatter_default_6: "f32[4160]" = torch.ops.aten.slice_scatter.default(slice_tensor_3, getitem_19, 0, 4096, 4160); slice_tensor_3 = getitem_19 = None slice_scatter_default_7: "f32[8320]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_5, slice_scatter_default_6, 0, 0, 4160); slice_scatter_default_5 = slice_scatter_default_6 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( slice_12: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_7, 0, 0, 4160); slice_scatter_default_7 = None all_gather_into_tensor_1: "f32[8320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_12, 2, '0'); slice_12 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] wait_tensor_1: "f32[8320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), view_6: "f32[2, 4160]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]); wait_tensor_1 = None split_with_sizes_10 = torch.ops.aten.split_with_sizes.default(view_6, [4096, 64], 1); view_6 = None getitem_26: "f32[2, 4096]" = split_with_sizes_10[0] clone_2: "f32[2, 4096]" = torch.ops.aten.clone.default(getitem_26, memory_format = torch.contiguous_format); getitem_26 = None view_7: "f32[8192]" = torch.ops.aten.reshape.default(clone_2, [8192]); clone_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_2: "f32[128, 64]" = torch.ops.aten.as_strided.default(view_7, [128, 64], [64, 1], 0); view_7 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_29: "f32[2, 64]" = split_with_sizes_10[1]; split_with_sizes_10 = None clone_3: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_29, memory_format = torch.contiguous_format); getitem_29 = None view_9: "f32[128]" = torch.ops.aten.reshape.default(clone_3, [128]); clone_3 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_3: "f32[128]" = torch.ops.aten.as_strided.default(view_9, [128], [1], 0); view_9 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) _foreach_copy_3 = torch.ops.aten._foreach_copy.default([primals_9, primals_10], [as_strided_2, as_strided_3]); primals_10 = as_strided_2 = as_strided_3 = None getitem_30: "f32[128, 64]" = _foreach_copy_3[0] getitem_31: "f32[128]" = _foreach_copy_3[1]; _foreach_copy_3 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_3: "f32[64, 128]" = torch.ops.aten.permute.default(getitem_30, [1, 0]); getitem_30 = None addmm_1: "f32[8, 128]" = torch.ops.aten.addmm.default(getitem_31, addmm, permute_3); getitem_31 = permute_3 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( empty_2: "f32[33024]" = torch.ops.aten.empty.memory_format([33024], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( slice_13: "f32[16512]" = torch.ops.aten.slice.Tensor(empty_2, 0, 0, 16512) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) split_with_sizes_12 = torch.ops.aten.split_with_sizes.default(slice_13, [16384, 128]); slice_13 = None getitem_32: "f32[16384]" = split_with_sizes_12[0] getitem_33: "f32[128]" = split_with_sizes_12[1]; split_with_sizes_12 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) _foreach_copy_4 = torch.ops.aten._foreach_copy.default([getitem_32, getitem_33], [primals_11, primals_12]); getitem_32 = getitem_33 = primals_11 = primals_12 = None getitem_34: "f32[16384]" = _foreach_copy_4[0] getitem_35: "f32[128]" = _foreach_copy_4[1]; _foreach_copy_4 = None # No stacktrace found for following nodes slice_tensor_4: "f32[16512]" = torch.ops.aten.slice.Tensor(empty_2, 0, 0, 16512) slice_scatter_default_8: "f32[16512]" = torch.ops.aten.slice_scatter.default(slice_tensor_4, getitem_34, 0, 0, 16384); slice_tensor_4 = getitem_34 = None slice_scatter_default_9: "f32[33024]" = torch.ops.aten.slice_scatter.default(empty_2, slice_scatter_default_8, 0, 0, 16512); empty_2 = slice_scatter_default_8 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_15: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 0, 16512) # No stacktrace found for following nodes slice_tensor_5: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 0, 16512) slice_scatter_default_10: "f32[16512]" = torch.ops.aten.slice_scatter.default(slice_tensor_5, getitem_35, 0, 16384, 16512); slice_tensor_5 = getitem_35 = None slice_scatter_default_11: "f32[33024]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_9, slice_scatter_default_10, 0, 0, 16512); slice_scatter_default_9 = slice_scatter_default_10 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( slice_18: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 0, 16512); slice_scatter_default_11 = None all_gather_into_tensor_2: "f32[33024]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_18, 2, '0'); slice_18 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] wait_tensor_2: "f32[33024]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), view_11: "f32[2, 16512]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]); wait_tensor_2 = None split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_11, [16384, 128], 1); view_11 = None getitem_42: "f32[2, 16384]" = split_with_sizes_16[0] clone_4: "f32[2, 16384]" = torch.ops.aten.clone.default(getitem_42, memory_format = torch.contiguous_format); getitem_42 = None view_12: "f32[32768]" = torch.ops.aten.reshape.default(clone_4, [32768]); clone_4 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_4: "f32[256, 128]" = torch.ops.aten.as_strided.default(view_12, [256, 128], [128, 1], 0); view_12 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_45: "f32[2, 128]" = split_with_sizes_16[1]; split_with_sizes_16 = None clone_5: "f32[2, 128]" = torch.ops.aten.clone.default(getitem_45, memory_format = torch.contiguous_format); getitem_45 = None view_14: "f32[256]" = torch.ops.aten.reshape.default(clone_5, [256]); clone_5 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_5: "f32[256]" = torch.ops.aten.as_strided.default(view_14, [256], [1], 0); view_14 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) _foreach_copy_5 = torch.ops.aten._foreach_copy.default([primals_13, primals_14], [as_strided_4, as_strided_5]); primals_14 = as_strided_4 = as_strided_5 = None getitem_46: "f32[256, 128]" = _foreach_copy_5[0] getitem_47: "f32[256]" = _foreach_copy_5[1]; _foreach_copy_5 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_5: "f32[128, 256]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None addmm_2: "f32[8, 256]" = torch.ops.aten.addmm.default(getitem_47, addmm_1, permute_5); getitem_47 = permute_5 = None return [addmm_2, primals_1, primals_9, primals_13, addmm, addmm_1] TRACED GRAPH ===== AFTER POST GRAD ===== /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[1024]", primals_3: "f32[32]", primals_4: "f32[64, 32]", primals_5: "f32[64]", primals_6, primals_7: "f32[4096]", primals_8: "f32[64]", primals_9: "f32[128, 64]", primals_10: "f32[128]", primals_11: "f32[16384]", primals_12: "f32[128]", primals_13: "f32[256, 128]", primals_14: "f32[256]"): # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( empty: "f32[2112]" = torch.ops.aten.empty.memory_format([2112], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( slice_1: "f32[1056]" = torch.ops.aten.slice.Tensor(empty, 0, 1056, 2112) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [1024, 32]); slice_1 = None getitem: "f32[1024]" = split_with_sizes[0] getitem_1: "f32[32]" = split_with_sizes[1]; split_with_sizes = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) _foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1], [primals_2, primals_3]); getitem = getitem_1 = primals_2 = primals_3 = None getitem_2: "f32[1024]" = _foreach_copy[0] getitem_3: "f32[32]" = _foreach_copy[1]; _foreach_copy = None # No stacktrace found for following nodes slice_tensor: "f32[1056]" = torch.ops.aten.slice.Tensor(empty, 0, 1056, 2112) slice_scatter_default: "f32[1056]" = torch.ops.aten.slice_scatter.default(slice_tensor, getitem_2, 0, 0, 1024); slice_tensor = getitem_2 = None slice_scatter_default_1: "f32[2112]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default, 0, 1056, 2112); empty = slice_scatter_default = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_3: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 1056, 2112) # No stacktrace found for following nodes slice_tensor_1: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 1056, 2112) slice_scatter_default_2: "f32[1056]" = torch.ops.aten.slice_scatter.default(slice_tensor_1, getitem_3, 0, 1024, 1056); slice_tensor_1 = getitem_3 = None slice_scatter_default_3: "f32[2112]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_1, slice_scatter_default_2, 0, 1056, 2112); slice_scatter_default_1 = slice_scatter_default_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( slice_6: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 1056, 2112); slice_scatter_default_3 = None all_gather_into_tensor: "f32[2112]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_6, 2, '0'); slice_6 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] wait_tensor: "f32[2112]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), view_1: "f32[2, 1056]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]); wait_tensor = None split_with_sizes_4 = torch.ops.aten.split_with_sizes.default(view_1, [1024, 32], 1); view_1 = None getitem_10: "f32[2, 1024]" = split_with_sizes_4[0] clone: "f32[2, 1024]" = torch.ops.aten.clone.default(getitem_10, memory_format = torch.contiguous_format); getitem_10 = None view_2: "f32[2048]" = torch.ops.aten.reshape.default(clone, [2048]); clone = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided: "f32[64, 32]" = torch.ops.aten.as_strided.default(view_2, [64, 32], [32, 1], 0); view_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_13: "f32[2, 32]" = split_with_sizes_4[1]; split_with_sizes_4 = None clone_1: "f32[2, 32]" = torch.ops.aten.clone.default(getitem_13, memory_format = torch.contiguous_format); getitem_13 = None view_4: "f32[64]" = torch.ops.aten.reshape.default(clone_1, [64]); clone_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_1: "f32[64]" = torch.ops.aten.as_strided.default(view_4, [64], [1], 0); view_4 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) _foreach_copy_1 = torch.ops.aten._foreach_copy.default([primals_4, primals_5], [as_strided, as_strided_1]); primals_4 = primals_5 = as_strided = as_strided_1 = None getitem_14: "f32[64, 32]" = _foreach_copy_1[0] getitem_15: "f32[64]" = _foreach_copy_1[1]; _foreach_copy_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_1: "f32[32, 64]" = torch.ops.aten.permute.default(getitem_14, [1, 0]); getitem_14 = None addmm: "f32[8, 64]" = torch.ops.aten.addmm.default(getitem_15, primals_1, permute_1); getitem_15 = permute_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( empty_1: "f32[8320]" = torch.ops.aten.empty.memory_format([8320], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( slice_7: "f32[4160]" = torch.ops.aten.slice.Tensor(empty_1, 0, 4160, 8320) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(slice_7, [4096, 64]); slice_7 = None getitem_16: "f32[4096]" = split_with_sizes_6[0] getitem_17: "f32[64]" = split_with_sizes_6[1]; split_with_sizes_6 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) _foreach_copy_2 = torch.ops.aten._foreach_copy.default([getitem_16, getitem_17], [primals_7, primals_8]); getitem_16 = getitem_17 = primals_7 = primals_8 = None getitem_18: "f32[4096]" = _foreach_copy_2[0] getitem_19: "f32[64]" = _foreach_copy_2[1]; _foreach_copy_2 = None # No stacktrace found for following nodes slice_tensor_2: "f32[4160]" = torch.ops.aten.slice.Tensor(empty_1, 0, 4160, 8320) slice_scatter_default_4: "f32[4160]" = torch.ops.aten.slice_scatter.default(slice_tensor_2, getitem_18, 0, 0, 4096); slice_tensor_2 = getitem_18 = None slice_scatter_default_5: "f32[8320]" = torch.ops.aten.slice_scatter.default(empty_1, slice_scatter_default_4, 0, 4160, 8320); empty_1 = slice_scatter_default_4 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_9: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4160, 8320) # No stacktrace found for following nodes slice_tensor_3: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4160, 8320) slice_scatter_default_6: "f32[4160]" = torch.ops.aten.slice_scatter.default(slice_tensor_3, getitem_19, 0, 4096, 4160); slice_tensor_3 = getitem_19 = None slice_scatter_default_7: "f32[8320]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_5, slice_scatter_default_6, 0, 4160, 8320); slice_scatter_default_5 = slice_scatter_default_6 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( slice_12: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_7, 0, 4160, 8320); slice_scatter_default_7 = None all_gather_into_tensor_1: "f32[8320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_12, 2, '0'); slice_12 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] wait_tensor_1: "f32[8320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), view_6: "f32[2, 4160]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]); wait_tensor_1 = None split_with_sizes_10 = torch.ops.aten.split_with_sizes.default(view_6, [4096, 64], 1); view_6 = None getitem_26: "f32[2, 4096]" = split_with_sizes_10[0] clone_2: "f32[2, 4096]" = torch.ops.aten.clone.default(getitem_26, memory_format = torch.contiguous_format); getitem_26 = None view_7: "f32[8192]" = torch.ops.aten.reshape.default(clone_2, [8192]); clone_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_2: "f32[128, 64]" = torch.ops.aten.as_strided.default(view_7, [128, 64], [64, 1], 0); view_7 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_29: "f32[2, 64]" = split_with_sizes_10[1]; split_with_sizes_10 = None clone_3: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_29, memory_format = torch.contiguous_format); getitem_29 = None view_9: "f32[128]" = torch.ops.aten.reshape.default(clone_3, [128]); clone_3 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_3: "f32[128]" = torch.ops.aten.as_strided.default(view_9, [128], [1], 0); view_9 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) _foreach_copy_3 = torch.ops.aten._foreach_copy.default([primals_9, primals_10], [as_strided_2, as_strided_3]); primals_10 = as_strided_2 = as_strided_3 = None getitem_30: "f32[128, 64]" = _foreach_copy_3[0] getitem_31: "f32[128]" = _foreach_copy_3[1]; _foreach_copy_3 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_3: "f32[64, 128]" = torch.ops.aten.permute.default(getitem_30, [1, 0]); getitem_30 = None addmm_1: "f32[8, 128]" = torch.ops.aten.addmm.default(getitem_31, addmm, permute_3); getitem_31 = permute_3 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( empty_2: "f32[33024]" = torch.ops.aten.empty.memory_format([33024], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( slice_13: "f32[16512]" = torch.ops.aten.slice.Tensor(empty_2, 0, 16512, 33024) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) split_with_sizes_12 = torch.ops.aten.split_with_sizes.default(slice_13, [16384, 128]); slice_13 = None getitem_32: "f32[16384]" = split_with_sizes_12[0] getitem_33: "f32[128]" = split_with_sizes_12[1]; split_with_sizes_12 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) _foreach_copy_4 = torch.ops.aten._foreach_copy.default([getitem_32, getitem_33], [primals_11, primals_12]); getitem_32 = getitem_33 = primals_11 = primals_12 = None getitem_34: "f32[16384]" = _foreach_copy_4[0] getitem_35: "f32[128]" = _foreach_copy_4[1]; _foreach_copy_4 = None # No stacktrace found for following nodes slice_tensor_4: "f32[16512]" = torch.ops.aten.slice.Tensor(empty_2, 0, 16512, 33024) slice_scatter_default_8: "f32[16512]" = torch.ops.aten.slice_scatter.default(slice_tensor_4, getitem_34, 0, 0, 16384); slice_tensor_4 = getitem_34 = None slice_scatter_default_9: "f32[33024]" = torch.ops.aten.slice_scatter.default(empty_2, slice_scatter_default_8, 0, 16512, 33024); empty_2 = slice_scatter_default_8 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_15: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 16512, 33024) # No stacktrace found for following nodes slice_tensor_5: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 16512, 33024) slice_scatter_default_10: "f32[16512]" = torch.ops.aten.slice_scatter.default(slice_tensor_5, getitem_35, 0, 16384, 16512); slice_tensor_5 = getitem_35 = None slice_scatter_default_11: "f32[33024]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_9, slice_scatter_default_10, 0, 16512, 33024); slice_scatter_default_9 = slice_scatter_default_10 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( slice_18: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 16512, 33024); slice_scatter_default_11 = None all_gather_into_tensor_2: "f32[33024]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_18, 2, '0'); slice_18 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] wait_tensor_2: "f32[33024]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), view_11: "f32[2, 16512]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]); wait_tensor_2 = None split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_11, [16384, 128], 1); view_11 = None getitem_42: "f32[2, 16384]" = split_with_sizes_16[0] clone_4: "f32[2, 16384]" = torch.ops.aten.clone.default(getitem_42, memory_format = torch.contiguous_format); getitem_42 = None view_12: "f32[32768]" = torch.ops.aten.reshape.default(clone_4, [32768]); clone_4 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_4: "f32[256, 128]" = torch.ops.aten.as_strided.default(view_12, [256, 128], [128, 1], 0); view_12 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_45: "f32[2, 128]" = split_with_sizes_16[1]; split_with_sizes_16 = None clone_5: "f32[2, 128]" = torch.ops.aten.clone.default(getitem_45, memory_format = torch.contiguous_format); getitem_45 = None view_14: "f32[256]" = torch.ops.aten.reshape.default(clone_5, [256]); clone_5 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_5: "f32[256]" = torch.ops.aten.as_strided.default(view_14, [256], [1], 0); view_14 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) _foreach_copy_5 = torch.ops.aten._foreach_copy.default([primals_13, primals_14], [as_strided_4, as_strided_5]); primals_14 = as_strided_4 = as_strided_5 = None getitem_46: "f32[256, 128]" = _foreach_copy_5[0] getitem_47: "f32[256]" = _foreach_copy_5[1]; _foreach_copy_5 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_5: "f32[128, 256]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None addmm_2: "f32[8, 256]" = torch.ops.aten.addmm.default(getitem_47, addmm_1, permute_5); getitem_47 = permute_5 = None return [addmm_2, primals_1, primals_9, primals_13, addmm, addmm_1]
-
yf225 created this gist
Mar 27, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,330 @@ TRACED GRAPH ===== AFTER POST GRAD ===== /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[2048]", primals_3: "f32[64]", primals_4: "f32[2048]", primals_5: "f32[16]", primals_6: "f32[128, 32]", primals_7: "f32[128]", primals_8: "f32[32, 128]", primals_9: "f32[32]", primals_10, primals_11: "f32[2048]", primals_12: "f32[64]", primals_13: "f32[2048]", primals_14: "f32[16]", primals_15: "f32[128, 32]", primals_16: "f32[128]", primals_17: "f32[32, 128]", primals_18: "f32[32]", primals_19: "f32[2048]", primals_20: "f32[64]", primals_21: "f32[2048]", primals_22: "f32[16]", primals_23: "f32[128, 32]", primals_24: "f32[128]", primals_25: "f32[32, 128]", primals_26: "f32[32]"): # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( empty: "f32[8352]" = torch.ops.aten.empty.memory_format([8352], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( slice_1: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [2048, 64, 2048, 16]); slice_1 = None getitem: "f32[2048]" = split_with_sizes[0] getitem_1: "f32[64]" = split_with_sizes[1] getitem_2: "f32[2048]" = split_with_sizes[2] getitem_3: "f32[16]" = split_with_sizes[3]; split_with_sizes = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) _foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_2, primals_3, primals_4, primals_5]); primals_2 = primals_3 = primals_4 = primals_5 = None getitem_4: "f32[2048]" = _foreach_copy[0] getitem_5: "f32[64]" = _foreach_copy[1] getitem_6: "f32[2048]" = _foreach_copy[2] getitem_7: "f32[16]" = _foreach_copy[3]; _foreach_copy = None # No stacktrace found for following nodes slice_tensor: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352) slice_scatter_default: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor, getitem_4, 0, 0, 2048); slice_tensor = getitem_4 = None slice_scatter_default_1: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default, 0, 4176, 8352); slice_scatter_default = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_3: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_1: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 4176, 8352) slice_scatter_default_2: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_1, getitem_5, 0, 2048, 2112); slice_tensor_1 = getitem_5 = None slice_scatter_default_3: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_1, slice_scatter_default_2, 0, 4176, 8352); slice_scatter_default_1 = slice_scatter_default_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_4: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_2: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 4176, 8352) slice_scatter_default_4: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_2, getitem_6, 0, 2112, 4160); slice_tensor_2 = getitem_6 = None slice_scatter_default_5: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_3, slice_scatter_default_4, 0, 4176, 8352); slice_scatter_default_3 = slice_scatter_default_4 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_5: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_3: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4176, 8352) slice_scatter_default_6: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_3, getitem_7, 0, 4160, 4176); slice_tensor_3 = getitem_7 = None slice_scatter_default_7: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_5, slice_scatter_default_6, 0, 4176, 8352); slice_scatter_default_5 = slice_scatter_default_6 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( slice_10: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_7, 0, 4176, 8352); slice_scatter_default_7 = None all_gather_into_tensor: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_10, 2, '0'); slice_10 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] wait_tensor: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), view_1: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]); wait_tensor = None split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_1, [2048, 64, 2048, 16], 1); view_1 = None getitem_28: "f32[2, 2048]" = split_with_sizes_6[0] clone: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_28, memory_format = torch.contiguous_format); getitem_28 = None view_2: "f32[4096]" = torch.ops.aten.reshape.default(clone, [4096]); clone = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_2, [128, 32], [32, 1], 0); view_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_33: "f32[2, 64]" = split_with_sizes_6[1] clone_1: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_33, memory_format = torch.contiguous_format); getitem_33 = None view_4: "f32[128]" = torch.ops.aten.reshape.default(clone_1, [128]); clone_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_1: "f32[128]" = torch.ops.aten.as_strided.default(view_4, [128], [1], 0); view_4 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_38: "f32[2, 2048]" = split_with_sizes_6[2] clone_2: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_38, memory_format = torch.contiguous_format); getitem_38 = None view_6: "f32[4096]" = torch.ops.aten.reshape.default(clone_2, [4096]); clone_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_2: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_6, [32, 128], [128, 1], 0); view_6 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_43: "f32[2, 16]" = split_with_sizes_6[3]; split_with_sizes_6 = None clone_3: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_43, memory_format = torch.contiguous_format); getitem_43 = None view_8: "f32[32]" = torch.ops.aten.reshape.default(clone_3, [32]); clone_3 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_3: "f32[32]" = torch.ops.aten.as_strided.default(view_8, [32], [1], 0); view_8 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) _foreach_copy_1 = torch.ops.aten._foreach_copy.default([primals_6, primals_7, primals_8, primals_9], [as_strided, as_strided_1, as_strided_2, as_strided_3]); primals_6 = primals_7 = primals_9 = as_strided = as_strided_1 = as_strided_2 = as_strided_3 = None getitem_44: "f32[128, 32]" = _foreach_copy_1[0] getitem_45: "f32[128]" = _foreach_copy_1[1] getitem_46: "f32[32, 128]" = _foreach_copy_1[2] getitem_47: "f32[32]" = _foreach_copy_1[3]; _foreach_copy_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_1: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_44, [1, 0]); getitem_44 = None # No stacktrace found for following nodes mm_default_5: "f32[8, 128]" = torch.ops.aten.mm.default(primals_1, permute_1); permute_1 = None add_tensor_5: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_5, getitem_45); mm_default_5 = getitem_45 = None # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z) relu: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_5); add_tensor_5 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_3: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None # No stacktrace found for following nodes mm_default_4: "f32[8, 32]" = torch.ops.aten.mm.default(relu, permute_3); permute_3 = None add_tensor_4: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default_4, getitem_47); mm_default_4 = getitem_47 = None # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z) relu_1: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor_4); add_tensor_4 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) _foreach_copy_2 = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_11, primals_12, primals_13, primals_14]); primals_11 = primals_12 = primals_13 = primals_14 = None getitem_52: "f32[2048]" = _foreach_copy_2[0] getitem_53: "f32[64]" = _foreach_copy_2[1] getitem_54: "f32[2048]" = _foreach_copy_2[2] getitem_55: "f32[16]" = _foreach_copy_2[3]; _foreach_copy_2 = None # No stacktrace found for following nodes slice_tensor_4: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352) slice_scatter_default_8: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_4, getitem_52, 0, 0, 2048); slice_tensor_4 = getitem_52 = None slice_scatter_default_9: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default_8, 0, 4176, 8352); slice_scatter_default_8 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_13: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_5: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 4176, 8352) slice_scatter_default_10: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_5, getitem_53, 0, 2048, 2112); slice_tensor_5 = getitem_53 = None slice_scatter_default_11: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_9, slice_scatter_default_10, 0, 4176, 8352); slice_scatter_default_9 = slice_scatter_default_10 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_14: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_6: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 4176, 8352) slice_scatter_default_12: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_6, getitem_54, 0, 2112, 4160); slice_tensor_6 = getitem_54 = None slice_scatter_default_13: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_11, slice_scatter_default_12, 0, 4176, 8352); slice_scatter_default_11 = slice_scatter_default_12 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_15: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_13, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_7: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_13, 0, 4176, 8352) slice_scatter_default_14: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_7, getitem_55, 0, 4160, 4176); slice_tensor_7 = getitem_55 = None slice_scatter_default_15: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_13, slice_scatter_default_14, 0, 4176, 8352); slice_scatter_default_13 = slice_scatter_default_14 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( slice_20: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_15, 0, 4176, 8352); slice_scatter_default_15 = None all_gather_into_tensor_1: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_20, 2, '0'); slice_20 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] wait_tensor_1: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), view_10: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]); wait_tensor_1 = None split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_10, [2048, 64, 2048, 16], 1); view_10 = None getitem_76: "f32[2, 2048]" = split_with_sizes_16[0] clone_4: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_76, memory_format = torch.contiguous_format); getitem_76 = None view_11: "f32[4096]" = torch.ops.aten.reshape.default(clone_4, [4096]); clone_4 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_4: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_11, [128, 32], [32, 1], 0); view_11 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_81: "f32[2, 64]" = split_with_sizes_16[1] clone_5: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_81, memory_format = torch.contiguous_format); getitem_81 = None view_13: "f32[128]" = torch.ops.aten.reshape.default(clone_5, [128]); clone_5 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_5: "f32[128]" = torch.ops.aten.as_strided.default(view_13, [128], [1], 0); view_13 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_86: "f32[2, 2048]" = split_with_sizes_16[2] clone_6: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_86, memory_format = torch.contiguous_format); getitem_86 = None view_15: "f32[4096]" = torch.ops.aten.reshape.default(clone_6, [4096]); clone_6 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_6: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_15, [32, 128], [128, 1], 0); view_15 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_91: "f32[2, 16]" = split_with_sizes_16[3]; split_with_sizes_16 = None clone_7: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_91, memory_format = torch.contiguous_format); getitem_91 = None view_17: "f32[32]" = torch.ops.aten.reshape.default(clone_7, [32]); clone_7 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_7: "f32[32]" = torch.ops.aten.as_strided.default(view_17, [32], [1], 0); view_17 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) _foreach_copy_3 = torch.ops.aten._foreach_copy.default([primals_15, primals_16, primals_17, primals_18], [as_strided_4, as_strided_5, as_strided_6, as_strided_7]); primals_16 = primals_18 = as_strided_4 = as_strided_5 = as_strided_6 = as_strided_7 = None getitem_92: "f32[128, 32]" = _foreach_copy_3[0] getitem_93: "f32[128]" = _foreach_copy_3[1] getitem_94: "f32[32, 128]" = _foreach_copy_3[2] getitem_95: "f32[32]" = _foreach_copy_3[3]; _foreach_copy_3 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_5: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_92, [1, 0]); getitem_92 = None # No stacktrace found for following nodes mm_default_3: "f32[8, 128]" = torch.ops.aten.mm.default(relu_1, permute_5); permute_5 = None add_tensor_3: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_3, getitem_93); mm_default_3 = getitem_93 = None # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z) relu_2: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_3); add_tensor_3 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_7: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_94, [1, 0]); getitem_94 = None # No stacktrace found for following nodes mm_default_2: "f32[8, 32]" = torch.ops.aten.mm.default(relu_2, permute_7); permute_7 = None add_tensor_2: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default_2, getitem_95); mm_default_2 = getitem_95 = None # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z) relu_3: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor_2); add_tensor_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) _foreach_copy_4 = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_19, primals_20, primals_21, primals_22]); getitem = getitem_1 = getitem_2 = getitem_3 = primals_19 = primals_20 = primals_21 = primals_22 = None getitem_100: "f32[2048]" = _foreach_copy_4[0] getitem_101: "f32[64]" = _foreach_copy_4[1] getitem_102: "f32[2048]" = _foreach_copy_4[2] getitem_103: "f32[16]" = _foreach_copy_4[3]; _foreach_copy_4 = None # No stacktrace found for following nodes slice_tensor_8: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352) slice_scatter_default_16: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_8, getitem_100, 0, 0, 2048); slice_tensor_8 = getitem_100 = None slice_scatter_default_17: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default_16, 0, 4176, 8352); empty = slice_scatter_default_16 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_23: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_17, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_9: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_17, 0, 4176, 8352) slice_scatter_default_18: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_9, getitem_101, 0, 2048, 2112); slice_tensor_9 = getitem_101 = None slice_scatter_default_19: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_17, slice_scatter_default_18, 0, 4176, 8352); slice_scatter_default_17 = slice_scatter_default_18 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_24: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_19, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_10: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_19, 0, 4176, 8352) slice_scatter_default_20: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_10, getitem_102, 0, 2112, 4160); slice_tensor_10 = getitem_102 = None slice_scatter_default_21: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_19, slice_scatter_default_20, 0, 4176, 8352); slice_scatter_default_19 = slice_scatter_default_20 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) slice_25: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_21, 0, 4176, 8352) # No stacktrace found for following nodes slice_tensor_11: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_21, 0, 4176, 8352) slice_scatter_default_22: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_11, getitem_103, 0, 4160, 4176); slice_tensor_11 = getitem_103 = None slice_scatter_default_23: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_21, slice_scatter_default_22, 0, 4176, 8352); slice_scatter_default_21 = slice_scatter_default_22 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( slice_30: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_23, 0, 4176, 8352); slice_scatter_default_23 = None all_gather_into_tensor_2: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_30, 2, '0'); slice_30 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] wait_tensor_2: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), view_19: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]); wait_tensor_2 = None split_with_sizes_26 = torch.ops.aten.split_with_sizes.default(view_19, [2048, 64, 2048, 16], 1); view_19 = None getitem_124: "f32[2, 2048]" = split_with_sizes_26[0] clone_8: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_124, memory_format = torch.contiguous_format); getitem_124 = None view_20: "f32[4096]" = torch.ops.aten.reshape.default(clone_8, [4096]); clone_8 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_8: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_20, [128, 32], [32, 1], 0); view_20 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_129: "f32[2, 64]" = split_with_sizes_26[1] clone_9: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_129, memory_format = torch.contiguous_format); getitem_129 = None view_22: "f32[128]" = torch.ops.aten.reshape.default(clone_9, [128]); clone_9 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_9: "f32[128]" = torch.ops.aten.as_strided.default(view_22, [128], [1], 0); view_22 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_134: "f32[2, 2048]" = split_with_sizes_26[2] clone_10: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_134, memory_format = torch.contiguous_format); getitem_134 = None view_24: "f32[4096]" = torch.ops.aten.reshape.default(clone_10, [4096]); clone_10 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_10: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_24, [32, 128], [128, 1], 0); view_24 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_139: "f32[2, 16]" = split_with_sizes_26[3]; split_with_sizes_26 = None clone_11: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_139, memory_format = torch.contiguous_format); getitem_139 = None view_26: "f32[32]" = torch.ops.aten.reshape.default(clone_11, [32]); clone_11 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_11: "f32[32]" = torch.ops.aten.as_strided.default(view_26, [32], [1], 0); view_26 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) _foreach_copy_5 = torch.ops.aten._foreach_copy.default([primals_23, primals_24, primals_25, primals_26], [as_strided_8, as_strided_9, as_strided_10, as_strided_11]); primals_24 = primals_26 = as_strided_8 = as_strided_9 = as_strided_10 = as_strided_11 = None getitem_140: "f32[128, 32]" = _foreach_copy_5[0] getitem_141: "f32[128]" = _foreach_copy_5[1] getitem_142: "f32[32, 128]" = _foreach_copy_5[2] getitem_143: "f32[32]" = _foreach_copy_5[3]; _foreach_copy_5 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_9: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_140, [1, 0]); getitem_140 = None # No stacktrace found for following nodes mm_default_1: "f32[8, 128]" = torch.ops.aten.mm.default(relu_3, permute_9); permute_9 = None add_tensor_1: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_1, getitem_141); mm_default_1 = getitem_141 = None # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z) relu_4: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_1); add_tensor_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_11: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_142, [1, 0]); getitem_142 = None # No stacktrace found for following nodes mm_default: "f32[8, 32]" = torch.ops.aten.mm.default(relu_4, permute_11); permute_11 = None add_tensor: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default, getitem_143); mm_default = getitem_143 = None # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z) relu_5: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor); add_tensor = None le: "b8[8, 32]" = torch.ops.aten.le.Scalar(relu_5, 0) return [relu_5, primals_1, primals_8, primals_15, primals_17, primals_23, primals_25, relu, relu_1, relu_2, relu_3, relu_4, le]