import time from contextlib import suppress import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torch.backends.cuda as cuda from torch.utils.data import DataLoader, IterableDataset import wandb from tqdm import tqdm from datasets import load_dataset from transformers import GPT2TokenizerFast from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Attention _attn_orig = GPT2Attention._attn WANDB_STYLE = """ """ # patch GPT2Attention to use flash_sdp, disable it when doing the inference def _attn_wrapper(self, query, key, value, attention_mask=None, head_mask=None): if head_mask is not None: raise NotImplementedError("head_mask is not implemented for flash_sdp") is_causal = attention_mask is None with cuda.sdp_kernel( enable_flash=True, enable_math=False, enable_mem_efficient=False, ): attn_out = F.scaled_dot_product_attention( query=query.half(), key=key.half(), value=value.half(), is_causal=is_causal, attn_mask=attention_mask, dropout_p=self.attn_dropout.p, ).float() return attn_out, None def closest_power_of_2(x): return 2 ** (x - 1).bit_length() def make_model(pretrained_name, max_tokens): model = GPT2LMHeadModel.from_pretrained(pretrained_name).cuda() GPT2Attention._attn = _attn_wrapper model.config.update( dict( n_ctx=max_tokens, n_positions=max_tokens, ) ) # patch model embeddings emb = model.transformer.wpe.weight.data wpe = nn.Embedding(max_tokens, emb.shape[1]) wpe.weight.data = emb.repeat(max_tokens // emb.shape[0], 1) model.transformer.wpe = wpe # also increase mask size for block in model.transformer.h: block.attn.bias.data = ( torch.tril(torch.ones((max_tokens, max_tokens), dtype=torch.bool)) .view(1, 1, max_tokens, max_tokens) .cuda() ) return model class DatasetWrapper(IterableDataset): def __init__(self, max_tokens=2**12): self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") self.max_tokens = max_tokens def __iter__(self): buffer = [] for sample in load_dataset( "the_pile", name="all", split="train", streaming=True, ).shuffle(buffer_size=10_000): buffer += self.tokenizer(sample["text"])["input_ids"] buffer += [self.tokenizer.eos_token_id] while len(buffer) > self.max_tokens: yield torch.tensor(buffer[: self.max_tokens]) buffer = buffer[self.max_tokens :] class Trainer: def __init__(self): self.max_tokens = 2**13 self.grad = 1 self.step = 0 self.dataset = DatasetWrapper(self.max_tokens) self.tokenizer = self.dataset.tokenizer self.loader = DataLoader( self.dataset, batch_size=1, num_workers=8, ) self.scaler = torch.cuda.amp.GradScaler() self.model = model = make_model("gpt2-medium", self.max_tokens) self.opt = optim.Adam( params=model.parameters(), lr=5e-6, weight_decay=1e-1, betas=(0.9, 0.95), fused=True, ) self.model = torch.compile(model) def train_step(self, batch): batch = batch.cuda() with torch.autocast(device_type="cuda", enabled=True): loss = self.model(batch, labels=batch).loss loss = loss / self.grad self.scaler.scale(loss).backward() return loss def generate_samples(self, n_samples=8): GPT2Attention._attn = _attn_orig # back to faster but more memory consuming model = self.model x = torch.tensor([[self.tokenizer.eos_token_id]] * n_samples).cuda() t0 = time.time() model.eval() y = model.generate( inputs=x, max_length=self.max_tokens, do_sample=True, ).tolist() model.train() t1 = time.time() t = [self.tokenizer.decode(z) for z in y] t = "
".join(f"

{c}

" for c in t) html = WANDB_STYLE + t wandb.log({"samples": wandb.Html(html)}, step=self.step) print(f"Generated in {t1-t0:.3f}s") GPT2Attention._attn = _attn_wrapper def train(self): wandb.init( project="long-gptx", entity="_", ) prog = tqdm(self.loader) self.opt.zero_grad() for i, batch in enumerate(prog): self.step = i + 1 loss = self.train_step(batch) prog.set_description(f"loss: {loss.item():.3f}") wandb.log( { "loss": loss.item(), "grad": self.grad, }, step=i, ) if i % self.grad == 0: self.scaler.step(self.opt) self.scaler.update() self.opt.zero_grad() self.grad = max(1, closest_power_of_2(i + 1) // 32) # if i % 1000 == 0: # with suppress(Exception): # self.model.save_pretrained( # "_", # push_to_hub=True, # max_shard_size="500MB", # ) if i % 1000 == 0: self.generate_samples(16) if __name__ == "__main__": trainer = Trainer() trainer.train()