Skip to content

Instantly share code, notes, and snippets.

@yiliu30
Created May 22, 2025 08:21
Show Gist options
  • Save yiliu30/e2dcacc941ff206aac6aba52e5f46664 to your computer and use it in GitHub Desktop.
Save yiliu30/e2dcacc941ff206aac6aba52e5f46664 to your computer and use it in GitHub Desktop.
"""
------------------------------------------------------------------------------
out shape: torch.Size([4096, 7168])
out shape: torch.Size([4096, 7168])
out shape: torch.Size([8192, 7168])
out shape: torch.Size([8192, 7168])
out shape: torch.Size([16384, 7168])
out shape: torch.Size([16384, 7168])
out shape: torch.Size([4096, 7168])
out shape: torch.Size([4096, 7168])
[rank: 0][test 0] out max: 14637248544768.0
[rank: 1][test 0] out max: 14637248544768.0
[rank: 1][test 0] out max: 14637248544768.0
[rank: 0][test 0] out max: 14637248544768.0
[rank: 0][test 0] out max: 14637248544768.0
[rank: 1][test 0] out max: 14637248544768.0
Test 0 RANK 1 HPU Graph 0 Used memory: 0 B
Test 0 RANK 0 HPU Graph 0 Used memory: 0 B
out shape: torch.Size([8192, 7168])
out shape: torch.Size([8192, 7168])
[rank: 1][test 1] out max: 14637248544768.0
[rank: 0][test 1] out max: 14637248544768.0
[rank: 0][test 1] out max: 14637248544768.0
[rank: 1][test 1] out max: 14637248544768.0
[rank: 1][test 1] out max: 14637248544768.0
[rank: 0][test 1] out max: 14637248544768.0
Test 1 RANK 0 HPU Graph 0 Used memory: 0 B
Test 1 RANK 1 HPU Graph 0 Used memory: 0 B
out shape: torch.Size([16384, 7168])
out shape: torch.Size([16384, 7168])
[rank: 1][test 2] out max: 14637248544768.0
[rank: 0][test 2] out max: 14637248544768.0
[rank: 0][test 2] out max: 14637248544768.0
[rank: 1][test 2] out max: 14637248544768.0
[rank: 1][test 2] out max: 14637248544768.0
[rank: 0][test 2] out max: 14637248544768.0
Test 2 RANK 1 HPU Graph 0 Used memory: 0 B
Test 2 RANK 0 HPU Graph 0 Used memory: 0 B
RANK 0 HPU Graph Total memory: 402 MiB
RANK 1 HPU Graph Total memory: 402 MiB
Test 0 RANK 0 HPU Graph 0 Used memory: 56 MiB
Test 0 RANK 1 HPU Graph 0 Used memory: 56 MiB
out shape: torch.Size([8192, 7168])
out shape: torch.Size([8192, 7168])
Test 1 RANK 0 HPU Graph 0 Used memory: 112 MiB
Test 1 RANK 1 HPU Graph 0 Used memory: 112 MiB
out shape: torch.Size([16384, 7168])
out shape: torch.Size([16384, 7168])
Test 2 RANK 0 HPU Graph 0 Used memory: 224 MiB
Test 2 RANK 1 HPU Graph 0 Used memory: 224 MiB
RANK 0 HPU Graph Total memory: 182 MiB
RANK 1 HPU Graph Total memory: 182 MiB
"""
import os
import gc
import functools
import habana_frameworks.torch.internal.bridge_config as bc
os.environ['PT_HPU_LAZY_MODE'] = '1'
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES'] = '1'
# os.environ['HABANA_PROFILE'] = '1'
# os.environ['GRAPH_VISUALIZATION'] = '1'
import torch.distributed as dist
import torch
import habana_frameworks.torch as ht
import habana_frameworks.torch.core as htcore
import habana_frameworks.torch.internal.bridge_config as bc
from vllm_hpu_extension.profiler import (HabanaMemoryProfiler, format_bytes)
from habana_frameworks.torch.distributed.hccl import initialize_distributed_hpu
device = torch.device('hpu')
activities = [torch.profiler.ProfilerActivity.CPU]
activities.append(torch.profiler.ProfilerActivity.HPU)
out_buffer = torch.zeros([16384, 7168], dtype=torch.bfloat16).to('hpu')
def update_tensor(tensor):
seq_len, hidden_dim = tensor.shape
assert len(tensor.shape) == len(out_buffer.shape), f"tensor shape {tensor.shape} does not match out_buffer shape {out_buffer.shape}"
out_buffer.zero_()
out_buffer[:seq_len, :hidden_dim] = tensor
htcore.mark_step()
del tensor
new_tensor = out_buffer[:seq_len, :hidden_dim]
return new_tensor
def fn(data, weight, shared_output, out_buffer=None):
batch_size, seq_len, hidden_dim = data.shape
data = data.view(batch_size * seq_len, hidden_dim)
out = torch.matmul(data, weight) * 3.14
out = out + shared_output
htcore.mark_step()
# dist.all_reduce(out, op=dist.ReduceOp.SUM)
print(f"out shape: {out.shape}")
out = update_tensor(out)
if out_buffer is not None:
out_buffer.copy_(out)
return None
return out
# our = out.view(batch_size, seq_len, hidden_dim)
def main():
rank: int = dist.get_rank()
wrapped_fn = torch.hpu.wrap_in_hpu_graph_func(fn, disable_tensor_cache=True)
data = torch.randn([2, 2048, 7168], dtype=torch.bfloat16).to('hpu')
weight = torch.randn([7168, 7168], dtype=torch.bfloat16).to('hpu')
shared_output = torch.randn([4096, 7168], dtype=torch.bfloat16).to('hpu')
out_buffer = torch.randn([4096, 7168], dtype=torch.bfloat16).to('hpu')
data1 = torch.randn([4, 2048, 7168], dtype=torch.bfloat16).to('hpu')
weight1 = torch.randn([7168, 7168], dtype=torch.bfloat16).to('hpu')
shared_output1 = torch.randn([8192, 7168], dtype=torch.bfloat16).to('hpu')
out_buffer1 = torch.randn([8192, 7168], dtype=torch.bfloat16).to('hpu')
data2 = torch.randn([8, 2048, 7168], dtype=torch.bfloat16).to('hpu')
weight2 = torch.randn([7168, 7168], dtype=torch.bfloat16).to('hpu')
shared_output2 = torch.randn([16384, 7168], dtype=torch.bfloat16).to('hpu')
out_buffer2 = torch.randn([16384, 7168], dtype=torch.bfloat16).to('hpu')
htcore.mark_step()
start_mem = HabanaMemoryProfiler.current_device_memory_usage()
compile_only_mode_context = functools.partial(bc.env_setting,
"PT_COMPILE_ONLY_MODE",
True)
with compile_only_mode_context():
out = fn(data, weight, shared_output, out_buffer)
out_ref = out_buffer.clone()
htcore.mark_step()
torch.hpu.synchronize()
torch.distributed.barrier()
out1 = fn(data1, weight1, shared_output1, out_buffer1)
out1_ref = out_buffer1.clone()
htcore.mark_step()
torch.hpu.synchronize()
torch.distributed.barrier()
out2 = fn(data2, weight2, shared_output2, out_buffer2)
out2_ref = out_buffer2.clone()
htcore.mark_step()
torch.hpu.synchronize()
torch.distributed.barrier()
test_pairs = [
(data, weight, shared_output, out_ref, out_buffer),
(data1, weight1, shared_output1, out1_ref, out_buffer1),
(data2, weight2, shared_output2, out2_ref, out_buffer2),
]
for i, (data, weight, shared_output, out, out_buffer) in enumerate(test_pairs):
with HabanaMemoryProfiler() as mem_prof:
for _ in range(3):
res = wrapped_fn(data, weight, shared_output, out_buffer)
htcore.mark_step()
torch.hpu.synchronize()
print(f"[rank: {rank}][test {i}] out max: {out_buffer.max()}")
# if not torch.allclose(out, out_buffer):
# print(f"RANK {rank} HPU Graph {i} out not equal")
# print(f"out: {out}")
# print(f"res: {res}")
# else:
# print(f"RANK {rank} HPU Graph {i} out equal")
torch.distributed.barrier()
gc.collect()
used_mem = mem_prof.consumed_device_memory
print(f"Test {i} RANK {rank} HPU Graph 0 Used memory: {format_bytes(used_mem)}")
htcore.mark_step()
torch.hpu.synchronize()
torch.distributed.barrier()
end_mem = HabanaMemoryProfiler.current_device_memory_usage()
print(f"RANK {rank} HPU Graph Total memory: {format_bytes(end_mem - start_mem)}")
if __name__ == "__main__":
world_size, rank, local_rank = initialize_distributed_hpu()
if local_rank != -1:
dist.init_process_group('hccl', rank=rank, world_size=world_size)
main()
# PT_HPU_LAZY_MODE=1 python -m torch.distributed.run --nproc-per-node 2 hpu_graph_demo.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment