Skip to content

Instantly share code, notes, and snippets.

@enhuiz
Last active November 16, 2020 12:10
Show Gist options
  • Select an option

  • Save enhuiz/2c9773b76488e2eb527acd92e9ba947c to your computer and use it in GitHub Desktop.

Select an option

Save enhuiz/2c9773b76488e2eb527acd92e9ba947c to your computer and use it in GitHub Desktop.
sgs.py
import torch
import torch.nn as nn
from torchvision.models import resnet18
def sample_mask(xl, p):
mask = []
for l in xl:
idx = torch.randperm(l)
idx = idx[: int(l * p)]
m = torch.zeros(l).bool()
m[idx] = True
mask.append(m)
return torch.cat(mask, dim=0)
def create_sgs_applier(p_detach, lengths):
detached = sample_mask(lengths, p_detach)
attached = ~detached
attaching = attached.any()
detaching = detached.any()
def sgs_apply(module, *data):
n = len(data[0])
if attaching:
attached_output = module(*[d[attached] for d in data])
if detaching:
with torch.no_grad():
detached_output = module(*[d[detached] for d in data])
if attaching:
slot = torch.empty(
n, *attached_output.shape[1:], dtype=attached_output.dtype
)
else:
slot = torch.empty(
n, *detached_output.shape[1:], dtype=detached_output.dtype
)
slot = slot.to(data[0].device)
if attaching:
slot[attached] = attached_output
if detaching:
slot[detached] = detached_output
return slot
return sgs_apply
if __name__ == "__main__":
sgs_apply = create_sgs_applier(0.9, [512])
model = resnet18(False)
model.fc = nn.Identity()
model = model.cuda()
x = torch.randn(512, 3, 224, 224).cuda()
output = sgs_apply(model, x)
output.sum().backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment