Skip to content

Instantly share code, notes, and snippets.

@aimnemy
Forked from Laeeth/long_gpt.py
Created July 31, 2024 08:02
Show Gist options
  • Save aimnemy/63b1d5805f88327062d4d1e59d2152e7 to your computer and use it in GitHub Desktop.
Save aimnemy/63b1d5805f88327062d4d1e59d2152e7 to your computer and use it in GitHub Desktop.
Training script for LongGPT; Fine-tunes GPT-2 (335M) on The Pile Dataset with a context size of 8k tokens. (requires > 16GB RAM)
import time
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, IterableDataset
import wandb
from tqdm import tqdm
from datasets import load_dataset
from transformers import GPT2TokenizerFast
# copy from here: https://github.com/karpathy/nanoGPT/blob/master/model.py
from ngpt import GPT
WANDB_STYLE = """
<style>
html, body {
padding: 0;
margin: 0;
width: 100%;
height: 100%;
}
p {
font-family: 'Verdana', sans-serif;
}
</style>
"""
def closest_power_of_2(x):
return 2 ** (x - 1).bit_length()
class DatasetWrapper(IterableDataset):
def __init__(self, max_tokens=2**12):
self.ds = load_dataset(
"the_pile",
name="all",
split="train",
streaming=True,
).shuffle(buffer_size=100_000)
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
self.max_tokens = max_tokens
def __iter__(self):
buffer = []
for sample in self.ds:
buffer += self.tokenizer(sample["text"])["input_ids"]
buffer += [self.tokenizer.eos_token_id]
while len(buffer) > self.max_tokens:
yield torch.tensor(buffer[: self.max_tokens])
buffer = buffer[self.max_tokens :]
class Trainer:
def __init__(self):
self.tokenizer: GPT2TokenizerFast = GPT2TokenizerFast.from_pretrained("gpt2")
self.max_tokens = 2**12
self.grad = 1
self.step = 0
self.dataset = DatasetWrapper(self.max_tokens)
self.loader = DataLoader(
self.dataset,
batch_size=1,
num_workers=8,
)
self.model = model = GPT.from_pretrained("gpt2").cuda()
# model.load_state_dict(torch.load("v2.pt"))
self.opt = model.configure_optimizers(
weight_decay=1e-1,
learning_rate=1e-6,
betas=(0.9, 0.95),
device_type="cuda",
)
# patch model embeddings
emb = model.transformer.wpe.weight.data
wpe = nn.Embedding(self.max_tokens, emb.shape[1])
wpe.weight.data = emb.repeat(self.max_tokens // emb.shape[0], 1)
model.transformer.wpe = wpe
model.config.block_size = self.max_tokens
print("Patched model embeddings:", wpe.weight.shape)
self.model = torch.compile(model)
def train_step(self, batch):
batch = batch.cuda()
x, y = batch[:, :-1], batch[:, 1:].contiguous()
_, loss = self.model(x, targets=y)
(loss / self.grad).backward()
return loss
def generate_samples(self, n_samples=8):
x = torch.tensor([[self.tokenizer.eos_token_id]] * n_samples).cuda()
t0 = time.time()
self.model.eval()
y = self.model.generate(x, max_new_tokens=1100).tolist()
self.model.train()
t1 = time.time()
t = [self.tokenizer.decode(z) for z in y]
t = "<hr>".join(f"<p>{c}</p>" for c in t)
html = WANDB_STYLE + t
wandb.log({"samples": wandb.Html(html)}, step=self.step)
print(f"Generated in {t1-t0:.3f}s")
def train(self):
wandb.init(
project="long-gpt",
entity="...",
)
prog = tqdm(self.loader)
self.opt.zero_grad()
for i, batch in enumerate(prog):
self.step = i + 1
loss = self.train_step(batch)
prog.set_description(f"loss: {loss.item():.3f}")
wandb.log({"loss": loss.item(), "grad": self.grad}, step=i)
if i % self.grad == 0:
self.opt.step()
self.opt.zero_grad()
self.grad = closest_power_of_2(i + 1)
if i % 100 == 0:
torch.save(self.model.state_dict(), "model.pt")
if i % 1000 == 0:
self.generate_samples(8)
if __name__ == "__main__":
trainer = Trainer()
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment