Last active
September 2, 2024 01:39
-
-
Save NaxAlpha/c6a7c65f40c05af0907b25fd742a8df0 to your computer and use it in GitHub Desktop.
Revisions
-
NaxAlpha revised this gist
Jan 12, 2023 . 1 changed file with 2 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,3 +1,4 @@ import torch from torch.utils.data import IterableDataset from transformers import PreTrainedTokenizerBase @@ -26,7 +27,7 @@ def __iter__(self): tokens = self.tokenizer.encode(next(ds)["text"]) buffer += [self.tokenizer.eos_token_id] + tokens while len(buffer) > self.max_length: yield torch.tensor(buffer[: self.max_length]) buffer = buffer[self.max_length // self.repeat_factor :] -
NaxAlpha revised this gist
Jan 12, 2023 . 1 changed file with 5 additions and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -11,10 +11,13 @@ def __init__( base_dataset: ThePile, tokenizer: PreTrainedTokenizerBase, max_length: int = 1024, repeat_factor: int = 1, ): assert repeat_factor >= 1 # but can be a float self.pile = base_dataset self.tokenizer = tokenizer self.max_length = max_length self.repeat_factor = repeat_factor def __iter__(self): ds = iter(self.pile) @@ -24,7 +27,7 @@ def __iter__(self): buffer += [self.tokenizer.eos_token_id] + tokens while len(buffer) > self.max_length: yield buffer[: self.max_length] buffer = buffer[self.max_length // self.repeat_factor :] if __name__ == "__main__": @@ -36,6 +39,7 @@ def __iter__(self): ThePile("train"), GPT2Tokenizer.from_pretrained("gpt2"), max_length=1024, repeat_factor=2, ) dataloader = DataLoader( dataset, -
NaxAlpha revised this gist
Jan 12, 2023 . 1 changed file with 46 additions and 0 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,46 @@ from torch.utils.data import IterableDataset from transformers import PreTrainedTokenizerBase from pile import ThePile class ThePileTokenized(IterableDataset): def __init__( self, base_dataset: ThePile, tokenizer: PreTrainedTokenizerBase, max_length: int = 1024, ): self.pile = base_dataset self.tokenizer = tokenizer self.max_length = max_length def __iter__(self): ds = iter(self.pile) buffer = [] while True: tokens = self.tokenizer.encode(next(ds)["text"]) buffer += [self.tokenizer.eos_token_id] + tokens while len(buffer) > self.max_length: yield buffer[: self.max_length] buffer = buffer[self.max_length :] if __name__ == "__main__": from tqdm import tqdm from torch.utils.data import DataLoader from transformers import GPT2Tokenizer dataset = ThePileTokenized( ThePile("train"), GPT2Tokenizer.from_pretrained("gpt2"), max_length=1024, ) dataloader = DataLoader( dataset, batch_size=64, ) for batch in tqdm(dataloader, smoothing=0.01): pass # ~6 iters/s for 1 worker -
NaxAlpha revised this gist
Jan 12, 2023 . 1 changed file with 48 additions and 74 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -2,53 +2,51 @@ import time import random from typing import Literal import requests import zstandard as zstd from torch.utils.data import IterableDataset, get_worker_info Subset = Literal["train", "val", "test"] URLs = { "val": [ "https://the-eye.eu/public/AI/pile/val.jsonl.zst", ], "test": [ "https://the-eye.eu/public/AI/pile/test.jsonl.zst", ], "train": [ "https://the-eye.eu/public/AI/pile/train/00.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/01.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/02.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/03.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/04.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/05.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/06.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/07.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/08.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/09.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/10.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/11.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/12.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/13.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/14.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/15.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/16.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/17.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/18.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/19.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/20.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/21.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/22.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/23.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/24.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/25.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/26.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/27.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/28.jsonl.zst", "https://the-eye.eu/public/AI/pile/train/29.jsonl.zst", ], } @@ -82,52 +80,28 @@ def _line_streamer(reader, buffer_size=4096): class ThePile(IterableDataset): TEXT_BUFFER_SIZE = 4096 def __init__(self, subset: Subset): self.subset = subset def __iter__(self): urls = URLs[self.subset].copy() while True: wi = get_worker_info() seed = wi.id if wi is not None else None rnd = random.Random(seed) rnd.shuffle(urls) for url in urls: r = requests.get(url, stream=True) with zstd.ZstdDecompressor().stream_reader(r.raw) as reader: for line in _line_streamer(reader, self.TEXT_BUFFER_SIZE): data = json.loads(line) yield data if __name__ == "__main__": from tqdm import tqdm dataset = ThePile("train") for data in tqdm(dataset, smoothing=0.01): pass # Average: ~2000 samples/sec/worker -
NaxAlpha created this gist
Dec 23, 2022 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,133 @@ import json import time import random from typing import Literal from threading import Lock from concurrent.futures import ThreadPoolExecutor import requests import zstandard as zstd from torch.utils.data import IterableDataset Subset = Literal["train", "val", "test"] URLs = { "val": [ "https://mystic.the-eye.eu/public/AI/pile/val.jsonl.zst", ], "test": [ "https://mystic.the-eye.eu/public/AI/pile/test.jsonl.zst", ], "train": [ "https://mystic.the-eye.eu/public/AI/pile/train/00.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/01.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/02.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/03.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/04.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/05.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/06.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/07.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/08.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/09.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/10.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/11.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/12.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/13.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/14.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/15.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/16.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/17.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/18.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/19.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/20.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/21.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/22.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/23.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/24.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/25.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/26.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/27.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/28.jsonl.zst", "https://mystic.the-eye.eu/public/AI/pile/train/29.jsonl.zst", ], } def _read_line_from_stream(reader, initial_line="", buffer_size=4096): line = initial_line while True: c = reader.read(buffer_size) if not c: raise StopIteration line += c.decode("utf-8") if "\n" in line: break return line.split("\n", 1) def _line_streamer(reader, buffer_size=4096): rest = "" while True: try: line, rest = _read_line_from_stream( reader, rest, buffer_size, ) yield line except StopIteration: break class ThePile(IterableDataset): TEXT_BUFFER_SIZE = 4096 def __init__( self, subset: Subset, min_buffer: int = 1_000, max_buffer: int = 10_000, num_threads: int = 1, random_seed: int = 42, ): self.subset = subset self.min_buffer = min_buffer self.max_buffer = max_buffer self.num_threads = num_threads self.random_seed = random_seed self.lock = Lock() self.buffer = [] def downloader(self): for url in URLs[self.subset]: r = requests.get(url, stream=True) with zstd.ZstdDecompressor().stream_reader(r.raw) as reader: for line in _line_streamer(reader): data = json.loads(line) with self.lock: self.buffer.append(data) while len(self.buffer) > self.max_buffer: time.sleep(0.01) def __iter__(self): rnd = random.Random(self.random_seed) with ThreadPoolExecutor(max_workers=self.num_threads) as executor: for _ in range(self.num_threads): executor.submit(self.downloader) while True: while len(self.buffer) < self.min_buffer: time.sleep(0.01) with self.lock: loc = rnd.randint(0, len(self.buffer) - 1) out = self.buffer.pop(loc) yield out if __name__ == "__main__": from tqdm import tqdm dataset = ThePile("train", num_threads=4) for data in tqdm(dataset, smoothing=0.01): pass # Average: 1500-2000 samples/sec/worker