Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save yf225/b1063249168de00d0a1f9e4a563607f1 to your computer and use it in GitHub Desktop.

Select an option

Save yf225/b1063249168de00d0a1f9e4a563607f1 to your computer and use it in GitHub Desktop.

Revisions

  1. yf225 revised this gist Mar 27, 2024. 1 changed file with 275 additions and 334 deletions.
    609 changes: 275 additions & 334 deletions ppfsdp_multi_group_fwd_graph_no_fsdp_fx_passes.txt
    Original file line number Diff line number Diff line change
    @@ -1,389 +1,330 @@
    TRACED GRAPH
    ===== AFTER POST GRAD =====
    /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[1024]", primals_3: "f32[32]", primals_4: "f32[64, 32]", primals_5: "f32[64]", primals_6, primals_7: "f32[4096]", primals_8: "f32[64]", primals_9: "f32[128, 64]", primals_10: "f32[128]", primals_11: "f32[16384]", primals_12: "f32[128]", primals_13: "f32[256, 128]", primals_14: "f32[256]"):
    def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[2048]", primals_3: "f32[64]", primals_4: "f32[2048]", primals_5: "f32[16]", primals_6: "f32[128, 32]", primals_7: "f32[128]", primals_8: "f32[32, 128]", primals_9: "f32[32]", primals_10, primals_11: "f32[2048]", primals_12: "f32[64]", primals_13: "f32[2048]", primals_14: "f32[16]", primals_15: "f32[128, 32]", primals_16: "f32[128]", primals_17: "f32[32, 128]", primals_18: "f32[32]", primals_19: "f32[2048]", primals_20: "f32[64]", primals_21: "f32[2048]", primals_22: "f32[16]", primals_23: "f32[128, 32]", primals_24: "f32[128]", primals_25: "f32[32, 128]", primals_26: "f32[32]"):
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
    empty: "f32[2112]" = torch.ops.aten.empty.memory_format([2112], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False)
    empty: "f32[8352]" = torch.ops.aten.empty.memory_format([8352], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
    slice_1: "f32[1056]" = torch.ops.aten.slice.Tensor(empty, 0, 0, 1056)
    slice_1: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
    split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [1024, 32]); slice_1 = None
    getitem: "f32[1024]" = split_with_sizes[0]
    getitem_1: "f32[32]" = split_with_sizes[1]; split_with_sizes = None

    split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [2048, 64, 2048, 16]); slice_1 = None
    getitem: "f32[2048]" = split_with_sizes[0]
    getitem_1: "f32[64]" = split_with_sizes[1]
    getitem_2: "f32[2048]" = split_with_sizes[2]
    getitem_3: "f32[16]" = split_with_sizes[3]; split_with_sizes = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    _foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1], [primals_2, primals_3]); getitem = getitem_1 = primals_2 = primals_3 = None
    getitem_2: "f32[1024]" = _foreach_copy[0]
    getitem_3: "f32[32]" = _foreach_copy[1]; _foreach_copy = None

    _foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_2, primals_3, primals_4, primals_5]); primals_2 = primals_3 = primals_4 = primals_5 = None
    getitem_4: "f32[2048]" = _foreach_copy[0]
    getitem_5: "f32[64]" = _foreach_copy[1]
    getitem_6: "f32[2048]" = _foreach_copy[2]
    getitem_7: "f32[16]" = _foreach_copy[3]; _foreach_copy = None

    # No stacktrace found for following nodes
    slice_tensor: "f32[1056]" = torch.ops.aten.slice.Tensor(empty, 0, 0, 1056)
    slice_scatter_default: "f32[1056]" = torch.ops.aten.slice_scatter.default(slice_tensor, getitem_2, 0, 0, 1024); slice_tensor = getitem_2 = None
    slice_scatter_default_1: "f32[2112]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default, 0, 0, 1056); empty = slice_scatter_default = None
    slice_tensor: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352)
    slice_scatter_default: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor, getitem_4, 0, 0, 2048); slice_tensor = getitem_4 = None
    slice_scatter_default_1: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default, 0, 4176, 8352); slice_scatter_default = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_3: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 0, 1056)
    slice_3: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_1: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 0, 1056)
    slice_scatter_default_2: "f32[1056]" = torch.ops.aten.slice_scatter.default(slice_tensor_1, getitem_3, 0, 1024, 1056); slice_tensor_1 = getitem_3 = None
    slice_scatter_default_3: "f32[2112]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_1, slice_scatter_default_2, 0, 0, 1056); slice_scatter_default_1 = slice_scatter_default_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
    slice_6: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 0, 1056); slice_scatter_default_3 = None
    all_gather_into_tensor: "f32[2112]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_6, 2, '0'); slice_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
    wait_tensor: "f32[2112]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_1: "f32[2, 1056]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]); wait_tensor = None
    split_with_sizes_4 = torch.ops.aten.split_with_sizes.default(view_1, [1024, 32], 1); view_1 = None
    getitem_10: "f32[2, 1024]" = split_with_sizes_4[0]
    clone: "f32[2, 1024]" = torch.ops.aten.clone.default(getitem_10, memory_format = torch.contiguous_format); getitem_10 = None
    view_2: "f32[2048]" = torch.ops.aten.reshape.default(clone, [2048]); clone = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided: "f32[64, 32]" = torch.ops.aten.as_strided.default(view_2, [64, 32], [32, 1], 0); view_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_13: "f32[2, 32]" = split_with_sizes_4[1]; split_with_sizes_4 = None
    clone_1: "f32[2, 32]" = torch.ops.aten.clone.default(getitem_13, memory_format = torch.contiguous_format); getitem_13 = None
    view_4: "f32[64]" = torch.ops.aten.reshape.default(clone_1, [64]); clone_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_1: "f32[64]" = torch.ops.aten.as_strided.default(view_4, [64], [1], 0); view_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_1 = torch.ops.aten._foreach_copy.default([primals_4, primals_5], [as_strided, as_strided_1]); primals_4 = primals_5 = as_strided = as_strided_1 = None
    getitem_14: "f32[64, 32]" = _foreach_copy_1[0]
    getitem_15: "f32[64]" = _foreach_copy_1[1]; _foreach_copy_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_1: "f32[32, 64]" = torch.ops.aten.permute.default(getitem_14, [1, 0]); getitem_14 = None
    addmm: "f32[8, 64]" = torch.ops.aten.addmm.default(getitem_15, primals_1, permute_1); getitem_15 = permute_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
    empty_1: "f32[8320]" = torch.ops.aten.empty.memory_format([8320], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
    slice_7: "f32[4160]" = torch.ops.aten.slice.Tensor(empty_1, 0, 0, 4160)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
    split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(slice_7, [4096, 64]); slice_7 = None
    getitem_16: "f32[4096]" = split_with_sizes_6[0]
    getitem_17: "f32[64]" = split_with_sizes_6[1]; split_with_sizes_6 = None

    slice_tensor_1: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 4176, 8352)
    slice_scatter_default_2: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_1, getitem_5, 0, 2048, 2112); slice_tensor_1 = getitem_5 = None
    slice_scatter_default_3: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_1, slice_scatter_default_2, 0, 4176, 8352); slice_scatter_default_1 = slice_scatter_default_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    _foreach_copy_2 = torch.ops.aten._foreach_copy.default([getitem_16, getitem_17], [primals_7, primals_8]); getitem_16 = getitem_17 = primals_7 = primals_8 = None
    getitem_18: "f32[4096]" = _foreach_copy_2[0]
    getitem_19: "f32[64]" = _foreach_copy_2[1]; _foreach_copy_2 = None

    slice_4: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_2: "f32[4160]" = torch.ops.aten.slice.Tensor(empty_1, 0, 0, 4160)
    slice_scatter_default_4: "f32[4160]" = torch.ops.aten.slice_scatter.default(slice_tensor_2, getitem_18, 0, 0, 4096); slice_tensor_2 = getitem_18 = None
    slice_scatter_default_5: "f32[8320]" = torch.ops.aten.slice_scatter.default(empty_1, slice_scatter_default_4, 0, 0, 4160); empty_1 = slice_scatter_default_4 = None
    slice_tensor_2: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 4176, 8352)
    slice_scatter_default_4: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_2, getitem_6, 0, 2112, 4160); slice_tensor_2 = getitem_6 = None
    slice_scatter_default_5: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_3, slice_scatter_default_4, 0, 4176, 8352); slice_scatter_default_3 = slice_scatter_default_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_9: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 0, 4160)
    slice_5: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_3: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 0, 4160)
    slice_scatter_default_6: "f32[4160]" = torch.ops.aten.slice_scatter.default(slice_tensor_3, getitem_19, 0, 4096, 4160); slice_tensor_3 = getitem_19 = None
    slice_scatter_default_7: "f32[8320]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_5, slice_scatter_default_6, 0, 0, 4160); slice_scatter_default_5 = slice_scatter_default_6 = None
    slice_tensor_3: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4176, 8352)
    slice_scatter_default_6: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_3, getitem_7, 0, 4160, 4176); slice_tensor_3 = getitem_7 = None
    slice_scatter_default_7: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_5, slice_scatter_default_6, 0, 4176, 8352); slice_scatter_default_5 = slice_scatter_default_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
    slice_12: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_7, 0, 0, 4160); slice_scatter_default_7 = None
    all_gather_into_tensor_1: "f32[8320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_12, 2, '0'); slice_12 = None
    slice_10: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_7, 0, 4176, 8352); slice_scatter_default_7 = None
    all_gather_into_tensor: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_10, 2, '0'); slice_10 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
    wait_tensor_1: "f32[8320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None
    wait_tensor: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_6: "f32[2, 4160]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]); wait_tensor_1 = None
    split_with_sizes_10 = torch.ops.aten.split_with_sizes.default(view_6, [4096, 64], 1); view_6 = None
    getitem_26: "f32[2, 4096]" = split_with_sizes_10[0]
    clone_2: "f32[2, 4096]" = torch.ops.aten.clone.default(getitem_26, memory_format = torch.contiguous_format); getitem_26 = None
    view_7: "f32[8192]" = torch.ops.aten.reshape.default(clone_2, [8192]); clone_2 = None
    view_1: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]); wait_tensor = None
    split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_1, [2048, 64, 2048, 16], 1); view_1 = None
    getitem_28: "f32[2, 2048]" = split_with_sizes_6[0]
    clone: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_28, memory_format = torch.contiguous_format); getitem_28 = None
    view_2: "f32[4096]" = torch.ops.aten.reshape.default(clone, [4096]); clone = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_2: "f32[128, 64]" = torch.ops.aten.as_strided.default(view_7, [128, 64], [64, 1], 0); view_7 = None
    as_strided: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_2, [128, 32], [32, 1], 0); view_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_29: "f32[2, 64]" = split_with_sizes_10[1]; split_with_sizes_10 = None
    clone_3: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_29, memory_format = torch.contiguous_format); getitem_29 = None
    view_9: "f32[128]" = torch.ops.aten.reshape.default(clone_3, [128]); clone_3 = None
    getitem_33: "f32[2, 64]" = split_with_sizes_6[1]
    clone_1: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_33, memory_format = torch.contiguous_format); getitem_33 = None
    view_4: "f32[128]" = torch.ops.aten.reshape.default(clone_1, [128]); clone_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_3: "f32[128]" = torch.ops.aten.as_strided.default(view_9, [128], [1], 0); view_9 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_3 = torch.ops.aten._foreach_copy.default([primals_9, primals_10], [as_strided_2, as_strided_3]); primals_10 = as_strided_2 = as_strided_3 = None
    getitem_30: "f32[128, 64]" = _foreach_copy_3[0]
    getitem_31: "f32[128]" = _foreach_copy_3[1]; _foreach_copy_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_3: "f32[64, 128]" = torch.ops.aten.permute.default(getitem_30, [1, 0]); getitem_30 = None
    addmm_1: "f32[8, 128]" = torch.ops.aten.addmm.default(getitem_31, addmm, permute_3); getitem_31 = permute_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
    empty_2: "f32[33024]" = torch.ops.aten.empty.memory_format([33024], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
    slice_13: "f32[16512]" = torch.ops.aten.slice.Tensor(empty_2, 0, 0, 16512)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
    split_with_sizes_12 = torch.ops.aten.split_with_sizes.default(slice_13, [16384, 128]); slice_13 = None
    getitem_32: "f32[16384]" = split_with_sizes_12[0]
    getitem_33: "f32[128]" = split_with_sizes_12[1]; split_with_sizes_12 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    _foreach_copy_4 = torch.ops.aten._foreach_copy.default([getitem_32, getitem_33], [primals_11, primals_12]); getitem_32 = getitem_33 = primals_11 = primals_12 = None
    getitem_34: "f32[16384]" = _foreach_copy_4[0]
    getitem_35: "f32[128]" = _foreach_copy_4[1]; _foreach_copy_4 = None

    # No stacktrace found for following nodes
    slice_tensor_4: "f32[16512]" = torch.ops.aten.slice.Tensor(empty_2, 0, 0, 16512)
    slice_scatter_default_8: "f32[16512]" = torch.ops.aten.slice_scatter.default(slice_tensor_4, getitem_34, 0, 0, 16384); slice_tensor_4 = getitem_34 = None
    slice_scatter_default_9: "f32[33024]" = torch.ops.aten.slice_scatter.default(empty_2, slice_scatter_default_8, 0, 0, 16512); empty_2 = slice_scatter_default_8 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_15: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 0, 16512)

    # No stacktrace found for following nodes
    slice_tensor_5: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 0, 16512)
    slice_scatter_default_10: "f32[16512]" = torch.ops.aten.slice_scatter.default(slice_tensor_5, getitem_35, 0, 16384, 16512); slice_tensor_5 = getitem_35 = None
    slice_scatter_default_11: "f32[33024]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_9, slice_scatter_default_10, 0, 0, 16512); slice_scatter_default_9 = slice_scatter_default_10 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
    slice_18: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 0, 16512); slice_scatter_default_11 = None
    all_gather_into_tensor_2: "f32[33024]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_18, 2, '0'); slice_18 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
    wait_tensor_2: "f32[33024]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None

    as_strided_1: "f32[128]" = torch.ops.aten.as_strided.default(view_4, [128], [1], 0); view_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_11: "f32[2, 16512]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]); wait_tensor_2 = None
    split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_11, [16384, 128], 1); view_11 = None
    getitem_42: "f32[2, 16384]" = split_with_sizes_16[0]
    clone_4: "f32[2, 16384]" = torch.ops.aten.clone.default(getitem_42, memory_format = torch.contiguous_format); getitem_42 = None
    view_12: "f32[32768]" = torch.ops.aten.reshape.default(clone_4, [32768]); clone_4 = None

    getitem_38: "f32[2, 2048]" = split_with_sizes_6[2]
    clone_2: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_38, memory_format = torch.contiguous_format); getitem_38 = None
    view_6: "f32[4096]" = torch.ops.aten.reshape.default(clone_2, [4096]); clone_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_4: "f32[256, 128]" = torch.ops.aten.as_strided.default(view_12, [256, 128], [128, 1], 0); view_12 = None
    as_strided_2: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_6, [32, 128], [128, 1], 0); view_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_45: "f32[2, 128]" = split_with_sizes_16[1]; split_with_sizes_16 = None
    clone_5: "f32[2, 128]" = torch.ops.aten.clone.default(getitem_45, memory_format = torch.contiguous_format); getitem_45 = None
    view_14: "f32[256]" = torch.ops.aten.reshape.default(clone_5, [256]); clone_5 = None
    getitem_43: "f32[2, 16]" = split_with_sizes_6[3]; split_with_sizes_6 = None
    clone_3: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_43, memory_format = torch.contiguous_format); getitem_43 = None
    view_8: "f32[32]" = torch.ops.aten.reshape.default(clone_3, [32]); clone_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_5: "f32[256]" = torch.ops.aten.as_strided.default(view_14, [256], [1], 0); view_14 = None
    as_strided_3: "f32[32]" = torch.ops.aten.as_strided.default(view_8, [32], [1], 0); view_8 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_5 = torch.ops.aten._foreach_copy.default([primals_13, primals_14], [as_strided_4, as_strided_5]); primals_14 = as_strided_4 = as_strided_5 = None
    getitem_46: "f32[256, 128]" = _foreach_copy_5[0]
    getitem_47: "f32[256]" = _foreach_copy_5[1]; _foreach_copy_5 = None

    _foreach_copy_1 = torch.ops.aten._foreach_copy.default([primals_6, primals_7, primals_8, primals_9], [as_strided, as_strided_1, as_strided_2, as_strided_3]); primals_6 = primals_7 = primals_9 = as_strided = as_strided_1 = as_strided_2 = as_strided_3 = None
    getitem_44: "f32[128, 32]" = _foreach_copy_1[0]
    getitem_45: "f32[128]" = _foreach_copy_1[1]
    getitem_46: "f32[32, 128]" = _foreach_copy_1[2]
    getitem_47: "f32[32]" = _foreach_copy_1[3]; _foreach_copy_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_5: "f32[128, 256]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None
    addmm_2: "f32[8, 256]" = torch.ops.aten.addmm.default(getitem_47, addmm_1, permute_5); getitem_47 = permute_5 = None
    return [addmm_2, primals_1, primals_9, primals_13, addmm, addmm_1]

    permute_1: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_44, [1, 0]); getitem_44 = None

    # No stacktrace found for following nodes
    mm_default_5: "f32[8, 128]" = torch.ops.aten.mm.default(primals_1, permute_1); permute_1 = None
    add_tensor_5: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_5, getitem_45); mm_default_5 = getitem_45 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z)
    relu: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_5); add_tensor_5 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_3: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None

    # No stacktrace found for following nodes
    mm_default_4: "f32[8, 32]" = torch.ops.aten.mm.default(relu, permute_3); permute_3 = None
    add_tensor_4: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default_4, getitem_47); mm_default_4 = getitem_47 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z)
    relu_1: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor_4); add_tensor_4 = None

    TRACED GRAPH
    ===== AFTER POST GRAD =====
    /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[1024]", primals_3: "f32[32]", primals_4: "f32[64, 32]", primals_5: "f32[64]", primals_6, primals_7: "f32[4096]", primals_8: "f32[64]", primals_9: "f32[128, 64]", primals_10: "f32[128]", primals_11: "f32[16384]", primals_12: "f32[128]", primals_13: "f32[256, 128]", primals_14: "f32[256]"):
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
    empty: "f32[2112]" = torch.ops.aten.empty.memory_format([2112], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
    slice_1: "f32[1056]" = torch.ops.aten.slice.Tensor(empty, 0, 1056, 2112)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
    split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [1024, 32]); slice_1 = None
    getitem: "f32[1024]" = split_with_sizes[0]
    getitem_1: "f32[32]" = split_with_sizes[1]; split_with_sizes = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    _foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1], [primals_2, primals_3]); getitem = getitem_1 = primals_2 = primals_3 = None
    getitem_2: "f32[1024]" = _foreach_copy[0]
    getitem_3: "f32[32]" = _foreach_copy[1]; _foreach_copy = None

    _foreach_copy_2 = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_11, primals_12, primals_13, primals_14]); primals_11 = primals_12 = primals_13 = primals_14 = None
    getitem_52: "f32[2048]" = _foreach_copy_2[0]
    getitem_53: "f32[64]" = _foreach_copy_2[1]
    getitem_54: "f32[2048]" = _foreach_copy_2[2]
    getitem_55: "f32[16]" = _foreach_copy_2[3]; _foreach_copy_2 = None

    # No stacktrace found for following nodes
    slice_tensor: "f32[1056]" = torch.ops.aten.slice.Tensor(empty, 0, 1056, 2112)
    slice_scatter_default: "f32[1056]" = torch.ops.aten.slice_scatter.default(slice_tensor, getitem_2, 0, 0, 1024); slice_tensor = getitem_2 = None
    slice_scatter_default_1: "f32[2112]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default, 0, 1056, 2112); empty = slice_scatter_default = None
    slice_tensor_4: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352)
    slice_scatter_default_8: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_4, getitem_52, 0, 0, 2048); slice_tensor_4 = getitem_52 = None
    slice_scatter_default_9: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default_8, 0, 4176, 8352); slice_scatter_default_8 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_3: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 1056, 2112)
    slice_13: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_1: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 1056, 2112)
    slice_scatter_default_2: "f32[1056]" = torch.ops.aten.slice_scatter.default(slice_tensor_1, getitem_3, 0, 1024, 1056); slice_tensor_1 = getitem_3 = None
    slice_scatter_default_3: "f32[2112]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_1, slice_scatter_default_2, 0, 1056, 2112); slice_scatter_default_1 = slice_scatter_default_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
    slice_6: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 1056, 2112); slice_scatter_default_3 = None
    all_gather_into_tensor: "f32[2112]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_6, 2, '0'); slice_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
    wait_tensor: "f32[2112]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_1: "f32[2, 1056]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]); wait_tensor = None
    split_with_sizes_4 = torch.ops.aten.split_with_sizes.default(view_1, [1024, 32], 1); view_1 = None
    getitem_10: "f32[2, 1024]" = split_with_sizes_4[0]
    clone: "f32[2, 1024]" = torch.ops.aten.clone.default(getitem_10, memory_format = torch.contiguous_format); getitem_10 = None
    view_2: "f32[2048]" = torch.ops.aten.reshape.default(clone, [2048]); clone = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided: "f32[64, 32]" = torch.ops.aten.as_strided.default(view_2, [64, 32], [32, 1], 0); view_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_13: "f32[2, 32]" = split_with_sizes_4[1]; split_with_sizes_4 = None
    clone_1: "f32[2, 32]" = torch.ops.aten.clone.default(getitem_13, memory_format = torch.contiguous_format); getitem_13 = None
    view_4: "f32[64]" = torch.ops.aten.reshape.default(clone_1, [64]); clone_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_1: "f32[64]" = torch.ops.aten.as_strided.default(view_4, [64], [1], 0); view_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_1 = torch.ops.aten._foreach_copy.default([primals_4, primals_5], [as_strided, as_strided_1]); primals_4 = primals_5 = as_strided = as_strided_1 = None
    getitem_14: "f32[64, 32]" = _foreach_copy_1[0]
    getitem_15: "f32[64]" = _foreach_copy_1[1]; _foreach_copy_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_1: "f32[32, 64]" = torch.ops.aten.permute.default(getitem_14, [1, 0]); getitem_14 = None
    addmm: "f32[8, 64]" = torch.ops.aten.addmm.default(getitem_15, primals_1, permute_1); getitem_15 = permute_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
    empty_1: "f32[8320]" = torch.ops.aten.empty.memory_format([8320], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
    slice_7: "f32[4160]" = torch.ops.aten.slice.Tensor(empty_1, 0, 4160, 8320)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
    split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(slice_7, [4096, 64]); slice_7 = None
    getitem_16: "f32[4096]" = split_with_sizes_6[0]
    getitem_17: "f32[64]" = split_with_sizes_6[1]; split_with_sizes_6 = None

    slice_tensor_5: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 4176, 8352)
    slice_scatter_default_10: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_5, getitem_53, 0, 2048, 2112); slice_tensor_5 = getitem_53 = None
    slice_scatter_default_11: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_9, slice_scatter_default_10, 0, 4176, 8352); slice_scatter_default_9 = slice_scatter_default_10 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    _foreach_copy_2 = torch.ops.aten._foreach_copy.default([getitem_16, getitem_17], [primals_7, primals_8]); getitem_16 = getitem_17 = primals_7 = primals_8 = None
    getitem_18: "f32[4096]" = _foreach_copy_2[0]
    getitem_19: "f32[64]" = _foreach_copy_2[1]; _foreach_copy_2 = None

    slice_14: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_2: "f32[4160]" = torch.ops.aten.slice.Tensor(empty_1, 0, 4160, 8320)
    slice_scatter_default_4: "f32[4160]" = torch.ops.aten.slice_scatter.default(slice_tensor_2, getitem_18, 0, 0, 4096); slice_tensor_2 = getitem_18 = None
    slice_scatter_default_5: "f32[8320]" = torch.ops.aten.slice_scatter.default(empty_1, slice_scatter_default_4, 0, 4160, 8320); empty_1 = slice_scatter_default_4 = None
    slice_tensor_6: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 4176, 8352)
    slice_scatter_default_12: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_6, getitem_54, 0, 2112, 4160); slice_tensor_6 = getitem_54 = None
    slice_scatter_default_13: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_11, slice_scatter_default_12, 0, 4176, 8352); slice_scatter_default_11 = slice_scatter_default_12 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_9: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4160, 8320)
    slice_15: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_13, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_3: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4160, 8320)
    slice_scatter_default_6: "f32[4160]" = torch.ops.aten.slice_scatter.default(slice_tensor_3, getitem_19, 0, 4096, 4160); slice_tensor_3 = getitem_19 = None
    slice_scatter_default_7: "f32[8320]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_5, slice_scatter_default_6, 0, 4160, 8320); slice_scatter_default_5 = slice_scatter_default_6 = None
    slice_tensor_7: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_13, 0, 4176, 8352)
    slice_scatter_default_14: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_7, getitem_55, 0, 4160, 4176); slice_tensor_7 = getitem_55 = None
    slice_scatter_default_15: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_13, slice_scatter_default_14, 0, 4176, 8352); slice_scatter_default_13 = slice_scatter_default_14 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
    slice_12: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_7, 0, 4160, 8320); slice_scatter_default_7 = None
    all_gather_into_tensor_1: "f32[8320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_12, 2, '0'); slice_12 = None
    slice_20: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_15, 0, 4176, 8352); slice_scatter_default_15 = None
    all_gather_into_tensor_1: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_20, 2, '0'); slice_20 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
    wait_tensor_1: "f32[8320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None
    wait_tensor_1: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_6: "f32[2, 4160]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]); wait_tensor_1 = None
    split_with_sizes_10 = torch.ops.aten.split_with_sizes.default(view_6, [4096, 64], 1); view_6 = None
    getitem_26: "f32[2, 4096]" = split_with_sizes_10[0]
    clone_2: "f32[2, 4096]" = torch.ops.aten.clone.default(getitem_26, memory_format = torch.contiguous_format); getitem_26 = None
    view_7: "f32[8192]" = torch.ops.aten.reshape.default(clone_2, [8192]); clone_2 = None
    view_10: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]); wait_tensor_1 = None
    split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_10, [2048, 64, 2048, 16], 1); view_10 = None
    getitem_76: "f32[2, 2048]" = split_with_sizes_16[0]
    clone_4: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_76, memory_format = torch.contiguous_format); getitem_76 = None
    view_11: "f32[4096]" = torch.ops.aten.reshape.default(clone_4, [4096]); clone_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_2: "f32[128, 64]" = torch.ops.aten.as_strided.default(view_7, [128, 64], [64, 1], 0); view_7 = None
    as_strided_4: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_11, [128, 32], [32, 1], 0); view_11 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_29: "f32[2, 64]" = split_with_sizes_10[1]; split_with_sizes_10 = None
    clone_3: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_29, memory_format = torch.contiguous_format); getitem_29 = None
    view_9: "f32[128]" = torch.ops.aten.reshape.default(clone_3, [128]); clone_3 = None
    getitem_81: "f32[2, 64]" = split_with_sizes_16[1]
    clone_5: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_81, memory_format = torch.contiguous_format); getitem_81 = None
    view_13: "f32[128]" = torch.ops.aten.reshape.default(clone_5, [128]); clone_5 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_3: "f32[128]" = torch.ops.aten.as_strided.default(view_9, [128], [1], 0); view_9 = None

    as_strided_5: "f32[128]" = torch.ops.aten.as_strided.default(view_13, [128], [1], 0); view_13 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_86: "f32[2, 2048]" = split_with_sizes_16[2]
    clone_6: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_86, memory_format = torch.contiguous_format); getitem_86 = None
    view_15: "f32[4096]" = torch.ops.aten.reshape.default(clone_6, [4096]); clone_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_6: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_15, [32, 128], [128, 1], 0); view_15 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_91: "f32[2, 16]" = split_with_sizes_16[3]; split_with_sizes_16 = None
    clone_7: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_91, memory_format = torch.contiguous_format); getitem_91 = None
    view_17: "f32[32]" = torch.ops.aten.reshape.default(clone_7, [32]); clone_7 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_7: "f32[32]" = torch.ops.aten.as_strided.default(view_17, [32], [1], 0); view_17 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_3 = torch.ops.aten._foreach_copy.default([primals_9, primals_10], [as_strided_2, as_strided_3]); primals_10 = as_strided_2 = as_strided_3 = None
    getitem_30: "f32[128, 64]" = _foreach_copy_3[0]
    getitem_31: "f32[128]" = _foreach_copy_3[1]; _foreach_copy_3 = None

    _foreach_copy_3 = torch.ops.aten._foreach_copy.default([primals_15, primals_16, primals_17, primals_18], [as_strided_4, as_strided_5, as_strided_6, as_strided_7]); primals_16 = primals_18 = as_strided_4 = as_strided_5 = as_strided_6 = as_strided_7 = None
    getitem_92: "f32[128, 32]" = _foreach_copy_3[0]
    getitem_93: "f32[128]" = _foreach_copy_3[1]
    getitem_94: "f32[32, 128]" = _foreach_copy_3[2]
    getitem_95: "f32[32]" = _foreach_copy_3[3]; _foreach_copy_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_3: "f32[64, 128]" = torch.ops.aten.permute.default(getitem_30, [1, 0]); getitem_30 = None
    addmm_1: "f32[8, 128]" = torch.ops.aten.addmm.default(getitem_31, addmm, permute_3); getitem_31 = permute_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
    empty_2: "f32[33024]" = torch.ops.aten.empty.memory_format([33024], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
    slice_13: "f32[16512]" = torch.ops.aten.slice.Tensor(empty_2, 0, 16512, 33024)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
    split_with_sizes_12 = torch.ops.aten.split_with_sizes.default(slice_13, [16384, 128]); slice_13 = None
    getitem_32: "f32[16384]" = split_with_sizes_12[0]
    getitem_33: "f32[128]" = split_with_sizes_12[1]; split_with_sizes_12 = None

    permute_5: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_92, [1, 0]); getitem_92 = None

    # No stacktrace found for following nodes
    mm_default_3: "f32[8, 128]" = torch.ops.aten.mm.default(relu_1, permute_5); permute_5 = None
    add_tensor_3: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_3, getitem_93); mm_default_3 = getitem_93 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z)
    relu_2: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_3); add_tensor_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_7: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_94, [1, 0]); getitem_94 = None

    # No stacktrace found for following nodes
    mm_default_2: "f32[8, 32]" = torch.ops.aten.mm.default(relu_2, permute_7); permute_7 = None
    add_tensor_2: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default_2, getitem_95); mm_default_2 = getitem_95 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z)
    relu_3: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor_2); add_tensor_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    _foreach_copy_4 = torch.ops.aten._foreach_copy.default([getitem_32, getitem_33], [primals_11, primals_12]); getitem_32 = getitem_33 = primals_11 = primals_12 = None
    getitem_34: "f32[16384]" = _foreach_copy_4[0]
    getitem_35: "f32[128]" = _foreach_copy_4[1]; _foreach_copy_4 = None

    _foreach_copy_4 = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_19, primals_20, primals_21, primals_22]); getitem = getitem_1 = getitem_2 = getitem_3 = primals_19 = primals_20 = primals_21 = primals_22 = None
    getitem_100: "f32[2048]" = _foreach_copy_4[0]
    getitem_101: "f32[64]" = _foreach_copy_4[1]
    getitem_102: "f32[2048]" = _foreach_copy_4[2]
    getitem_103: "f32[16]" = _foreach_copy_4[3]; _foreach_copy_4 = None

    # No stacktrace found for following nodes
    slice_tensor_4: "f32[16512]" = torch.ops.aten.slice.Tensor(empty_2, 0, 16512, 33024)
    slice_scatter_default_8: "f32[16512]" = torch.ops.aten.slice_scatter.default(slice_tensor_4, getitem_34, 0, 0, 16384); slice_tensor_4 = getitem_34 = None
    slice_scatter_default_9: "f32[33024]" = torch.ops.aten.slice_scatter.default(empty_2, slice_scatter_default_8, 0, 16512, 33024); empty_2 = slice_scatter_default_8 = None
    slice_tensor_8: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352)
    slice_scatter_default_16: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_8, getitem_100, 0, 0, 2048); slice_tensor_8 = getitem_100 = None
    slice_scatter_default_17: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default_16, 0, 4176, 8352); empty = slice_scatter_default_16 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_15: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 16512, 33024)
    slice_23: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_17, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_5: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 16512, 33024)
    slice_scatter_default_10: "f32[16512]" = torch.ops.aten.slice_scatter.default(slice_tensor_5, getitem_35, 0, 16384, 16512); slice_tensor_5 = getitem_35 = None
    slice_scatter_default_11: "f32[33024]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_9, slice_scatter_default_10, 0, 16512, 33024); slice_scatter_default_9 = slice_scatter_default_10 = None

    slice_tensor_9: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_17, 0, 4176, 8352)
    slice_scatter_default_18: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_9, getitem_101, 0, 2048, 2112); slice_tensor_9 = getitem_101 = None
    slice_scatter_default_19: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_17, slice_scatter_default_18, 0, 4176, 8352); slice_scatter_default_17 = slice_scatter_default_18 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_24: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_19, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_10: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_19, 0, 4176, 8352)
    slice_scatter_default_20: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_10, getitem_102, 0, 2112, 4160); slice_tensor_10 = getitem_102 = None
    slice_scatter_default_21: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_19, slice_scatter_default_20, 0, 4176, 8352); slice_scatter_default_19 = slice_scatter_default_20 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_25: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_21, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_11: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_21, 0, 4176, 8352)
    slice_scatter_default_22: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_11, getitem_103, 0, 4160, 4176); slice_tensor_11 = getitem_103 = None
    slice_scatter_default_23: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_21, slice_scatter_default_22, 0, 4176, 8352); slice_scatter_default_21 = slice_scatter_default_22 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
    slice_18: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 16512, 33024); slice_scatter_default_11 = None
    all_gather_into_tensor_2: "f32[33024]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_18, 2, '0'); slice_18 = None
    slice_30: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_23, 0, 4176, 8352); slice_scatter_default_23 = None
    all_gather_into_tensor_2: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_30, 2, '0'); slice_30 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
    wait_tensor_2: "f32[33024]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None

    wait_tensor_2: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_19: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]); wait_tensor_2 = None
    split_with_sizes_26 = torch.ops.aten.split_with_sizes.default(view_19, [2048, 64, 2048, 16], 1); view_19 = None
    getitem_124: "f32[2, 2048]" = split_with_sizes_26[0]
    clone_8: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_124, memory_format = torch.contiguous_format); getitem_124 = None
    view_20: "f32[4096]" = torch.ops.aten.reshape.default(clone_8, [4096]); clone_8 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_8: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_20, [128, 32], [32, 1], 0); view_20 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_11: "f32[2, 16512]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]); wait_tensor_2 = None
    split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_11, [16384, 128], 1); view_11 = None
    getitem_42: "f32[2, 16384]" = split_with_sizes_16[0]
    clone_4: "f32[2, 16384]" = torch.ops.aten.clone.default(getitem_42, memory_format = torch.contiguous_format); getitem_42 = None
    view_12: "f32[32768]" = torch.ops.aten.reshape.default(clone_4, [32768]); clone_4 = None

    getitem_129: "f32[2, 64]" = split_with_sizes_26[1]
    clone_9: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_129, memory_format = torch.contiguous_format); getitem_129 = None
    view_22: "f32[128]" = torch.ops.aten.reshape.default(clone_9, [128]); clone_9 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_4: "f32[256, 128]" = torch.ops.aten.as_strided.default(view_12, [256, 128], [128, 1], 0); view_12 = None
    as_strided_9: "f32[128]" = torch.ops.aten.as_strided.default(view_22, [128], [1], 0); view_22 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_45: "f32[2, 128]" = split_with_sizes_16[1]; split_with_sizes_16 = None
    clone_5: "f32[2, 128]" = torch.ops.aten.clone.default(getitem_45, memory_format = torch.contiguous_format); getitem_45 = None
    view_14: "f32[256]" = torch.ops.aten.reshape.default(clone_5, [256]); clone_5 = None
    getitem_134: "f32[2, 2048]" = split_with_sizes_26[2]
    clone_10: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_134, memory_format = torch.contiguous_format); getitem_134 = None
    view_24: "f32[4096]" = torch.ops.aten.reshape.default(clone_10, [4096]); clone_10 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_5: "f32[256]" = torch.ops.aten.as_strided.default(view_14, [256], [1], 0); view_14 = None

    as_strided_10: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_24, [32, 128], [128, 1], 0); view_24 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_139: "f32[2, 16]" = split_with_sizes_26[3]; split_with_sizes_26 = None
    clone_11: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_139, memory_format = torch.contiguous_format); getitem_139 = None
    view_26: "f32[32]" = torch.ops.aten.reshape.default(clone_11, [32]); clone_11 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_11: "f32[32]" = torch.ops.aten.as_strided.default(view_26, [32], [1], 0); view_26 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_5 = torch.ops.aten._foreach_copy.default([primals_13, primals_14], [as_strided_4, as_strided_5]); primals_14 = as_strided_4 = as_strided_5 = None
    getitem_46: "f32[256, 128]" = _foreach_copy_5[0]
    getitem_47: "f32[256]" = _foreach_copy_5[1]; _foreach_copy_5 = None

    _foreach_copy_5 = torch.ops.aten._foreach_copy.default([primals_23, primals_24, primals_25, primals_26], [as_strided_8, as_strided_9, as_strided_10, as_strided_11]); primals_24 = primals_26 = as_strided_8 = as_strided_9 = as_strided_10 = as_strided_11 = None
    getitem_140: "f32[128, 32]" = _foreach_copy_5[0]
    getitem_141: "f32[128]" = _foreach_copy_5[1]
    getitem_142: "f32[32, 128]" = _foreach_copy_5[2]
    getitem_143: "f32[32]" = _foreach_copy_5[3]; _foreach_copy_5 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_5: "f32[128, 256]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None
    addmm_2: "f32[8, 256]" = torch.ops.aten.addmm.default(getitem_47, addmm_1, permute_5); getitem_47 = permute_5 = None
    return [addmm_2, primals_1, primals_9, primals_13, addmm, addmm_1]

    permute_9: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_140, [1, 0]); getitem_140 = None

    # No stacktrace found for following nodes
    mm_default_1: "f32[8, 128]" = torch.ops.aten.mm.default(relu_3, permute_9); permute_9 = None
    add_tensor_1: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_1, getitem_141); mm_default_1 = getitem_141 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z)
    relu_4: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_1); add_tensor_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_11: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_142, [1, 0]); getitem_142 = None

    # No stacktrace found for following nodes
    mm_default: "f32[8, 32]" = torch.ops.aten.mm.default(relu_4, permute_11); permute_11 = None
    add_tensor: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default, getitem_143); mm_default = getitem_143 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z)
    relu_5: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor); add_tensor = None
    le: "b8[8, 32]" = torch.ops.aten.le.Scalar(relu_5, 0)
    return [relu_5, primals_1, primals_8, primals_15, primals_17, primals_23, primals_25, relu, relu_1, relu_2, relu_3, relu_4, le]
  2. yf225 revised this gist Mar 27, 2024. 1 changed file with 277 additions and 218 deletions.
    495 changes: 277 additions & 218 deletions ppfsdp_multi_group_fwd_graph_no_fsdp_fx_passes.txt
    Original file line number Diff line number Diff line change
    @@ -1,330 +1,389 @@
    TRACED GRAPH
    ===== AFTER POST GRAD =====
    /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[2048]", primals_3: "f32[64]", primals_4: "f32[2048]", primals_5: "f32[16]", primals_6: "f32[128, 32]", primals_7: "f32[128]", primals_8: "f32[32, 128]", primals_9: "f32[32]", primals_10, primals_11: "f32[2048]", primals_12: "f32[64]", primals_13: "f32[2048]", primals_14: "f32[16]", primals_15: "f32[128, 32]", primals_16: "f32[128]", primals_17: "f32[32, 128]", primals_18: "f32[32]", primals_19: "f32[2048]", primals_20: "f32[64]", primals_21: "f32[2048]", primals_22: "f32[16]", primals_23: "f32[128, 32]", primals_24: "f32[128]", primals_25: "f32[32, 128]", primals_26: "f32[32]"):
    def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[1024]", primals_3: "f32[32]", primals_4: "f32[64, 32]", primals_5: "f32[64]", primals_6, primals_7: "f32[4096]", primals_8: "f32[64]", primals_9: "f32[128, 64]", primals_10: "f32[128]", primals_11: "f32[16384]", primals_12: "f32[128]", primals_13: "f32[256, 128]", primals_14: "f32[256]"):
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
    empty: "f32[8352]" = torch.ops.aten.empty.memory_format([8352], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False)
    empty: "f32[2112]" = torch.ops.aten.empty.memory_format([2112], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
    slice_1: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352)
    slice_1: "f32[1056]" = torch.ops.aten.slice.Tensor(empty, 0, 0, 1056)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
    split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [2048, 64, 2048, 16]); slice_1 = None
    getitem: "f32[2048]" = split_with_sizes[0]
    getitem_1: "f32[64]" = split_with_sizes[1]
    getitem_2: "f32[2048]" = split_with_sizes[2]
    getitem_3: "f32[16]" = split_with_sizes[3]; split_with_sizes = None
    split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [1024, 32]); slice_1 = None
    getitem: "f32[1024]" = split_with_sizes[0]
    getitem_1: "f32[32]" = split_with_sizes[1]; split_with_sizes = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    _foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_2, primals_3, primals_4, primals_5]); primals_2 = primals_3 = primals_4 = primals_5 = None
    getitem_4: "f32[2048]" = _foreach_copy[0]
    getitem_5: "f32[64]" = _foreach_copy[1]
    getitem_6: "f32[2048]" = _foreach_copy[2]
    getitem_7: "f32[16]" = _foreach_copy[3]; _foreach_copy = None
    _foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1], [primals_2, primals_3]); getitem = getitem_1 = primals_2 = primals_3 = None
    getitem_2: "f32[1024]" = _foreach_copy[0]
    getitem_3: "f32[32]" = _foreach_copy[1]; _foreach_copy = None

    # No stacktrace found for following nodes
    slice_tensor: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352)
    slice_scatter_default: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor, getitem_4, 0, 0, 2048); slice_tensor = getitem_4 = None
    slice_scatter_default_1: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default, 0, 4176, 8352); slice_scatter_default = None
    slice_tensor: "f32[1056]" = torch.ops.aten.slice.Tensor(empty, 0, 0, 1056)
    slice_scatter_default: "f32[1056]" = torch.ops.aten.slice_scatter.default(slice_tensor, getitem_2, 0, 0, 1024); slice_tensor = getitem_2 = None
    slice_scatter_default_1: "f32[2112]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default, 0, 0, 1056); empty = slice_scatter_default = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_3: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 4176, 8352)
    slice_3: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 0, 1056)

    # No stacktrace found for following nodes
    slice_tensor_1: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 4176, 8352)
    slice_scatter_default_2: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_1, getitem_5, 0, 2048, 2112); slice_tensor_1 = getitem_5 = None
    slice_scatter_default_3: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_1, slice_scatter_default_2, 0, 4176, 8352); slice_scatter_default_1 = slice_scatter_default_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_4: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_2: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 4176, 8352)
    slice_scatter_default_4: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_2, getitem_6, 0, 2112, 4160); slice_tensor_2 = getitem_6 = None
    slice_scatter_default_5: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_3, slice_scatter_default_4, 0, 4176, 8352); slice_scatter_default_3 = slice_scatter_default_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_5: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_3: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4176, 8352)
    slice_scatter_default_6: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_3, getitem_7, 0, 4160, 4176); slice_tensor_3 = getitem_7 = None
    slice_scatter_default_7: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_5, slice_scatter_default_6, 0, 4176, 8352); slice_scatter_default_5 = slice_scatter_default_6 = None
    slice_tensor_1: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 0, 1056)
    slice_scatter_default_2: "f32[1056]" = torch.ops.aten.slice_scatter.default(slice_tensor_1, getitem_3, 0, 1024, 1056); slice_tensor_1 = getitem_3 = None
    slice_scatter_default_3: "f32[2112]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_1, slice_scatter_default_2, 0, 0, 1056); slice_scatter_default_1 = slice_scatter_default_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
    slice_10: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_7, 0, 4176, 8352); slice_scatter_default_7 = None
    all_gather_into_tensor: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_10, 2, '0'); slice_10 = None
    slice_6: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 0, 1056); slice_scatter_default_3 = None
    all_gather_into_tensor: "f32[2112]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_6, 2, '0'); slice_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
    wait_tensor: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None
    wait_tensor: "f32[2112]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_1: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]); wait_tensor = None
    split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_1, [2048, 64, 2048, 16], 1); view_1 = None
    getitem_28: "f32[2, 2048]" = split_with_sizes_6[0]
    clone: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_28, memory_format = torch.contiguous_format); getitem_28 = None
    view_2: "f32[4096]" = torch.ops.aten.reshape.default(clone, [4096]); clone = None
    view_1: "f32[2, 1056]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]); wait_tensor = None
    split_with_sizes_4 = torch.ops.aten.split_with_sizes.default(view_1, [1024, 32], 1); view_1 = None
    getitem_10: "f32[2, 1024]" = split_with_sizes_4[0]
    clone: "f32[2, 1024]" = torch.ops.aten.clone.default(getitem_10, memory_format = torch.contiguous_format); getitem_10 = None
    view_2: "f32[2048]" = torch.ops.aten.reshape.default(clone, [2048]); clone = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_2, [128, 32], [32, 1], 0); view_2 = None
    as_strided: "f32[64, 32]" = torch.ops.aten.as_strided.default(view_2, [64, 32], [32, 1], 0); view_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_33: "f32[2, 64]" = split_with_sizes_6[1]
    clone_1: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_33, memory_format = torch.contiguous_format); getitem_33 = None
    view_4: "f32[128]" = torch.ops.aten.reshape.default(clone_1, [128]); clone_1 = None
    getitem_13: "f32[2, 32]" = split_with_sizes_4[1]; split_with_sizes_4 = None
    clone_1: "f32[2, 32]" = torch.ops.aten.clone.default(getitem_13, memory_format = torch.contiguous_format); getitem_13 = None
    view_4: "f32[64]" = torch.ops.aten.reshape.default(clone_1, [64]); clone_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_1: "f32[128]" = torch.ops.aten.as_strided.default(view_4, [128], [1], 0); view_4 = None
    as_strided_1: "f32[64]" = torch.ops.aten.as_strided.default(view_4, [64], [1], 0); view_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_38: "f32[2, 2048]" = split_with_sizes_6[2]
    clone_2: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_38, memory_format = torch.contiguous_format); getitem_38 = None
    view_6: "f32[4096]" = torch.ops.aten.reshape.default(clone_2, [4096]); clone_2 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_1 = torch.ops.aten._foreach_copy.default([primals_4, primals_5], [as_strided, as_strided_1]); primals_4 = primals_5 = as_strided = as_strided_1 = None
    getitem_14: "f32[64, 32]" = _foreach_copy_1[0]
    getitem_15: "f32[64]" = _foreach_copy_1[1]; _foreach_copy_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_2: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_6, [32, 128], [128, 1], 0); view_6 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_1: "f32[32, 64]" = torch.ops.aten.permute.default(getitem_14, [1, 0]); getitem_14 = None
    addmm: "f32[8, 64]" = torch.ops.aten.addmm.default(getitem_15, primals_1, permute_1); getitem_15 = permute_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_43: "f32[2, 16]" = split_with_sizes_6[3]; split_with_sizes_6 = None
    clone_3: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_43, memory_format = torch.contiguous_format); getitem_43 = None
    view_8: "f32[32]" = torch.ops.aten.reshape.default(clone_3, [32]); clone_3 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
    empty_1: "f32[8320]" = torch.ops.aten.empty.memory_format([8320], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_3: "f32[32]" = torch.ops.aten.as_strided.default(view_8, [32], [1], 0); view_8 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
    slice_7: "f32[4160]" = torch.ops.aten.slice.Tensor(empty_1, 0, 0, 4160)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_1 = torch.ops.aten._foreach_copy.default([primals_6, primals_7, primals_8, primals_9], [as_strided, as_strided_1, as_strided_2, as_strided_3]); primals_6 = primals_7 = primals_9 = as_strided = as_strided_1 = as_strided_2 = as_strided_3 = None
    getitem_44: "f32[128, 32]" = _foreach_copy_1[0]
    getitem_45: "f32[128]" = _foreach_copy_1[1]
    getitem_46: "f32[32, 128]" = _foreach_copy_1[2]
    getitem_47: "f32[32]" = _foreach_copy_1[3]; _foreach_copy_1 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
    split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(slice_7, [4096, 64]); slice_7 = None
    getitem_16: "f32[4096]" = split_with_sizes_6[0]
    getitem_17: "f32[64]" = split_with_sizes_6[1]; split_with_sizes_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_1: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_44, [1, 0]); getitem_44 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    _foreach_copy_2 = torch.ops.aten._foreach_copy.default([getitem_16, getitem_17], [primals_7, primals_8]); getitem_16 = getitem_17 = primals_7 = primals_8 = None
    getitem_18: "f32[4096]" = _foreach_copy_2[0]
    getitem_19: "f32[64]" = _foreach_copy_2[1]; _foreach_copy_2 = None

    # No stacktrace found for following nodes
    mm_default_5: "f32[8, 128]" = torch.ops.aten.mm.default(primals_1, permute_1); permute_1 = None
    add_tensor_5: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_5, getitem_45); mm_default_5 = getitem_45 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z)
    relu: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_5); add_tensor_5 = None
    slice_tensor_2: "f32[4160]" = torch.ops.aten.slice.Tensor(empty_1, 0, 0, 4160)
    slice_scatter_default_4: "f32[4160]" = torch.ops.aten.slice_scatter.default(slice_tensor_2, getitem_18, 0, 0, 4096); slice_tensor_2 = getitem_18 = None
    slice_scatter_default_5: "f32[8320]" = torch.ops.aten.slice_scatter.default(empty_1, slice_scatter_default_4, 0, 0, 4160); empty_1 = slice_scatter_default_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_3: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_9: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 0, 4160)

    # No stacktrace found for following nodes
    mm_default_4: "f32[8, 32]" = torch.ops.aten.mm.default(relu, permute_3); permute_3 = None
    add_tensor_4: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default_4, getitem_47); mm_default_4 = getitem_47 = None
    slice_tensor_3: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 0, 4160)
    slice_scatter_default_6: "f32[4160]" = torch.ops.aten.slice_scatter.default(slice_tensor_3, getitem_19, 0, 4096, 4160); slice_tensor_3 = getitem_19 = None
    slice_scatter_default_7: "f32[8320]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_5, slice_scatter_default_6, 0, 0, 4160); slice_scatter_default_5 = slice_scatter_default_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z)
    relu_1: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor_4); add_tensor_4 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
    slice_12: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_7, 0, 0, 4160); slice_scatter_default_7 = None
    all_gather_into_tensor_1: "f32[8320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_12, 2, '0'); slice_12 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    _foreach_copy_2 = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_11, primals_12, primals_13, primals_14]); primals_11 = primals_12 = primals_13 = primals_14 = None
    getitem_52: "f32[2048]" = _foreach_copy_2[0]
    getitem_53: "f32[64]" = _foreach_copy_2[1]
    getitem_54: "f32[2048]" = _foreach_copy_2[2]
    getitem_55: "f32[16]" = _foreach_copy_2[3]; _foreach_copy_2 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
    wait_tensor_1: "f32[8320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None

    # No stacktrace found for following nodes
    slice_tensor_4: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352)
    slice_scatter_default_8: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_4, getitem_52, 0, 0, 2048); slice_tensor_4 = getitem_52 = None
    slice_scatter_default_9: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default_8, 0, 4176, 8352); slice_scatter_default_8 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_6: "f32[2, 4160]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]); wait_tensor_1 = None
    split_with_sizes_10 = torch.ops.aten.split_with_sizes.default(view_6, [4096, 64], 1); view_6 = None
    getitem_26: "f32[2, 4096]" = split_with_sizes_10[0]
    clone_2: "f32[2, 4096]" = torch.ops.aten.clone.default(getitem_26, memory_format = torch.contiguous_format); getitem_26 = None
    view_7: "f32[8192]" = torch.ops.aten.reshape.default(clone_2, [8192]); clone_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_13: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 4176, 8352)
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_2: "f32[128, 64]" = torch.ops.aten.as_strided.default(view_7, [128, 64], [64, 1], 0); view_7 = None

    # No stacktrace found for following nodes
    slice_tensor_5: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 4176, 8352)
    slice_scatter_default_10: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_5, getitem_53, 0, 2048, 2112); slice_tensor_5 = getitem_53 = None
    slice_scatter_default_11: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_9, slice_scatter_default_10, 0, 4176, 8352); slice_scatter_default_9 = slice_scatter_default_10 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_29: "f32[2, 64]" = split_with_sizes_10[1]; split_with_sizes_10 = None
    clone_3: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_29, memory_format = torch.contiguous_format); getitem_29 = None
    view_9: "f32[128]" = torch.ops.aten.reshape.default(clone_3, [128]); clone_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_3: "f32[128]" = torch.ops.aten.as_strided.default(view_9, [128], [1], 0); view_9 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_3 = torch.ops.aten._foreach_copy.default([primals_9, primals_10], [as_strided_2, as_strided_3]); primals_10 = as_strided_2 = as_strided_3 = None
    getitem_30: "f32[128, 64]" = _foreach_copy_3[0]
    getitem_31: "f32[128]" = _foreach_copy_3[1]; _foreach_copy_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_3: "f32[64, 128]" = torch.ops.aten.permute.default(getitem_30, [1, 0]); getitem_30 = None
    addmm_1: "f32[8, 128]" = torch.ops.aten.addmm.default(getitem_31, addmm, permute_3); getitem_31 = permute_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
    empty_2: "f32[33024]" = torch.ops.aten.empty.memory_format([33024], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
    slice_13: "f32[16512]" = torch.ops.aten.slice.Tensor(empty_2, 0, 0, 16512)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
    split_with_sizes_12 = torch.ops.aten.split_with_sizes.default(slice_13, [16384, 128]); slice_13 = None
    getitem_32: "f32[16384]" = split_with_sizes_12[0]
    getitem_33: "f32[128]" = split_with_sizes_12[1]; split_with_sizes_12 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_14: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 4176, 8352)
    _foreach_copy_4 = torch.ops.aten._foreach_copy.default([getitem_32, getitem_33], [primals_11, primals_12]); getitem_32 = getitem_33 = primals_11 = primals_12 = None
    getitem_34: "f32[16384]" = _foreach_copy_4[0]
    getitem_35: "f32[128]" = _foreach_copy_4[1]; _foreach_copy_4 = None

    # No stacktrace found for following nodes
    slice_tensor_6: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 4176, 8352)
    slice_scatter_default_12: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_6, getitem_54, 0, 2112, 4160); slice_tensor_6 = getitem_54 = None
    slice_scatter_default_13: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_11, slice_scatter_default_12, 0, 4176, 8352); slice_scatter_default_11 = slice_scatter_default_12 = None
    slice_tensor_4: "f32[16512]" = torch.ops.aten.slice.Tensor(empty_2, 0, 0, 16512)
    slice_scatter_default_8: "f32[16512]" = torch.ops.aten.slice_scatter.default(slice_tensor_4, getitem_34, 0, 0, 16384); slice_tensor_4 = getitem_34 = None
    slice_scatter_default_9: "f32[33024]" = torch.ops.aten.slice_scatter.default(empty_2, slice_scatter_default_8, 0, 0, 16512); empty_2 = slice_scatter_default_8 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_15: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_13, 0, 4176, 8352)
    slice_15: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 0, 16512)

    # No stacktrace found for following nodes
    slice_tensor_7: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_13, 0, 4176, 8352)
    slice_scatter_default_14: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_7, getitem_55, 0, 4160, 4176); slice_tensor_7 = getitem_55 = None
    slice_scatter_default_15: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_13, slice_scatter_default_14, 0, 4176, 8352); slice_scatter_default_13 = slice_scatter_default_14 = None
    slice_tensor_5: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 0, 16512)
    slice_scatter_default_10: "f32[16512]" = torch.ops.aten.slice_scatter.default(slice_tensor_5, getitem_35, 0, 16384, 16512); slice_tensor_5 = getitem_35 = None
    slice_scatter_default_11: "f32[33024]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_9, slice_scatter_default_10, 0, 0, 16512); slice_scatter_default_9 = slice_scatter_default_10 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
    slice_20: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_15, 0, 4176, 8352); slice_scatter_default_15 = None
    all_gather_into_tensor_1: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_20, 2, '0'); slice_20 = None
    slice_18: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 0, 16512); slice_scatter_default_11 = None
    all_gather_into_tensor_2: "f32[33024]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_18, 2, '0'); slice_18 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
    wait_tensor_1: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None
    wait_tensor_2: "f32[33024]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_10: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]); wait_tensor_1 = None
    split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_10, [2048, 64, 2048, 16], 1); view_10 = None
    getitem_76: "f32[2, 2048]" = split_with_sizes_16[0]
    clone_4: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_76, memory_format = torch.contiguous_format); getitem_76 = None
    view_11: "f32[4096]" = torch.ops.aten.reshape.default(clone_4, [4096]); clone_4 = None
    view_11: "f32[2, 16512]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]); wait_tensor_2 = None
    split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_11, [16384, 128], 1); view_11 = None
    getitem_42: "f32[2, 16384]" = split_with_sizes_16[0]
    clone_4: "f32[2, 16384]" = torch.ops.aten.clone.default(getitem_42, memory_format = torch.contiguous_format); getitem_42 = None
    view_12: "f32[32768]" = torch.ops.aten.reshape.default(clone_4, [32768]); clone_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_4: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_11, [128, 32], [32, 1], 0); view_11 = None
    as_strided_4: "f32[256, 128]" = torch.ops.aten.as_strided.default(view_12, [256, 128], [128, 1], 0); view_12 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_81: "f32[2, 64]" = split_with_sizes_16[1]
    clone_5: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_81, memory_format = torch.contiguous_format); getitem_81 = None
    view_13: "f32[128]" = torch.ops.aten.reshape.default(clone_5, [128]); clone_5 = None
    getitem_45: "f32[2, 128]" = split_with_sizes_16[1]; split_with_sizes_16 = None
    clone_5: "f32[2, 128]" = torch.ops.aten.clone.default(getitem_45, memory_format = torch.contiguous_format); getitem_45 = None
    view_14: "f32[256]" = torch.ops.aten.reshape.default(clone_5, [256]); clone_5 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_5: "f32[128]" = torch.ops.aten.as_strided.default(view_13, [128], [1], 0); view_13 = None
    as_strided_5: "f32[256]" = torch.ops.aten.as_strided.default(view_14, [256], [1], 0); view_14 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_86: "f32[2, 2048]" = split_with_sizes_16[2]
    clone_6: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_86, memory_format = torch.contiguous_format); getitem_86 = None
    view_15: "f32[4096]" = torch.ops.aten.reshape.default(clone_6, [4096]); clone_6 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_5 = torch.ops.aten._foreach_copy.default([primals_13, primals_14], [as_strided_4, as_strided_5]); primals_14 = as_strided_4 = as_strided_5 = None
    getitem_46: "f32[256, 128]" = _foreach_copy_5[0]
    getitem_47: "f32[256]" = _foreach_copy_5[1]; _foreach_copy_5 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_6: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_15, [32, 128], [128, 1], 0); view_15 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_5: "f32[128, 256]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None
    addmm_2: "f32[8, 256]" = torch.ops.aten.addmm.default(getitem_47, addmm_1, permute_5); getitem_47 = permute_5 = None
    return [addmm_2, primals_1, primals_9, primals_13, addmm, addmm_1]

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_91: "f32[2, 16]" = split_with_sizes_16[3]; split_with_sizes_16 = None
    clone_7: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_91, memory_format = torch.contiguous_format); getitem_91 = None
    view_17: "f32[32]" = torch.ops.aten.reshape.default(clone_7, [32]); clone_7 = None

    TRACED GRAPH
    ===== AFTER POST GRAD =====
    /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[1024]", primals_3: "f32[32]", primals_4: "f32[64, 32]", primals_5: "f32[64]", primals_6, primals_7: "f32[4096]", primals_8: "f32[64]", primals_9: "f32[128, 64]", primals_10: "f32[128]", primals_11: "f32[16384]", primals_12: "f32[128]", primals_13: "f32[256, 128]", primals_14: "f32[256]"):
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
    empty: "f32[2112]" = torch.ops.aten.empty.memory_format([2112], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_7: "f32[32]" = torch.ops.aten.as_strided.default(view_17, [32], [1], 0); view_17 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
    slice_1: "f32[1056]" = torch.ops.aten.slice.Tensor(empty, 0, 1056, 2112)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_3 = torch.ops.aten._foreach_copy.default([primals_15, primals_16, primals_17, primals_18], [as_strided_4, as_strided_5, as_strided_6, as_strided_7]); primals_16 = primals_18 = as_strided_4 = as_strided_5 = as_strided_6 = as_strided_7 = None
    getitem_92: "f32[128, 32]" = _foreach_copy_3[0]
    getitem_93: "f32[128]" = _foreach_copy_3[1]
    getitem_94: "f32[32, 128]" = _foreach_copy_3[2]
    getitem_95: "f32[32]" = _foreach_copy_3[3]; _foreach_copy_3 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
    split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [1024, 32]); slice_1 = None
    getitem: "f32[1024]" = split_with_sizes[0]
    getitem_1: "f32[32]" = split_with_sizes[1]; split_with_sizes = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_5: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_92, [1, 0]); getitem_92 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    _foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1], [primals_2, primals_3]); getitem = getitem_1 = primals_2 = primals_3 = None
    getitem_2: "f32[1024]" = _foreach_copy[0]
    getitem_3: "f32[32]" = _foreach_copy[1]; _foreach_copy = None

    # No stacktrace found for following nodes
    mm_default_3: "f32[8, 128]" = torch.ops.aten.mm.default(relu_1, permute_5); permute_5 = None
    add_tensor_3: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_3, getitem_93); mm_default_3 = getitem_93 = None
    slice_tensor: "f32[1056]" = torch.ops.aten.slice.Tensor(empty, 0, 1056, 2112)
    slice_scatter_default: "f32[1056]" = torch.ops.aten.slice_scatter.default(slice_tensor, getitem_2, 0, 0, 1024); slice_tensor = getitem_2 = None
    slice_scatter_default_1: "f32[2112]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default, 0, 1056, 2112); empty = slice_scatter_default = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z)
    relu_2: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_3); add_tensor_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_7: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_94, [1, 0]); getitem_94 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_3: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 1056, 2112)

    # No stacktrace found for following nodes
    mm_default_2: "f32[8, 32]" = torch.ops.aten.mm.default(relu_2, permute_7); permute_7 = None
    add_tensor_2: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default_2, getitem_95); mm_default_2 = getitem_95 = None
    slice_tensor_1: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 1056, 2112)
    slice_scatter_default_2: "f32[1056]" = torch.ops.aten.slice_scatter.default(slice_tensor_1, getitem_3, 0, 1024, 1056); slice_tensor_1 = getitem_3 = None
    slice_scatter_default_3: "f32[2112]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_1, slice_scatter_default_2, 0, 1056, 2112); slice_scatter_default_1 = slice_scatter_default_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z)
    relu_3: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor_2); add_tensor_2 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
    slice_6: "f32[1056]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 1056, 2112); slice_scatter_default_3 = None
    all_gather_into_tensor: "f32[2112]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_6, 2, '0'); slice_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    _foreach_copy_4 = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_19, primals_20, primals_21, primals_22]); getitem = getitem_1 = getitem_2 = getitem_3 = primals_19 = primals_20 = primals_21 = primals_22 = None
    getitem_100: "f32[2048]" = _foreach_copy_4[0]
    getitem_101: "f32[64]" = _foreach_copy_4[1]
    getitem_102: "f32[2048]" = _foreach_copy_4[2]
    getitem_103: "f32[16]" = _foreach_copy_4[3]; _foreach_copy_4 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
    wait_tensor: "f32[2112]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None

    # No stacktrace found for following nodes
    slice_tensor_8: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352)
    slice_scatter_default_16: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_8, getitem_100, 0, 0, 2048); slice_tensor_8 = getitem_100 = None
    slice_scatter_default_17: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default_16, 0, 4176, 8352); empty = slice_scatter_default_16 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_1: "f32[2, 1056]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]); wait_tensor = None
    split_with_sizes_4 = torch.ops.aten.split_with_sizes.default(view_1, [1024, 32], 1); view_1 = None
    getitem_10: "f32[2, 1024]" = split_with_sizes_4[0]
    clone: "f32[2, 1024]" = torch.ops.aten.clone.default(getitem_10, memory_format = torch.contiguous_format); getitem_10 = None
    view_2: "f32[2048]" = torch.ops.aten.reshape.default(clone, [2048]); clone = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_23: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_17, 0, 4176, 8352)
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided: "f32[64, 32]" = torch.ops.aten.as_strided.default(view_2, [64, 32], [32, 1], 0); view_2 = None

    # No stacktrace found for following nodes
    slice_tensor_9: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_17, 0, 4176, 8352)
    slice_scatter_default_18: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_9, getitem_101, 0, 2048, 2112); slice_tensor_9 = getitem_101 = None
    slice_scatter_default_19: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_17, slice_scatter_default_18, 0, 4176, 8352); slice_scatter_default_17 = slice_scatter_default_18 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_13: "f32[2, 32]" = split_with_sizes_4[1]; split_with_sizes_4 = None
    clone_1: "f32[2, 32]" = torch.ops.aten.clone.default(getitem_13, memory_format = torch.contiguous_format); getitem_13 = None
    view_4: "f32[64]" = torch.ops.aten.reshape.default(clone_1, [64]); clone_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_1: "f32[64]" = torch.ops.aten.as_strided.default(view_4, [64], [1], 0); view_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_1 = torch.ops.aten._foreach_copy.default([primals_4, primals_5], [as_strided, as_strided_1]); primals_4 = primals_5 = as_strided = as_strided_1 = None
    getitem_14: "f32[64, 32]" = _foreach_copy_1[0]
    getitem_15: "f32[64]" = _foreach_copy_1[1]; _foreach_copy_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_1: "f32[32, 64]" = torch.ops.aten.permute.default(getitem_14, [1, 0]); getitem_14 = None
    addmm: "f32[8, 64]" = torch.ops.aten.addmm.default(getitem_15, primals_1, permute_1); getitem_15 = permute_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
    empty_1: "f32[8320]" = torch.ops.aten.empty.memory_format([8320], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
    slice_7: "f32[4160]" = torch.ops.aten.slice.Tensor(empty_1, 0, 4160, 8320)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
    split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(slice_7, [4096, 64]); slice_7 = None
    getitem_16: "f32[4096]" = split_with_sizes_6[0]
    getitem_17: "f32[64]" = split_with_sizes_6[1]; split_with_sizes_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_24: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_19, 0, 4176, 8352)
    _foreach_copy_2 = torch.ops.aten._foreach_copy.default([getitem_16, getitem_17], [primals_7, primals_8]); getitem_16 = getitem_17 = primals_7 = primals_8 = None
    getitem_18: "f32[4096]" = _foreach_copy_2[0]
    getitem_19: "f32[64]" = _foreach_copy_2[1]; _foreach_copy_2 = None

    # No stacktrace found for following nodes
    slice_tensor_10: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_19, 0, 4176, 8352)
    slice_scatter_default_20: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_10, getitem_102, 0, 2112, 4160); slice_tensor_10 = getitem_102 = None
    slice_scatter_default_21: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_19, slice_scatter_default_20, 0, 4176, 8352); slice_scatter_default_19 = slice_scatter_default_20 = None
    slice_tensor_2: "f32[4160]" = torch.ops.aten.slice.Tensor(empty_1, 0, 4160, 8320)
    slice_scatter_default_4: "f32[4160]" = torch.ops.aten.slice_scatter.default(slice_tensor_2, getitem_18, 0, 0, 4096); slice_tensor_2 = getitem_18 = None
    slice_scatter_default_5: "f32[8320]" = torch.ops.aten.slice_scatter.default(empty_1, slice_scatter_default_4, 0, 4160, 8320); empty_1 = slice_scatter_default_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_25: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_21, 0, 4176, 8352)
    slice_9: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4160, 8320)

    # No stacktrace found for following nodes
    slice_tensor_11: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_21, 0, 4176, 8352)
    slice_scatter_default_22: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_11, getitem_103, 0, 4160, 4176); slice_tensor_11 = getitem_103 = None
    slice_scatter_default_23: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_21, slice_scatter_default_22, 0, 4176, 8352); slice_scatter_default_21 = slice_scatter_default_22 = None
    slice_tensor_3: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4160, 8320)
    slice_scatter_default_6: "f32[4160]" = torch.ops.aten.slice_scatter.default(slice_tensor_3, getitem_19, 0, 4096, 4160); slice_tensor_3 = getitem_19 = None
    slice_scatter_default_7: "f32[8320]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_5, slice_scatter_default_6, 0, 4160, 8320); slice_scatter_default_5 = slice_scatter_default_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
    slice_30: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_23, 0, 4176, 8352); slice_scatter_default_23 = None
    all_gather_into_tensor_2: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_30, 2, '0'); slice_30 = None
    slice_12: "f32[4160]" = torch.ops.aten.slice.Tensor(slice_scatter_default_7, 0, 4160, 8320); slice_scatter_default_7 = None
    all_gather_into_tensor_1: "f32[8320]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_12, 2, '0'); slice_12 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
    wait_tensor_2: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None
    wait_tensor_1: "f32[8320]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_19: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]); wait_tensor_2 = None
    split_with_sizes_26 = torch.ops.aten.split_with_sizes.default(view_19, [2048, 64, 2048, 16], 1); view_19 = None
    getitem_124: "f32[2, 2048]" = split_with_sizes_26[0]
    clone_8: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_124, memory_format = torch.contiguous_format); getitem_124 = None
    view_20: "f32[4096]" = torch.ops.aten.reshape.default(clone_8, [4096]); clone_8 = None
    view_6: "f32[2, 4160]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]); wait_tensor_1 = None
    split_with_sizes_10 = torch.ops.aten.split_with_sizes.default(view_6, [4096, 64], 1); view_6 = None
    getitem_26: "f32[2, 4096]" = split_with_sizes_10[0]
    clone_2: "f32[2, 4096]" = torch.ops.aten.clone.default(getitem_26, memory_format = torch.contiguous_format); getitem_26 = None
    view_7: "f32[8192]" = torch.ops.aten.reshape.default(clone_2, [8192]); clone_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_8: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_20, [128, 32], [32, 1], 0); view_20 = None
    as_strided_2: "f32[128, 64]" = torch.ops.aten.as_strided.default(view_7, [128, 64], [64, 1], 0); view_7 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_129: "f32[2, 64]" = split_with_sizes_26[1]
    clone_9: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_129, memory_format = torch.contiguous_format); getitem_129 = None
    view_22: "f32[128]" = torch.ops.aten.reshape.default(clone_9, [128]); clone_9 = None
    getitem_29: "f32[2, 64]" = split_with_sizes_10[1]; split_with_sizes_10 = None
    clone_3: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_29, memory_format = torch.contiguous_format); getitem_29 = None
    view_9: "f32[128]" = torch.ops.aten.reshape.default(clone_3, [128]); clone_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_9: "f32[128]" = torch.ops.aten.as_strided.default(view_22, [128], [1], 0); view_22 = None
    as_strided_3: "f32[128]" = torch.ops.aten.as_strided.default(view_9, [128], [1], 0); view_9 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_134: "f32[2, 2048]" = split_with_sizes_26[2]
    clone_10: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_134, memory_format = torch.contiguous_format); getitem_134 = None
    view_24: "f32[4096]" = torch.ops.aten.reshape.default(clone_10, [4096]); clone_10 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_3 = torch.ops.aten._foreach_copy.default([primals_9, primals_10], [as_strided_2, as_strided_3]); primals_10 = as_strided_2 = as_strided_3 = None
    getitem_30: "f32[128, 64]" = _foreach_copy_3[0]
    getitem_31: "f32[128]" = _foreach_copy_3[1]; _foreach_copy_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_10: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_24, [32, 128], [128, 1], 0); view_24 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_3: "f32[64, 128]" = torch.ops.aten.permute.default(getitem_30, [1, 0]); getitem_30 = None
    addmm_1: "f32[8, 128]" = torch.ops.aten.addmm.default(getitem_31, addmm, permute_3); getitem_31 = permute_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_139: "f32[2, 16]" = split_with_sizes_26[3]; split_with_sizes_26 = None
    clone_11: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_139, memory_format = torch.contiguous_format); getitem_139 = None
    view_26: "f32[32]" = torch.ops.aten.reshape.default(clone_11, [32]); clone_11 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
    empty_2: "f32[33024]" = torch.ops.aten.empty.memory_format([33024], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_11: "f32[32]" = torch.ops.aten.as_strided.default(view_26, [32], [1], 0); view_26 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
    slice_13: "f32[16512]" = torch.ops.aten.slice.Tensor(empty_2, 0, 16512, 33024)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_5 = torch.ops.aten._foreach_copy.default([primals_23, primals_24, primals_25, primals_26], [as_strided_8, as_strided_9, as_strided_10, as_strided_11]); primals_24 = primals_26 = as_strided_8 = as_strided_9 = as_strided_10 = as_strided_11 = None
    getitem_140: "f32[128, 32]" = _foreach_copy_5[0]
    getitem_141: "f32[128]" = _foreach_copy_5[1]
    getitem_142: "f32[32, 128]" = _foreach_copy_5[2]
    getitem_143: "f32[32]" = _foreach_copy_5[3]; _foreach_copy_5 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
    split_with_sizes_12 = torch.ops.aten.split_with_sizes.default(slice_13, [16384, 128]); slice_13 = None
    getitem_32: "f32[16384]" = split_with_sizes_12[0]
    getitem_33: "f32[128]" = split_with_sizes_12[1]; split_with_sizes_12 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_9: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_140, [1, 0]); getitem_140 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    _foreach_copy_4 = torch.ops.aten._foreach_copy.default([getitem_32, getitem_33], [primals_11, primals_12]); getitem_32 = getitem_33 = primals_11 = primals_12 = None
    getitem_34: "f32[16384]" = _foreach_copy_4[0]
    getitem_35: "f32[128]" = _foreach_copy_4[1]; _foreach_copy_4 = None

    # No stacktrace found for following nodes
    mm_default_1: "f32[8, 128]" = torch.ops.aten.mm.default(relu_3, permute_9); permute_9 = None
    add_tensor_1: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_1, getitem_141); mm_default_1 = getitem_141 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z)
    relu_4: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_1); add_tensor_1 = None
    slice_tensor_4: "f32[16512]" = torch.ops.aten.slice.Tensor(empty_2, 0, 16512, 33024)
    slice_scatter_default_8: "f32[16512]" = torch.ops.aten.slice_scatter.default(slice_tensor_4, getitem_34, 0, 0, 16384); slice_tensor_4 = getitem_34 = None
    slice_scatter_default_9: "f32[33024]" = torch.ops.aten.slice_scatter.default(empty_2, slice_scatter_default_8, 0, 16512, 33024); empty_2 = slice_scatter_default_8 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_11: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_142, [1, 0]); getitem_142 = None
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_15: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 16512, 33024)

    # No stacktrace found for following nodes
    mm_default: "f32[8, 32]" = torch.ops.aten.mm.default(relu_4, permute_11); permute_11 = None
    add_tensor: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default, getitem_143); mm_default = getitem_143 = None
    slice_tensor_5: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 16512, 33024)
    slice_scatter_default_10: "f32[16512]" = torch.ops.aten.slice_scatter.default(slice_tensor_5, getitem_35, 0, 16384, 16512); slice_tensor_5 = getitem_35 = None
    slice_scatter_default_11: "f32[33024]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_9, slice_scatter_default_10, 0, 16512, 33024); slice_scatter_default_9 = slice_scatter_default_10 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
    slice_18: "f32[16512]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 16512, 33024); slice_scatter_default_11 = None
    all_gather_into_tensor_2: "f32[33024]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_18, 2, '0'); slice_18 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
    wait_tensor_2: "f32[33024]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_11: "f32[2, 16512]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]); wait_tensor_2 = None
    split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_11, [16384, 128], 1); view_11 = None
    getitem_42: "f32[2, 16384]" = split_with_sizes_16[0]
    clone_4: "f32[2, 16384]" = torch.ops.aten.clone.default(getitem_42, memory_format = torch.contiguous_format); getitem_42 = None
    view_12: "f32[32768]" = torch.ops.aten.reshape.default(clone_4, [32768]); clone_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_4: "f32[256, 128]" = torch.ops.aten.as_strided.default(view_12, [256, 128], [128, 1], 0); view_12 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z)
    relu_5: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor); add_tensor = None
    le: "b8[8, 32]" = torch.ops.aten.le.Scalar(relu_5, 0)
    return [relu_5, primals_1, primals_8, primals_15, primals_17, primals_23, primals_25, relu, relu_1, relu_2, relu_3, relu_4, le]
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_45: "f32[2, 128]" = split_with_sizes_16[1]; split_with_sizes_16 = None
    clone_5: "f32[2, 128]" = torch.ops.aten.clone.default(getitem_45, memory_format = torch.contiguous_format); getitem_45 = None
    view_14: "f32[256]" = torch.ops.aten.reshape.default(clone_5, [256]); clone_5 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_5: "f32[256]" = torch.ops.aten.as_strided.default(view_14, [256], [1], 0); view_14 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_5 = torch.ops.aten._foreach_copy.default([primals_13, primals_14], [as_strided_4, as_strided_5]); primals_14 = as_strided_4 = as_strided_5 = None
    getitem_46: "f32[256, 128]" = _foreach_copy_5[0]
    getitem_47: "f32[256]" = _foreach_copy_5[1]; _foreach_copy_5 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_5: "f32[128, 256]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None
    addmm_2: "f32[8, 256]" = torch.ops.aten.addmm.default(getitem_47, addmm_1, permute_5); getitem_47 = permute_5 = None
    return [addmm_2, primals_1, primals_9, primals_13, addmm, addmm_1]

  3. yf225 created this gist Mar 27, 2024.
    330 changes: 330 additions & 0 deletions ppfsdp_multi_group_fwd_graph_no_fsdp_fx_passes.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,330 @@
    TRACED GRAPH
    ===== AFTER POST GRAD =====
    /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[2048]", primals_3: "f32[64]", primals_4: "f32[2048]", primals_5: "f32[16]", primals_6: "f32[128, 32]", primals_7: "f32[128]", primals_8: "f32[32, 128]", primals_9: "f32[32]", primals_10, primals_11: "f32[2048]", primals_12: "f32[64]", primals_13: "f32[2048]", primals_14: "f32[16]", primals_15: "f32[128, 32]", primals_16: "f32[128]", primals_17: "f32[32, 128]", primals_18: "f32[32]", primals_19: "f32[2048]", primals_20: "f32[64]", primals_21: "f32[2048]", primals_22: "f32[16]", primals_23: "f32[128, 32]", primals_24: "f32[128]", primals_25: "f32[32, 128]", primals_26: "f32[32]"):
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
    empty: "f32[8352]" = torch.ops.aten.empty.memory_format([8352], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
    slice_1: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
    split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [2048, 64, 2048, 16]); slice_1 = None
    getitem: "f32[2048]" = split_with_sizes[0]
    getitem_1: "f32[64]" = split_with_sizes[1]
    getitem_2: "f32[2048]" = split_with_sizes[2]
    getitem_3: "f32[16]" = split_with_sizes[3]; split_with_sizes = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    _foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_2, primals_3, primals_4, primals_5]); primals_2 = primals_3 = primals_4 = primals_5 = None
    getitem_4: "f32[2048]" = _foreach_copy[0]
    getitem_5: "f32[64]" = _foreach_copy[1]
    getitem_6: "f32[2048]" = _foreach_copy[2]
    getitem_7: "f32[16]" = _foreach_copy[3]; _foreach_copy = None

    # No stacktrace found for following nodes
    slice_tensor: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352)
    slice_scatter_default: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor, getitem_4, 0, 0, 2048); slice_tensor = getitem_4 = None
    slice_scatter_default_1: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default, 0, 4176, 8352); slice_scatter_default = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_3: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_1: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_1, 0, 4176, 8352)
    slice_scatter_default_2: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_1, getitem_5, 0, 2048, 2112); slice_tensor_1 = getitem_5 = None
    slice_scatter_default_3: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_1, slice_scatter_default_2, 0, 4176, 8352); slice_scatter_default_1 = slice_scatter_default_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_4: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_2: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_3, 0, 4176, 8352)
    slice_scatter_default_4: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_2, getitem_6, 0, 2112, 4160); slice_tensor_2 = getitem_6 = None
    slice_scatter_default_5: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_3, slice_scatter_default_4, 0, 4176, 8352); slice_scatter_default_3 = slice_scatter_default_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_5: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_3: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_5, 0, 4176, 8352)
    slice_scatter_default_6: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_3, getitem_7, 0, 4160, 4176); slice_tensor_3 = getitem_7 = None
    slice_scatter_default_7: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_5, slice_scatter_default_6, 0, 4176, 8352); slice_scatter_default_5 = slice_scatter_default_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
    slice_10: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_7, 0, 4176, 8352); slice_scatter_default_7 = None
    all_gather_into_tensor: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_10, 2, '0'); slice_10 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
    wait_tensor: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_1: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor, [2, -1]); wait_tensor = None
    split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_1, [2048, 64, 2048, 16], 1); view_1 = None
    getitem_28: "f32[2, 2048]" = split_with_sizes_6[0]
    clone: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_28, memory_format = torch.contiguous_format); getitem_28 = None
    view_2: "f32[4096]" = torch.ops.aten.reshape.default(clone, [4096]); clone = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_2, [128, 32], [32, 1], 0); view_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_33: "f32[2, 64]" = split_with_sizes_6[1]
    clone_1: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_33, memory_format = torch.contiguous_format); getitem_33 = None
    view_4: "f32[128]" = torch.ops.aten.reshape.default(clone_1, [128]); clone_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_1: "f32[128]" = torch.ops.aten.as_strided.default(view_4, [128], [1], 0); view_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_38: "f32[2, 2048]" = split_with_sizes_6[2]
    clone_2: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_38, memory_format = torch.contiguous_format); getitem_38 = None
    view_6: "f32[4096]" = torch.ops.aten.reshape.default(clone_2, [4096]); clone_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_2: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_6, [32, 128], [128, 1], 0); view_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_43: "f32[2, 16]" = split_with_sizes_6[3]; split_with_sizes_6 = None
    clone_3: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_43, memory_format = torch.contiguous_format); getitem_43 = None
    view_8: "f32[32]" = torch.ops.aten.reshape.default(clone_3, [32]); clone_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_3: "f32[32]" = torch.ops.aten.as_strided.default(view_8, [32], [1], 0); view_8 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_1 = torch.ops.aten._foreach_copy.default([primals_6, primals_7, primals_8, primals_9], [as_strided, as_strided_1, as_strided_2, as_strided_3]); primals_6 = primals_7 = primals_9 = as_strided = as_strided_1 = as_strided_2 = as_strided_3 = None
    getitem_44: "f32[128, 32]" = _foreach_copy_1[0]
    getitem_45: "f32[128]" = _foreach_copy_1[1]
    getitem_46: "f32[32, 128]" = _foreach_copy_1[2]
    getitem_47: "f32[32]" = _foreach_copy_1[3]; _foreach_copy_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_1: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_44, [1, 0]); getitem_44 = None

    # No stacktrace found for following nodes
    mm_default_5: "f32[8, 128]" = torch.ops.aten.mm.default(primals_1, permute_1); permute_1 = None
    add_tensor_5: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_5, getitem_45); mm_default_5 = getitem_45 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z)
    relu: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_5); add_tensor_5 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_3: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None

    # No stacktrace found for following nodes
    mm_default_4: "f32[8, 32]" = torch.ops.aten.mm.default(relu, permute_3); permute_3 = None
    add_tensor_4: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default_4, getitem_47); mm_default_4 = getitem_47 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z)
    relu_1: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor_4); add_tensor_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    _foreach_copy_2 = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_11, primals_12, primals_13, primals_14]); primals_11 = primals_12 = primals_13 = primals_14 = None
    getitem_52: "f32[2048]" = _foreach_copy_2[0]
    getitem_53: "f32[64]" = _foreach_copy_2[1]
    getitem_54: "f32[2048]" = _foreach_copy_2[2]
    getitem_55: "f32[16]" = _foreach_copy_2[3]; _foreach_copy_2 = None

    # No stacktrace found for following nodes
    slice_tensor_4: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352)
    slice_scatter_default_8: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_4, getitem_52, 0, 0, 2048); slice_tensor_4 = getitem_52 = None
    slice_scatter_default_9: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default_8, 0, 4176, 8352); slice_scatter_default_8 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_13: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_5: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_9, 0, 4176, 8352)
    slice_scatter_default_10: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_5, getitem_53, 0, 2048, 2112); slice_tensor_5 = getitem_53 = None
    slice_scatter_default_11: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_9, slice_scatter_default_10, 0, 4176, 8352); slice_scatter_default_9 = slice_scatter_default_10 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_14: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_6: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_11, 0, 4176, 8352)
    slice_scatter_default_12: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_6, getitem_54, 0, 2112, 4160); slice_tensor_6 = getitem_54 = None
    slice_scatter_default_13: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_11, slice_scatter_default_12, 0, 4176, 8352); slice_scatter_default_11 = slice_scatter_default_12 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_15: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_13, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_7: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_13, 0, 4176, 8352)
    slice_scatter_default_14: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_7, getitem_55, 0, 4160, 4176); slice_tensor_7 = getitem_55 = None
    slice_scatter_default_15: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_13, slice_scatter_default_14, 0, 4176, 8352); slice_scatter_default_13 = slice_scatter_default_14 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
    slice_20: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_15, 0, 4176, 8352); slice_scatter_default_15 = None
    all_gather_into_tensor_1: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_20, 2, '0'); slice_20 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
    wait_tensor_1: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_1); all_gather_into_tensor_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_10: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_1, [2, -1]); wait_tensor_1 = None
    split_with_sizes_16 = torch.ops.aten.split_with_sizes.default(view_10, [2048, 64, 2048, 16], 1); view_10 = None
    getitem_76: "f32[2, 2048]" = split_with_sizes_16[0]
    clone_4: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_76, memory_format = torch.contiguous_format); getitem_76 = None
    view_11: "f32[4096]" = torch.ops.aten.reshape.default(clone_4, [4096]); clone_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_4: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_11, [128, 32], [32, 1], 0); view_11 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_81: "f32[2, 64]" = split_with_sizes_16[1]
    clone_5: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_81, memory_format = torch.contiguous_format); getitem_81 = None
    view_13: "f32[128]" = torch.ops.aten.reshape.default(clone_5, [128]); clone_5 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_5: "f32[128]" = torch.ops.aten.as_strided.default(view_13, [128], [1], 0); view_13 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_86: "f32[2, 2048]" = split_with_sizes_16[2]
    clone_6: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_86, memory_format = torch.contiguous_format); getitem_86 = None
    view_15: "f32[4096]" = torch.ops.aten.reshape.default(clone_6, [4096]); clone_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_6: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_15, [32, 128], [128, 1], 0); view_15 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_91: "f32[2, 16]" = split_with_sizes_16[3]; split_with_sizes_16 = None
    clone_7: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_91, memory_format = torch.contiguous_format); getitem_91 = None
    view_17: "f32[32]" = torch.ops.aten.reshape.default(clone_7, [32]); clone_7 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_7: "f32[32]" = torch.ops.aten.as_strided.default(view_17, [32], [1], 0); view_17 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_3 = torch.ops.aten._foreach_copy.default([primals_15, primals_16, primals_17, primals_18], [as_strided_4, as_strided_5, as_strided_6, as_strided_7]); primals_16 = primals_18 = as_strided_4 = as_strided_5 = as_strided_6 = as_strided_7 = None
    getitem_92: "f32[128, 32]" = _foreach_copy_3[0]
    getitem_93: "f32[128]" = _foreach_copy_3[1]
    getitem_94: "f32[32, 128]" = _foreach_copy_3[2]
    getitem_95: "f32[32]" = _foreach_copy_3[3]; _foreach_copy_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_5: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_92, [1, 0]); getitem_92 = None

    # No stacktrace found for following nodes
    mm_default_3: "f32[8, 128]" = torch.ops.aten.mm.default(relu_1, permute_5); permute_5 = None
    add_tensor_3: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_3, getitem_93); mm_default_3 = getitem_93 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z)
    relu_2: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_3); add_tensor_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_7: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_94, [1, 0]); getitem_94 = None

    # No stacktrace found for following nodes
    mm_default_2: "f32[8, 32]" = torch.ops.aten.mm.default(relu_2, permute_7); permute_7 = None
    add_tensor_2: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default_2, getitem_95); mm_default_2 = getitem_95 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z)
    relu_3: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor_2); add_tensor_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    _foreach_copy_4 = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_19, primals_20, primals_21, primals_22]); getitem = getitem_1 = getitem_2 = getitem_3 = primals_19 = primals_20 = primals_21 = primals_22 = None
    getitem_100: "f32[2048]" = _foreach_copy_4[0]
    getitem_101: "f32[64]" = _foreach_copy_4[1]
    getitem_102: "f32[2048]" = _foreach_copy_4[2]
    getitem_103: "f32[16]" = _foreach_copy_4[3]; _foreach_copy_4 = None

    # No stacktrace found for following nodes
    slice_tensor_8: "f32[4176]" = torch.ops.aten.slice.Tensor(empty, 0, 4176, 8352)
    slice_scatter_default_16: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_8, getitem_100, 0, 0, 2048); slice_tensor_8 = getitem_100 = None
    slice_scatter_default_17: "f32[8352]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter_default_16, 0, 4176, 8352); empty = slice_scatter_default_16 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_23: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_17, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_9: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_17, 0, 4176, 8352)
    slice_scatter_default_18: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_9, getitem_101, 0, 2048, 2112); slice_tensor_9 = getitem_101 = None
    slice_scatter_default_19: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_17, slice_scatter_default_18, 0, 4176, 8352); slice_scatter_default_17 = slice_scatter_default_18 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_24: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_19, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_10: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_19, 0, 4176, 8352)
    slice_scatter_default_20: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_10, getitem_102, 0, 2112, 4160); slice_tensor_10 = getitem_102 = None
    slice_scatter_default_21: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_19, slice_scatter_default_20, 0, 4176, 8352); slice_scatter_default_19 = slice_scatter_default_20 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    slice_25: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_21, 0, 4176, 8352)

    # No stacktrace found for following nodes
    slice_tensor_11: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_21, 0, 4176, 8352)
    slice_scatter_default_22: "f32[4176]" = torch.ops.aten.slice_scatter.default(slice_tensor_11, getitem_103, 0, 4160, 4176); slice_tensor_11 = getitem_103 = None
    slice_scatter_default_23: "f32[8352]" = torch.ops.aten.slice_scatter.default(slice_scatter_default_21, slice_scatter_default_22, 0, 4176, 8352); slice_scatter_default_21 = slice_scatter_default_22 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
    slice_30: "f32[4176]" = torch.ops.aten.slice.Tensor(slice_scatter_default_23, 0, 4176, 8352); slice_scatter_default_23 = None
    all_gather_into_tensor_2: "f32[8352]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_30, 2, '0'); slice_30 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
    wait_tensor_2: "f32[8352]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor_2); all_gather_into_tensor_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_19: "f32[2, 4176]" = torch.ops.aten.reshape.default(wait_tensor_2, [2, -1]); wait_tensor_2 = None
    split_with_sizes_26 = torch.ops.aten.split_with_sizes.default(view_19, [2048, 64, 2048, 16], 1); view_19 = None
    getitem_124: "f32[2, 2048]" = split_with_sizes_26[0]
    clone_8: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_124, memory_format = torch.contiguous_format); getitem_124 = None
    view_20: "f32[4096]" = torch.ops.aten.reshape.default(clone_8, [4096]); clone_8 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_8: "f32[128, 32]" = torch.ops.aten.as_strided.default(view_20, [128, 32], [32, 1], 0); view_20 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_129: "f32[2, 64]" = split_with_sizes_26[1]
    clone_9: "f32[2, 64]" = torch.ops.aten.clone.default(getitem_129, memory_format = torch.contiguous_format); getitem_129 = None
    view_22: "f32[128]" = torch.ops.aten.reshape.default(clone_9, [128]); clone_9 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_9: "f32[128]" = torch.ops.aten.as_strided.default(view_22, [128], [1], 0); view_22 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_134: "f32[2, 2048]" = split_with_sizes_26[2]
    clone_10: "f32[2, 2048]" = torch.ops.aten.clone.default(getitem_134, memory_format = torch.contiguous_format); getitem_134 = None
    view_24: "f32[4096]" = torch.ops.aten.reshape.default(clone_10, [4096]); clone_10 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_10: "f32[32, 128]" = torch.ops.aten.as_strided.default(view_24, [32, 128], [128, 1], 0); view_24 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_139: "f32[2, 16]" = split_with_sizes_26[3]; split_with_sizes_26 = None
    clone_11: "f32[2, 16]" = torch.ops.aten.clone.default(getitem_139, memory_format = torch.contiguous_format); getitem_139 = None
    view_26: "f32[32]" = torch.ops.aten.reshape.default(clone_11, [32]); clone_11 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_11: "f32[32]" = torch.ops.aten.as_strided.default(view_26, [32], [1], 0); view_26 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_5 = torch.ops.aten._foreach_copy.default([primals_23, primals_24, primals_25, primals_26], [as_strided_8, as_strided_9, as_strided_10, as_strided_11]); primals_24 = primals_26 = as_strided_8 = as_strided_9 = as_strided_10 = as_strided_11 = None
    getitem_140: "f32[128, 32]" = _foreach_copy_5[0]
    getitem_141: "f32[128]" = _foreach_copy_5[1]
    getitem_142: "f32[32, 128]" = _foreach_copy_5[2]
    getitem_143: "f32[32]" = _foreach_copy_5[3]; _foreach_copy_5 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_9: "f32[32, 128]" = torch.ops.aten.permute.default(getitem_140, [1, 0]); getitem_140 = None

    # No stacktrace found for following nodes
    mm_default_1: "f32[8, 128]" = torch.ops.aten.mm.default(relu_3, permute_9); permute_9 = None
    add_tensor_1: "f32[8, 128]" = torch.ops.aten.add.Tensor(mm_default_1, getitem_141); mm_default_1 = getitem_141 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:855 in forward, code: z = F.relu(z)
    relu_4: "f32[8, 128]" = torch.ops.aten.relu.default(add_tensor_1); add_tensor_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_11: "f32[128, 32]" = torch.ops.aten.permute.default(getitem_142, [1, 0]); getitem_142 = None

    # No stacktrace found for following nodes
    mm_default: "f32[8, 32]" = torch.ops.aten.mm.default(relu_4, permute_11); permute_11 = None
    add_tensor: "f32[8, 32]" = torch.ops.aten.add.Tensor(mm_default, getitem_143); mm_default = getitem_143 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/testing/_internal/common_fsdp.py:857 in forward, code: z = F.relu(z)
    relu_5: "f32[8, 32]" = torch.ops.aten.relu.default(add_tensor); add_tensor = None
    le: "b8[8, 32]" = torch.ops.aten.le.Scalar(relu_5, 0)
    return [relu_5, primals_1, primals_8, primals_15, primals_17, primals_23, primals_25, relu, relu_1, relu_2, relu_3, relu_4, le]