Created
December 12, 2024 12:44
-
-
Save gau-nernst/290e6e89a89ad3198fa9a11b69d734c4 to your computer and use it in GitHub Desktop.
Revisions
-
gau-nernst created this gist
Dec 12, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,187 @@ import torch from torch import Tensor, nn from tqdm import tqdm class PerLayerOffloadWithBackwardGradient: "This version also offloads gradients. To ensure proper synchronization, it will take control over the optimizer." def __init__( self, model: nn.Module, optim_cls: type[torch.optim.Optimizer], optim_kwargs: dict | None = None, enable: bool = True, ): self.model = model self.enable = enable if not enable: return self.optim_cls = optim_cls self.optim_kwargs = optim_kwargs self.stream = torch.cuda.Stream() self.disable_forward_hook = False self.key2flat_gpu_buffer = dict() self.key2flat_cpu_params = dict() self.param2cpu_view = dict() self.param2optim = dict() self.param_queue = [] # we will run optimizer in this order manual_params = set() def traverse(module: nn.Module, key: tuple[str, ...] = ()): if ( isinstance(module, (nn.ModuleList, nn.Sequential)) and len(module) > 1 and all(type(layer) == type(module[0]) for layer in module) ): self._register_sequential(module, key) else: for p in module.parameters(recurse=False): manual_params.add(p) for name, child in module.named_children(): traverse(child, key + (name,)) traverse(model) self.manual_tensors = list(manual_params) + list(self.model.buffers()) self.manual_optim = optim_cls(manual_params, **(optim_kwargs or dict())) def cuda(self): if not self.enable: self.model.cuda() else: for p in self.manual_tensors: p.data = p.data.cuda(non_blocking=True) return self def cpu(self): if not self.enable: self.model.cpu() else: for p in self.manual_tensors: p.data = self.param2cpu.get(p, p.data.cpu()) return self @staticmethod def _get_flat_param(module: nn.Module): return torch.cat([x.detach().view(-1) for x in module.parameters()], dim=0) @staticmethod @torch.compiler.disable() def _view_into_flat_param(module: nn.Module, flat_param: Tensor): offset = 0 for p in module.parameters(): p.data = flat_param[offset : offset + p.numel()].view(p.shape) offset += p.numel() def _register_sequential(self, module_list: nn.Sequential | nn.ModuleList, key: tuple[str, ...]): self.key2flat_gpu_buffer[key] = [ self._get_flat_param(module_list[0]).cuda(), self._get_flat_param(module_list[-1]).cuda(), ] self.key2flat_cpu_params[key] = [] def create_pre_forward_hook(idx: int): def pre_forward_hook(module: nn.Module, inputs: tuple): # when there is activation checkpointing, .forward() is re-run in backward pass. # we use this flag to disable forward hooks in this case. set it to True before # calling loss.backward() and set it back to False after that. if self.disable_forward_hook: return compute_buffer, transfer_buffer = self.key2flat_gpu_buffer[key] self._view_into_flat_param(module, compute_buffer) current_stream = torch.cuda.current_stream() current_stream.wait_stream(self.stream) self.stream.wait_stream(current_stream) with torch.cuda.stream(self.stream): next_layer_cpu = self.key2flat_cpu_params[key][(idx + 1) % len(module_list)] transfer_buffer.copy_(next_layer_cpu, non_blocking=True) self.key2flat_gpu_buffer[key] = [transfer_buffer, compute_buffer] return pre_forward_hook def create_pre_backward_hook(idx: int): def pre_backward_hook(module, grad_output): transfer_buffer, compute_buffer = self.key2flat_gpu_buffer[key] self._view_into_flat_param(module, compute_buffer) current_stream = torch.cuda.current_stream() current_stream.wait_stream(self.stream) self.stream.wait_stream(current_stream) with torch.cuda.stream(self.stream): next_layer_cpu = self.key2flat_cpu_params[key][(idx - 1) % len(module_list)] transfer_buffer.copy_(next_layer_cpu, non_blocking=True) self.key2flat_gpu_buffer[key] = [compute_buffer, transfer_buffer] return pre_backward_hook # NOTE: apparently when nn.Module.register_full_backward_hook() fires, param.grad # is not guaranteed to be computed https://github.com/pytorch/pytorch/issues/86051 # hence, we have to use Tensor.register_post_accumulate_grad_hook() to offload grads. def post_grad_hook(p: Tensor): # make sure p.grad finished being computed self.stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.stream): self.param2cpu_view[p].grad.copy_(p.grad, non_blocking=True) # we will execute optim step in this order self.param_queue.append((p, self.stream.record_event())) # free grad memory p.grad.record_stream(self.stream) p.grad = None desc = f"Copying params to pinned memory {key}" for i, curr_layer in enumerate(tqdm(module_list, desc=desc, dynamic_ncols=True)): flat_param = self._get_flat_param(curr_layer).cpu().pin_memory() self.key2flat_cpu_params[key].append(flat_param) offset = 0 for p in curr_layer.parameters(): cpu_param = flat_param[offset : offset + p.numel()].view(p.shape) offset += p.numel() self.param2cpu_view[p] = cpu_param # pre-allocate pinned memory for gradients, and install hooks to offload grads if p.requires_grad: cpu_param.grad = torch.empty(p.shape, dtype=p.dtype, device="cpu", pin_memory=True) self.param2optim[p] = self.optim_cls([cpu_param], **(self.optim_kwargs or dict())) p.register_post_accumulate_grad_hook(post_grad_hook) curr_layer.register_forward_pre_hook(create_pre_forward_hook(i)) curr_layer.register_full_backward_pre_hook(create_pre_backward_hook(i)) @torch.no_grad() def optim_step(self): after_bwd_event = torch.cuda.current_stream().record_event() self.manual_optim.step() for p, sync_event in self.param_queue: sync_event.synchronize() # wait for grad offload to finish self.param2optim[p].step() # manually prefetch 1st layer, since it won't be prefetched in pre-forward hook # make sure backward finishes self.stream.wait_event(after_bwd_event) with torch.cuda.stream(self.stream): for key in self.key2flat_cpu_params.keys(): self.key2flat_gpu_buffer[key][0].copy_(self.key2flat_cpu_params[key][0], non_blocking=True) def optim_zero_grad(self): self.manual_optim.zero_grad() self.param_queue = [] This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,137 @@ import os import time os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import datasets import torch import torch.nn.functional as F import wandb from torch.utils.data import DataLoader, IterableDataset from tqdm import tqdm from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from offload import PerLayerOffloadWithBackwardGradient class TokenDataset(IterableDataset): def __init__(self, dataset_id: str, model_id: str, seq_len: int): self.ds = datasets.load_dataset(dataset_id, split="train", streaming=True) self.model_id = model_id self.seq_len = seq_len def __iter__(self): tokenizer = AutoTokenizer.from_pretrained(self.model_id) tokens = [] for sample in self.ds: tokens.extend(tokenizer(sample["text"])["input_ids"]) while len(tokens) >= self.seq_len + 1: yield torch.tensor(tokens[: self.seq_len + 1]) tokens = tokens[self.seq_len + 1 :] def get_loss(model, tokens): logits = model(tokens[:, :-1])[0] return cross_entropy(logits, tokens[:, 1:]) # wrap logits.float() and F.cross_entropy() in a compiled function to reduce memory @torch.compile def cross_entropy(logits, labels): return F.cross_entropy(logits.float().view(-1, logits.shape[-1]), labels.flatten()) if __name__ == "__main__": model_id = "meta-llama/Llama-3.2-1B" dtype = torch.bfloat16 bsize = 4 seq_len = 2048 num_steps = 200 use_compile = True offload = False profile = False torch.manual_seed(2024) cfg = AutoConfig.from_pretrained( model_id, max_position_embeddings=seq_len, use_cache=False, ) model = AutoModelForCausalLM.from_config(cfg, torch_dtype=dtype) model.gradient_checkpointing_enable() # current there is a bug with model.compile() + module hooks # hence, we will manually compile .forward() instead # https://github.com/pytorch/pytorch/issues/142358 if use_compile: for layer in model.model.layers: layer.forward = torch.compile(layer.forward) optim_cls = torch.optim.AdamW optim_kwargs = dict(lr=3e-4, weight_decay=0.0, fused=True) if offload: offloader = PerLayerOffloadWithBackwardGradient(model, optim_cls, optim_kwargs) offloader.cuda() else: model.cuda() optim = optim_cls(model.parameters(), **optim_kwargs) ds = TokenDataset("HuggingFaceFW/fineweb-edu", model_id, seq_len) dloader = DataLoader(ds, bsize, num_workers=1, pin_memory=True) dloader_iter = iter(dloader) if profile: torch._inductor.config.triton.unique_kernel_names = True prof = torch.profiler.profile() log_interval = 10 pbar = tqdm(total=num_steps, dynamic_ncols=True) model.train() step = 0 wandb.init(project="CPU offload", dir="/tmp", mode="disabled" if profile else None) torch.cuda.reset_peak_memory_stats() time0 = time.time() while step < num_steps: tokens = next(dloader_iter).cuda() # torch.compile(get_loss)(model, tokens) is faster for baseline, # but does not work for CPU offload (due to module hooks) loss = get_loss(model, tokens) if offload: offloader.disable_forward_hook = True loss.backward() if offload: offloader.disable_forward_hook = False if step % log_interval == 0: wandb.log(dict(loss=loss.item()), step=step) if offload: offloader.optim_step() offloader.optim_zero_grad() else: optim.step() optim.zero_grad() step += 1 pbar.update() if profile: if step == 1: prof.start() elif step == 3: break if step % log_interval == 0: time1 = time.time() log_dict = dict( max_memory_allocated=torch.cuda.max_memory_allocated() / 1e9, tokens_per_second=bsize * seq_len * log_interval / (time1 - time0), ) time0 = time1 wandb.log(log_dict, step=step) wandb.finish() if profile: prof.stop() prof.export_chrome_trace("trace.json.gz")