Skip to content

Instantly share code, notes, and snippets.

@Laeeth
Forked from NaxAlpha/long_gpt.py
Created April 13, 2023 04:17
Show Gist options
  • Save Laeeth/ecaa9cfb982b58fe8c0680057b0af50d to your computer and use it in GitHub Desktop.
Save Laeeth/ecaa9cfb982b58fe8c0680057b0af50d to your computer and use it in GitHub Desktop.

Revisions

  1. @NaxAlpha NaxAlpha revised this gist Apr 6, 2023. 1 changed file with 102 additions and 41 deletions.
    143 changes: 102 additions & 41 deletions long_gpt.py
    Original file line number Diff line number Diff line change
    @@ -1,19 +1,21 @@
    import time
    import random
    from contextlib import suppress

    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    import torch.backends.cuda as cuda
    from torch.utils.data import DataLoader, IterableDataset

    import wandb
    from tqdm import tqdm

    from datasets import load_dataset
    from transformers import GPT2TokenizerFast
    from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Attention

    # download from here: https://github.com/karpathy/nanoGPT/blob/master/model.py
    from ngpt import GPT
    _attn_orig = GPT2Attention._attn

    WANDB_STYLE = """
    <style>
    @@ -30,24 +32,72 @@
    """


    # patch GPT2Attention to use flash_sdp, disable it when doing the inference
    def _attn_wrapper(self, query, key, value, attention_mask=None, head_mask=None):
    if head_mask is not None:
    raise NotImplementedError("head_mask is not implemented for flash_sdp")
    is_causal = attention_mask is None
    with cuda.sdp_kernel(
    enable_flash=True,
    enable_math=False,
    enable_mem_efficient=False,
    ):
    attn_out = F.scaled_dot_product_attention(
    query=query.half(),
    key=key.half(),
    value=value.half(),
    is_causal=is_causal,
    attn_mask=attention_mask,
    dropout_p=self.attn_dropout.p,
    ).float()
    return attn_out, None


    def closest_power_of_2(x):
    return 2 ** (x - 1).bit_length()


    def make_model(pretrained_name, max_tokens):
    model = GPT2LMHeadModel.from_pretrained(pretrained_name).cuda()
    GPT2Attention._attn = _attn_wrapper

    model.config.update(
    dict(
    n_ctx=max_tokens,
    n_positions=max_tokens,
    )
    )

    # patch model embeddings
    emb = model.transformer.wpe.weight.data
    wpe = nn.Embedding(max_tokens, emb.shape[1])
    wpe.weight.data = emb.repeat(max_tokens // emb.shape[0], 1)
    model.transformer.wpe = wpe

    # also increase mask size
    for block in model.transformer.h:
    block.attn.bias.data = (
    torch.tril(torch.ones((max_tokens, max_tokens), dtype=torch.bool))
    .view(1, 1, max_tokens, max_tokens)
    .cuda()
    )

    return model


    class DatasetWrapper(IterableDataset):
    def __init__(self, max_tokens=2**12):
    self.ds = load_dataset(
    "the_pile",
    name="all",
    split="train",
    streaming=True,
    ).shuffle(buffer_size=100_000)
    self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
    self.max_tokens = max_tokens

    def __iter__(self):
    buffer = []
    for sample in self.ds:
    for sample in load_dataset(
    "the_pile",
    name="all",
    split="train",
    streaming=True,
    ).shuffle(buffer_size=10_000):
    buffer += self.tokenizer(sample["text"])["input_ids"]
    buffer += [self.tokenizer.eos_token_id]
    while len(buffer) > self.max_tokens:
    @@ -57,61 +107,60 @@ def __iter__(self):

    class Trainer:
    def __init__(self):
    self.tokenizer: GPT2TokenizerFast = GPT2TokenizerFast.from_pretrained("gpt2")
    self.max_tokens = 2**12
    self.max_tokens = 2**13
    self.grad = 1
    self.step = 0

    self.dataset = DatasetWrapper(self.max_tokens)
    self.tokenizer = self.dataset.tokenizer
    self.loader = DataLoader(
    self.dataset,
    batch_size=1,
    num_workers=8,
    )
    self.model = model = GPT.from_pretrained("gpt2").cuda()
    # model.load_state_dict(torch.load("v2.pt"))

    self.opt = model.configure_optimizers(
    self.scaler = torch.cuda.amp.GradScaler()
    self.model = model = make_model("gpt2-medium", self.max_tokens)
    self.opt = optim.Adam(
    params=model.parameters(),
    lr=5e-6,
    weight_decay=1e-1,
    learning_rate=1e-6,
    betas=(0.9, 0.95),
    device_type="cuda",
    fused=True,
    )

    # patch model embeddings
    emb = model.transformer.wpe.weight.data
    wpe = nn.Embedding(self.max_tokens, emb.shape[1])
    wpe.weight.data = emb.repeat(self.max_tokens // emb.shape[0], 1)
    model.transformer.wpe = wpe
    model.config.block_size = self.max_tokens
    print("Patched model embeddings:", wpe.weight.shape)

    self.model = torch.compile(model)

    def train_step(self, batch):
    batch = batch.cuda()
    x, y = batch[:, :-1], batch[:, 1:].contiguous()
    _, loss = self.model(x, targets=y)
    (loss / self.grad).backward()
    with torch.autocast(device_type="cuda", enabled=True):
    loss = self.model(batch, labels=batch).loss
    loss = loss / self.grad
    self.scaler.scale(loss).backward()
    return loss

    def generate_samples(self, n_samples=8):
    GPT2Attention._attn = _attn_orig # back to faster but more memory consuming
    model = self.model
    x = torch.tensor([[self.tokenizer.eos_token_id]] * n_samples).cuda()
    t0 = time.time()
    self.model.eval()
    y = self.model.generate(x, max_new_tokens=1100).tolist()
    self.model.train()
    model.eval()
    y = model.generate(
    inputs=x,
    max_length=self.max_tokens,
    do_sample=True,
    ).tolist()
    model.train()
    t1 = time.time()
    t = [self.tokenizer.decode(z) for z in y]
    t = "<hr>".join(f"<p>{c}</p>" for c in t)
    html = WANDB_STYLE + t
    wandb.log({"samples": wandb.Html(html)}, step=self.step)
    print(f"Generated in {t1-t0:.3f}s")
    GPT2Attention._attn = _attn_wrapper

    def train(self):
    wandb.init(
    project="long-gpt",
    entity="...",
    project="long-gptx",
    entity="_",
    )

    prog = tqdm(self.loader)
    @@ -122,18 +171,30 @@ def train(self):

    loss = self.train_step(batch)
    prog.set_description(f"loss: {loss.item():.3f}")
    wandb.log({"loss": loss.item(), "grad": self.grad}, step=i)
    wandb.log(
    {
    "loss": loss.item(),
    "grad": self.grad,
    },
    step=i,
    )

    if i % self.grad == 0:
    self.opt.step()
    self.scaler.step(self.opt)
    self.scaler.update()
    self.opt.zero_grad()
    self.grad = max(1, closest_power_of_2(i + 1) // 64)
    self.grad = max(1, closest_power_of_2(i + 1) // 32)

    if i % 100 == 0:
    torch.save(self.model.state_dict(), "model.pt")
    # if i % 1000 == 0:
    # with suppress(Exception):
    # self.model.save_pretrained(
    # "_",
    # push_to_hub=True,
    # max_shard_size="500MB",
    # )

    if i % 1000 == 0:
    self.generate_samples(8)
    self.generate_samples(16)


    if __name__ == "__main__":
  2. @NaxAlpha NaxAlpha revised this gist Mar 17, 2023. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion long_gpt.py
    Original file line number Diff line number Diff line change
    @@ -111,7 +111,7 @@ def generate_samples(self, n_samples=8):
    def train(self):
    wandb.init(
    project="long-gpt",
    entity="nax-autify",
    entity="...",
    )

    prog = tqdm(self.loader)
  3. @NaxAlpha NaxAlpha revised this gist Mar 17, 2023. 1 changed file with 3 additions and 3 deletions.
    6 changes: 3 additions & 3 deletions long_gpt.py
    Original file line number Diff line number Diff line change
    @@ -12,7 +12,7 @@
    from datasets import load_dataset
    from transformers import GPT2TokenizerFast

    # copy from here: https://github.com/karpathy/nanoGPT/blob/master/model.py
    # download from here: https://github.com/karpathy/nanoGPT/blob/master/model.py
    from ngpt import GPT

    WANDB_STYLE = """
    @@ -111,7 +111,7 @@ def generate_samples(self, n_samples=8):
    def train(self):
    wandb.init(
    project="long-gpt",
    entity="...",
    entity="nax-autify",
    )

    prog = tqdm(self.loader)
    @@ -127,7 +127,7 @@ def train(self):
    if i % self.grad == 0:
    self.opt.step()
    self.opt.zero_grad()
    self.grad = closest_power_of_2(i + 1)
    self.grad = max(1, closest_power_of_2(i + 1) // 64)

    if i % 100 == 0:
    torch.save(self.model.state_dict(), "model.pt")
  4. @NaxAlpha NaxAlpha created this gist Mar 17, 2023.
    141 changes: 141 additions & 0 deletions long_gpt.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,141 @@
    import time
    import random

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import DataLoader, IterableDataset

    import wandb
    from tqdm import tqdm

    from datasets import load_dataset
    from transformers import GPT2TokenizerFast

    # copy from here: https://github.com/karpathy/nanoGPT/blob/master/model.py
    from ngpt import GPT

    WANDB_STYLE = """
    <style>
    html, body {
    padding: 0;
    margin: 0;
    width: 100%;
    height: 100%;
    }
    p {
    font-family: 'Verdana', sans-serif;
    }
    </style>
    """


    def closest_power_of_2(x):
    return 2 ** (x - 1).bit_length()


    class DatasetWrapper(IterableDataset):
    def __init__(self, max_tokens=2**12):
    self.ds = load_dataset(
    "the_pile",
    name="all",
    split="train",
    streaming=True,
    ).shuffle(buffer_size=100_000)
    self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
    self.max_tokens = max_tokens

    def __iter__(self):
    buffer = []
    for sample in self.ds:
    buffer += self.tokenizer(sample["text"])["input_ids"]
    buffer += [self.tokenizer.eos_token_id]
    while len(buffer) > self.max_tokens:
    yield torch.tensor(buffer[: self.max_tokens])
    buffer = buffer[self.max_tokens :]


    class Trainer:
    def __init__(self):
    self.tokenizer: GPT2TokenizerFast = GPT2TokenizerFast.from_pretrained("gpt2")
    self.max_tokens = 2**12
    self.grad = 1
    self.step = 0

    self.dataset = DatasetWrapper(self.max_tokens)
    self.loader = DataLoader(
    self.dataset,
    batch_size=1,
    num_workers=8,
    )
    self.model = model = GPT.from_pretrained("gpt2").cuda()
    # model.load_state_dict(torch.load("v2.pt"))

    self.opt = model.configure_optimizers(
    weight_decay=1e-1,
    learning_rate=1e-6,
    betas=(0.9, 0.95),
    device_type="cuda",
    )

    # patch model embeddings
    emb = model.transformer.wpe.weight.data
    wpe = nn.Embedding(self.max_tokens, emb.shape[1])
    wpe.weight.data = emb.repeat(self.max_tokens // emb.shape[0], 1)
    model.transformer.wpe = wpe
    model.config.block_size = self.max_tokens
    print("Patched model embeddings:", wpe.weight.shape)

    self.model = torch.compile(model)

    def train_step(self, batch):
    batch = batch.cuda()
    x, y = batch[:, :-1], batch[:, 1:].contiguous()
    _, loss = self.model(x, targets=y)
    (loss / self.grad).backward()
    return loss

    def generate_samples(self, n_samples=8):
    x = torch.tensor([[self.tokenizer.eos_token_id]] * n_samples).cuda()
    t0 = time.time()
    self.model.eval()
    y = self.model.generate(x, max_new_tokens=1100).tolist()
    self.model.train()
    t1 = time.time()
    t = [self.tokenizer.decode(z) for z in y]
    t = "<hr>".join(f"<p>{c}</p>" for c in t)
    html = WANDB_STYLE + t
    wandb.log({"samples": wandb.Html(html)}, step=self.step)
    print(f"Generated in {t1-t0:.3f}s")

    def train(self):
    wandb.init(
    project="long-gpt",
    entity="...",
    )

    prog = tqdm(self.loader)
    self.opt.zero_grad()

    for i, batch in enumerate(prog):
    self.step = i + 1

    loss = self.train_step(batch)
    prog.set_description(f"loss: {loss.item():.3f}")
    wandb.log({"loss": loss.item(), "grad": self.grad}, step=i)

    if i % self.grad == 0:
    self.opt.step()
    self.opt.zero_grad()
    self.grad = closest_power_of_2(i + 1)

    if i % 100 == 0:
    torch.save(self.model.state_dict(), "model.pt")

    if i % 1000 == 0:
    self.generate_samples(8)


    if __name__ == "__main__":
    trainer = Trainer()
    trainer.train()