Created
March 23, 2024 19:07
-
-
Save yf225/8eebf52f5fe3352caaa4f08cdcd61623 to your computer and use it in GitHub Desktop.
Revisions
-
yf225 created this gist
Mar 23, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,271 @@ ===== Joint graph 0 ===== /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class joint_helper(torch.nn.Module): def forward(self, primals, tangents): primals_1: "f32[4, 16]"; primals_2: "f32[128]"; primals_3: "f32[8]"; primals_4: "f32[60]"; primals_5: "f32[4]"; primals_6: "f32[15, 16]"; primals_7: "f32[15]"; primals_8: "f32[8, 15]"; primals_9: "f32[8]"; tangents_1: "f32[4, 8]"; primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( empty: "f32[400]" = torch.ops.aten.empty.memory_format([400], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( slice_1: "f32[200]" = torch.ops.aten.slice.Tensor(empty, 0, 200, 400) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [128, 8, 60, 4]); slice_1 = None getitem: "f32[128]" = split_with_sizes[0] getitem_1: "f32[8]" = split_with_sizes[1] getitem_2: "f32[60]" = split_with_sizes[2] getitem_3: "f32[4]" = split_with_sizes[3]; split_with_sizes = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) _foreach_copy = torch.ops.aten._foreach_copy.default([getitem, getitem_1, getitem_2, getitem_3], [primals_2, primals_3, primals_4, primals_5]); getitem = getitem_1 = getitem_2 = getitem_3 = primals_2 = primals_3 = primals_4 = primals_5 = None getitem_4: "f32[128]" = _foreach_copy[0] getitem_5: "f32[8]" = _foreach_copy[1] getitem_6: "f32[60]" = _foreach_copy[2] getitem_7: "f32[4]" = _foreach_copy[3]; _foreach_copy = None slice_2: "f32[200]" = torch.ops.aten.slice.Tensor(empty, 0, 200, 400) slice_scatter: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_2, getitem_4, 0, 0, 128); slice_2 = getitem_4 = None slice_scatter_1: "f32[400]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter, 0, 200, 400); empty = slice_scatter = None slice_3: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_1, 0, 200, 400) slice_scatter_2: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_3, getitem_5, 0, 128, 136); slice_3 = getitem_5 = None slice_scatter_3: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_1, slice_scatter_2, 0, 200, 400); slice_scatter_1 = slice_scatter_2 = None slice_4: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_3, 0, 200, 400) slice_scatter_4: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_4, getitem_6, 0, 136, 196); slice_4 = getitem_6 = None slice_scatter_5: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_3, slice_scatter_4, 0, 200, 400); slice_scatter_3 = slice_scatter_4 = None slice_5: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_5, 0, 200, 400) slice_scatter_6: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_5, getitem_7, 0, 196, 200); slice_5 = getitem_7 = None slice_scatter_7: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_5, slice_scatter_6, 0, 200, 400); slice_scatter_5 = slice_scatter_6 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( slice_10: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_7, 0, 200, 400) all_gather_into_tensor: "f32[400]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_10, 2, '0'); slice_10 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] wait_tensor: "f32[400]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:975 in all_gather_tensor_inplace, code: return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag)) copy: "f32[400]" = torch.ops.aten.copy.default(slice_scatter_7, wait_tensor); slice_scatter_7 = wait_tensor = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), view_1: "f32[2, 200]" = torch.ops.aten.view.default(copy, [2, -1]) split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_1, [128, 8, 60, 4], 1); view_1 = None getitem_28: "f32[2, 128]" = split_with_sizes_6[0]; split_with_sizes_6 = None clone: "f32[2, 128]" = torch.ops.aten.clone.default(getitem_28, memory_format = torch.contiguous_format); getitem_28 = None view_2: "f32[256]" = torch.ops.aten.view.default(clone, [256]); clone = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided: "f32[15, 16]" = torch.ops.aten.as_strided.default(view_2, [15, 16], [16, 1], 0); view_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), view_3: "f32[2, 200]" = torch.ops.aten.view.default(copy, [2, -1]) split_with_sizes_7 = torch.ops.aten.split_with_sizes.default(view_3, [128, 8, 60, 4], 1); view_3 = None getitem_33: "f32[2, 8]" = split_with_sizes_7[1]; split_with_sizes_7 = None clone_1: "f32[2, 8]" = torch.ops.aten.clone.default(getitem_33, memory_format = torch.contiguous_format); getitem_33 = None view_4: "f32[16]" = torch.ops.aten.view.default(clone_1, [16]); clone_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_1: "f32[15]" = torch.ops.aten.as_strided.default(view_4, [15], [1], 0); view_4 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), view_5: "f32[2, 200]" = torch.ops.aten.view.default(copy, [2, -1]) split_with_sizes_8 = torch.ops.aten.split_with_sizes.default(view_5, [128, 8, 60, 4], 1); view_5 = None getitem_38: "f32[2, 60]" = split_with_sizes_8[2]; split_with_sizes_8 = None clone_2: "f32[2, 60]" = torch.ops.aten.clone.default(getitem_38, memory_format = torch.contiguous_format); getitem_38 = None view_6: "f32[120]" = torch.ops.aten.view.default(clone_2, [120]); clone_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_2: "f32[8, 15]" = torch.ops.aten.as_strided.default(view_6, [8, 15], [15, 1], 0); view_6 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), view_7: "f32[2, 200]" = torch.ops.aten.view.default(copy, [2, -1]); copy = None split_with_sizes_9 = torch.ops.aten.split_with_sizes.default(view_7, [128, 8, 60, 4], 1); view_7 = None getitem_43: "f32[2, 4]" = split_with_sizes_9[3]; split_with_sizes_9 = None clone_3: "f32[2, 4]" = torch.ops.aten.clone.default(getitem_43, memory_format = torch.contiguous_format); getitem_43 = None view_8: "f32[8]" = torch.ops.aten.view.default(clone_3, [8]); clone_3 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_3: "f32[8]" = torch.ops.aten.as_strided.default(view_8, [8], [1], 0); view_8 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) _foreach_copy_1 = torch.ops.aten._foreach_copy.default([primals_6, primals_7, primals_8, primals_9], [as_strided, as_strided_1, as_strided_2, as_strided_3]); primals_6 = primals_7 = primals_9 = as_strided = as_strided_1 = as_strided_2 = as_strided_3 = None getitem_44: "f32[15, 16]" = _foreach_copy_1[0] getitem_45: "f32[15]" = _foreach_copy_1[1] getitem_46: "f32[8, 15]" = _foreach_copy_1[2] getitem_47: "f32[8]" = _foreach_copy_1[3]; _foreach_copy_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_1: "f32[16, 15]" = torch.ops.aten.permute.default(getitem_44, [1, 0]); getitem_44 = None addmm: "f32[4, 15]" = torch.ops.aten.addmm.default(getitem_45, primals_1, permute_1); getitem_45 = permute_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/activation.py:103 in forward, code: return F.relu(input, inplace=self.inplace) relu: "f32[4, 15]" = torch.ops.aten.relu.default(addmm); addmm = None alias: "f32[4, 15]" = torch.ops.aten.alias.default(relu) alias_1: "f32[4, 15]" = torch.ops.aten.alias.default(alias); alias = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_3: "f32[15, 8]" = torch.ops.aten.permute.default(getitem_46, [1, 0]); getitem_46 = None addmm_1: "f32[4, 8]" = torch.ops.aten.addmm.default(getitem_47, relu, permute_3); getitem_47 = permute_3 = None # No stacktrace found for following nodes trace_wrapped: "f32[4, 8]" = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(tangents_1, bw_state = primals_10); tangents_1 = primals_10 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_5: "f32[15, 8]" = torch.ops.aten.permute.default(primals_8, [1, 0]); primals_8 = None permute_6: "f32[8, 15]" = torch.ops.aten.permute.default(permute_5, [1, 0]); permute_5 = None mm: "f32[4, 15]" = torch.ops.aten.mm.default(trace_wrapped, permute_6); permute_6 = None permute_7: "f32[8, 4]" = torch.ops.aten.permute.default(trace_wrapped, [1, 0]) mm_1: "f32[8, 15]" = torch.ops.aten.mm.default(permute_7, relu); permute_7 = relu = None permute_8: "f32[15, 8]" = torch.ops.aten.permute.default(mm_1, [1, 0]); mm_1 = None sum_1: "f32[1, 8]" = torch.ops.aten.sum.dim_IntList(trace_wrapped, [0], True); trace_wrapped = None view_9: "f32[8]" = torch.ops.aten.view.default(sum_1, [8]); sum_1 = None permute_9: "f32[8, 15]" = torch.ops.aten.permute.default(permute_8, [1, 0]); permute_8 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/activation.py:103 in forward, code: return F.relu(input, inplace=self.inplace) alias_2: "f32[4, 15]" = torch.ops.aten.alias.default(alias_1); alias_1 = None alias_3: "f32[4, 15]" = torch.ops.aten.alias.default(alias_2); alias_2 = None le: "b8[4, 15]" = torch.ops.aten.le.Scalar(alias_3, 0); alias_3 = None scalar_tensor: "f32[]" = torch.ops.aten.scalar_tensor.default(0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=1)) where: "f32[4, 15]" = torch.ops.aten.where.self(le, scalar_tensor, mm); le = scalar_tensor = mm = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_10: "f32[15, 4]" = torch.ops.aten.permute.default(where, [1, 0]) mm_2: "f32[15, 16]" = torch.ops.aten.mm.default(permute_10, primals_1); permute_10 = primals_1 = None permute_11: "f32[16, 15]" = torch.ops.aten.permute.default(mm_2, [1, 0]); mm_2 = None sum_2: "f32[1, 15]" = torch.ops.aten.sum.dim_IntList(where, [0], True); where = None view_10: "f32[15]" = torch.ops.aten.view.default(sum_2, [15]); sum_2 = None permute_12: "f32[15, 16]" = torch.ops.aten.permute.default(permute_11, [1, 0]); permute_11 = None return pytree.tree_unflatten([addmm_1, None, None, None, None, None, permute_12, view_10, permute_9, view_9, None], self._out_spec) /data/users/willfeng/pytorch_yf225/torch/_inductor/compile_fx.py:133: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance. warnings.warn( TRACED GRAPH ===== Forward graph 0 ===== /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[4, 16]", primals_2: "f32[128]", primals_3: "f32[8]", primals_4: "f32[60]", primals_5: "f32[4]", primals_6: "f32[15, 16]", primals_7: "f32[15]", primals_8: "f32[8, 15]", primals_9: "f32[8]", primals_10): # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( empty: "f32[400]" = torch.ops.aten.empty.memory_format([400], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( slice_1: "f32[200]" = torch.ops.aten.slice.Tensor(empty, 0, 200, 400) # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:53 in foreach_all_gather, code: foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) split_with_sizes = torch.ops.aten.split_with_sizes.default(slice_1, [128, 8, 60, 4]) getitem: "f32[128]" = split_with_sizes[0] full_default: "f32[8]" = torch.ops.aten.full.default([8], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=1), pin_memory = False) getitem_2: "f32[60]" = split_with_sizes[2] getitem_3: "f32[4]" = split_with_sizes[3]; split_with_sizes = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:55 in foreach_all_gather, code: torch._foreach_copy_(foreach_copy_dsts, param_all_gather_inputs) _foreach_copy = torch.ops.aten._foreach_copy_.default([getitem, full_default, getitem_2, getitem_3], [primals_2, primals_3, primals_4, primals_5]); primals_2 = primals_3 = primals_4 = primals_5 = None slice_scatter: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_1, getitem, 0, 0, 128); slice_1 = getitem = None slice_scatter_1: "f32[400]" = torch.ops.aten.slice_scatter.default(empty, slice_scatter, 0, 200, 400); empty = slice_scatter = None slice_3: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_1, 0, 200, 400) slice_scatter_2: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_3, full_default, 0, 128, 136); slice_3 = full_default = None slice_scatter_3: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_1, slice_scatter_2, 0, 200, 400); slice_scatter_1 = slice_scatter_2 = None slice_4: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_3, 0, 200, 400) slice_scatter_4: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_4, getitem_2, 0, 136, 196); slice_4 = getitem_2 = None slice_scatter_5: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_3, slice_scatter_4, 0, 200, 400); slice_scatter_3 = slice_scatter_4 = None slice_5: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_5, 0, 200, 400) slice_scatter_6: "f32[200]" = torch.ops.aten.slice_scatter.default(slice_5, getitem_3, 0, 196, 200); slice_5 = getitem_3 = None slice_scatter_7: "f32[400]" = torch.ops.aten.slice_scatter.default(slice_scatter_5, slice_scatter_6, 0, 200, 400); slice_scatter_5 = slice_scatter_6 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:229 in all_gather_tensor, code: tensor = torch.ops._c10d_functional.all_gather_into_tensor( slice_10: "f32[200]" = torch.ops.aten.slice.Tensor(slice_scatter_7, 0, 200, 400) all_gather_into_tensor: "f32[400]" = torch.ops._c10d_functional.all_gather_into_tensor.default(slice_10, 2, '0'); slice_10 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:144 in wait_tensor, code: return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined] wait_tensor: "f32[400]" = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor); all_gather_into_tensor = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_functional_collectives.py:975 in all_gather_tensor_inplace, code: return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag)) copy: "f32[400]" = torch.ops.aten.copy.default(slice_scatter_7, wait_tensor); slice_scatter_7 = wait_tensor = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), view_1: "f32[2, 200]" = torch.ops.aten.view.default(copy, [2, -1]); copy = None split_with_sizes_6 = torch.ops.aten.split_with_sizes.default(view_1, [128, 8, 60, 4], 1); view_1 = None getitem_28: "f32[2, 128]" = split_with_sizes_6[0] clone: "f32[2, 128]" = torch.ops.aten.clone.default(getitem_28, memory_format = torch.contiguous_format); getitem_28 = None view_2: "f32[256]" = torch.ops.aten.view.default(clone, [256]); clone = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided: "f32[15, 16]" = torch.ops.aten.as_strided.default(view_2, [15, 16], [16, 1], 0); view_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_33: "f32[2, 8]" = split_with_sizes_6[1] clone_1: "f32[2, 8]" = torch.ops.aten.clone.default(getitem_33, memory_format = torch.contiguous_format); getitem_33 = None view_4: "f32[16]" = torch.ops.aten.view.default(clone_1, [16]); clone_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_1: "f32[15]" = torch.ops.aten.as_strided.default(view_4, [15], [1], 0); view_4 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_38: "f32[2, 60]" = split_with_sizes_6[2] clone_2: "f32[2, 60]" = torch.ops.aten.clone.default(getitem_38, memory_format = torch.contiguous_format); getitem_38 = None view_6: "f32[120]" = torch.ops.aten.view.default(clone_2, [120]); clone_2 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_2: "f32[8, 15]" = torch.ops.aten.as_strided.default(view_6, [8, 15], [15, 1], 0); view_6 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:128 in foreach_all_gather_copy_out, code: splits[i].contiguous().view(splits[i].numel()), getitem_43: "f32[2, 4]" = split_with_sizes_6[3]; split_with_sizes_6 = None clone_3: "f32[2, 4]" = torch.ops.aten.clone.default(getitem_43, memory_format = torch.contiguous_format); getitem_43 = None view_8: "f32[8]" = torch.ops.aten.view.default(clone_3, [8]); clone_3 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:127 in foreach_all_gather_copy_out, code: torch.as_strided( as_strided_3: "f32[8]" = torch.ops.aten.as_strided.default(view_8, [8], [1], 0); view_8 = None # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:138 in foreach_all_gather_copy_out, code: torch._foreach_copy_(out, splits_unpadded) _foreach_copy_1 = torch.ops.aten._foreach_copy_.default([primals_6, primals_7, primals_8, primals_9], [as_strided, as_strided_1, as_strided_2, as_strided_3]); as_strided = as_strided_1 = as_strided_2 = as_strided_3 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_1: "f32[16, 15]" = torch.ops.aten.permute.default(primals_6, [1, 0]); primals_6 = None addmm: "f32[4, 15]" = torch.ops.aten.addmm.default(primals_7, primals_1, permute_1); primals_7 = permute_1 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/activation.py:103 in forward, code: return F.relu(input, inplace=self.inplace) relu: "f32[4, 15]" = torch.ops.aten.relu.default(addmm); addmm = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_3: "f32[15, 8]" = torch.ops.aten.permute.default(primals_8, [1, 0]) addmm_1: "f32[4, 8]" = torch.ops.aten.addmm.default(primals_9, relu, permute_3); primals_9 = permute_3 = None return [addmm_1, primals_1, primals_8, relu] TRACED GRAPH ===== Backward graph 0 ===== /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[4, 16]", primals_8: "f32[8, 15]", relu: "f32[4, 15]", tangents_1: "f32[4, 8]", primals_10): # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/activation.py:103 in forward, code: return F.relu(input, inplace=self.inplace) alias: "f32[4, 15]" = torch.ops.aten.alias.default(relu) alias_1: "f32[4, 15]" = torch.ops.aten.alias.default(alias); alias = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_3: "f32[15, 8]" = torch.ops.aten.permute.default(primals_8, [1, 0]); primals_8 = None # No stacktrace found for following nodes trace_wrapped: "f32[4, 8]" = torch__dynamo__trace_wrapped_higher_order_op_self_invoke(tangents_1, bw_state = primals_10); tangents_1 = primals_10 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_6: "f32[8, 15]" = torch.ops.aten.permute.default(permute_3, [1, 0]); permute_3 = None mm: "f32[4, 15]" = torch.ops.aten.mm.default(trace_wrapped, permute_6); permute_6 = None permute_7: "f32[8, 4]" = torch.ops.aten.permute.default(trace_wrapped, [1, 0]) mm_1: "f32[8, 15]" = torch.ops.aten.mm.default(permute_7, relu); permute_7 = relu = None permute_8: "f32[15, 8]" = torch.ops.aten.permute.default(mm_1, [1, 0]); mm_1 = None sum_1: "f32[1, 8]" = torch.ops.aten.sum.dim_IntList(trace_wrapped, [0], True); trace_wrapped = None view_9: "f32[8]" = torch.ops.aten.view.default(sum_1, [8]); sum_1 = None permute_9: "f32[8, 15]" = torch.ops.aten.permute.default(permute_8, [1, 0]); permute_8 = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/activation.py:103 in forward, code: return F.relu(input, inplace=self.inplace) alias_2: "f32[4, 15]" = torch.ops.aten.alias.default(alias_1); alias_1 = None alias_3: "f32[4, 15]" = torch.ops.aten.alias.default(alias_2); alias_2 = None le: "b8[4, 15]" = torch.ops.aten.le.Scalar(alias_3, 0); alias_3 = None full_default_1: "f32[]" = torch.ops.aten.full.default([], 0.0, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=1), pin_memory = False) where: "f32[4, 15]" = torch.ops.aten.where.self(le, full_default_1, mm); le = full_default_1 = mm = None # File: /data/users/willfeng/pytorch_yf225/torch/nn/modules/linear.py:116 in forward, code: return F.linear(input, self.weight, self.bias) permute_10: "f32[15, 4]" = torch.ops.aten.permute.default(where, [1, 0]) mm_2: "f32[15, 16]" = torch.ops.aten.mm.default(permute_10, primals_1); permute_10 = primals_1 = None permute_11: "f32[16, 15]" = torch.ops.aten.permute.default(mm_2, [1, 0]); mm_2 = None sum_2: "f32[1, 15]" = torch.ops.aten.sum.dim_IntList(where, [0], True); where = None view_10: "f32[15]" = torch.ops.aten.view.default(sum_2, [15]); sum_2 = None permute_12: "f32[15, 16]" = torch.ops.aten.permute.default(permute_11, [1, 0]); permute_11 = None return [None, None, None, None, None, permute_12, view_10, permute_9, view_9, None]