Skip to content

Instantly share code, notes, and snippets.

@NaxAlpha
Created November 23, 2022 01:18
Show Gist options
  • Select an option

  • Save NaxAlpha/d2da09a0de5c85962bbe42c929f4027a to your computer and use it in GitHub Desktop.

Select an option

Save NaxAlpha/d2da09a0de5c85962bbe42c929f4027a to your computer and use it in GitHub Desktop.

Revisions

  1. NaxAlpha created this gist Nov 23, 2022.
    62 changes: 62 additions & 0 deletions c4x.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,62 @@
    # stream C4 dataset from Huggingface with GPT-2 Tokenizer for PyTorch Language Model Training
    import json
    import torch
    import random
    from datasets import load_dataset
    from transformers import GPT2Tokenizer
    from torch.utils.data import Dataset, get_worker_info


    def cycled(itr):
    while True:
    for itm in itr:
    yield itm

    class C4X(Dataset):

    def __init__(self, seq_len=512, split='train'):
    self.seq = seq_len
    self.ds = load_dataset(
    'c4',
    name='en',
    split=split,
    streaming=True,
    )
    self.tok = GPT2Tokenizer.from_pretrained('gpt2')
    self.init = False

    def __len__(self):
    return 1_000_000_000

    def _init(self):
    if self.init:
    return
    wi = get_worker_info()
    self.ds = cycled(
    self.ds.shuffle(
    seed=wi.seed,
    buffer_size=10_000,
    )
    )
    self.init = True

    def _get_next(self):
    self._init()
    obj = next(self.ds)['text']
    tkn = self.tok.encode(obj)
    return tkn

    def _get_full(self):
    obj = []
    while len(obj) < self.seq:
    obj += self._get_next()
    obj.append(self.tok.eos_token_id)
    s = random.randint(0, len(obj)-self.seq)
    return obj[s:s+self.seq]

    def __getitem__(self, _):
    return torch.tensor(self._get_full())

    def decode(self, tkns):
    return self.tok.decode(tkns)