|
|
@@ -0,0 +1,64 @@ |
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.optim.sgd import sgd |
|
|
import gc |
|
|
import objgraph |
|
|
import weakref |
|
|
|
|
|
def all(): |
|
|
# Only a subset of the args you could have |
|
|
def set_sgd_hook(mod, p, lr, weight_decay, momentum): |
|
|
buff_list = [None] |
|
|
|
|
|
acc_grad = p.view_as(p).grad_fn.next_functions[0][0] |
|
|
|
|
|
# The grad accumulator is a weak ref, so we need to keep it |
|
|
# alive until the Tensor is alive. |
|
|
# Store it on the module to avoid uncollectable ref-cycle |
|
|
if not hasattr(mod, "_acc_grads"): |
|
|
mod._acc_grads = [] |
|
|
mod._acc_grads.append(acc_grad) |
|
|
|
|
|
def sgd_hook(*_unused): |
|
|
# Update the params |
|
|
sgd([p], [p.grad], buff_list, has_sparse_grad=False, foreach=False, |
|
|
weight_decay=weight_decay, momentum=momentum, lr=lr, dampening=0, |
|
|
nesterov=False, maximize=False) |
|
|
# Free up grad memory |
|
|
p.grad = None |
|
|
|
|
|
# We should have an API for post hooks... But we don't have one right now |
|
|
acc_grad.register_hook(sgd_hook) |
|
|
|
|
|
|
|
|
print("Startup", torch.cuda.memory_allocated()) |
|
|
|
|
|
mod = torch.nn.Linear(4, 1).cuda() |
|
|
crit = nn.MSELoss() |
|
|
|
|
|
for p in mod.parameters(): |
|
|
set_sgd_hook(mod, p, lr=.1, weight_decay=0., momentum=0.9) |
|
|
|
|
|
# Make sure the keepalive works well |
|
|
gc.collect() |
|
|
|
|
|
inp = torch.rand(10, 4, device="cuda") |
|
|
target = torch.rand(10, 1, device="cuda") |
|
|
for i in range(11): |
|
|
def eval_one(): |
|
|
print(f"It {i}, {torch.cuda.memory_allocated()}") |
|
|
pred = mod(inp) |
|
|
loss = crit(pred, target) |
|
|
print("Before backward", torch.cuda.memory_allocated()) |
|
|
loss.backward() |
|
|
print(f"Loss: {loss.item()}") |
|
|
|
|
|
eval_one() |
|
|
if i == 0: |
|
|
print("No memory decrease due to optimizer state lazy initialization") |
|
|
print("End of iteration", torch.cuda.memory_allocated()) |
|
|
|
|
|
return weakref.ref(mod.weight) |
|
|
|
|
|
w = all() |
|
|
print("Done, final memory", torch.cuda.memory_allocated()) |