_This Gist records optimization effort of **DLRM** on PyTorch CPU path._ Branch on track: [dlrm](https://github.com/mingfeima/pytorch/commits/dlrm) Task list: - [x] LAMB fused optimizer (fp32) - [x] Adagrad fused optimier (fp32) - [x] Split-SGD (bf16) - [x] Bucketize (bf16) - [x] Sum (bf16) - [x] LayerNorm (bf16) - [x] Softmax (bf16) - [x] cumsum (int64_t) - [ ] tranposed copy (fp32/bf16) - [x] offset range (int64_t) - [x] sigmoid/sigmoid_backward (bf16) ## LAMB optimizer **LAMB optimizer** - proposed in Papar [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/pdf/1904.00962.pdf). This implementation refers to fbgemm's gpu code at [gpu_ref](https://github.com/pytorch/FBGEMM/blob/29fa98eb8ad521169366ead870a8b16f6b907b70/fbgemm_gpu/codegen/embedding_backward_code_generator.py#L373). To use this CPU fused LAMB kernel, you need to cherry-pick [cf5e826b](https://github.com/mingfeima/pytorch/commit/cf5e826b185c83bc88fb57f8f26f29fac927379b) and build from source. #### Usage ```python ### fused=True will use native C++ fused kernel from ATen ### fused=False will fallback to imperative torch impl, used for validation purposes optimizer = optim.Lamb(model.parameters(), lr=0.01, fused=True) ``` #### Testing Test case posted below as `test_fused_lamb.py`, both contiguous and non-contiguous cases are tested. The weight tensor could be non-contiguous on occassion of explict fusion of multiple `nn.Linear` modules. The mnist from pytorch/examples converges as ```bash Test set: Average loss: 0.0297, Accuracy: 9934/10000 (99%) ``` #### Performance I tested on Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz, 20 cores per socket, dual sockets. For **single socket** run (with jemalloc), the update step of a [1024, 1024] weight tensor achieves **4.9x** speedup: ```bash ### LAMB optimier bench: unfused: 0.4495 ms; fused: 0.0923 ms ``` To reproduce the result (notice that jemalloc is applied): ```bash ./run.sh test_fused_lamb.py ``` [Notes] - perf speedup primarily comes from: a) reduce of memory bandwidth of immediate tensors; b) the kernel has no additional memory allocation. For temp result of `adam_step`, it reuses the memory of `grad`. So the kernel rewrites the gradient tensor since gradient is no longer used after the update of weight. - 4.9x perf speedup is tested on weight size of nn.Linear(1024, 1024). Speedup ratio would be greater if the weight tensor size is bigger. - thread synchronization, the algorithm itself requires thread sync (e.g. norm of weight and adam_step). Ideally, we could do this with `#pragma omp barrier` thus we can finish the whole computation within a single omp session. But this would trigger a bug: PyTorch omp wrapper `at::parallel` will not make sure all omp threads in the same TEAM to be used (N=64 will launch 16 threads even the #cores is 20), so the un-used thread will never reach the barrier and keep on waiting. So i break the code into 2 omp sessions. ## Adagrad Fusion #### Usage ```python ### fused=True will use native C++ fused kernel from ATen ### fused=False will fallback to imperative torch impl, used for validation purposes optimizer = optim.Adagrad(model.parameters(), lr=0.01, fused=True) ``` #### Testing The mnist from pytorch/examples converges as ```bash Test set: Average loss: 0.0363, Accuracy: 9881/10000 (99%) ``` #### Performance I tested on Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz, 20 cores per socket, dual sockets. For **single socket** run (with jemalloc), the update step of a [1024, 1024] weight tensor achieves **3.2x** speedup: ```bash ### ADAGRAD optimier bench: unfused: 0.1022 ms; fused: 0.0321 ms ``` To reproduce the result (notice that jemalloc is applied): ```bash ./run.sh test_fused_adagrad.py ``` ## Split SGB (BFloat16) Basic idea of the algorithm is to store a copy of master weight in fp32 by splitting the upper 16 bits and lower 16 bits. The lower half is stored in optimizer as a state. So the weight could be updated in fp32 through packing and unpacking. ![split_sgd_bf16](https://user-images.githubusercontent.com/20233731/114827043-cbb5b480-9dfa-11eb-8180-91eb8b5897cd.png) #### Usage The usage is identical to normal fp32 fused kernel, with `fused=True`, parameter with data type `torch.bfloat16` would automatically use split sgd algorithm: ```python ### fused=True will use native C++ fused kernel from ATen ### fused=False will fallback to imperative torch impl, used for validation purposes optimizer = optim.Lamb(model.parameters(), lr=0.01, fused=True) ``` #### Performance ```bash ### LAMB unfused (fp32): 0.4526 ms; fused (fp32): 0.0940 ms; split fused (bf16): 0.0879 ms ``` #### Testing ```bash python test_optim.py TestSplitSGD.test_lamb_bfloat16_cpu python test_optim.py TestSplitSGD.test_adagrad_bfloat16_cpu ``` [Notes]: Known issue: this impl is expected to have runtime error on AVX machine, make sure you have AVX2+ CPU. (I did not register the AVX kernels) ## Gerneric BF16 Operator Optimization #### Principle BFloat16 is not an actual data type, we need to handle BFloat16 operator in the following manner: - input/output: load: bf16->fp32; store: fp32->bf16 - immediate operations (including accumulation): use fp32 #### Implementation Details We have multiple ways to enable BFloat16 OP on PyTorch, namely: 1. **Naive Impl**: add `kBFloat16` to `AT_DISPATCH_FLOATING_TYPES` macro, since on PyTorch both scalar and Vec256<> logic has specialization for `BFloat16`, this could run smoothly. But this naive impl is not good. 2. **Funtional Specialization**: specialize `vec256::Map<>` from functional.cpp with `BFloat16`. Similar to oneDNN implementation. 3. **Cache FP32 Data**: Convert bf16 data to fp32 per input row and cache (possibly) in L1. Similar to cuda counterpart implementation. Consider the following example: ```C++ using Vec = Vec256; Vec one = Vec(BFloat16(1)); vec256::map([](Vec x) { return one / (one + x.exp()); }, y_ptr, x_ptr, N); ``` Impl-1 will end up with 3 pairs of dtype conversion, each for ".exp()", "+" and "/". Both Impl-2 and Impl-3 will only need dtype conversion for input and output. Benefits: 1. better performance since we have less dtype conversion; 2. less rounding error since immediate results are kept in fp32; 3. accumulation done on data type of fp32. For Impl-2 and Impl-3, with emulated dtype conversion Impl-3 is faster for most cases; with native conversion assembly, Impl-2 is faster. So I follow Impl-2 in these patches. #### Softmax Naive Impl: ```bash Softmax: 128x1024: fp32: 150.324 us; bf16: 356.587 us tensor max (abs) diff: 2.9515125788748264e-05 ``` Funtional Specialization: ```bash log_softmax: 128x1024: fp32: 150.132 us; bf16: 194.974 us tensor max (abs) diff: 1.509662251919508e-05 ``` Test: ``` cd pytorch/build/bin/ vec256_test_all_types_AVX vec256_test_all_types_AVX2 vec256_test_all_types_DEFAULT ``` ```bash python test_nn.py TestNN.test_log_softmax_cpu python test_nn.py TestNN.test_softmax_cpu ``` #### Sum Naive Impl: ```bash sum size: 128x30678, fp32: 0.588 ms; bf16: 0.899 ms ``` Funtional Specialization: ```bash sum size: 128x30678, fp32: 0.590 ms; bf16: 0.335 ms ``` #### LayerNorm Naive Impl: ```bash LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.806 ms; bf16: 9.901 ms tensor max (abs) diff: 0.1355377435684204 ``` Funtional Specialization: ```bash LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.813 ms; bf16: 2.306 ms tensor max (abs) diff: 0.04277598857879639 ``` Test ```bash python test_nn.py TestNNDeviceTypeCPU.test_LayerNorm_general_cpu ```