Created
September 24, 2025 19:17
-
-
Save danbri/edd5dc328a05e946f87ac17b59172f3c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # === COLAB / GIST NOTEBOOK === | |
| # You can run this whole cell as-is in Google Colab. | |
| # It will: | |
| # 1) Clone nanoGPT | |
| # 2) Add a Complex-LoRA module (block-real implementation) | |
| # 3) Provide a runtime patcher that wraps attention/MLP linears | |
| # 4) Smoke-test a forward pass; optional tiny training for a few steps | |
| # -------- | |
| # CELL 1: Environment setup | |
| # -------- | |
| import os, sys, textwrap, subprocess, json, shutil, pathlib, re | |
| IN_COLAB = "google.colab" in sys.modules | |
| print("Colab:", IN_COLAB) | |
| # Install minimal deps (Colab usually has torch). tiktoken is used by nanoGPT. | |
| !pip -q install --upgrade tiktoken | |
| # -------- | |
| # CELL 2: Clone nanoGPT and git checkout | |
| # -------- | |
| REPO_URL = "https://github.com/karpathy/nanoGPT.git" | |
| REPO_DIR = "/content/nanoGPT" | |
| if os.path.exists(REPO_DIR): | |
| print("Repo already exists, removing to reclone...") | |
| shutil.rmtree(REPO_DIR) | |
| !git clone {REPO_URL} {REPO_DIR} | |
| %cd {REPO_DIR} | |
| # Pin to a branch or commit if you want reproducibility. | |
| # Replace "main" with a specific commit hash for a stricter pin. | |
| PIN = "main" | |
| !git checkout {PIN} | |
| # -------- | |
| # CELL 3: Write Complex-LoRA module | |
| # -------- | |
| COMPLEX_LORA_PY = """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ComplexLoRA(nn.Module): | |
| \""" | |
| Complex-valued LoRA implemented via block-real factors. | |
| Delta W = Re((B_r + i B_i) (A_r + i A_i)) = B_r@A_r - B_i@A_i | |
| This keeps host layers real while giving a phase-aware parameterization. | |
| \""" | |
| def __init__(self, d_in, d_out, r, alpha=16, init_scale=1e-2): | |
| super().__init__() | |
| assert r > 0 | |
| self.r = r | |
| self.scale = alpha / r | |
| self.B_r = nn.Parameter(torch.randn(d_out, r) * init_scale) | |
| self.B_i = nn.Parameter(torch.randn(d_out, r) * init_scale) | |
| self.A_r = nn.Parameter(torch.zeros(r, d_in)) | |
| self.A_i = nn.Parameter(torch.zeros(r, d_in)) | |
| def deltaW_real(self): | |
| # Real part of (B A) | |
| Wr = self.B_r @ self.A_r - self.B_i @ self.A_i | |
| return self.scale * Wr | |
| class LinearWithComplexLoRA(nn.Module): | |
| \""" | |
| Wraps an existing nn.Linear and adds a complex-LoRA delta to its weight. | |
| By default, freezes the base linear (LoRA-style). | |
| \""" | |
| def __init__(self, base_linear: nn.Linear, r=8, alpha=16, freeze_base=True): | |
| super().__init__() | |
| assert isinstance(base_linear, nn.Linear) | |
| self.base = base_linear | |
| if freeze_base: | |
| for p in self.base.parameters(): | |
| p.requires_grad_(False) | |
| self.lora = ComplexLoRA(self.base.in_features, self.base.out_features, r=r, alpha=alpha) | |
| def forward(self, x): | |
| W_eff = self.base.weight + self.lora.deltaW_real() | |
| return F.linear(x, W_eff, self.base.bias) | |
| """ | |
| with open("complex_lora.py", "w") as f: | |
| f.write(textwrap.dedent(COMPLEX_LORA_PY).lstrip()) | |
| print("Wrote complex_lora.py") | |
| # -------- | |
| # CELL 4: Runtime patcher to wrap selected linears | |
| # -------- | |
| PATCHER_PY = r""" | |
| import torch | |
| import torch.nn as nn | |
| from types import SimpleNamespace | |
| from complex_lora import LinearWithComplexLoRA | |
| def wrap_named_linears(model: nn.Module, name_predicate, r=8, alpha=16, freeze_base=True, verbose=True): | |
| """ | |
| Traverse modules; if name_predicate(full_name, module) is True and module is nn.Linear, | |
| replace it with LinearWithComplexLoRA(base_linear=module). | |
| Returns a report dict with counts. | |
| """ | |
| replaced = 0 | |
| kept = 0 | |
| parent_map = {} | |
| for name, module in model.named_modules(): | |
| for child_name, child in module.named_children(): | |
| parent_map[f"{name}.{child_name}" if name else child_name] = (module, child_name) | |
| for full_name, module in list(model.named_modules()): | |
| if isinstance(module, nn.Linear) and name_predicate(full_name, module): | |
| parent, child_name = parent_map[full_name] | |
| new_mod = LinearWithComplexLoRA(module, r=r, alpha=alpha, freeze_base=freeze_base) | |
| setattr(parent, child_name, new_mod) | |
| replaced += 1 | |
| if verbose: | |
| print(f"[ComplexLoRA] Wrapped: {full_name} (in={module.in_features}, out={module.out_features})") | |
| elif isinstance(module, nn.Linear): | |
| kept += 1 | |
| return {"replaced": replaced, "kept_linear": kept} | |
| def default_transformer_predicate(full_name, module): | |
| """ | |
| Heuristic: wrap attention q/k/v/o and MLP projectors in nanoGPT: | |
| - names ending with .c_attn, .c_proj (attention) | |
| - names containing .w1, .w2, .w3 (MLP in/out in some variants) | |
| Adjust this predicate if your fork/names differ. | |
| """ | |
| targets = (".c_attn", ".c_proj", ".w1", ".w2", ".w3") | |
| return any(full_name.endswith(t) for t in targets) | |
| """ | |
| with open("patcher.py", "w") as f: | |
| f.write(textwrap.dedent(PATCHER_PY).lstrip()) | |
| print("Wrote patcher.py") | |
| # -------- | |
| # CELL 5: Quick forward-pass smoke test | |
| # -------- | |
| SMOKE_TEST_PY = r""" | |
| import torch | |
| import torch.nn as nn | |
| from model import GPTConfig, GPT | |
| from patcher import wrap_named_linears, default_transformer_predicate | |
| # Tiny config for speed | |
| conf = GPTConfig( | |
| block_size=128, | |
| vocab_size=50304, # GPT-2 BPE default-ish | |
| n_layer=2, | |
| n_head=2, | |
| n_embd=128 | |
| ) | |
| model = GPT(conf) | |
| print("Before patch, params:", sum(p.numel() for p in model.parameters())) | |
| # Apply complex-LoRA wrapping | |
| report = wrap_named_linears(model, default_transformer_predicate, r=8, alpha=16, freeze_base=True, verbose=True) | |
| print("Patch report:", report) | |
| print("After patch, params (trainable):", sum(p.numel() for p in model.parameters() if p.requires_grad)) | |
| # Smoke forward | |
| B, T = 4, 64 | |
| x = torch.randint(0, conf.vocab_size, (B, T)) | |
| with torch.no_grad(): | |
| logits, loss = model(x, targets=x) | |
| print("Forward ok. Logits:", logits.shape, "Loss:", float(loss)) | |
| print("SUCCESS: complex-LoRA forward pass completed.") | |
| """ | |
| with open("smoke_test.py", "w") as f: | |
| f.write(textwrap.dedent(SMOKE_TEST_PY).lstrip()) | |
| print("Wrote smoke_test.py") | |
| # Run the smoke test | |
| !python smoke_test.py | |
| # -------- | |
| # CELL 6 (optional): Tiny training smoke test on tiny Shakespeare | |
| # -------- | |
| # WARNING: This is just a few steps to verify training isn't broken. | |
| # Comment out or reduce steps if you're on very small/slow runtimes. | |
| TRAIN_SNIPPET = r""" | |
| import os, torch | |
| from pathlib import Path | |
| from model import GPTConfig, GPT | |
| from patcher import wrap_named_linears, default_transformer_predicate | |
| from torch.optim import AdamW | |
| import torch.nn.functional as F | |
| # Make a tiny toy dataset (random tokens) to prove the loop runs. | |
| # Replace with real tiny-shakespeare for meaningful loss curves. | |
| vocab_size = 50304 | |
| B, T = 8, 64 | |
| steps = 20 | |
| conf = GPTConfig(block_size=T, vocab_size=vocab_size, n_layer=2, n_head=2, n_embd=128) | |
| model = GPT(conf) | |
| wrap_named_linears(model, default_transformer_predicate, r=8, alpha=16, freeze_base=True, verbose=False) | |
| opt = AdamW([p for p in model.parameters() if p.requires_grad], lr=3e-3) | |
| model.train() | |
| for it in range(steps): | |
| x = torch.randint(0, vocab_size, (B, T)) | |
| logits, loss = model(x, targets=x) | |
| opt.zero_grad(set_to_none=True) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| opt.step() | |
| if (it+1) % 5 == 0: | |
| print(f"step {it+1}/{steps} loss={float(loss):.4f}") | |
| print("TRAIN LOOP OK") | |
| """ | |
| with open("tiny_train.py", "w") as f: | |
| f.write(textwrap.dedent(TRAIN_SNIPPET).lstrip()) | |
| print("Wrote tiny_train.py") | |
| # Uncomment to run a tiny dummy train loop: | |
| # !python tiny_train.py | |
| print("\nAll set. To train on real data, run nanoGPT's prepare scripts and training script, then import `patcher.wrap_named_linears(...)` before training to enable Complex-LoRA.\n") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment