Skip to content

Instantly share code, notes, and snippets.

@danbri
Created September 24, 2025 19:17
Show Gist options
  • Select an option

  • Save danbri/56226fba0ff50f0c22798ec32e6facd2 to your computer and use it in GitHub Desktop.

Select an option

Save danbri/56226fba0ff50f0c22798ec32e6facd2 to your computer and use it in GitHub Desktop.

Revisions

  1. danbri created this gist Sep 24, 2025.
    229 changes: 229 additions & 0 deletions nanog_cmplx.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,229 @@
    # === COLAB / GIST NOTEBOOK ===
    # You can run this whole cell as-is in Google Colab.
    # It will:
    # 1) Clone nanoGPT
    # 2) Add a Complex-LoRA module (block-real implementation)
    # 3) Provide a runtime patcher that wraps attention/MLP linears
    # 4) Smoke-test a forward pass; optional tiny training for a few steps

    # --------
    # CELL 1: Environment setup
    # --------
    import os, sys, textwrap, subprocess, json, shutil, pathlib, re

    IN_COLAB = "google.colab" in sys.modules
    print("Colab:", IN_COLAB)

    # Install minimal deps (Colab usually has torch). tiktoken is used by nanoGPT.
    !pip -q install --upgrade tiktoken

    # --------
    # CELL 2: Clone nanoGPT and git checkout
    # --------
    REPO_URL = "https://github.com/karpathy/nanoGPT.git"
    REPO_DIR = "/content/nanoGPT"

    if os.path.exists(REPO_DIR):
    print("Repo already exists, removing to reclone...")
    shutil.rmtree(REPO_DIR)

    !git clone {REPO_URL} {REPO_DIR}
    %cd {REPO_DIR}

    # Pin to a branch or commit if you want reproducibility.
    # Replace "main" with a specific commit hash for a stricter pin.
    PIN = "main"
    !git checkout {PIN}

    # --------
    # CELL 3: Write Complex-LoRA module
    # --------
    COMPLEX_LORA_PY = """
    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    class ComplexLoRA(nn.Module):
    \"""
    Complex-valued LoRA implemented via block-real factors.
    Delta W = Re((B_r + i B_i) (A_r + i A_i)) = B_r@A_r - B_i@A_i
    This keeps host layers real while giving a phase-aware parameterization.
    \"""
    def __init__(self, d_in, d_out, r, alpha=16, init_scale=1e-2):
    super().__init__()
    assert r > 0
    self.r = r
    self.scale = alpha / r

    self.B_r = nn.Parameter(torch.randn(d_out, r) * init_scale)
    self.B_i = nn.Parameter(torch.randn(d_out, r) * init_scale)
    self.A_r = nn.Parameter(torch.zeros(r, d_in))
    self.A_i = nn.Parameter(torch.zeros(r, d_in))

    def deltaW_real(self):
    # Real part of (B A)
    Wr = self.B_r @ self.A_r - self.B_i @ self.A_i
    return self.scale * Wr

    class LinearWithComplexLoRA(nn.Module):
    \"""
    Wraps an existing nn.Linear and adds a complex-LoRA delta to its weight.
    By default, freezes the base linear (LoRA-style).
    \"""
    def __init__(self, base_linear: nn.Linear, r=8, alpha=16, freeze_base=True):
    super().__init__()
    assert isinstance(base_linear, nn.Linear)
    self.base = base_linear
    if freeze_base:
    for p in self.base.parameters():
    p.requires_grad_(False)
    self.lora = ComplexLoRA(self.base.in_features, self.base.out_features, r=r, alpha=alpha)

    def forward(self, x):
    W_eff = self.base.weight + self.lora.deltaW_real()
    return F.linear(x, W_eff, self.base.bias)
    """

    with open("complex_lora.py", "w") as f:
    f.write(textwrap.dedent(COMPLEX_LORA_PY).lstrip())

    print("Wrote complex_lora.py")

    # --------
    # CELL 4: Runtime patcher to wrap selected linears
    # --------
    PATCHER_PY = r"""
    import torch
    import torch.nn as nn
    from types import SimpleNamespace
    from complex_lora import LinearWithComplexLoRA

    def wrap_named_linears(model: nn.Module, name_predicate, r=8, alpha=16, freeze_base=True, verbose=True):
    """
    Traverse modules; if name_predicate(full_name, module) is True and module is nn.Linear,
    replace it with LinearWithComplexLoRA(base_linear=module).
    Returns a report dict with counts.
    """
    replaced = 0
    kept = 0
    parent_map = {}
    for name, module in model.named_modules():
    for child_name, child in module.named_children():
    parent_map[f"{name}.{child_name}" if name else child_name] = (module, child_name)

    for full_name, module in list(model.named_modules()):
    if isinstance(module, nn.Linear) and name_predicate(full_name, module):
    parent, child_name = parent_map[full_name]
    new_mod = LinearWithComplexLoRA(module, r=r, alpha=alpha, freeze_base=freeze_base)
    setattr(parent, child_name, new_mod)
    replaced += 1
    if verbose:
    print(f"[ComplexLoRA] Wrapped: {full_name} (in={module.in_features}, out={module.out_features})")
    elif isinstance(module, nn.Linear):
    kept += 1
    return {"replaced": replaced, "kept_linear": kept}

    def default_transformer_predicate(full_name, module):
    """
    Heuristic: wrap attention q/k/v/o and MLP projectors in nanoGPT:
    - names ending with .c_attn, .c_proj (attention)
    - names containing .w1, .w2, .w3 (MLP in/out in some variants)
    Adjust this predicate if your fork/names differ.
    """
    targets = (".c_attn", ".c_proj", ".w1", ".w2", ".w3")
    return any(full_name.endswith(t) for t in targets)
    """
    with open("patcher.py", "w") as f:
    f.write(textwrap.dedent(PATCHER_PY).lstrip())

    print("Wrote patcher.py")

    # --------
    # CELL 5: Quick forward-pass smoke test
    # --------
    SMOKE_TEST_PY = r"""
    import torch
    import torch.nn as nn
    from model import GPTConfig, GPT
    from patcher import wrap_named_linears, default_transformer_predicate

    # Tiny config for speed
    conf = GPTConfig(
    block_size=128,
    vocab_size=50304, # GPT-2 BPE default-ish
    n_layer=2,
    n_head=2,
    n_embd=128
    )

    model = GPT(conf)
    print("Before patch, params:", sum(p.numel() for p in model.parameters()))

    # Apply complex-LoRA wrapping
    report = wrap_named_linears(model, default_transformer_predicate, r=8, alpha=16, freeze_base=True, verbose=True)
    print("Patch report:", report)
    print("After patch, params (trainable):", sum(p.numel() for p in model.parameters() if p.requires_grad))

    # Smoke forward
    B, T = 4, 64
    x = torch.randint(0, conf.vocab_size, (B, T))
    with torch.no_grad():
    logits, loss = model(x, targets=x)
    print("Forward ok. Logits:", logits.shape, "Loss:", float(loss))
    print("SUCCESS: complex-LoRA forward pass completed.")
    """
    with open("smoke_test.py", "w") as f:
    f.write(textwrap.dedent(SMOKE_TEST_PY).lstrip())

    print("Wrote smoke_test.py")

    # Run the smoke test
    !python smoke_test.py

    # --------
    # CELL 6 (optional): Tiny training smoke test on tiny Shakespeare
    # --------
    # WARNING: This is just a few steps to verify training isn't broken.
    # Comment out or reduce steps if you're on very small/slow runtimes.

    TRAIN_SNIPPET = r"""
    import os, torch
    from pathlib import Path
    from model import GPTConfig, GPT
    from patcher import wrap_named_linears, default_transformer_predicate
    from torch.optim import AdamW
    import torch.nn.functional as F

    # Make a tiny toy dataset (random tokens) to prove the loop runs.
    # Replace with real tiny-shakespeare for meaningful loss curves.
    vocab_size = 50304
    B, T = 8, 64
    steps = 20

    conf = GPTConfig(block_size=T, vocab_size=vocab_size, n_layer=2, n_head=2, n_embd=128)
    model = GPT(conf)

    wrap_named_linears(model, default_transformer_predicate, r=8, alpha=16, freeze_base=True, verbose=False)
    opt = AdamW([p for p in model.parameters() if p.requires_grad], lr=3e-3)

    model.train()
    for it in range(steps):
    x = torch.randint(0, vocab_size, (B, T))
    logits, loss = model(x, targets=x)
    opt.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    opt.step()
    if (it+1) % 5 == 0:
    print(f"step {it+1}/{steps} loss={float(loss):.4f}")
    print("TRAIN LOOP OK")
    """
    with open("tiny_train.py", "w") as f:
    f.write(textwrap.dedent(TRAIN_SNIPPET).lstrip())

    print("Wrote tiny_train.py")

    # Uncomment to run a tiny dummy train loop:
    # !python tiny_train.py

    print("\nAll set. To train on real data, run nanoGPT's prepare scripts and training script, then import `patcher.wrap_named_linears(...)` before training to enable Complex-LoRA.\n")