# === COLAB / GIST NOTEBOOK === # You can run this whole cell as-is in Google Colab. # It will: # 1) Clone nanoGPT # 2) Add a Complex-LoRA module (block-real implementation) # 3) Provide a runtime patcher that wraps attention/MLP linears # 4) Smoke-test a forward pass; optional tiny training for a few steps # -------- # CELL 1: Environment setup # -------- import os, sys, textwrap, subprocess, json, shutil, pathlib, re IN_COLAB = "google.colab" in sys.modules print("Colab:", IN_COLAB) # Install minimal deps (Colab usually has torch). tiktoken is used by nanoGPT. !pip -q install --upgrade tiktoken # -------- # CELL 2: Clone nanoGPT and git checkout # -------- REPO_URL = "https://github.com/karpathy/nanoGPT.git" REPO_DIR = "/content/nanoGPT" if os.path.exists(REPO_DIR): print("Repo already exists, removing to reclone...") shutil.rmtree(REPO_DIR) !git clone {REPO_URL} {REPO_DIR} %cd {REPO_DIR} # Pin to a branch or commit if you want reproducibility. # Replace "main" with a specific commit hash for a stricter pin. PIN = "main" !git checkout {PIN} # -------- # CELL 3: Write Complex-LoRA module # -------- COMPLEX_LORA_PY = """ import torch import torch.nn as nn import torch.nn.functional as F class ComplexLoRA(nn.Module): \""" Complex-valued LoRA implemented via block-real factors. Delta W = Re((B_r + i B_i) (A_r + i A_i)) = B_r@A_r - B_i@A_i This keeps host layers real while giving a phase-aware parameterization. \""" def __init__(self, d_in, d_out, r, alpha=16, init_scale=1e-2): super().__init__() assert r > 0 self.r = r self.scale = alpha / r self.B_r = nn.Parameter(torch.randn(d_out, r) * init_scale) self.B_i = nn.Parameter(torch.randn(d_out, r) * init_scale) self.A_r = nn.Parameter(torch.zeros(r, d_in)) self.A_i = nn.Parameter(torch.zeros(r, d_in)) def deltaW_real(self): # Real part of (B A) Wr = self.B_r @ self.A_r - self.B_i @ self.A_i return self.scale * Wr class LinearWithComplexLoRA(nn.Module): \""" Wraps an existing nn.Linear and adds a complex-LoRA delta to its weight. By default, freezes the base linear (LoRA-style). \""" def __init__(self, base_linear: nn.Linear, r=8, alpha=16, freeze_base=True): super().__init__() assert isinstance(base_linear, nn.Linear) self.base = base_linear if freeze_base: for p in self.base.parameters(): p.requires_grad_(False) self.lora = ComplexLoRA(self.base.in_features, self.base.out_features, r=r, alpha=alpha) def forward(self, x): W_eff = self.base.weight + self.lora.deltaW_real() return F.linear(x, W_eff, self.base.bias) """ with open("complex_lora.py", "w") as f: f.write(textwrap.dedent(COMPLEX_LORA_PY).lstrip()) print("Wrote complex_lora.py") # -------- # CELL 4: Runtime patcher to wrap selected linears # -------- PATCHER_PY = r""" import torch import torch.nn as nn from types import SimpleNamespace from complex_lora import LinearWithComplexLoRA def wrap_named_linears(model: nn.Module, name_predicate, r=8, alpha=16, freeze_base=True, verbose=True): """ Traverse modules; if name_predicate(full_name, module) is True and module is nn.Linear, replace it with LinearWithComplexLoRA(base_linear=module). Returns a report dict with counts. """ replaced = 0 kept = 0 parent_map = {} for name, module in model.named_modules(): for child_name, child in module.named_children(): parent_map[f"{name}.{child_name}" if name else child_name] = (module, child_name) for full_name, module in list(model.named_modules()): if isinstance(module, nn.Linear) and name_predicate(full_name, module): parent, child_name = parent_map[full_name] new_mod = LinearWithComplexLoRA(module, r=r, alpha=alpha, freeze_base=freeze_base) setattr(parent, child_name, new_mod) replaced += 1 if verbose: print(f"[ComplexLoRA] Wrapped: {full_name} (in={module.in_features}, out={module.out_features})") elif isinstance(module, nn.Linear): kept += 1 return {"replaced": replaced, "kept_linear": kept} def default_transformer_predicate(full_name, module): """ Heuristic: wrap attention q/k/v/o and MLP projectors in nanoGPT: - names ending with .c_attn, .c_proj (attention) - names containing .w1, .w2, .w3 (MLP in/out in some variants) Adjust this predicate if your fork/names differ. """ targets = (".c_attn", ".c_proj", ".w1", ".w2", ".w3") return any(full_name.endswith(t) for t in targets) """ with open("patcher.py", "w") as f: f.write(textwrap.dedent(PATCHER_PY).lstrip()) print("Wrote patcher.py") # -------- # CELL 5: Quick forward-pass smoke test # -------- SMOKE_TEST_PY = r""" import torch import torch.nn as nn from model import GPTConfig, GPT from patcher import wrap_named_linears, default_transformer_predicate # Tiny config for speed conf = GPTConfig( block_size=128, vocab_size=50304, # GPT-2 BPE default-ish n_layer=2, n_head=2, n_embd=128 ) model = GPT(conf) print("Before patch, params:", sum(p.numel() for p in model.parameters())) # Apply complex-LoRA wrapping report = wrap_named_linears(model, default_transformer_predicate, r=8, alpha=16, freeze_base=True, verbose=True) print("Patch report:", report) print("After patch, params (trainable):", sum(p.numel() for p in model.parameters() if p.requires_grad)) # Smoke forward B, T = 4, 64 x = torch.randint(0, conf.vocab_size, (B, T)) with torch.no_grad(): logits, loss = model(x, targets=x) print("Forward ok. Logits:", logits.shape, "Loss:", float(loss)) print("SUCCESS: complex-LoRA forward pass completed.") """ with open("smoke_test.py", "w") as f: f.write(textwrap.dedent(SMOKE_TEST_PY).lstrip()) print("Wrote smoke_test.py") # Run the smoke test !python smoke_test.py # -------- # CELL 6 (optional): Tiny training smoke test on tiny Shakespeare # -------- # WARNING: This is just a few steps to verify training isn't broken. # Comment out or reduce steps if you're on very small/slow runtimes. TRAIN_SNIPPET = r""" import os, torch from pathlib import Path from model import GPTConfig, GPT from patcher import wrap_named_linears, default_transformer_predicate from torch.optim import AdamW import torch.nn.functional as F # Make a tiny toy dataset (random tokens) to prove the loop runs. # Replace with real tiny-shakespeare for meaningful loss curves. vocab_size = 50304 B, T = 8, 64 steps = 20 conf = GPTConfig(block_size=T, vocab_size=vocab_size, n_layer=2, n_head=2, n_embd=128) model = GPT(conf) wrap_named_linears(model, default_transformer_predicate, r=8, alpha=16, freeze_base=True, verbose=False) opt = AdamW([p for p in model.parameters() if p.requires_grad], lr=3e-3) model.train() for it in range(steps): x = torch.randint(0, vocab_size, (B, T)) logits, loss = model(x, targets=x) opt.zero_grad(set_to_none=True) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() if (it+1) % 5 == 0: print(f"step {it+1}/{steps} loss={float(loss):.4f}") print("TRAIN LOOP OK") """ with open("tiny_train.py", "w") as f: f.write(textwrap.dedent(TRAIN_SNIPPET).lstrip()) print("Wrote tiny_train.py") # Uncomment to run a tiny dummy train loop: # !python tiny_train.py print("\nAll set. To train on real data, run nanoGPT's prepare scripts and training script, then import `patcher.wrap_named_linears(...)` before training to enable Complex-LoRA.\n")