Skip to content

Instantly share code, notes, and snippets.

@danbri
Created September 24, 2025 19:17
Show Gist options
  • Select an option

  • Save danbri/edd5dc328a05e946f87ac17b59172f3c to your computer and use it in GitHub Desktop.

Select an option

Save danbri/edd5dc328a05e946f87ac17b59172f3c to your computer and use it in GitHub Desktop.
# === COLAB / GIST NOTEBOOK ===
# You can run this whole cell as-is in Google Colab.
# It will:
# 1) Clone nanoGPT
# 2) Add a Complex-LoRA module (block-real implementation)
# 3) Provide a runtime patcher that wraps attention/MLP linears
# 4) Smoke-test a forward pass; optional tiny training for a few steps
# --------
# CELL 1: Environment setup
# --------
import os, sys, textwrap, subprocess, json, shutil, pathlib, re
IN_COLAB = "google.colab" in sys.modules
print("Colab:", IN_COLAB)
# Install minimal deps (Colab usually has torch). tiktoken is used by nanoGPT.
!pip -q install --upgrade tiktoken
# --------
# CELL 2: Clone nanoGPT and git checkout
# --------
REPO_URL = "https://github.com/karpathy/nanoGPT.git"
REPO_DIR = "/content/nanoGPT"
if os.path.exists(REPO_DIR):
print("Repo already exists, removing to reclone...")
shutil.rmtree(REPO_DIR)
!git clone {REPO_URL} {REPO_DIR}
%cd {REPO_DIR}
# Pin to a branch or commit if you want reproducibility.
# Replace "main" with a specific commit hash for a stricter pin.
PIN = "main"
!git checkout {PIN}
# --------
# CELL 3: Write Complex-LoRA module
# --------
COMPLEX_LORA_PY = """
import torch
import torch.nn as nn
import torch.nn.functional as F
class ComplexLoRA(nn.Module):
\"""
Complex-valued LoRA implemented via block-real factors.
Delta W = Re((B_r + i B_i) (A_r + i A_i)) = B_r@A_r - B_i@A_i
This keeps host layers real while giving a phase-aware parameterization.
\"""
def __init__(self, d_in, d_out, r, alpha=16, init_scale=1e-2):
super().__init__()
assert r > 0
self.r = r
self.scale = alpha / r
self.B_r = nn.Parameter(torch.randn(d_out, r) * init_scale)
self.B_i = nn.Parameter(torch.randn(d_out, r) * init_scale)
self.A_r = nn.Parameter(torch.zeros(r, d_in))
self.A_i = nn.Parameter(torch.zeros(r, d_in))
def deltaW_real(self):
# Real part of (B A)
Wr = self.B_r @ self.A_r - self.B_i @ self.A_i
return self.scale * Wr
class LinearWithComplexLoRA(nn.Module):
\"""
Wraps an existing nn.Linear and adds a complex-LoRA delta to its weight.
By default, freezes the base linear (LoRA-style).
\"""
def __init__(self, base_linear: nn.Linear, r=8, alpha=16, freeze_base=True):
super().__init__()
assert isinstance(base_linear, nn.Linear)
self.base = base_linear
if freeze_base:
for p in self.base.parameters():
p.requires_grad_(False)
self.lora = ComplexLoRA(self.base.in_features, self.base.out_features, r=r, alpha=alpha)
def forward(self, x):
W_eff = self.base.weight + self.lora.deltaW_real()
return F.linear(x, W_eff, self.base.bias)
"""
with open("complex_lora.py", "w") as f:
f.write(textwrap.dedent(COMPLEX_LORA_PY).lstrip())
print("Wrote complex_lora.py")
# --------
# CELL 4: Runtime patcher to wrap selected linears
# --------
PATCHER_PY = r"""
import torch
import torch.nn as nn
from types import SimpleNamespace
from complex_lora import LinearWithComplexLoRA
def wrap_named_linears(model: nn.Module, name_predicate, r=8, alpha=16, freeze_base=True, verbose=True):
"""
Traverse modules; if name_predicate(full_name, module) is True and module is nn.Linear,
replace it with LinearWithComplexLoRA(base_linear=module).
Returns a report dict with counts.
"""
replaced = 0
kept = 0
parent_map = {}
for name, module in model.named_modules():
for child_name, child in module.named_children():
parent_map[f"{name}.{child_name}" if name else child_name] = (module, child_name)
for full_name, module in list(model.named_modules()):
if isinstance(module, nn.Linear) and name_predicate(full_name, module):
parent, child_name = parent_map[full_name]
new_mod = LinearWithComplexLoRA(module, r=r, alpha=alpha, freeze_base=freeze_base)
setattr(parent, child_name, new_mod)
replaced += 1
if verbose:
print(f"[ComplexLoRA] Wrapped: {full_name} (in={module.in_features}, out={module.out_features})")
elif isinstance(module, nn.Linear):
kept += 1
return {"replaced": replaced, "kept_linear": kept}
def default_transformer_predicate(full_name, module):
"""
Heuristic: wrap attention q/k/v/o and MLP projectors in nanoGPT:
- names ending with .c_attn, .c_proj (attention)
- names containing .w1, .w2, .w3 (MLP in/out in some variants)
Adjust this predicate if your fork/names differ.
"""
targets = (".c_attn", ".c_proj", ".w1", ".w2", ".w3")
return any(full_name.endswith(t) for t in targets)
"""
with open("patcher.py", "w") as f:
f.write(textwrap.dedent(PATCHER_PY).lstrip())
print("Wrote patcher.py")
# --------
# CELL 5: Quick forward-pass smoke test
# --------
SMOKE_TEST_PY = r"""
import torch
import torch.nn as nn
from model import GPTConfig, GPT
from patcher import wrap_named_linears, default_transformer_predicate
# Tiny config for speed
conf = GPTConfig(
block_size=128,
vocab_size=50304, # GPT-2 BPE default-ish
n_layer=2,
n_head=2,
n_embd=128
)
model = GPT(conf)
print("Before patch, params:", sum(p.numel() for p in model.parameters()))
# Apply complex-LoRA wrapping
report = wrap_named_linears(model, default_transformer_predicate, r=8, alpha=16, freeze_base=True, verbose=True)
print("Patch report:", report)
print("After patch, params (trainable):", sum(p.numel() for p in model.parameters() if p.requires_grad))
# Smoke forward
B, T = 4, 64
x = torch.randint(0, conf.vocab_size, (B, T))
with torch.no_grad():
logits, loss = model(x, targets=x)
print("Forward ok. Logits:", logits.shape, "Loss:", float(loss))
print("SUCCESS: complex-LoRA forward pass completed.")
"""
with open("smoke_test.py", "w") as f:
f.write(textwrap.dedent(SMOKE_TEST_PY).lstrip())
print("Wrote smoke_test.py")
# Run the smoke test
!python smoke_test.py
# --------
# CELL 6 (optional): Tiny training smoke test on tiny Shakespeare
# --------
# WARNING: This is just a few steps to verify training isn't broken.
# Comment out or reduce steps if you're on very small/slow runtimes.
TRAIN_SNIPPET = r"""
import os, torch
from pathlib import Path
from model import GPTConfig, GPT
from patcher import wrap_named_linears, default_transformer_predicate
from torch.optim import AdamW
import torch.nn.functional as F
# Make a tiny toy dataset (random tokens) to prove the loop runs.
# Replace with real tiny-shakespeare for meaningful loss curves.
vocab_size = 50304
B, T = 8, 64
steps = 20
conf = GPTConfig(block_size=T, vocab_size=vocab_size, n_layer=2, n_head=2, n_embd=128)
model = GPT(conf)
wrap_named_linears(model, default_transformer_predicate, r=8, alpha=16, freeze_base=True, verbose=False)
opt = AdamW([p for p in model.parameters() if p.requires_grad], lr=3e-3)
model.train()
for it in range(steps):
x = torch.randint(0, vocab_size, (B, T))
logits, loss = model(x, targets=x)
opt.zero_grad(set_to_none=True)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
if (it+1) % 5 == 0:
print(f"step {it+1}/{steps} loss={float(loss):.4f}")
print("TRAIN LOOP OK")
"""
with open("tiny_train.py", "w") as f:
f.write(textwrap.dedent(TRAIN_SNIPPET).lstrip())
print("Wrote tiny_train.py")
# Uncomment to run a tiny dummy train loop:
# !python tiny_train.py
print("\nAll set. To train on real data, run nanoGPT's prepare scripts and training script, then import `patcher.wrap_named_linears(...)` before training to enable Complex-LoRA.\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment