Skip to content

Instantly share code, notes, and snippets.

@nyngwang
Forked from albanD/opt_as_hook.py
Created April 18, 2023 10:21
Show Gist options
  • Select an option

  • Save nyngwang/01c4d16c6a0dfad062332017fdb1aee7 to your computer and use it in GitHub Desktop.

Select an option

Save nyngwang/01c4d16c6a0dfad062332017fdb1aee7 to your computer and use it in GitHub Desktop.

Revisions

  1. @albanD albanD revised this gist Aug 8, 2022. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion opt_as_hook.py
    Original file line number Diff line number Diff line change
    @@ -37,7 +37,7 @@ def sgd_hook(*_unused):
    crit = nn.MSELoss()

    for p in mod.parameters():
    set_sgd_hook(mod, p, lr=.1, weight_decay=0., momentum=0.9)
    set_sgd_hook(mod, p, lr=.01, weight_decay=0., momentum=0.9)

    # Make sure the keepalive works well
    gc.collect()
  2. @albanD albanD created this gist Aug 8, 2022.
    64 changes: 64 additions & 0 deletions opt_as_hook.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,64 @@
    import torch
    from torch import nn
    from torch.optim.sgd import sgd
    import gc
    import objgraph
    import weakref

    def all():
    # Only a subset of the args you could have
    def set_sgd_hook(mod, p, lr, weight_decay, momentum):
    buff_list = [None]

    acc_grad = p.view_as(p).grad_fn.next_functions[0][0]

    # The grad accumulator is a weak ref, so we need to keep it
    # alive until the Tensor is alive.
    # Store it on the module to avoid uncollectable ref-cycle
    if not hasattr(mod, "_acc_grads"):
    mod._acc_grads = []
    mod._acc_grads.append(acc_grad)

    def sgd_hook(*_unused):
    # Update the params
    sgd([p], [p.grad], buff_list, has_sparse_grad=False, foreach=False,
    weight_decay=weight_decay, momentum=momentum, lr=lr, dampening=0,
    nesterov=False, maximize=False)
    # Free up grad memory
    p.grad = None

    # We should have an API for post hooks... But we don't have one right now
    acc_grad.register_hook(sgd_hook)


    print("Startup", torch.cuda.memory_allocated())

    mod = torch.nn.Linear(4, 1).cuda()
    crit = nn.MSELoss()

    for p in mod.parameters():
    set_sgd_hook(mod, p, lr=.1, weight_decay=0., momentum=0.9)

    # Make sure the keepalive works well
    gc.collect()

    inp = torch.rand(10, 4, device="cuda")
    target = torch.rand(10, 1, device="cuda")
    for i in range(11):
    def eval_one():
    print(f"It {i}, {torch.cuda.memory_allocated()}")
    pred = mod(inp)
    loss = crit(pred, target)
    print("Before backward", torch.cuda.memory_allocated())
    loss.backward()
    print(f"Loss: {loss.item()}")

    eval_one()
    if i == 0:
    print("No memory decrease due to optimizer state lazy initialization")
    print("End of iteration", torch.cuda.memory_allocated())

    return weakref.ref(mod.weight)

    w = all()
    print("Done, final memory", torch.cuda.memory_allocated())