Skip to content

Instantly share code, notes, and snippets.

@mingfeima
Last active April 16, 2021 05:04
Show Gist options
  • Save mingfeima/c7b4d9ef30f713e51a7568ae665f1dbd to your computer and use it in GitHub Desktop.
Save mingfeima/c7b4d9ef30f713e51a7568ae665f1dbd to your computer and use it in GitHub Desktop.
DLRM Task

Description

This Gist records optimization effort of LAMB optimizer on PyTorch CPU path.

LAMB optimizer - proposed in Papar Large Batch Optimization for Deep Learning: Training BERT in 76 minutes.

This implementation refers to fbgemm's gpu code at gpu_ref.

Usage

### fused=True will use native C++ fused kernel from ATen
### fused=False will fallback to imperative torch impl, used for validation purposes
optimizer = optim.Lamb(model.parameters(), lr=0.01, fused=True)

Testing

Test case posted below as test_fused_lamb.py, both contiguous and non-contiguous cases are tested. The weight tensor could be non-contiguous on occassion of explict fusion of multiple nn.Linear modules.

The mnist from pytorch/examples converges as

Test set: Average loss: 0.0297, Accuracy: 9934/10000 (99%)

Performance

I tested on Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz, 20 cores per socket, dual sockets. For single socket run (with jemalloc), the update step of a [1024, 1024] weight tensor achieves 4.9x speedup:

### LAMB optimier bench:
unfused: 0.4495 ms; fused: 0.0923 ms

To reproduce the result (notice that jemalloc is applied):

./run.sh test_fused_lamb.py

[Notes]

  • perf speedup primarily comes from: a) reduce of memory bandwidth of immediate tensors; b) the kernel has no additional memory allocation. For temp result of adam_step, it reuses the memory of grad. So the kernel rewrites the gradient tensor since gradient is no longer used after the update of weight.
  • 4.9x perf speedup is tested on weight size of nn.Linear(1024, 1024). Speedup ratio would be greater if the weight tensor size is bigger.
  • thread synchronization, the algorithm itself requires thread sync (e.g. norm of weight and adam_step). Ideally, we could do this with #pragma omp barrier thus we can finish the whole computation within a single omp session. But this would trigger a bug: PyTorch omp wrapper at::parallel will not make sure all omp threads in the same TEAM to be used (N=64 will launch 16 threads even the #cores is 20), so the un-used thread will never reach the barrier and keep on waiting. So i break the code into 2 omp sessions.
@mingfeima
Copy link
Author

mingfeima commented Feb 26, 2021

Benchmark script launcher - run.sh

### run script for the operator benchmark

#source activate pytorch-mingfei
# jemalloc:
   export MALLOC_CONF="oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1";
   export LD_PRELOAD=/home/mingfeim/packages/jemalloc-5.2.1/lib/libjemalloc.so
#
# tcmalloc:
#export LD_PRELOAD=/home/mingfeim/packages/gperftools-2.8/install/lib/libtcmalloc.so

