|
|
@@ -1,19 +1,21 @@ |
|
|
import time |
|
|
import random |
|
|
from contextlib import suppress |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
import torch.nn.functional as F |
|
|
import torch.backends.cuda as cuda |
|
|
from torch.utils.data import DataLoader, IterableDataset |
|
|
|
|
|
import wandb |
|
|
from tqdm import tqdm |
|
|
|
|
|
from datasets import load_dataset |
|
|
from transformers import GPT2TokenizerFast |
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Attention |
|
|
|
|
|
# download from here: https://github.com/karpathy/nanoGPT/blob/master/model.py |
|
|
from ngpt import GPT |
|
|
_attn_orig = GPT2Attention._attn |
|
|
|
|
|
WANDB_STYLE = """ |
|
|
<style> |
|
|
@@ -30,24 +32,72 @@ |
|
|
""" |
|
|
|
|
|
|
|
|
# patch GPT2Attention to use flash_sdp, disable it when doing the inference |
|
|
def _attn_wrapper(self, query, key, value, attention_mask=None, head_mask=None): |
|
|
if head_mask is not None: |
|
|
raise NotImplementedError("head_mask is not implemented for flash_sdp") |
|
|
is_causal = attention_mask is None |
|
|
with cuda.sdp_kernel( |
|
|
enable_flash=True, |
|
|
enable_math=False, |
|
|
enable_mem_efficient=False, |
|
|
): |
|
|
attn_out = F.scaled_dot_product_attention( |
|
|
query=query.half(), |
|
|
key=key.half(), |
|
|
value=value.half(), |
|
|
is_causal=is_causal, |
|
|
attn_mask=attention_mask, |
|
|
dropout_p=self.attn_dropout.p, |
|
|
).float() |
|
|
return attn_out, None |
|
|
|
|
|
|
|
|
def closest_power_of_2(x): |
|
|
return 2 ** (x - 1).bit_length() |
|
|
|
|
|
|
|
|
def make_model(pretrained_name, max_tokens): |
|
|
model = GPT2LMHeadModel.from_pretrained(pretrained_name).cuda() |
|
|
GPT2Attention._attn = _attn_wrapper |
|
|
|
|
|
model.config.update( |
|
|
dict( |
|
|
n_ctx=max_tokens, |
|
|
n_positions=max_tokens, |
|
|
) |
|
|
) |
|
|
|
|
|
# patch model embeddings |
|
|
emb = model.transformer.wpe.weight.data |
|
|
wpe = nn.Embedding(max_tokens, emb.shape[1]) |
|
|
wpe.weight.data = emb.repeat(max_tokens // emb.shape[0], 1) |
|
|
model.transformer.wpe = wpe |
|
|
|
|
|
# also increase mask size |
|
|
for block in model.transformer.h: |
|
|
block.attn.bias.data = ( |
|
|
torch.tril(torch.ones((max_tokens, max_tokens), dtype=torch.bool)) |
|
|
.view(1, 1, max_tokens, max_tokens) |
|
|
.cuda() |
|
|
) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
class DatasetWrapper(IterableDataset): |
|
|
def __init__(self, max_tokens=2**12): |
|
|
self.ds = load_dataset( |
|
|
"the_pile", |
|
|
name="all", |
|
|
split="train", |
|
|
streaming=True, |
|
|
).shuffle(buffer_size=100_000) |
|
|
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") |
|
|
self.max_tokens = max_tokens |
|
|
|
|
|
def __iter__(self): |
|
|
buffer = [] |
|
|
for sample in self.ds: |
|
|
for sample in load_dataset( |
|
|
"the_pile", |
|
|
name="all", |
|
|
split="train", |
|
|
streaming=True, |
|
|
).shuffle(buffer_size=10_000): |
|
|
buffer += self.tokenizer(sample["text"])["input_ids"] |
|
|
buffer += [self.tokenizer.eos_token_id] |
|
|
while len(buffer) > self.max_tokens: |
|
|
@@ -57,61 +107,60 @@ def __iter__(self): |
|
|
|
|
|
class Trainer: |
|
|
def __init__(self): |
|
|
self.tokenizer: GPT2TokenizerFast = GPT2TokenizerFast.from_pretrained("gpt2") |
|
|
self.max_tokens = 2**12 |
|
|
self.max_tokens = 2**13 |
|
|
self.grad = 1 |
|
|
self.step = 0 |
|
|
|
|
|
self.dataset = DatasetWrapper(self.max_tokens) |
|
|
self.tokenizer = self.dataset.tokenizer |
|
|
self.loader = DataLoader( |
|
|
self.dataset, |
|
|
batch_size=1, |
|
|
num_workers=8, |
|
|
) |
|
|
self.model = model = GPT.from_pretrained("gpt2").cuda() |
|
|
# model.load_state_dict(torch.load("v2.pt")) |
|
|
|
|
|
self.opt = model.configure_optimizers( |
|
|
self.scaler = torch.cuda.amp.GradScaler() |
|
|
self.model = model = make_model("gpt2-medium", self.max_tokens) |
|
|
self.opt = optim.Adam( |
|
|
params=model.parameters(), |
|
|
lr=5e-6, |
|
|
weight_decay=1e-1, |
|
|
learning_rate=1e-6, |
|
|
betas=(0.9, 0.95), |
|
|
device_type="cuda", |
|
|
fused=True, |
|
|
) |
|
|
|
|
|
# patch model embeddings |
|
|
emb = model.transformer.wpe.weight.data |
|
|
wpe = nn.Embedding(self.max_tokens, emb.shape[1]) |
|
|
wpe.weight.data = emb.repeat(self.max_tokens // emb.shape[0], 1) |
|
|
model.transformer.wpe = wpe |
|
|
model.config.block_size = self.max_tokens |
|
|
print("Patched model embeddings:", wpe.weight.shape) |
|
|
|
|
|
self.model = torch.compile(model) |
|
|
|
|
|
def train_step(self, batch): |
|
|
batch = batch.cuda() |
|
|
x, y = batch[:, :-1], batch[:, 1:].contiguous() |
|
|
_, loss = self.model(x, targets=y) |
|
|
(loss / self.grad).backward() |
|
|
with torch.autocast(device_type="cuda", enabled=True): |
|
|
loss = self.model(batch, labels=batch).loss |
|
|
loss = loss / self.grad |
|
|
self.scaler.scale(loss).backward() |
|
|
return loss |
|
|
|
|
|
def generate_samples(self, n_samples=8): |
|
|
GPT2Attention._attn = _attn_orig # back to faster but more memory consuming |
|
|
model = self.model |
|
|
x = torch.tensor([[self.tokenizer.eos_token_id]] * n_samples).cuda() |
|
|
t0 = time.time() |
|
|
self.model.eval() |
|
|
y = self.model.generate(x, max_new_tokens=1100).tolist() |
|
|
self.model.train() |
|
|
model.eval() |
|
|
y = model.generate( |
|
|
inputs=x, |
|
|
max_length=self.max_tokens, |
|
|
do_sample=True, |
|
|
).tolist() |
|
|
model.train() |
|
|
t1 = time.time() |
|
|
t = [self.tokenizer.decode(z) for z in y] |
|
|
t = "<hr>".join(f"<p>{c}</p>" for c in t) |
|
|
html = WANDB_STYLE + t |
|
|
wandb.log({"samples": wandb.Html(html)}, step=self.step) |
|
|
print(f"Generated in {t1-t0:.3f}s") |
|
|
GPT2Attention._attn = _attn_wrapper |
|
|
|
|
|
def train(self): |
|
|
wandb.init( |
|
|
project="long-gpt", |
|
|
entity="...", |
|
|
project="long-gptx", |
|
|
entity="_", |
|
|
) |
|
|
|
|
|
prog = tqdm(self.loader) |
|
|
@@ -122,18 +171,30 @@ def train(self): |
|
|
|
|
|
loss = self.train_step(batch) |
|
|
prog.set_description(f"loss: {loss.item():.3f}") |
|
|
wandb.log({"loss": loss.item(), "grad": self.grad}, step=i) |
|
|
wandb.log( |
|
|
{ |
|
|
"loss": loss.item(), |
|
|
"grad": self.grad, |
|
|
}, |
|
|
step=i, |
|
|
) |
|
|
|
|
|
if i % self.grad == 0: |
|
|
self.opt.step() |
|
|
self.scaler.step(self.opt) |
|
|
self.scaler.update() |
|
|
self.opt.zero_grad() |
|
|
self.grad = max(1, closest_power_of_2(i + 1) // 64) |
|
|
self.grad = max(1, closest_power_of_2(i + 1) // 32) |
|
|
|
|
|
if i % 100 == 0: |
|
|
torch.save(self.model.state_dict(), "model.pt") |
|
|
# if i % 1000 == 0: |
|
|
# with suppress(Exception): |
|
|
# self.model.save_pretrained( |
|
|
# "_", |
|
|
# push_to_hub=True, |
|
|
# max_shard_size="500MB", |
|
|
# ) |
|
|
|
|
|
if i % 1000 == 0: |
|
|
self.generate_samples(8) |
|
|
self.generate_samples(16) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|