Skip to content

Instantly share code, notes, and snippets.

@NaxAlpha
Last active September 2, 2024 01:39
Show Gist options
  • Select an option

  • Save NaxAlpha/c6a7c65f40c05af0907b25fd742a8df0 to your computer and use it in GitHub Desktop.

Select an option

Save NaxAlpha/c6a7c65f40c05af0907b25fd742a8df0 to your computer and use it in GitHub Desktop.

Revisions

  1. NaxAlpha revised this gist Jan 12, 2023. 1 changed file with 2 additions and 1 deletion.
    3 changes: 2 additions & 1 deletion hf_pile.py
    Original file line number Diff line number Diff line change
    @@ -1,3 +1,4 @@
    import torch
    from torch.utils.data import IterableDataset

    from transformers import PreTrainedTokenizerBase
    @@ -26,7 +27,7 @@ def __iter__(self):
    tokens = self.tokenizer.encode(next(ds)["text"])
    buffer += [self.tokenizer.eos_token_id] + tokens
    while len(buffer) > self.max_length:
    yield buffer[: self.max_length]
    yield torch.tensor(buffer[: self.max_length])
    buffer = buffer[self.max_length // self.repeat_factor :]


  2. NaxAlpha revised this gist Jan 12, 2023. 1 changed file with 5 additions and 1 deletion.
    6 changes: 5 additions & 1 deletion hf_pile.py
    Original file line number Diff line number Diff line change
    @@ -11,10 +11,13 @@ def __init__(
    base_dataset: ThePile,
    tokenizer: PreTrainedTokenizerBase,
    max_length: int = 1024,
    repeat_factor: int = 1,
    ):
    assert repeat_factor >= 1 # but can be a float
    self.pile = base_dataset
    self.tokenizer = tokenizer
    self.max_length = max_length
    self.repeat_factor = repeat_factor

    def __iter__(self):
    ds = iter(self.pile)
    @@ -24,7 +27,7 @@ def __iter__(self):
    buffer += [self.tokenizer.eos_token_id] + tokens
    while len(buffer) > self.max_length:
    yield buffer[: self.max_length]
    buffer = buffer[self.max_length :]
    buffer = buffer[self.max_length // self.repeat_factor :]


    if __name__ == "__main__":
    @@ -36,6 +39,7 @@ def __iter__(self):
    ThePile("train"),
    GPT2Tokenizer.from_pretrained("gpt2"),
    max_length=1024,
    repeat_factor=2,
    )
    dataloader = DataLoader(
    dataset,
  3. NaxAlpha revised this gist Jan 12, 2023. 1 changed file with 46 additions and 0 deletions.
    46 changes: 46 additions & 0 deletions hf_pile.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,46 @@
    from torch.utils.data import IterableDataset

    from transformers import PreTrainedTokenizerBase

    from pile import ThePile


    class ThePileTokenized(IterableDataset):
    def __init__(
    self,
    base_dataset: ThePile,
    tokenizer: PreTrainedTokenizerBase,
    max_length: int = 1024,
    ):
    self.pile = base_dataset
    self.tokenizer = tokenizer
    self.max_length = max_length

    def __iter__(self):
    ds = iter(self.pile)
    buffer = []
    while True:
    tokens = self.tokenizer.encode(next(ds)["text"])
    buffer += [self.tokenizer.eos_token_id] + tokens
    while len(buffer) > self.max_length:
    yield buffer[: self.max_length]
    buffer = buffer[self.max_length :]


    if __name__ == "__main__":
    from tqdm import tqdm
    from torch.utils.data import DataLoader
    from transformers import GPT2Tokenizer

    dataset = ThePileTokenized(
    ThePile("train"),
    GPT2Tokenizer.from_pretrained("gpt2"),
    max_length=1024,
    )
    dataloader = DataLoader(
    dataset,
    batch_size=64,
    )
    for batch in tqdm(dataloader, smoothing=0.01):
    pass
    # ~6 iters/s for 1 worker
  4. NaxAlpha revised this gist Jan 12, 2023. 1 changed file with 48 additions and 74 deletions.
    122 changes: 48 additions & 74 deletions pile.py
    Original file line number Diff line number Diff line change
    @@ -2,53 +2,51 @@
    import time
    import random
    from typing import Literal
    from threading import Lock
    from concurrent.futures import ThreadPoolExecutor

    import requests
    import zstandard as zstd
    from torch.utils.data import IterableDataset
    from torch.utils.data import IterableDataset, get_worker_info


    Subset = Literal["train", "val", "test"]
    URLs = {
    "val": [
    "https://mystic.the-eye.eu/public/AI/pile/val.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/val.jsonl.zst",
    ],
    "test": [
    "https://mystic.the-eye.eu/public/AI/pile/test.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/test.jsonl.zst",
    ],
    "train": [
    "https://mystic.the-eye.eu/public/AI/pile/train/00.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/01.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/02.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/03.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/04.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/05.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/06.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/07.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/08.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/09.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/10.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/11.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/12.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/13.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/14.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/15.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/16.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/17.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/18.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/19.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/20.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/21.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/22.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/23.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/24.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/25.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/26.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/27.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/28.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/29.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/00.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/01.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/02.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/03.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/04.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/05.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/06.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/07.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/08.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/09.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/10.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/11.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/12.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/13.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/14.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/15.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/16.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/17.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/18.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/19.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/20.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/21.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/22.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/23.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/24.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/25.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/26.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/27.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/28.jsonl.zst",
    "https://the-eye.eu/public/AI/pile/train/29.jsonl.zst",
    ],
    }

    @@ -82,52 +80,28 @@ def _line_streamer(reader, buffer_size=4096):
    class ThePile(IterableDataset):
    TEXT_BUFFER_SIZE = 4096

    def __init__(
    self,
    subset: Subset,
    min_buffer: int = 1_000,
    max_buffer: int = 10_000,
    num_threads: int = 1,
    random_seed: int = 42,
    ):
    def __init__(self, subset: Subset):
    self.subset = subset
    self.min_buffer = min_buffer
    self.max_buffer = max_buffer
    self.num_threads = num_threads
    self.random_seed = random_seed

    self.lock = Lock()
    self.buffer = []

    def downloader(self):
    for url in URLs[self.subset]:
    r = requests.get(url, stream=True)
    with zstd.ZstdDecompressor().stream_reader(r.raw) as reader:
    for line in _line_streamer(reader):
    data = json.loads(line)
    with self.lock:
    self.buffer.append(data)
    while len(self.buffer) > self.max_buffer:
    time.sleep(0.01)

    def __iter__(self):
    rnd = random.Random(self.random_seed)
    with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
    for _ in range(self.num_threads):
    executor.submit(self.downloader)
    while True:
    while len(self.buffer) < self.min_buffer:
    time.sleep(0.01)
    with self.lock:
    loc = rnd.randint(0, len(self.buffer) - 1)
    out = self.buffer.pop(loc)
    yield out
    urls = URLs[self.subset].copy()
    while True:
    wi = get_worker_info()
    seed = wi.id if wi is not None else None
    rnd = random.Random(seed)
    rnd.shuffle(urls)
    for url in urls:
    r = requests.get(url, stream=True)
    with zstd.ZstdDecompressor().stream_reader(r.raw) as reader:
    for line in _line_streamer(reader, self.TEXT_BUFFER_SIZE):
    data = json.loads(line)
    yield data


    if __name__ == "__main__":
    from tqdm import tqdm

    dataset = ThePile("train", num_threads=4)
    dataset = ThePile("train")
    for data in tqdm(dataset, smoothing=0.01):
    pass
    # Average: 1500-2000 samples/sec/worker
    # Average: ~2000 samples/sec/worker
  5. NaxAlpha created this gist Dec 23, 2022.
    133 changes: 133 additions & 0 deletions pile.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,133 @@
    import json
    import time
    import random
    from typing import Literal
    from threading import Lock
    from concurrent.futures import ThreadPoolExecutor

    import requests
    import zstandard as zstd
    from torch.utils.data import IterableDataset


    Subset = Literal["train", "val", "test"]
    URLs = {
    "val": [
    "https://mystic.the-eye.eu/public/AI/pile/val.jsonl.zst",
    ],
    "test": [
    "https://mystic.the-eye.eu/public/AI/pile/test.jsonl.zst",
    ],
    "train": [
    "https://mystic.the-eye.eu/public/AI/pile/train/00.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/01.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/02.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/03.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/04.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/05.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/06.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/07.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/08.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/09.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/10.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/11.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/12.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/13.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/14.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/15.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/16.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/17.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/18.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/19.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/20.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/21.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/22.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/23.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/24.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/25.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/26.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/27.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/28.jsonl.zst",
    "https://mystic.the-eye.eu/public/AI/pile/train/29.jsonl.zst",
    ],
    }


    def _read_line_from_stream(reader, initial_line="", buffer_size=4096):
    line = initial_line
    while True:
    c = reader.read(buffer_size)
    if not c:
    raise StopIteration
    line += c.decode("utf-8")
    if "\n" in line:
    break
    return line.split("\n", 1)


    def _line_streamer(reader, buffer_size=4096):
    rest = ""
    while True:
    try:
    line, rest = _read_line_from_stream(
    reader,
    rest,
    buffer_size,
    )
    yield line
    except StopIteration:
    break


    class ThePile(IterableDataset):
    TEXT_BUFFER_SIZE = 4096

    def __init__(
    self,
    subset: Subset,
    min_buffer: int = 1_000,
    max_buffer: int = 10_000,
    num_threads: int = 1,
    random_seed: int = 42,
    ):
    self.subset = subset
    self.min_buffer = min_buffer
    self.max_buffer = max_buffer
    self.num_threads = num_threads
    self.random_seed = random_seed

    self.lock = Lock()
    self.buffer = []

    def downloader(self):
    for url in URLs[self.subset]:
    r = requests.get(url, stream=True)
    with zstd.ZstdDecompressor().stream_reader(r.raw) as reader:
    for line in _line_streamer(reader):
    data = json.loads(line)
    with self.lock:
    self.buffer.append(data)
    while len(self.buffer) > self.max_buffer:
    time.sleep(0.01)

    def __iter__(self):
    rnd = random.Random(self.random_seed)
    with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
    for _ in range(self.num_threads):
    executor.submit(self.downloader)
    while True:
    while len(self.buffer) < self.min_buffer:
    time.sleep(0.01)
    with self.lock:
    loc = rnd.randint(0, len(self.buffer) - 1)
    out = self.buffer.pop(loc)
    yield out


    if __name__ == "__main__":
    from tqdm import tqdm

    dataset = ThePile("train", num_threads=4)
    for data in tqdm(dataset, smoothing=0.01):
    pass
    # Average: 1500-2000 samples/sec/worker