Skip to content

Instantly share code, notes, and snippets.

@gau-nernst
Created December 12, 2024 12:44
Show Gist options
  • Select an option

  • Save gau-nernst/290e6e89a89ad3198fa9a11b69d734c4 to your computer and use it in GitHub Desktop.

Select an option

Save gau-nernst/290e6e89a89ad3198fa9a11b69d734c4 to your computer and use it in GitHub Desktop.

Revisions

  1. gau-nernst created this gist Dec 12, 2024.
    187 changes: 187 additions & 0 deletions offload.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,187 @@
    import torch
    from torch import Tensor, nn
    from tqdm import tqdm


    class PerLayerOffloadWithBackwardGradient:
    "This version also offloads gradients. To ensure proper synchronization, it will take control over the optimizer."

    def __init__(
    self,
    model: nn.Module,
    optim_cls: type[torch.optim.Optimizer],
    optim_kwargs: dict | None = None,
    enable: bool = True,
    ):
    self.model = model
    self.enable = enable
    if not enable:
    return

    self.optim_cls = optim_cls
    self.optim_kwargs = optim_kwargs

    self.stream = torch.cuda.Stream()
    self.disable_forward_hook = False

    self.key2flat_gpu_buffer = dict()
    self.key2flat_cpu_params = dict()

    self.param2cpu_view = dict()
    self.param2optim = dict()
    self.param_queue = [] # we will run optimizer in this order

    manual_params = set()

    def traverse(module: nn.Module, key: tuple[str, ...] = ()):
    if (
    isinstance(module, (nn.ModuleList, nn.Sequential))
    and len(module) > 1
    and all(type(layer) == type(module[0]) for layer in module)
    ):
    self._register_sequential(module, key)

    else:
    for p in module.parameters(recurse=False):
    manual_params.add(p)
    for name, child in module.named_children():
    traverse(child, key + (name,))

    traverse(model)
    self.manual_tensors = list(manual_params) + list(self.model.buffers())
    self.manual_optim = optim_cls(manual_params, **(optim_kwargs or dict()))

    def cuda(self):
    if not self.enable:
    self.model.cuda()

    else:
    for p in self.manual_tensors:
    p.data = p.data.cuda(non_blocking=True)

    return self

    def cpu(self):
    if not self.enable:
    self.model.cpu()

    else:
    for p in self.manual_tensors:
    p.data = self.param2cpu.get(p, p.data.cpu())

    return self

    @staticmethod
    def _get_flat_param(module: nn.Module):
    return torch.cat([x.detach().view(-1) for x in module.parameters()], dim=0)

    @staticmethod
    @torch.compiler.disable()
    def _view_into_flat_param(module: nn.Module, flat_param: Tensor):
    offset = 0
    for p in module.parameters():
    p.data = flat_param[offset : offset + p.numel()].view(p.shape)
    offset += p.numel()

    def _register_sequential(self, module_list: nn.Sequential | nn.ModuleList, key: tuple[str, ...]):
    self.key2flat_gpu_buffer[key] = [
    self._get_flat_param(module_list[0]).cuda(),
    self._get_flat_param(module_list[-1]).cuda(),
    ]
    self.key2flat_cpu_params[key] = []

    def create_pre_forward_hook(idx: int):
    def pre_forward_hook(module: nn.Module, inputs: tuple):
    # when there is activation checkpointing, .forward() is re-run in backward pass.
    # we use this flag to disable forward hooks in this case. set it to True before
    # calling loss.backward() and set it back to False after that.
    if self.disable_forward_hook:
    return

    compute_buffer, transfer_buffer = self.key2flat_gpu_buffer[key]
    self._view_into_flat_param(module, compute_buffer)

    current_stream = torch.cuda.current_stream()
    current_stream.wait_stream(self.stream)
    self.stream.wait_stream(current_stream)

    with torch.cuda.stream(self.stream):
    next_layer_cpu = self.key2flat_cpu_params[key][(idx + 1) % len(module_list)]
    transfer_buffer.copy_(next_layer_cpu, non_blocking=True)

    self.key2flat_gpu_buffer[key] = [transfer_buffer, compute_buffer]

    return pre_forward_hook

    def create_pre_backward_hook(idx: int):
    def pre_backward_hook(module, grad_output):
    transfer_buffer, compute_buffer = self.key2flat_gpu_buffer[key]
    self._view_into_flat_param(module, compute_buffer)

    current_stream = torch.cuda.current_stream()
    current_stream.wait_stream(self.stream)
    self.stream.wait_stream(current_stream)

    with torch.cuda.stream(self.stream):
    next_layer_cpu = self.key2flat_cpu_params[key][(idx - 1) % len(module_list)]
    transfer_buffer.copy_(next_layer_cpu, non_blocking=True)

    self.key2flat_gpu_buffer[key] = [compute_buffer, transfer_buffer]

    return pre_backward_hook

    # NOTE: apparently when nn.Module.register_full_backward_hook() fires, param.grad
    # is not guaranteed to be computed https://github.com/pytorch/pytorch/issues/86051
    # hence, we have to use Tensor.register_post_accumulate_grad_hook() to offload grads.
    def post_grad_hook(p: Tensor):
    # make sure p.grad finished being computed
    self.stream.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(self.stream):
    self.param2cpu_view[p].grad.copy_(p.grad, non_blocking=True)

    # we will execute optim step in this order
    self.param_queue.append((p, self.stream.record_event()))

    # free grad memory
    p.grad.record_stream(self.stream)
    p.grad = None

    desc = f"Copying params to pinned memory {key}"
    for i, curr_layer in enumerate(tqdm(module_list, desc=desc, dynamic_ncols=True)):
    flat_param = self._get_flat_param(curr_layer).cpu().pin_memory()
    self.key2flat_cpu_params[key].append(flat_param)

    offset = 0
    for p in curr_layer.parameters():
    cpu_param = flat_param[offset : offset + p.numel()].view(p.shape)
    offset += p.numel()
    self.param2cpu_view[p] = cpu_param

    # pre-allocate pinned memory for gradients, and install hooks to offload grads
    if p.requires_grad:
    cpu_param.grad = torch.empty(p.shape, dtype=p.dtype, device="cpu", pin_memory=True)
    self.param2optim[p] = self.optim_cls([cpu_param], **(self.optim_kwargs or dict()))
    p.register_post_accumulate_grad_hook(post_grad_hook)

    curr_layer.register_forward_pre_hook(create_pre_forward_hook(i))
    curr_layer.register_full_backward_pre_hook(create_pre_backward_hook(i))

    @torch.no_grad()
    def optim_step(self):
    after_bwd_event = torch.cuda.current_stream().record_event()
    self.manual_optim.step()

    for p, sync_event in self.param_queue:
    sync_event.synchronize() # wait for grad offload to finish
    self.param2optim[p].step()

    # manually prefetch 1st layer, since it won't be prefetched in pre-forward hook
    # make sure backward finishes
    self.stream.wait_event(after_bwd_event)
    with torch.cuda.stream(self.stream):
    for key in self.key2flat_cpu_params.keys():
    self.key2flat_gpu_buffer[key][0].copy_(self.key2flat_cpu_params[key][0], non_blocking=True)

    def optim_zero_grad(self):
    self.manual_optim.zero_grad()
    self.param_queue = []
    137 changes: 137 additions & 0 deletions train_llm.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,137 @@
    import os
    import time

    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    import datasets
    import torch
    import torch.nn.functional as F
    import wandb
    from torch.utils.data import DataLoader, IterableDataset
    from tqdm import tqdm
    from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

    from offload import PerLayerOffloadWithBackwardGradient


    class TokenDataset(IterableDataset):
    def __init__(self, dataset_id: str, model_id: str, seq_len: int):
    self.ds = datasets.load_dataset(dataset_id, split="train", streaming=True)
    self.model_id = model_id
    self.seq_len = seq_len

    def __iter__(self):
    tokenizer = AutoTokenizer.from_pretrained(self.model_id)
    tokens = []
    for sample in self.ds:
    tokens.extend(tokenizer(sample["text"])["input_ids"])
    while len(tokens) >= self.seq_len + 1:
    yield torch.tensor(tokens[: self.seq_len + 1])
    tokens = tokens[self.seq_len + 1 :]


    def get_loss(model, tokens):
    logits = model(tokens[:, :-1])[0]
    return cross_entropy(logits, tokens[:, 1:])


    # wrap logits.float() and F.cross_entropy() in a compiled function to reduce memory
    @torch.compile
    def cross_entropy(logits, labels):
    return F.cross_entropy(logits.float().view(-1, logits.shape[-1]), labels.flatten())


    if __name__ == "__main__":
    model_id = "meta-llama/Llama-3.2-1B"
    dtype = torch.bfloat16
    bsize = 4
    seq_len = 2048
    num_steps = 200
    use_compile = True
    offload = False
    profile = False

    torch.manual_seed(2024)
    cfg = AutoConfig.from_pretrained(
    model_id,
    max_position_embeddings=seq_len,
    use_cache=False,
    )
    model = AutoModelForCausalLM.from_config(cfg, torch_dtype=dtype)
    model.gradient_checkpointing_enable()

    # current there is a bug with model.compile() + module hooks
    # hence, we will manually compile .forward() instead
    # https://github.com/pytorch/pytorch/issues/142358
    if use_compile:
    for layer in model.model.layers:
    layer.forward = torch.compile(layer.forward)

    optim_cls = torch.optim.AdamW
    optim_kwargs = dict(lr=3e-4, weight_decay=0.0, fused=True)

    if offload:
    offloader = PerLayerOffloadWithBackwardGradient(model, optim_cls, optim_kwargs)
    offloader.cuda()
    else:
    model.cuda()
    optim = optim_cls(model.parameters(), **optim_kwargs)

    ds = TokenDataset("HuggingFaceFW/fineweb-edu", model_id, seq_len)
    dloader = DataLoader(ds, bsize, num_workers=1, pin_memory=True)
    dloader_iter = iter(dloader)

    if profile:
    torch._inductor.config.triton.unique_kernel_names = True
    prof = torch.profiler.profile()

    log_interval = 10
    pbar = tqdm(total=num_steps, dynamic_ncols=True)
    model.train()
    step = 0
    wandb.init(project="CPU offload", dir="/tmp", mode="disabled" if profile else None)

    torch.cuda.reset_peak_memory_stats()
    time0 = time.time()
    while step < num_steps:
    tokens = next(dloader_iter).cuda()
    # torch.compile(get_loss)(model, tokens) is faster for baseline,
    # but does not work for CPU offload (due to module hooks)
    loss = get_loss(model, tokens)
    if offload:
    offloader.disable_forward_hook = True
    loss.backward()
    if offload:
    offloader.disable_forward_hook = False

    if step % log_interval == 0:
    wandb.log(dict(loss=loss.item()), step=step)

    if offload:
    offloader.optim_step()
    offloader.optim_zero_grad()
    else:
    optim.step()
    optim.zero_grad()
    step += 1
    pbar.update()

    if profile:
    if step == 1:
    prof.start()
    elif step == 3:
    break

    if step % log_interval == 0:
    time1 = time.time()
    log_dict = dict(
    max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9,
    tokens_per_second=bsize * seq_len * log_interval / (time1 - time0),
    )
    time0 = time1
    wandb.log(log_dict, step=step)

    wandb.finish()
    if profile:
    prof.stop()
    prof.export_chrome_trace("trace.json.gz")