This Gist records optimization effort of LAMB optimizer on PyTorch CPU path.
LAMB optimizer - proposed in Papar Large Batch Optimization for Deep Learning: Training BERT in 76 minutes.
This implementation refers to fbgemm's gpu code at gpu_ref.
### fused=True will use native C++ fused kernel from ATen
### fused=False will fallback to imperative torch impl, used for validation purposes
optimizer = optim.Lamb(model.parameters(), lr=0.01, fused=True)Test case posted below as test_fused_lamb.py, both contiguous and non-contiguous cases are tested. The weight tensor could be non-contiguous on occassion of explict fusion of multiple nn.Linear modules.
The mnist from pytorch/examples converges as
Test set: Average loss: 0.0297, Accuracy: 9934/10000 (99%)I tested on Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz, 20 cores per socket, dual sockets. For single socket run (with jemalloc), the update step of a [1024, 1024] weight tensor achieves 4.9x speedup:
### LAMB optimier bench:
unfused: 0.4495 ms; fused: 0.0923 msTo reproduce the result (notice that jemalloc is applied):
./run.sh test_fused_lamb.py[Notes]
- perf speedup primarily comes from: a) reduce of memory bandwidth of immediate tensors; b) the kernel has no additional memory allocation. For temp result of
adam_step, it reuses the memory ofgrad. So the kernel rewrites the gradient tensor since gradient is no longer used after the update of weight. - 4.9x perf speedup is tested on weight size of nn.Linear(1024, 1024). Speedup ratio would be greater if the weight tensor size is bigger.
- thread synchronization, the algorithm itself requires thread sync (e.g. norm of weight and adam_step). Ideally, we could do this with
#pragma omp barrierthus we can finish the whole computation within a single omp session. But this would trigger a bug: PyTorch omp wrapperat::parallelwill not make sure all omp threads in the same TEAM to be used (N=64 will launch 16 threads even the #cores is 20), so the un-used thread will never reach the barrier and keep on waiting. So i break the code into 2 omp sessions.
Benchmark script launcher - run.sh
Testing case and benchmark - test_fused_lamb.py
Testing case and benchmark - test_fused_adagrad.py