import os import time os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import datasets import torch import torch.nn.functional as F import wandb from torch.utils.data import DataLoader, IterableDataset from tqdm import tqdm from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from offload import PerLayerOffloadWithBackwardGradient class TokenDataset(IterableDataset): def __init__(self, dataset_id: str, model_id: str, seq_len: int): self.ds = datasets.load_dataset(dataset_id, split="train", streaming=True) self.model_id = model_id self.seq_len = seq_len def __iter__(self): tokenizer = AutoTokenizer.from_pretrained(self.model_id) tokens = [] for sample in self.ds: tokens.extend(tokenizer(sample["text"])["input_ids"]) while len(tokens) >= self.seq_len + 1: yield torch.tensor(tokens[: self.seq_len + 1]) tokens = tokens[self.seq_len + 1 :] def get_loss(model, tokens): logits = model(tokens[:, :-1])[0] return cross_entropy(logits, tokens[:, 1:]) # wrap logits.float() and F.cross_entropy() in a compiled function to reduce memory @torch.compile def cross_entropy(logits, labels): return F.cross_entropy(logits.float().view(-1, logits.shape[-1]), labels.flatten()) if __name__ == "__main__": model_id = "meta-llama/Llama-3.2-1B" dtype = torch.bfloat16 bsize = 4 seq_len = 2048 num_steps = 200 use_compile = True offload = False profile = False torch.manual_seed(2024) cfg = AutoConfig.from_pretrained( model_id, max_position_embeddings=seq_len, use_cache=False, ) model = AutoModelForCausalLM.from_config(cfg, torch_dtype=dtype) model.gradient_checkpointing_enable() # current there is a bug with model.compile() + module hooks # hence, we will manually compile .forward() instead # https://github.com/pytorch/pytorch/issues/142358 if use_compile: for layer in model.model.layers: layer.forward = torch.compile(layer.forward) optim_cls = torch.optim.AdamW optim_kwargs = dict(lr=3e-4, weight_decay=0.0, fused=True) if offload: offloader = PerLayerOffloadWithBackwardGradient(model, optim_cls, optim_kwargs) offloader.cuda() else: model.cuda() optim = optim_cls(model.parameters(), **optim_kwargs) ds = TokenDataset("HuggingFaceFW/fineweb-edu", model_id, seq_len) dloader = DataLoader(ds, bsize, num_workers=1, pin_memory=True) dloader_iter = iter(dloader) if profile: torch._inductor.config.triton.unique_kernel_names = True prof = torch.profiler.profile() log_interval = 10 pbar = tqdm(total=num_steps, dynamic_ncols=True) model.train() step = 0 wandb.init(project="CPU offload", dir="/tmp", mode="disabled" if profile else None) torch.cuda.reset_peak_memory_stats() time0 = time.time() while step < num_steps: tokens = next(dloader_iter).cuda() # torch.compile(get_loss)(model, tokens) is faster for baseline, # but does not work for CPU offload (due to module hooks) loss = get_loss(model, tokens) if offload: offloader.disable_forward_hook = True loss.backward() if offload: offloader.disable_forward_hook = False if step % log_interval == 0: wandb.log(dict(loss=loss.item()), step=step) if offload: offloader.optim_step() offloader.optim_zero_grad() else: optim.step() optim.zero_grad() step += 1 pbar.update() if profile: if step == 1: prof.start() elif step == 3: break if step % log_interval == 0: time1 = time.time() log_dict = dict( max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9, tokens_per_second=bsize * seq_len * log_interval / (time1 - time0), ) time0 = time1 wandb.log(log_dict, step=step) wandb.finish() if profile: prof.stop() prof.export_chrome_trace("trace.json.gz")