Skip to content

Instantly share code, notes, and snippets.

@dvruette
Last active May 3, 2025 07:14
Show Gist options
  • Select an option

  • Save dvruette/5cb21202b4497adea9fae46a0ca8e07f to your computer and use it in GitHub Desktop.

Select an option

Save dvruette/5cb21202b4497adea9fae46a0ca8e07f to your computer and use it in GitHub Desktop.

Revisions

  1. dvruette revised this gist Apr 30, 2025. 1 changed file with 3 additions and 3 deletions.
    6 changes: 3 additions & 3 deletions min_mup.py
    Original file line number Diff line number Diff line change
    @@ -21,6 +21,8 @@
    test_ds = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)
    test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)


    # muP implementation of MLP, following Tab. 3, Tab. 8, and Tab. 9 from https://arxiv.org/pdf/2203.03466

    class muMLPTab3(nn.Module):
    def __init__(self, width=128, num_classes=10):
    @@ -54,9 +56,7 @@ def get_param_groups(self, base_lr):
    dict(params=self.fc_2.parameters(), lr=base_lr),
    dict(params=self.fc_3.parameters(), lr=base_lr/self.width),
    ]


    # muP implementation of MLP, following Tab. 3, Tab. 8, and Tab. 9 from https://arxiv.org/pdf/2203.03466


    class muMLPTab8(nn.Module):
    def __init__(self, width=128, num_classes=10):
  2. dvruette revised this gist Apr 30, 2025. No changes.
  3. dvruette revised this gist Apr 30, 2025. 1 changed file with 3 additions and 1 deletion.
    4 changes: 3 additions & 1 deletion min_mup.py
    Original file line number Diff line number Diff line change
    @@ -54,7 +54,9 @@ def get_param_groups(self, base_lr):
    dict(params=self.fc_2.parameters(), lr=base_lr),
    dict(params=self.fc_3.parameters(), lr=base_lr/self.width),
    ]



    # muP implementation of MLP, following Tab. 3, Tab. 8, and Tab. 9 from https://arxiv.org/pdf/2203.03466

    class muMLPTab8(nn.Module):
    def __init__(self, width=128, num_classes=10):
  4. dvruette created this gist Apr 30, 2025.
    177 changes: 177 additions & 0 deletions min_mup.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,177 @@
    import numpy as np
    import torch.nn.functional as F
    from torchvision import datasets, transforms
    import torch
    from torch import nn
    from torch.optim import SGD
    import matplotlib.pyplot as plt


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    batch_size = 128
    data_dir = '/tmp'

    transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    train_ds = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
    train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
    test_ds = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)
    test_dl = torch.utils.data.DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2)


    class muMLPTab3(nn.Module):
    def __init__(self, width=128, num_classes=10):
    super().__init__()
    self.width = width
    self.input_mult = 1.0
    self.output_mult = 1.0
    self.fc_1 = nn.Linear(3072, width, bias=False)
    self.fc_2 = nn.Linear(width, width, bias=False)
    self.fc_3 = nn.Linear(width, num_classes, bias=False)
    self.reset_parameters()

    def reset_parameters(self):
    nn.init.normal_(self.fc_1.weight, std=3072**-0.5)
    nn.init.normal_(self.fc_2.weight, std=self.width**-0.5)
    nn.init.normal_(self.fc_3.weight, std=1/self.width)

    def forward(self, x):
    activations = []
    h = self.input_mult * self.fc_1(x)
    activations.append(h)
    h = self.fc_2(F.relu(h))
    activations.append(h)
    h = self.output_mult * self.fc_3(F.relu(h))
    activations.append(h)
    return h, activations

    def get_param_groups(self, base_lr):
    return [
    dict(params=self.fc_1.parameters(), lr=base_lr*self.width),
    dict(params=self.fc_2.parameters(), lr=base_lr),
    dict(params=self.fc_3.parameters(), lr=base_lr/self.width),
    ]


    class muMLPTab8(nn.Module):
    def __init__(self, width=128, num_classes=10):
    super().__init__()
    self.width = width
    self.input_mult = 1.0
    self.output_mult = 1.0 / self.width
    self.fc_1 = nn.Linear(3072, width, bias=False)
    self.fc_2 = nn.Linear(width, width, bias=False)
    self.fc_3 = nn.Linear(width, num_classes, bias=False)
    self.reset_parameters()

    def reset_parameters(self):
    nn.init.normal_(self.fc_1.weight, std=3072**-0.5)
    nn.init.normal_(self.fc_2.weight, std=self.width**-0.5)
    nn.init.normal_(self.fc_3.weight, std=1.0)

    def forward(self, x, return_activations=False):
    activations = []
    h = self.input_mult * self.fc_1(x)
    activations.append(h)
    h = self.fc_2(F.relu(h))
    activations.append(h)
    h = self.output_mult * self.fc_3(F.relu(h))
    activations.append(h)
    return h, activations

    def get_param_groups(self, base_lr):
    return [
    dict(params=self.fc_1.parameters(), lr=base_lr*self.width),
    dict(params=self.fc_2.parameters(), lr=base_lr),
    dict(params=self.fc_3.parameters(), lr=base_lr*self.width),
    ]


    class muMLPTab9(nn.Module):
    def __init__(self, width=128, num_classes=10):
    super().__init__()
    self.width = width
    self.input_mult = self.width**0.5
    self.output_mult = self.width**-0.5
    self.fc_1 = nn.Linear(3072, width, bias=False)
    self.fc_2 = nn.Linear(width, width, bias=False)
    self.fc_3 = nn.Linear(width, num_classes, bias=False)
    self.reset_parameters()

    def reset_parameters(self):
    nn.init.normal_(self.fc_1.weight, std=(self.width*3072)**-0.5)
    nn.init.normal_(self.fc_2.weight, std=self.width**-0.5)
    nn.init.normal_(self.fc_3.weight, std=self.width**-0.5)

    def forward(self, x):
    activations = []
    h = self.input_mult * self.fc_1(x)
    activations.append(h)
    h = self.fc_2(F.relu(h))
    activations.append(h)
    h = self.output_mult * self.fc_3(F.relu(h))
    activations.append(h)
    return h, activations

    def get_param_groups(self, base_lr):
    return [
    dict(params=self.fc_1.parameters(), lr=base_lr),
    dict(params=self.fc_2.parameters(), lr=base_lr),
    dict(params=self.fc_3.parameters(), lr=base_lr),
    ]


    # run coordinate check to test correctness

    widths = [64, 128, 256, 512, 1024, 2048, 4096]
    max_t = 5
    num_seeds = 5
    base_lr = 0.1

    dataset = [next(iter(train_dl))] * max_t

    all_metrics = []
    for width in widths:
    metrics = []
    for seed in range(num_seeds):
    torch.manual_seed(seed)
    # model = muMLPTab3(width=width).to(device)
    # model = muMLPTab8(width=width).to(device)
    model = muMLPTab9(width=width).to(device)
    optimizer = SGD(model.get_param_groups(base_lr))
    acts_t = []
    for batch_idx, (data, target) in enumerate(dataset):
    if batch_idx >= max_t:
    break
    data, target = data.to(device), target.to(device)
    output, acts = model(data.view(data.size(0), -1))
    acts_t.append([a.detach().cpu() for a in acts])
    loss = F.cross_entropy(output, target)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    del model

    act_diff_std = []
    acts_0 = acts_t[0]
    for acts in acts_t[:]:
    diffs = [(ht).abs().mean().item() for h0, ht in zip(acts_0, acts)]
    act_diff_std.append(diffs)
    metrics.append(act_diff_std)
    all_metrics.append(np.stack(metrics, axis=0).mean(axis=0))
    all_metrics = np.array(all_metrics)


    for layer_idx in range(3):
    fig, ax = plt.subplots()
    ax.set_title(f"layer_idx={layer_idx+1}")
    for t in range(max_t):
    ax.plot(widths, all_metrics[:, t, layer_idx], label=f"t={t+1}")
    ax.set_xlabel("width")
    ax.set_ylabel("activation scale")
    ax.legend()
    ax.set_xscale("log")
    ax.set_yscale("log")