if [ $# -lt 1 ]; then
  echo "usage: ./run.sh [xxx.py]"
  exit
fi

INPUT_FILE=$1

CORES=`lscpu | grep Core | awk '{print $4}'`
SOCKETS=`lscpu | grep Socket | awk '{print $2}'`
TOTAL_CORES=`expr $CORES \* $SOCKETS`
LAST_CORE=`expr $CORES - 1`

KMP_SETTING="KMP_AFFINITY=granularity=fine,compact,1,0"
KMP_BLOCKTIME=1

PREFIX="numactl --physcpubind=0-$LAST_CORE --membind=0"

export $KMP_SETTING
export KMP_BLOCKTIME=$KMP_BLOCKTIME

echo -e "\n### using $KMP_SETTING"
echo -e "### using KMP_BLOCKTIME=$KMP_BLOCKTIME\n"

### single socket test
echo -e "\n### using OMP_NUM_THREADS=$CORES"
PREFIX="numactl --physcpubind=0-$LAST_CORE --membind=0"
echo -e "### using $PREFIX\n"
OMP_NUM_THREADS=$CORES $PREFIX python -u $INPUT_FILE

Testing case and benchmark - test_fused_lamb.py

import torch
from time import time

def cmp(t1, t2, msg, debug=False):
    if debug:
        print(t1.size(), 'sum: {:.6f}'.format(t1.sum().item()))
        print(t2.size(), 'sum: {:.6f}'.format(t2.sum().item()))
    res = torch.allclose(t1, t2, atol=1e-6)
    print(msg, res, "; size: ", t2.size(), "; stride: ", t2.stride())

fused = torch.lamb_fused_step

def lamb(param, exp_avg, exp_avg_sq, grad, step, beta1, beta2, learning_rate, weight_decay, eps):
    bias_correction1 = 1 - beta1 ** step
    bias_correction2 = 1 - beta2 ** step

    # Decay the first and second moment running average coefficient
    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
    exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

    adam_step = (exp_avg / bias_correction1) / ((exp_avg_sq / bias_correction2).sqrt() + eps)

    if weight_decay != 0:
        adam_step.add_(param, alpha=weight_decay)

    weight_norm = param.norm(p=2)
    rtw_norm = adam_step.norm(p=2)
    true_ratio = weight_norm / rtw_norm

    param.add_(adam_step, alpha=-learning_rate * true_ratio)


def test_fused_lamb_cpu(size, contig=True):
    print("\n### test_fused_lamb_cpu: ", ("contiguous" if contig else "non-contiguous"))
    s1, s2 = size[0], size[1]
    param = torch.randn(s1, s2)
    grad = torch.randn(s1, s2)
    exp_avg = torch.randn(s1, s2).abs()
    exp_avg_sq = torch.randn(s1, s2).abs()

    ### optim is inplace operator, rewrites param, exp_avg, exp_avg_sq, ?grad
    param2 = param.clone()
    grad2 = grad.clone()
    exp_avg2 = exp_avg.clone()
    exp_avg_sq2 = exp_avg_sq.clone()

    if not contig:
        param = param.narrow(-1, 2, int(s2/2))
        grad = grad.narrow(-1, 2, int(s2/2))
        exp_avg = exp_avg.narrow(-1, 2, int(s2/2))
        exp_avg_sq = exp_avg_sq.narrow(-1, 2, int(s2/2))
        param2 = param2.narrow(-1, 2, int(s2/2))
        grad2 = grad2.narrow(-1, 2, int(s2/2))
        exp_avg2 = exp_avg2.narrow(-1, 2, int(s2/2))
        exp_avg_sq2 = exp_avg_sq2.narrow(-1, 2, int(s2/2))

    step = 10
    beta1 = 0.8
    beta2 = 0.9
    learning_rate = 0.1
    weight_decay = 0.3
    eps = 0.1
    lamb(param, exp_avg, exp_avg_sq, grad, step, beta1, beta2, learning_rate, weight_decay, eps)
    fused(param2, exp_avg2, exp_avg_sq2, grad2, step, beta1, beta2, learning_rate, weight_decay, eps)

    cmp(param, param2, "param: ")
    cmp(grad, grad2, "grad: ")
    cmp(exp_avg, exp_avg2, "exp_avg: ")
    cmp(exp_avg_sq, exp_avg_sq2, "exp_avg_sq: ")


test_fused_lamb_cpu([45, 63])
test_fused_lamb_cpu([100, 200], False)


niters = 1000
nwarmups = int(niters/100)

def bench_fused_lamb_cpu(size):
    s1, s2 = size[0], size[1]
    param = torch.randn(s1, s2)
    grad = torch.randn(s1, s2)
    exp_avg = torch.randn(s1, s2).abs()
    exp_avg_sq = torch.randn(s1, s2).abs()

    ### optim is inplace operator, rewrites param, exp_avg, exp_avg_sq, ?grad
    param2 = param.clone()
    grad2 = grad.clone()
    exp_avg2 = exp_avg.clone()
    exp_avg_sq2 = exp_avg_sq.clone()

    step = 10
    beta1 = 0.8
    beta2 = 0.9
    learning_rate = 0.1
    weight_decay = 0.3
    eps = 0.1

    for _ in range(nwarmups):
        lamb(param, exp_avg, exp_avg_sq, grad, step, beta1, beta2, learning_rate, weight_decay, eps)

    t1 = time()
    for _ in range(niters):
        lamb(param, exp_avg, exp_avg_sq, grad, step, beta1, beta2, learning_rate, weight_decay, eps)
    t2 = time()

    for _ in range(nwarmups):
        fused(param2, exp_avg2, exp_avg_sq2, grad2, step, beta1, beta2, learning_rate, weight_decay, eps)

    t3 = time()
    for _ in range(niters):
        fused(param2, exp_avg2, exp_avg_sq2, grad2, step, beta1, beta2, learning_rate, weight_decay, eps)
    t4 = time()

    # ms
    time_per_iter1 = (t2- t1) * 1000 / niters
    time_per_iter2 = (t4- t3) * 1000 / niters
    print("\n### LAMB optimier bench:\nunfused: {:.4f} ms; fused: {:.4f} ms".format(time_per_iter1, time_per_iter2))


bench_fused_lamb_cpu([1024, 1024])

Testing case and benchmark - test_fused_adagrad.py

import torch
from time import time

def cmp(t1, t2, msg, debug=False):
    if debug:
        print(t1.size(), 'sum: {:.6f}'.format(t1.sum().item()))
        print(t2.size(), 'sum: {:.6f}'.format(t2.sum().item()))
    res = torch.allclose(t1, t2, atol=1e-6)
    print(msg, res, "; size: ", t2.size(), "; stride: ", t2.stride())

fused = torch.adagrad_fused_step

def adagrad(param, grad, state_sum, step, learning_rate, weight_decay, lr_decay, eps):
    clr = learning_rate / (1 + (step - 1) * lr_decay)

    if weight_decay != 0:
        grad = grad.add(param, alpha=weight_decay)

    state_sum.addcmul_(grad, grad, value=1)
    std = state_sum.sqrt().add_(eps)
    param.addcdiv_(grad, std, value=-clr)


def test_fused_adagrad_cpu(size, contig=True):
    print("\n### test_fused_adagrad_cpu: ", ("contiguous" if contig else "non-contiguous"))
    s1, s2 = size[0], size[1]
    param = torch.randn(s1, s2)
    grad = torch.randn(s1, s2)
    state_sum = torch.randn(s1, s2).abs()

    ### optim is inplace operator, rewrites param, state_sums
    param2 = param.clone()
    grad2 = grad.clone()
    state_sum2 = state_sum.clone()

    if not contig:
        param = param.narrow(-1, 2, int(s2/2))
        grad = grad.narrow(-1, 2, int(s2/2))
        state_sum = state_sum.narrow(-1, 2, int(s2/2))
        param2 = param2.narrow(-1, 2, int(s2/2))
        grad2 = grad2.narrow(-1, 2, int(s2/2))
        state_sum2 = state_sum2.narrow(-1, 2, int(s2/2))

    step = 10
    learning_rate = 0.1
    weight_decay = 0.3
    lr_decay = 0.01
    eps = 0.1
    adagrad(param, grad, state_sum, step, learning_rate, weight_decay, lr_decay, eps)
    fused(param2, grad2, state_sum2, step, learning_rate, weight_decay, lr_decay, eps)

    cmp(param, param2, "param: ")
    cmp(grad, grad2, "grad: ")
    cmp(state_sum, state_sum2, "state_sums: ")


test_fused_adagrad_cpu([45, 63])
test_fused_adagrad_cpu([100, 200], False)


niters = 1000
nwarmups = int(niters/100)

def bench_fused_lamb_cpu(size):
    s1, s2 = size[0], size[1]
    param = torch.randn(s1, s2)
    grad = torch.randn(s1, s2)
    state_sum = torch.randn(s1, s2).abs()

    ### optim is inplace operator, rewrites param, exp_avg, exp_avg_sq, ?grad
    param2 = param.clone()
    grad2 = grad.clone()
    state_sum2 = state_sum.clone()

    step = 10
    learning_rate = 0.1
    weight_decay = 0.3
    lr_decay = 0.01
    eps = 0.1

    for _ in range(nwarmups):
        adagrad(param, grad, state_sum, step, learning_rate, weight_decay, lr_decay, eps)

    t1 = time()
    for _ in range(niters):
        adagrad(param, grad, state_sum, step, learning_rate, weight_decay, lr_decay, eps)
    t2 = time()

    for _ in range(nwarmups):
        fused(param2, grad2, state_sum2, step, learning_rate, weight_decay, lr_decay, eps)

    t3 = time()
    for _ in range(niters):
        fused(param2, grad2, state_sum2, step, learning_rate, weight_decay, lr_decay, eps)
    t4 = time()

    # ms
    time_per_iter1 = (t2- t1) * 1000 / niters
    time_per_iter2 = (t4- t3) * 1000 / niters
    print("\n### ADAGRAD optimier bench:\nunfused: {:.4f} ms; fused: {:.4f} ms".format(time_per_iter1, time_per_iter2))


bench_fused_lamb_cpu([1024, 1024])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment