This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # testing PT Lightning + FSDP2 + CPU offloading + compile + validation | |
| # toy training loop | |
| # 1. copy-paste from https://github.com/Lightning-AI/pytorch-lightning?tab=readme-ov-file#pytorch-lightning-example | |
| # 2. modify to include conversion to torchao.float8 and compiling encoder/decoder | |
| # main.py | |
| # ! pip install torchvision | |
| import torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F | |
| from torch.utils.data import Subset |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| File "/data/users/willfeng/pytorch_yf225/torch/_dynamo/variables/builtin.py", line 939, in call_function | |
| return handler(tx, args, kwargs) | |
| ^^^^^^^^^^^^^^^^^^^^^^^^^ | |
| File "/data/users/willfeng/pytorch_yf225/torch/_dynamo/variables/builtin.py", line 814, in builtin_dipatch | |
| rv = handler(tx, args, kwargs) | |
| ^^^^^^^^^^^^^^^^^^^^^^^^^ | |
| File "/data/users/willfeng/pytorch_yf225/torch/_dynamo/variables/builtin.py", line 743, in call_self_handler | |
| result = self_handler(tx, *args, **kwargs) | |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | |
| File "/data/users/willfeng/pytorch_yf225/torch/_dynamo/variables/builtin.py", line 1621, in call_setattr |
This file has been truncated, but you can view the full file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| + export USE_LIBUV=1 | |
| + USE_LIBUV=1 | |
| + TRAINER_DIR=/home/willfeng/local/torchtrain | |
| + NGPU=8 | |
| + LOG_RANK=0 | |
| + CONFIG_FILE=./train_configs/llama_1b_full_graph_fsdp.toml | |
| + torchrun --nproc_per_node=8 --rdzv_endpoint=localhost:5972 --local-ranks-filter 0 --role rank --tee 3 train.py --job.config_file ./train_configs/llama_1b_full_graph_fsdp.toml | |
| W2024-03-28 18:03:14,934.934000 140450963445568 torch/distributed/run.py:757] | |
| W2024-03-28 18:03:14,934.934000 140450963445568 torch/distributed/run.py:757] ***************************************** |
This file has been truncated, but you can view the full file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| + export USE_LIBUV=1 | |
| + USE_LIBUV=1 | |
| + TRAINER_DIR=/home/willfeng/local/torchtrain | |
| + NGPU=8 | |
| + LOG_RANK=0 | |
| + CONFIG_FILE=./train_configs/toy_model_full_graph_fsdp.toml | |
| + torchrun --nproc_per_node=8 --rdzv_endpoint=localhost:5972 --local-ranks-filter 0 --role rank --tee 3 train.py --job.config_file ./train_configs/toy_model_full_graph_fsdp.toml | |
| W2024-03-27 14:57:24,673.673000 140480499401728 torch/distributed/run.py:757] | |
| W2024-03-27 14:57:24,673.673000 140480499401728 torch/distributed/run.py:757] ***************************************** | |
| W2024-03-27 14:57:24,673.673000 140480499401728 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| TRACED GRAPH | |
| ===== AFTER POST GRAD ===== | |
| /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): | |
| def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[1024]", primals_3: "f32[32]", primals_4: "f32[64, 32]", primals_5: "f32[64]", primals_6, primals_7: "f32[4096]", primals_8: "f32[64]", primals_9: "f32[128, 64]", primals_10: "f32[128]", primals_11: "f32[16384]", primals_12: "f32[128]", primals_13: "f32[256, 128]", primals_14: "f32[256]"): | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( | |
| empty: "f32[2112]" = torch.ops.aten.empty.memory_format([2112], dtype = torch.float32, device = device(type='cuda', index=0), pin_memory = False) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:50 in foreach_all_gather, code: all_gather_input = all_gather_output.narrow( | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ===== AFTER POST GRAD ===== | |
| /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): | |
| def forward(self, arg0_1: "f32[]", arg1_1: "f32[8, 32]", arg2_1: "f32[32, 128]", arg3_1: "f32[128, 32]", arg4_1: "f32[32, 128]", arg5_1: "f32[128, 32]", arg6_1: "f32[32, 128]", arg7_1: "f32[8, 128]", arg8_1: "f32[8, 32]", arg9_1: "f32[8, 128]", arg10_1: "f32[8, 32]", arg11_1: "f32[8, 128]", arg12_1: "b8[8, 32]", arg13_1: "f32[32]", arg14_1: "f32[128]", arg15_1: "f32[32]", arg16_1: "f32[128]", arg17_1: "f32[32]", arg18_1: "f32[128]", arg19_1: "f32[128, 32]", arg20_1: "f32[2048]", arg21_1: "f32[64]", arg22_1: "f32[2048]", arg23_1: "f32[16]", arg24_1: "f32[2048]", arg25_1: "f32[64]", arg26_1: "f32[2048]", arg27_1: "f32[16]", arg28_1: "f32[2048]", arg29_1: "f32[64]", arg30_1: "f32[2048]", arg31_1: "f32[16]"): | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| TRACED GRAPH | |
| ===== AFTER POST GRAD ===== | |
| /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): | |
| def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[2048]", primals_3: "f32[64]", primals_4: "f32[2048]", primals_5: "f32[16]", primals_6: "f32[128, 32]", primals_7: "f32[128]", primals_8: "f32[32, 128]", primals_9: "f32[32]", primals_10, primals_11: "f32[2048]", primals_12: "f32[64]", primals_13: "f32[2048]", primals_14: "f32[16]", primals_15: "f32[128, 32]", primals_16: "f32[128]", primals_17: "f32[32, 128]", primals_18: "f32[32]", primals_19: "f32[2048]", primals_20: "f32[64]", primals_21: "f32[2048]", primals_22: "f32[16]", primals_23: "f32[128, 32]", primals_24: "f32[128]", primals_25: "f32[32, 128]", primals_26: "f32[32]"): | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( | |
| empty: "f32[8352]" = torch.ops.aten.empty.memory_f |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| BWD graph | |
| ===== AFTER POST GRAD ===== | |
| /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class <lambda>(torch.nn.Module): | |
| def forward(self, arg0_1: "f32[]", arg1_1: "f32[8, 32]", arg2_1: "f32[32, 128]", arg3_1: "f32[128, 32]", arg4_1: "f32[32, 128]", arg5_1: "f32[128, 32]", arg6_1: "f32[32, 128]", arg7_1: "f32[8, 128]", arg8_1: "f32[8, 32]", arg9_1: "f32[8, 128]", arg10_1: "f32[8, 32]", arg11_1: "f32[8, 128]", arg12_1: "b8[8, 32]", arg13_1: "f32[32]", arg14_1: "f32[128]", arg15_1: "f32[32]", arg16_1: "f32[128]", arg17_1: "f32[32]", arg18_1: "f32[128]", arg19_1: "f32[128, 32]", arg20_1: "f32[2048]", arg21_1: "f32[64]", arg22_1: "f32[2048]", arg23_1: "f32[16]", arg24_1: "f32[2048]", arg25_1: "f32[64]", arg26_1: "f32[2048]", arg27_1: "f32[16]", arg28_1: "f32[2048]", arg29_1: "f32[64]", arg30_1: "f32[2048]", arg31_1: "f32[16]"): | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| FWD graph | |
| ===== AFTER POST GRAD ===== | |
| /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): | |
| def forward(self, primals_1: "f32[8, 32]", primals_2: "f32[2048]", primals_3: "f32[64]", primals_4: "f32[2048]", primals_5: "f32[16]", primals_6: "f32[128, 32]", primals_7: "f32[128]", primals_8: "f32[32, 128]", primals_9: "f32[32]", primals_10, primals_11: "f32[2048]", primals_12: "f32[64]", primals_13: "f32[2048]", primals_14: "f32[16]", primals_15: "f32[128, 32]", primals_16: "f32[128]", primals_17: "f32[32, 128]", primals_18: "f32[32]", primals_19: "f32[2048]", primals_20: "f32[64]", primals_21: "f32[2048]", primals_22: "f32[16]", primals_23: "f32[128, 32]", primals_24: "f32[128]", primals_25: "f32[32, 128]", primals_26: "f32[32]"): | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( | |
| empty: "f32[8352]" = torch.ops.aten.empty.memory_for |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ===== Joint graph 0 ===== | |
| /data/users/willfeng/pytorch_yf225/torch/fx/_lazy_graph_module.py class joint_helper(torch.nn.Module): | |
| def forward(self, primals, tangents): | |
| primals_1: "f32[4, 16]"; primals_2: "f32[128]"; primals_3: "f32[8]"; primals_4: "f32[60]"; primals_5: "f32[4]"; primals_6: "f32[15, 16]"; primals_7: "f32[15]"; primals_8: "f32[8, 15]"; primals_9: "f32[8]"; tangents_1: "f32[4, 8]"; | |
| primals_1, primals_2, primals_3, primals_4, primals_5, primals_6, primals_7, primals_8, primals_9, primals_10, tangents_1, = fx_pytree.tree_flatten_spec([primals, tangents], self._in_spec) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_composable/fsdp/_fsdp_collectives.py:47 in foreach_all_gather, code: all_gather_output = torch.empty( | |
| empty: "f32[400]" = torch.ops.aten.empty.memory_format([400], dtype = torch.float32, device = device(type='cuda', index=1), pin_memory = False) | |
| # File: /data/users/willfeng/pytorch_yf225/torch/distributed/_compos |
NewerOlder