Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save yf225/8eebf52f5fe3352caaa4f08cdcd61623 to your computer and use it in GitHub Desktop.

Select an option

Save yf225/8eebf52f5fe3352caaa4f08cdcd61623 to your computer and use it in GitHub Desktop.

Revisions

  1. yf225 created this gist Mar 23, 2024.
    271 changes: 271 additions & 0 deletions constant_fold_getitem_becomes_full_issue.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,271 @@
    ===== Joint graph 0 =====
    /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class joint_helper(torch.nn.Module):
    def forward(self, primals, tangents):
    primals_1: "f32[4, 16]"; primals_2: "f32[128]"; primals_3: "f32[8]"; primals_4: "f32[60]"; primals_5: "f32[4]"; primals_6: "f32[15, 16]"; primals_7: "f32[15]"; primals_8: "f32[8, 15]"; primals_9: "f32[8]"; tangents_1: "f32[4, 8]";

    primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec)
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
    empty: "f32[400]" = torch.ops.aten.empty.memory_format([400], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
    slice_1: "f32[200]" = torch.ops.aten.slice.Tensor(empty, 0, 200, 400)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
    split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [128, 8, 60, 4]); slice_1 = None
    getitem: "f32[128]" = split_with_sizes[0]
    getitem_1: "f32[8]" = split_with_sizes[1]
    getitem_2: "f32[60]" = split_with_sizes[2]
    getitem_3: "f32[4]" = split_with_sizes[3]; split_with_sizes = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    _foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_2, primals_3, primals_4, primals_5]); getitem = getitem_1 = getitem_2 = getitem_3 = primals_2 = primals_3 = primals_4 = primals_5 = None
    getitem_4: "f32[128]" = _foreach_copy[0]
    getitem_5: "f32[8]" = _foreach_copy[1]
    getitem_6: "f32[60]" = _foreach_copy[2]
    getitem_7: "f32[4]" = _foreach_copy[3]; _foreach_copy = None
    slice_2: "f32[200]" = torch.ops.aten.slice.Tensor(empty, 0, 200, 400)
    slice_scatter: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_2, getitem_4, 0, 0, 128); slice_2 = getitem_4 = None
    slice_scatter_1: "f32[400]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter, 0, 200, 400); empty = slice_scatter = None
    slice_3: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_1, 0, 200, 400)
    slice_scatter_2: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_3, getitem_5, 0, 128, 136); slice_3 = getitem_5 = None
    slice_scatter_3: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_1, slice_scatter_2, 0, 200, 400); slice_scatter_1 = slice_scatter_2 = None
    slice_4: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_3, 0, 200, 400)
    slice_scatter_4: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_4, getitem_6, 0, 136, 196); slice_4 = getitem_6 = None
    slice_scatter_5: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_3, slice_scatter_4, 0, 200, 400); slice_scatter_3 = slice_scatter_4 = None
    slice_5: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_5, 0, 200, 400)
    slice_scatter_6: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_5, getitem_7, 0, 196, 200); slice_5 = getitem_7 = None
    slice_scatter_7: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_5, slice_scatter_6, 0, 200, 400); slice_scatter_5 = slice_scatter_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
    slice_10: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_7, 0, 200, 400)
    all_gather_into_tensor: "f32[400]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_10, 2, '0'); slice_10 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
    wait_tensor: "f32[400]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:975 in all_gather_tensor_inplace, code: return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag))
    copy: "f32[400]" = torch.ops.aten.copy.default(slice_scatter_7, wait_tensor); slice_scatter_7 = wait_tensor = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_1: "f32[2, 200]" = torch.ops.aten.view.default(copy, [2, -1])
    split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_1, [128, 8, 60, 4], 1); view_1 = None
    getitem_28: "f32[2, 128]" = split_with_sizes_6[0]; split_with_sizes_6 = None
    clone: "f32[2, 128]" = torch.ops.aten.clone.default(getitem_28, memory_format = torch.contiguous_format); getitem_28 = None
    view_2: "f32[256]" = torch.ops.aten.view.default(clone, [256]); clone = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided: "f32[15, 16]" = torch.ops.aten.as_strided.default(view_2, [15, 16], [16, 1], 0); view_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_3: "f32[2, 200]" = torch.ops.aten.view.default(copy, [2, -1])
    split_with_sizes_7 = torch.ops.aten.split_with_sizes.default(view_3, [128, 8, 60, 4], 1); view_3 = None
    getitem_33: "f32[2, 8]" = split_with_sizes_7[1]; split_with_sizes_7 = None
    clone_1: "f32[2, 8]" = torch.ops.aten.clone.default(getitem_33, memory_format = torch.contiguous_format); getitem_33 = None
    view_4: "f32[16]" = torch.ops.aten.view.default(clone_1, [16]); clone_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_1: "f32[15]" = torch.ops.aten.as_strided.default(view_4, [15], [1], 0); view_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_5: "f32[2, 200]" = torch.ops.aten.view.default(copy, [2, -1])
    split_with_sizes_8 = torch.ops.aten.split_with_sizes.default(view_5, [128, 8, 60, 4], 1); view_5 = None
    getitem_38: "f32[2, 60]" = split_with_sizes_8[2]; split_with_sizes_8 = None
    clone_2: "f32[2, 60]" = torch.ops.aten.clone.default(getitem_38, memory_format = torch.contiguous_format); getitem_38 = None
    view_6: "f32[120]" = torch.ops.aten.view.default(clone_2, [120]); clone_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_2: "f32[8, 15]" = torch.ops.aten.as_strided.default(view_6, [8, 15], [15, 1], 0); view_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_7: "f32[2, 200]" = torch.ops.aten.view.default(copy, [2, -1]); copy = None
    split_with_sizes_9 = torch.ops.aten.split_with_sizes.default(view_7, [128, 8, 60, 4], 1); view_7 = None
    getitem_43: "f32[2, 4]" = split_with_sizes_9[3]; split_with_sizes_9 = None
    clone_3: "f32[2, 4]" = torch.ops.aten.clone.default(getitem_43, memory_format = torch.contiguous_format); getitem_43 = None
    view_8: "f32[8]" = torch.ops.aten.view.default(clone_3, [8]); clone_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_3: "f32[8]" = torch.ops.aten.as_strided.default(view_8, [8], [1], 0); view_8 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_1 = torch.ops.aten._foreach_copy.default([primals_6, primals_7, primals_8, primals_9], [as_strided, as_strided_1, as_strided_2, as_strided_3]); primals_6 = primals_7 = primals_9 = as_strided = as_strided_1 = as_strided_2 = as_strided_3 = None
    getitem_44: "f32[15, 16]" = _foreach_copy_1[0]
    getitem_45: "f32[15]" = _foreach_copy_1[1]
    getitem_46: "f32[8, 15]" = _foreach_copy_1[2]
    getitem_47: "f32[8]" = _foreach_copy_1[3]; _foreach_copy_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_1: "f32[16, 15]" = torch.ops.aten.permute.default(getitem_44, [1, 0]); getitem_44 = None
    addmm: "f32[4, 15]" = torch.ops.aten.addmm.default(getitem_45, primals_1, permute_1); getitem_45 = permute_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/activation.py:103 in forward, code: return F.relu(input, inplace=self.inplace)
    relu: "f32[4, 15]" = torch.ops.aten.relu.default(addmm); addmm = None
    alias: "f32[4, 15]" = torch.ops.aten.alias.default(relu)
    alias_1: "f32[4, 15]" = torch.ops.aten.alias.default(alias); alias = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_3: "f32[15, 8]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None
    addmm_1: "f32[4, 8]" = torch.ops.aten.addmm.default(getitem_47, relu, permute_3); getitem_47 = permute_3 = None

    # No stacktrace found for following nodes
    trace_wrapped: "f32[4, 8]" = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(tangents_1, bw_state = primals_10); tangents_1 = primals_10 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_5: "f32[15, 8]" = torch.ops.aten.permute.default(primals_8, [1, 0]); primals_8 = None
    permute_6: "f32[8, 15]" = torch.ops.aten.permute.default(permute_5, [1, 0]); permute_5 = None
    mm: "f32[4, 15]" = torch.ops.aten.mm.default(trace_wrapped, permute_6); permute_6 = None
    permute_7: "f32[8, 4]" = torch.ops.aten.permute.default(trace_wrapped, [1, 0])
    mm_1: "f32[8, 15]" = torch.ops.aten.mm.default(permute_7, relu); permute_7 = relu = None
    permute_8: "f32[15, 8]" = torch.ops.aten.permute.default(mm_1, [1, 0]); mm_1 = None
    sum_1: "f32[1, 8]" = torch.ops.aten.sum.dim_IntList(trace_wrapped, [0], True); trace_wrapped = None
    view_9: "f32[8]" = torch.ops.aten.view.default(sum_1, [8]); sum_1 = None
    permute_9: "f32[8, 15]" = torch.ops.aten.permute.default(permute_8, [1, 0]); permute_8 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/activation.py:103 in forward, code: return F.relu(input, inplace=self.inplace)
    alias_2: "f32[4, 15]" = torch.ops.aten.alias.default(alias_1); alias_1 = None
    alias_3: "f32[4, 15]" = torch.ops.aten.alias.default(alias_2); alias_2 = None
    le: "b8[4, 15]" = torch.ops.aten.le.Scalar(alias_3, 0); alias_3 = None
    scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=1))
    where: "f32[4, 15]" = torch.ops.aten.where.self(le, scalar_tensor, mm); le = scalar_tensor = mm = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_10: "f32[15, 4]" = torch.ops.aten.permute.default(where, [1, 0])
    mm_2: "f32[15, 16]" = torch.ops.aten.mm.default(permute_10, primals_1); permute_10 = primals_1 = None
    permute_11: "f32[16, 15]" = torch.ops.aten.permute.default(mm_2, [1, 0]); mm_2 = None
    sum_2: "f32[1, 15]" = torch.ops.aten.sum.dim_IntList(where, [0], True); where = None
    view_10: "f32[15]" = torch.ops.aten.view.default(sum_2, [15]); sum_2 = None
    permute_12: "f32[15, 16]" = torch.ops.aten.permute.default(permute_11, [1, 0]); permute_11 = None
    return pytree.tree_unflatten([addmm_1, None, None, None, None, None, permute_12, view_10, permute_9, view_9, None], self._out_spec)


    /data/users/willfeng/pytorch_yf225/torch/_inductor/compile_fx.py:133: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
    warnings.warn(
    TRACED GRAPH
    ===== Forward graph 0 =====
    /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 16]", primals_2: "f32[128]", primals_3: "f32[8]", primals_4: "f32[60]", primals_5: "f32[4]", primals_6: "f32[15, 16]", primals_7: "f32[15]", primals_8: "f32[8, 15]", primals_9: "f32[8]", primals_10):
    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty(
    empty: "f32[400]" = torch.ops.aten.empty.memory_format([400], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow(
    slice_1: "f32[200]" = torch.ops.aten.slice.Tensor(empty, 0, 200, 400)

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
    split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [128, 8, 60, 4])
    getitem: "f32[128]" = split_with_sizes[0]
    full_default: "f32[8]" = torch.ops.aten.full.default([8], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=1), pin_memory = False)
    getitem_2: "f32[60]" = split_with_sizes[2]
    getitem_3: "f32[4]" = split_with_sizes[3]; split_with_sizes = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs)
    _foreach_copy = torch.ops.aten._foreach_copy_.default([getitem, full_default, getitem_2, getitem_3], [primals_2, primals_3, primals_4, primals_5]); primals_2 = primals_3 = primals_4 = primals_5 = None
    slice_scatter: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_1, getitem, 0, 0, 128); slice_1 = getitem = None
    slice_scatter_1: "f32[400]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter, 0, 200, 400); empty = slice_scatter = None
    slice_3: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_1, 0, 200, 400)
    slice_scatter_2: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_3, full_default, 0, 128, 136); slice_3 = full_default = None
    slice_scatter_3: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_1, slice_scatter_2, 0, 200, 400); slice_scatter_1 = slice_scatter_2 = None
    slice_4: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_3, 0, 200, 400)
    slice_scatter_4: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_4, getitem_2, 0, 136, 196); slice_4 = getitem_2 = None
    slice_scatter_5: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_3, slice_scatter_4, 0, 200, 400); slice_scatter_3 = slice_scatter_4 = None
    slice_5: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_5, 0, 200, 400)
    slice_scatter_6: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_5, getitem_3, 0, 196, 200); slice_5 = getitem_3 = None
    slice_scatter_7: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_5, slice_scatter_6, 0, 200, 400); slice_scatter_5 = slice_scatter_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor(
    slice_10: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_7, 0, 200, 400)
    all_gather_into_tensor: "f32[400]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_10, 2, '0'); slice_10 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
    wait_tensor: "f32[400]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:975 in all_gather_tensor_inplace, code: return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag))
    copy: "f32[400]" = torch.ops.aten.copy.default(slice_scatter_7, wait_tensor); slice_scatter_7 = wait_tensor = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    view_1: "f32[2, 200]" = torch.ops.aten.view.default(copy, [2, -1]); copy = None
    split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_1, [128, 8, 60, 4], 1); view_1 = None
    getitem_28: "f32[2, 128]" = split_with_sizes_6[0]
    clone: "f32[2, 128]" = torch.ops.aten.clone.default(getitem_28, memory_format = torch.contiguous_format); getitem_28 = None
    view_2: "f32[256]" = torch.ops.aten.view.default(clone, [256]); clone = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided: "f32[15, 16]" = torch.ops.aten.as_strided.default(view_2, [15, 16], [16, 1], 0); view_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_33: "f32[2, 8]" = split_with_sizes_6[1]
    clone_1: "f32[2, 8]" = torch.ops.aten.clone.default(getitem_33, memory_format = torch.contiguous_format); getitem_33 = None
    view_4: "f32[16]" = torch.ops.aten.view.default(clone_1, [16]); clone_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_1: "f32[15]" = torch.ops.aten.as_strided.default(view_4, [15], [1], 0); view_4 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_38: "f32[2, 60]" = split_with_sizes_6[2]
    clone_2: "f32[2, 60]" = torch.ops.aten.clone.default(getitem_38, memory_format = torch.contiguous_format); getitem_38 = None
    view_6: "f32[120]" = torch.ops.aten.view.default(clone_2, [120]); clone_2 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_2: "f32[8, 15]" = torch.ops.aten.as_strided.default(view_6, [8, 15], [15, 1], 0); view_6 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()),
    getitem_43: "f32[2, 4]" = split_with_sizes_6[3]; split_with_sizes_6 = None
    clone_3: "f32[2, 4]" = torch.ops.aten.clone.default(getitem_43, memory_format = torch.contiguous_format); getitem_43 = None
    view_8: "f32[8]" = torch.ops.aten.view.default(clone_3, [8]); clone_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided(
    as_strided_3: "f32[8]" = torch.ops.aten.as_strided.default(view_8, [8], [1], 0); view_8 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded)
    _foreach_copy_1 = torch.ops.aten._foreach_copy_.default([primals_6, primals_7, primals_8, primals_9], [as_strided, as_strided_1, as_strided_2, as_strided_3]); as_strided = as_strided_1 = as_strided_2 = as_strided_3 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_1: "f32[16, 15]" = torch.ops.aten.permute.default(primals_6, [1, 0]); primals_6 = None
    addmm: "f32[4, 15]" = torch.ops.aten.addmm.default(primals_7, primals_1, permute_1); primals_7 = permute_1 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/activation.py:103 in forward, code: return F.relu(input, inplace=self.inplace)
    relu: "f32[4, 15]" = torch.ops.aten.relu.default(addmm); addmm = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_3: "f32[15, 8]" = torch.ops.aten.permute.default(primals_8, [1, 0])
    addmm_1: "f32[4, 8]" = torch.ops.aten.addmm.default(primals_9, relu, permute_3); primals_9 = permute_3 = None
    return [addmm_1, primals_1, primals_8, relu]


    TRACED GRAPH
    ===== Backward graph 0 =====
    /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[4, 16]", primals_8: "f32[8, 15]", relu: "f32[4, 15]", tangents_1: "f32[4, 8]", primals_10):
    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/activation.py:103 in forward, code: return F.relu(input, inplace=self.inplace)
    alias: "f32[4, 15]" = torch.ops.aten.alias.default(relu)
    alias_1: "f32[4, 15]" = torch.ops.aten.alias.default(alias); alias = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_3: "f32[15, 8]" = torch.ops.aten.permute.default(primals_8, [1, 0]); primals_8 = None

    # No stacktrace found for following nodes
    trace_wrapped: "f32[4, 8]" = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(tangents_1, bw_state = primals_10); tangents_1 = primals_10 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_6: "f32[8, 15]" = torch.ops.aten.permute.default(permute_3, [1, 0]); permute_3 = None
    mm: "f32[4, 15]" = torch.ops.aten.mm.default(trace_wrapped, permute_6); permute_6 = None
    permute_7: "f32[8, 4]" = torch.ops.aten.permute.default(trace_wrapped, [1, 0])
    mm_1: "f32[8, 15]" = torch.ops.aten.mm.default(permute_7, relu); permute_7 = relu = None
    permute_8: "f32[15, 8]" = torch.ops.aten.permute.default(mm_1, [1, 0]); mm_1 = None
    sum_1: "f32[1, 8]" = torch.ops.aten.sum.dim_IntList(trace_wrapped, [0], True); trace_wrapped = None
    view_9: "f32[8]" = torch.ops.aten.view.default(sum_1, [8]); sum_1 = None
    permute_9: "f32[8, 15]" = torch.ops.aten.permute.default(permute_8, [1, 0]); permute_8 = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/activation.py:103 in forward, code: return F.relu(input, inplace=self.inplace)
    alias_2: "f32[4, 15]" = torch.ops.aten.alias.default(alias_1); alias_1 = None
    alias_3: "f32[4, 15]" = torch.ops.aten.alias.default(alias_2); alias_2 = None
    le: "b8[4, 15]" = torch.ops.aten.le.Scalar(alias_3, 0); alias_3 = None
    full_default_1: "f32[]" = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=1), pin_memory = False)
    where: "f32[4, 15]" = torch.ops.aten.where.self(le, full_default_1, mm); le = full_default_1 = mm = None

    # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias)
    permute_10: "f32[15, 4]" = torch.ops.aten.permute.default(where, [1, 0])
    mm_2: "f32[15, 16]" = torch.ops.aten.mm.default(permute_10, primals_1); permute_10 = primals_1 = None
    permute_11: "f32[16, 15]" = torch.ops.aten.permute.default(mm_2, [1, 0]); mm_2 = None
    sum_2: "f32[1, 15]" = torch.ops.aten.sum.dim_IntList(where, [0], True); where = None
    view_10: "f32[15]" = torch.ops.aten.view.default(sum_2, [15]); sum_2 = None
    permute_12: "f32[15, 16]" = torch.ops.aten.permute.default(permute_11, [1, 0]); permute_11 = None
    return [None, None, None, None, None, permute_12, view_10, permute_9, view_9, None]