Skip to content

Instantly share code, notes, and snippets.

@cli99
Last active September 20, 2024 06:08
Show Gist options
  • Save cli99/ceaacf96c811f9189fc88aa258abd2a5 to your computer and use it in GitHub Desktop.
Save cli99/ceaacf96c811f9189fc88aa258abd2a5 to your computer and use it in GitHub Desktop.
# speechmatics.com/company/articles-and-news/timing-operations-in-pytorch
import time
import torch
# 400000000B/1000000 = 400 MB
a = torch.randn(1000, 1000, device="cuda")
torch.softmax(a, dim=1)
torch.cuda.synchronize()
def flush_cache():
a.zero_()
times = []
for i in range(1000):
t0 = time.perf_counter()
torch.softmax(a, dim=1)
t1 = time.perf_counter()
times.append(t1 - t0)
print(f"perf_counter no sync Time: {1000*sum(times):.4f} us")
torch.softmax(a, dim=1)
torch.cuda.synchronize()
times = []
for i in range(1000):
flush_cache()
torch.cuda.synchronize()
t0 = time.perf_counter()
torch.softmax(a, dim=1)
torch.cuda.synchronize()
t1 = time.perf_counter()
times.append(t1 - t0)
print(f"perf_counter Time: {1000*sum(times):.4f} us")
torch.softmax(a, dim=1)
torch.cuda.synchronize()
a.zero_()
times = []
for i in range(1000):
flush_cache()
torch.cuda.synchronize()
t0 = time.perf_counter_ns()
torch.softmax(a, dim=1)
torch.cuda.synchronize()
t1 = time.perf_counter_ns()
times.append(t1 - t0)
print(f"perf_counter_ns Time: {sum(times)/1000/1000:.4f} us")
torch.softmax(a, dim=1)
torch.cuda.synchronize()
times = []
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for i in range(1000):
flush_cache()
start.record()
torch.softmax(a, dim=1)
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
print(f"cuda.Event Time: {sum(times):.4f} us")
torch.softmax(a, dim=1)
torch.cuda.synchronize()
a.zero_()
starts = [torch.cuda.Event(enable_timing=True) for _ in range(1000)]
ends = [torch.cuda.Event(enable_timing=True) for _ in range(1000)]
for i in range(1000):
flush_cache()
torch.cuda._sleep(1_000_000)
starts[i].record()
torch.softmax(a, dim=1)
ends[i].record()
torch.cuda.synchronize()
times = [starts[i].elapsed_time(ends[i]) for i in range(1000)]
print(f"cuda.Event list Time: {sum(times):.4f} us")
# without flush_cache and without torch.cuda._sleep
# perf_counter no sync Time: 4.2106 us
# perf_counter Time: 950.8353 us
# perf_counter_ns Time: 950.6415 us
# cuda.Event Time: 948.8796 us
# cuda.Event list Time: 945.8083 us
# with flush_cache and torch.cuda._sleep
# perf_counter no sync Time: 4.2853 us
# perf_counter Time: 958.5552 us
# perf_counter_ns Time: 958.4630 us
# cuda.Event Time: 953.4228 us
# cuda.Event list Time: 952.6513 us
# a = torch.randn(1000, 1000, device="cuda") with flush_cache and torch.cuda._sleep
# perf_counter no sync Time: 4.5707 us
# perf_counter Time: 11.7443 us
# perf_counter_ns Time: 11.7657 us
# cuda.Event Time: 13.3076 us
# cuda.Event list Time: 5.8498 us
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment