Last active
November 16, 2020 12:10
-
-
Save enhuiz/2c9773b76488e2eb527acd92e9ba947c to your computer and use it in GitHub Desktop.
sgs.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import torch.nn as nn | |
| from torchvision.models import resnet18 | |
| def sample_mask(xl, p): | |
| mask = [] | |
| for l in xl: | |
| idx = torch.randperm(l) | |
| idx = idx[: int(l * p)] | |
| m = torch.zeros(l).bool() | |
| m[idx] = True | |
| mask.append(m) | |
| return torch.cat(mask, dim=0) | |
| def create_sgs_applier(p_detach, lengths): | |
| detached = sample_mask(lengths, p_detach) | |
| attached = ~detached | |
| attaching = attached.any() | |
| detaching = detached.any() | |
| def sgs_apply(module, *data): | |
| n = len(data[0]) | |
| if attaching: | |
| attached_output = module(*[d[attached] for d in data]) | |
| if detaching: | |
| with torch.no_grad(): | |
| detached_output = module(*[d[detached] for d in data]) | |
| if attaching: | |
| slot = torch.empty( | |
| n, *attached_output.shape[1:], dtype=attached_output.dtype | |
| ) | |
| else: | |
| slot = torch.empty( | |
| n, *detached_output.shape[1:], dtype=detached_output.dtype | |
| ) | |
| slot = slot.to(data[0].device) | |
| if attaching: | |
| slot[attached] = attached_output | |
| if detaching: | |
| slot[detached] = detached_output | |
| return slot | |
| return sgs_apply | |
| if __name__ == "__main__": | |
| sgs_apply = create_sgs_applier(0.9, [512]) | |
| model = resnet18(False) | |
| model.fc = nn.Identity() | |
| model = model.cuda() | |
| x = torch.randn(512, 3, 224, 224).cuda() | |
| output = sgs_apply(model, x) | |
| output.sum().backward() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment