Skip to content

Instantly share code, notes, and snippets.

@karminski
Created November 4, 2025 09:15
Show Gist options
  • Select an option

  • Save karminski/f0263a57155c352c092c7d7a0271025e to your computer and use it in GitHub Desktop.

Select an option

Save karminski/f0263a57155c352c092c7d7a0271025e to your computer and use it in GitHub Desktop.
"""
Advanced FP4 Performance Test for NVIDIA B200
This script attempts to access FP4 formats if available
"""
import time
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"CUDA version: {torch.version.cuda}")
print(f"Compute capability: {torch.cuda.get_device_capability(0)}")
print("-" * 60)
# Check for Transformer Engine
try:
import transformer_engine.pytorch as te
import transformer_engine as transformer_engine
from transformer_engine.common.recipe import Format, DelayedScaling
print(f"Transformer Engine version: {transformer_engine.__version__}")
print("Transformer Engine imported successfully")
# List all available formats
print("\nAvailable FP formats in Transformer Engine:")
for attr in dir(Format):
if not attr.startswith("_"):
print(f" - Format.{attr}")
HAS_TE = True
except ImportError as e:
print(f"ERROR: Transformer Engine not found: {e}")
print("Please install with:")
print(" pip install transformer-engine")
HAS_TE = False
if not HAS_TE:
exit(1)
print("-" * 60)
# Check all available torch float types
print("\nAvailable PyTorch float types:")
for attr in dir(torch):
if "float" in attr.lower() and not attr.startswith("_"):
try:
dtype = getattr(torch, attr)
if isinstance(dtype, torch.dtype):
print(f" - torch.{attr}")
except:
pass
print("-" * 60)
# Configuration
d = 8192
num_iterations = 50
num_runs = 10
warmup_runs = 10
print(f"\nBenchmark Configuration:")
print(f" Matrix size: {d}x{d}")
print(f" Iterations per run: {num_iterations}")
print(f" Number of runs: {num_runs}")
print(f" Warmup runs: {warmup_runs}")
print("-" * 60)
# Prepare data
x = torch.randn(size=(d, d), dtype=torch.float32).cuda()
y = torch.randn(size=(d, d), dtype=torch.float32).cuda()
def benchmark_format(format_name, fp8_format):
"""Generic benchmark function for any FP format"""
try:
fp8_recipe = DelayedScaling(
fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max"
)
def fun(x, y):
result = x
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
for _ in range(num_iterations):
result = result @ y.T
return result
# Warmup
print(f" Warming up {format_name}...")
for _ in range(warmup_runs):
fun(x, y)
torch.cuda.synchronize()
# Benchmark
print(f" Running benchmark for {format_name}...")
tic = time.time()
for _ in range(num_runs):
fun(x, y)
torch.cuda.synchronize()
toc = time.time()
return toc - tic
except Exception as e:
print(f" ERROR: {format_name} failed: {e}")
return None
def print_results(name, elapsed_time, reference_time=None):
"""Print benchmark results with optional speedup comparison"""
if elapsed_time is None:
print(f"\n{name}: Not available")
return
msec = 1e3 * elapsed_time
tf = (d**3) * 2 * num_iterations * num_runs / (1024**4)
tflops = tf / elapsed_time
print(f"\n{name}:")
print(f" Time: {msec:.3f} ms")
print(f" Performance: {tflops:.3f} TFLOPS")
if reference_time:
speedup = reference_time / elapsed_time
print(f" Speedup vs reference: {speedup:.2f}x")
# Run comprehensive benchmarks
print("\n" + "=" * 60)
print("COMPREHENSIVE FP BENCHMARK")
print("=" * 60)
results = {}
# Test all available formats
test_formats = [
("E4M3", Format.E4M3),
("HYBRID", Format.HYBRID),
]
# Try to add E5M2 if available
try:
test_formats.append(("E5M2", Format.E5M2))
except:
pass
# Test each format
for i, (name, fmt) in enumerate(test_formats):
print(f"\n[{i+1}/{len(test_formats)}] Testing {name}...")
elapsed = benchmark_format(name, fmt)
results[name] = elapsed
print_results(f"FP8 {name}", elapsed)
# Test baseline BF16
print(f"\n[{len(test_formats)+1}/{len(test_formats)+1}] Testing BF16 (baseline)...")
x_bf16 = x.to(torch.bfloat16)
y_bf16 = y.to(torch.bfloat16)
def fun_bf16(x):
result = x
for _ in range(num_iterations):
result = result @ y_bf16.T
return result
# Warmup
print(" Warming up BF16...")
for _ in range(warmup_runs):
fun_bf16(x_bf16)
torch.cuda.synchronize()
# Benchmark
print(" Running benchmark for BF16...")
tic = time.time()
for _ in range(num_runs):
fun_bf16(x_bf16)
torch.cuda.synchronize()
toc = time.time()
bf16_time = toc - tic
results["BF16"] = bf16_time
print_results("BF16 (Baseline)", bf16_time)
# Summary comparison
print("\n" + "=" * 60)
print("PERFORMANCE SUMMARY (vs BF16 baseline)")
print("=" * 60)
if bf16_time:
sorted_results = sorted(
[(name, time) for name, time in results.items() if time is not None],
key=lambda x: x[1],
)
print(f"\n{'Format':<20} {'Time (ms)':<15} {'TFLOPS':<15} {'Speedup':>10}")
print("-" * 60)
tf = (d**3) * 2 * num_iterations * num_runs / (1024**4)
for name, elapsed in sorted_results:
msec = 1e3 * elapsed
tflops = tf / elapsed
speedup = bf16_time / elapsed if name != "BF16" else 1.0
print(f"{name:<20} {msec:<15.3f} {tflops:<15.3f} {speedup:>9.2f}x")
# Additional information
print("\n" + "=" * 60)
print("NOTES FOR B200 (Blackwell Architecture)")
print("=" * 60)
print(
"""
1. FP8 Support: Fully supported (E4M3, E5M2, Hybrid)
2. FP4 Support: Check latest Transformer Engine documentation
- FP4 may be available in newer TE versions
- May require specific flags or environment variables
3. To enable potential FP4 features:
- Update Transformer Engine: pip install --upgrade transformer-engine
- Update PyTorch: pip install --upgrade torch
- Check NVIDIA NGC containers for latest optimizations
4. Performance Tips:
- Ensure CUDA 12.4+ for best B200 support
- Use tensor cores via proper alignment
- Consider using cuBLAS LT for better FP8/FP4 performance
5. If FP4 is not showing in available formats:
- It may not be exposed through standard APIs yet
- Check NVIDIA's TensorRT-LLM for FP4 support
- Monitor Transformer Engine updates
"""
)
print("=" * 60)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment