Last active
May 3, 2025 07:14
-
-
Save dvruette/5cb21202b4497adea9fae46a0ca8e07f to your computer and use it in GitHub Desktop.
Revisions
-
dvruette revised this gist
Apr 30, 2025 . 1 changed file with 3 additions and 3 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -21,6 +21,8 @@ test_ds = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform) test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2) # muP implementation of MLP, following Tab. 3, Tab. 8, and Tab. 9 from https://arxiv.org/pdf/2203.03466 class muMLPTab3(nn.Module): def __init__(self, width=128, num_classes=10): @@ -54,9 +56,7 @@ def get_param_groups(self, base_lr): dict(params=self.fc_2.parameters(), lr=base_lr), dict(params=self.fc_3.parameters(), lr=base_lr/self.width), ] class muMLPTab8(nn.Module): def __init__(self, width=128, num_classes=10): -
dvruette revised this gist
Apr 30, 2025 . No changes.There are no files selected for viewing
-
dvruette revised this gist
Apr 30, 2025 . 1 changed file with 3 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -54,7 +54,9 @@ def get_param_groups(self, base_lr): dict(params=self.fc_2.parameters(), lr=base_lr), dict(params=self.fc_3.parameters(), lr=base_lr/self.width), ] # muP implementation of MLP, following Tab. 3, Tab. 8, and Tab. 9 from https://arxiv.org/pdf/2203.03466 class muMLPTab8(nn.Module): def __init__(self, width=128, num_classes=10): -
dvruette created this gist
Apr 30, 2025 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,177 @@ import numpy as np import torch.nn.functional as F from torchvision import datasets, transforms import torch from torch import nn from torch.optim import SGD import matplotlib.pyplot as plt device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_size = 128 data_dir = '/tmp' transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) train_ds = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform) train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2) test_ds = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform) test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2) class muMLPTab3(nn.Module): def __init__(self, width=128, num_classes=10): super().__init__() self.width = width self.input_mult = 1.0 self.output_mult = 1.0 self.fc_1 = nn.Linear(3072, width, bias=False) self.fc_2 = nn.Linear(width, width, bias=False) self.fc_3 = nn.Linear(width, num_classes, bias=False) self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.fc_1.weight, std=3072**-0.5) nn.init.normal_(self.fc_2.weight, std=self.width**-0.5) nn.init.normal_(self.fc_3.weight, std=1/self.width) def forward(self, x): activations = [] h = self.input_mult * self.fc_1(x) activations.append(h) h = self.fc_2(F.relu(h)) activations.append(h) h = self.output_mult * self.fc_3(F.relu(h)) activations.append(h) return h, activations def get_param_groups(self, base_lr): return [ dict(params=self.fc_1.parameters(), lr=base_lr*self.width), dict(params=self.fc_2.parameters(), lr=base_lr), dict(params=self.fc_3.parameters(), lr=base_lr/self.width), ] class muMLPTab8(nn.Module): def __init__(self, width=128, num_classes=10): super().__init__() self.width = width self.input_mult = 1.0 self.output_mult = 1.0 / self.width self.fc_1 = nn.Linear(3072, width, bias=False) self.fc_2 = nn.Linear(width, width, bias=False) self.fc_3 = nn.Linear(width, num_classes, bias=False) self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.fc_1.weight, std=3072**-0.5) nn.init.normal_(self.fc_2.weight, std=self.width**-0.5) nn.init.normal_(self.fc_3.weight, std=1.0) def forward(self, x, return_activations=False): activations = [] h = self.input_mult * self.fc_1(x) activations.append(h) h = self.fc_2(F.relu(h)) activations.append(h) h = self.output_mult * self.fc_3(F.relu(h)) activations.append(h) return h, activations def get_param_groups(self, base_lr): return [ dict(params=self.fc_1.parameters(), lr=base_lr*self.width), dict(params=self.fc_2.parameters(), lr=base_lr), dict(params=self.fc_3.parameters(), lr=base_lr*self.width), ] class muMLPTab9(nn.Module): def __init__(self, width=128, num_classes=10): super().__init__() self.width = width self.input_mult = self.width**0.5 self.output_mult = self.width**-0.5 self.fc_1 = nn.Linear(3072, width, bias=False) self.fc_2 = nn.Linear(width, width, bias=False) self.fc_3 = nn.Linear(width, num_classes, bias=False) self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.fc_1.weight, std=(self.width*3072)**-0.5) nn.init.normal_(self.fc_2.weight, std=self.width**-0.5) nn.init.normal_(self.fc_3.weight, std=self.width**-0.5) def forward(self, x): activations = [] h = self.input_mult * self.fc_1(x) activations.append(h) h = self.fc_2(F.relu(h)) activations.append(h) h = self.output_mult * self.fc_3(F.relu(h)) activations.append(h) return h, activations def get_param_groups(self, base_lr): return [ dict(params=self.fc_1.parameters(), lr=base_lr), dict(params=self.fc_2.parameters(), lr=base_lr), dict(params=self.fc_3.parameters(), lr=base_lr), ] # run coordinate check to test correctness widths = [64, 128, 256, 512, 1024, 2048, 4096] max_t = 5 num_seeds = 5 base_lr = 0.1 dataset = [next(iter(train_dl))] * max_t all_metrics = [] for width in widths: metrics = [] for seed in range(num_seeds): torch.manual_seed(seed) # model = muMLPTab3(width=width).to(device) # model = muMLPTab8(width=width).to(device) model = muMLPTab9(width=width).to(device) optimizer = SGD(model.get_param_groups(base_lr)) acts_t = [] for batch_idx, (data, target) in enumerate(dataset): if batch_idx >= max_t: break data, target = data.to(device), target.to(device) output, acts = model(data.view(data.size(0), -1)) acts_t.append([a.detach().cpu() for a in acts]) loss = F.cross_entropy(output, target) loss.backward() optimizer.step() optimizer.zero_grad() del model act_diff_std = [] acts_0 = acts_t[0] for acts in acts_t[:]: diffs = [(ht).abs().mean().item() for h0, ht in zip(acts_0, acts)] act_diff_std.append(diffs) metrics.append(act_diff_std) all_metrics.append(np.stack(metrics, axis=0).mean(axis=0)) all_metrics = np.array(all_metrics) for layer_idx in range(3): fig, ax = plt.subplots() ax.set_title(f"layer_idx={layer_idx+1}") for t in range(max_t): ax.plot(widths, all_metrics[:, t, layer_idx], label=f"t={t+1}") ax.set_xlabel("width") ax.set_ylabel("activation scale") ax.legend() ax.set_xscale("log") ax.set_yscale("log")