Instantly share code, notes, and snippets.
Created
May 22, 2025 08:21
-
Star
0
(0)
You must be signed in to star a gist -
Fork
0
(0)
You must be signed in to fork a gist
-
Save yiliu30/e2dcacc941ff206aac6aba52e5f46664 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| ------------------------------------------------------------------------------ | |
| out shape: torch.Size([4096, 7168]) | |
| out shape: torch.Size([4096, 7168]) | |
| out shape: torch.Size([8192, 7168]) | |
| out shape: torch.Size([8192, 7168]) | |
| out shape: torch.Size([16384, 7168]) | |
| out shape: torch.Size([16384, 7168]) | |
| out shape: torch.Size([4096, 7168]) | |
| out shape: torch.Size([4096, 7168]) | |
| [rank: 0][test 0] out max: 14637248544768.0 | |
| [rank: 1][test 0] out max: 14637248544768.0 | |
| [rank: 1][test 0] out max: 14637248544768.0 | |
| [rank: 0][test 0] out max: 14637248544768.0 | |
| [rank: 0][test 0] out max: 14637248544768.0 | |
| [rank: 1][test 0] out max: 14637248544768.0 | |
| Test 0 RANK 1 HPU Graph 0 Used memory: 0 B | |
| Test 0 RANK 0 HPU Graph 0 Used memory: 0 B | |
| out shape: torch.Size([8192, 7168]) | |
| out shape: torch.Size([8192, 7168]) | |
| [rank: 1][test 1] out max: 14637248544768.0 | |
| [rank: 0][test 1] out max: 14637248544768.0 | |
| [rank: 0][test 1] out max: 14637248544768.0 | |
| [rank: 1][test 1] out max: 14637248544768.0 | |
| [rank: 1][test 1] out max: 14637248544768.0 | |
| [rank: 0][test 1] out max: 14637248544768.0 | |
| Test 1 RANK 0 HPU Graph 0 Used memory: 0 B | |
| Test 1 RANK 1 HPU Graph 0 Used memory: 0 B | |
| out shape: torch.Size([16384, 7168]) | |
| out shape: torch.Size([16384, 7168]) | |
| [rank: 1][test 2] out max: 14637248544768.0 | |
| [rank: 0][test 2] out max: 14637248544768.0 | |
| [rank: 0][test 2] out max: 14637248544768.0 | |
| [rank: 1][test 2] out max: 14637248544768.0 | |
| [rank: 1][test 2] out max: 14637248544768.0 | |
| [rank: 0][test 2] out max: 14637248544768.0 | |
| Test 2 RANK 1 HPU Graph 0 Used memory: 0 B | |
| Test 2 RANK 0 HPU Graph 0 Used memory: 0 B | |
| RANK 0 HPU Graph Total memory: 402 MiB | |
| RANK 1 HPU Graph Total memory: 402 MiB | |
| Test 0 RANK 0 HPU Graph 0 Used memory: 56 MiB | |
| Test 0 RANK 1 HPU Graph 0 Used memory: 56 MiB | |
| out shape: torch.Size([8192, 7168]) | |
| out shape: torch.Size([8192, 7168]) | |
| Test 1 RANK 0 HPU Graph 0 Used memory: 112 MiB | |
| Test 1 RANK 1 HPU Graph 0 Used memory: 112 MiB | |
| out shape: torch.Size([16384, 7168]) | |
| out shape: torch.Size([16384, 7168]) | |
| Test 2 RANK 0 HPU Graph 0 Used memory: 224 MiB | |
| Test 2 RANK 1 HPU Graph 0 Used memory: 224 MiB | |
| RANK 0 HPU Graph Total memory: 182 MiB | |
| RANK 1 HPU Graph Total memory: 182 MiB | |
| """ | |
| import os | |
| import gc | |
| import functools | |
| import habana_frameworks.torch.internal.bridge_config as bc | |
| os.environ['PT_HPU_LAZY_MODE'] = '1' | |
| os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = '1' | |
| # os.environ['HABANA_PROFILE'] = '1' | |
| # os.environ['GRAPH_VISUALIZATION'] = '1' | |
| import torch.distributed as dist | |
| import torch | |
| import habana_frameworks.torch as ht | |
| import habana_frameworks.torch.core as htcore | |
| import habana_frameworks.torch.internal.bridge_config as bc | |
| from vllm_hpu_extension.profiler import (HabanaMemoryProfiler, format_bytes) | |
| from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu | |
| device = torch.device('hpu') | |
| activities = [torch.profiler.ProfilerActivity.CPU] | |
| activities.append(torch.profiler.ProfilerActivity.HPU) | |
| out_buffer = torch.zeros([16384, 7168], dtype=torch.bfloat16).to('hpu') | |
| def update_tensor(tensor): | |
| seq_len, hidden_dim = tensor.shape | |
| assert len(tensor.shape) == len(out_buffer.shape), f"tensor shape {tensor.shape} does not match out_buffer shape {out_buffer.shape}" | |
| out_buffer.zero_() | |
| out_buffer[:seq_len, :hidden_dim] = tensor | |
| htcore.mark_step() | |
| del tensor | |
| new_tensor = out_buffer[:seq_len, :hidden_dim] | |
| return new_tensor | |
| def fn(data, weight, shared_output, out_buffer=None): | |
| batch_size, seq_len, hidden_dim = data.shape | |
| data = data.view(batch_size * seq_len, hidden_dim) | |
| out = torch.matmul(data, weight) * 3.14 | |
| out = out + shared_output | |
| htcore.mark_step() | |
| # dist.all_reduce(out, op=dist.ReduceOp.SUM) | |
| print(f"out shape: {out.shape}") | |
| out = update_tensor(out) | |
| if out_buffer is not None: | |
| out_buffer.copy_(out) | |
| return None | |
| return out | |
| # our = out.view(batch_size, seq_len, hidden_dim) | |
| def main(): | |
| rank: int = dist.get_rank() | |
| wrapped_fn = torch.hpu.wrap_in_hpu_graph_func(fn, disable_tensor_cache=True) | |
| data = torch.randn([2, 2048, 7168], dtype=torch.bfloat16).to('hpu') | |
| weight = torch.randn([7168, 7168], dtype=torch.bfloat16).to('hpu') | |
| shared_output = torch.randn([4096, 7168], dtype=torch.bfloat16).to('hpu') | |
| out_buffer = torch.randn([4096, 7168], dtype=torch.bfloat16).to('hpu') | |
| data1 = torch.randn([4, 2048, 7168], dtype=torch.bfloat16).to('hpu') | |
| weight1 = torch.randn([7168, 7168], dtype=torch.bfloat16).to('hpu') | |
| shared_output1 = torch.randn([8192, 7168], dtype=torch.bfloat16).to('hpu') | |
| out_buffer1 = torch.randn([8192, 7168], dtype=torch.bfloat16).to('hpu') | |
| data2 = torch.randn([8, 2048, 7168], dtype=torch.bfloat16).to('hpu') | |
| weight2 = torch.randn([7168, 7168], dtype=torch.bfloat16).to('hpu') | |
| shared_output2 = torch.randn([16384, 7168], dtype=torch.bfloat16).to('hpu') | |
| out_buffer2 = torch.randn([16384, 7168], dtype=torch.bfloat16).to('hpu') | |
| htcore.mark_step() | |
| start_mem = HabanaMemoryProfiler.current_device_memory_usage() | |
| compile_only_mode_context = functools.partial(bc.env_setting, | |
| "PT_COMPILE_ONLY_MODE", | |
| True) | |
| with compile_only_mode_context(): | |
| out = fn(data, weight, shared_output, out_buffer) | |
| out_ref = out_buffer.clone() | |
| htcore.mark_step() | |
| torch.hpu.synchronize() | |
| torch.distributed.barrier() | |
| out1 = fn(data1, weight1, shared_output1, out_buffer1) | |
| out1_ref = out_buffer1.clone() | |
| htcore.mark_step() | |
| torch.hpu.synchronize() | |
| torch.distributed.barrier() | |
| out2 = fn(data2, weight2, shared_output2, out_buffer2) | |
| out2_ref = out_buffer2.clone() | |
| htcore.mark_step() | |
| torch.hpu.synchronize() | |
| torch.distributed.barrier() | |
| test_pairs = [ | |
| (data, weight, shared_output, out_ref, out_buffer), | |
| (data1, weight1, shared_output1, out1_ref, out_buffer1), | |
| (data2, weight2, shared_output2, out2_ref, out_buffer2), | |
| ] | |
| for i, (data, weight, shared_output, out, out_buffer) in enumerate(test_pairs): | |
| with HabanaMemoryProfiler() as mem_prof: | |
| for _ in range(3): | |
| res = wrapped_fn(data, weight, shared_output, out_buffer) | |
| htcore.mark_step() | |
| torch.hpu.synchronize() | |
| print(f"[rank: {rank}][test {i}] out max: {out_buffer.max()}") | |
| # if not torch.allclose(out, out_buffer): | |
| # print(f"RANK {rank} HPU Graph {i} out not equal") | |
| # print(f"out: {out}") | |
| # print(f"res: {res}") | |
| # else: | |
| # print(f"RANK {rank} HPU Graph {i} out equal") | |
| torch.distributed.barrier() | |
| gc.collect() | |
| used_mem = mem_prof.consumed_device_memory | |
| print(f"Test {i} RANK {rank} HPU Graph 0 Used memory: {format_bytes(used_mem)}") | |
| htcore.mark_step() | |
| torch.hpu.synchronize() | |
| torch.distributed.barrier() | |
| end_mem = HabanaMemoryProfiler.current_device_memory_usage() | |
| print(f"RANK {rank} HPU Graph Total memory: {format_bytes(end_mem - start_mem)}") | |
| if __name__ == "__main__": | |
| world_size, rank, local_rank = initialize_distributed_hpu() | |
| if local_rank != -1: | |
| dist.init_process_group('hccl', rank=rank, world_size=world_size) | |
| main() | |
| # PT_HPU_LAZY_MODE=1 python -m torch.distributed.run --nproc-per-node 2 hpu_graph_demo.py |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment