import torch from torch import Tensor, nn from tqdm import tqdm class PerLayerOffloadWithBackwardGradient: "This version also offloads gradients. To ensure proper synchronization, it will take control over the optimizer." def __init__( self, model: nn.Module, optim_cls: type[torch.optim.Optimizer], optim_kwargs: dict | None = None, enable: bool = True, ): self.model = model self.enable = enable if not enable: return self.optim_cls = optim_cls self.optim_kwargs = optim_kwargs self.stream = torch.cuda.Stream() self.disable_forward_hook = False self.key2flat_gpu_buffer = dict() self.key2flat_cpu_params = dict() self.param2cpu_view = dict() self.param2optim = dict() self.param_queue = [] # we will run optimizer in this order manual_params = set() def traverse(module: nn.Module, key: tuple[str, ...] = ()): if ( isinstance(module, (nn.ModuleList, nn.Sequential)) and len(module) > 1 and all(type(layer) == type(module[0]) for layer in module) ): self._register_sequential(module, key) else: for p in module.parameters(recurse=False): manual_params.add(p) for name, child in module.named_children(): traverse(child, key + (name,)) traverse(model) self.manual_tensors = list(manual_params) + list(self.model.buffers()) self.manual_optim = optim_cls(manual_params, **(optim_kwargs or dict())) def cuda(self): if not self.enable: self.model.cuda() else: for p in self.manual_tensors: p.data = p.data.cuda(non_blocking=True) return self def cpu(self): if not self.enable: self.model.cpu() else: for p in self.manual_tensors: p.data = self.param2cpu.get(p, p.data.cpu()) return self @staticmethod def _get_flat_param(module: nn.Module): return torch.cat([x.detach().view(-1) for x in module.parameters()], dim=0) @staticmethod @torch.compiler.disable() def _view_into_flat_param(module: nn.Module, flat_param: Tensor): offset = 0 for p in module.parameters(): p.data = flat_param[offset : offset + p.numel()].view(p.shape) offset += p.numel() def _register_sequential(self, module_list: nn.Sequential | nn.ModuleList, key: tuple[str, ...]): self.key2flat_gpu_buffer[key] = [ self._get_flat_param(module_list[0]).cuda(), self._get_flat_param(module_list[-1]).cuda(), ] self.key2flat_cpu_params[key] = [] def create_pre_forward_hook(idx: int): def pre_forward_hook(module: nn.Module, inputs: tuple): # when there is activation checkpointing, .forward() is re-run in backward pass. # we use this flag to disable forward hooks in this case. set it to True before # calling loss.backward() and set it back to False after that. if self.disable_forward_hook: return compute_buffer, transfer_buffer = self.key2flat_gpu_buffer[key] self._view_into_flat_param(module, compute_buffer) current_stream = torch.cuda.current_stream() current_stream.wait_stream(self.stream) self.stream.wait_stream(current_stream) with torch.cuda.stream(self.stream): next_layer_cpu = self.key2flat_cpu_params[key][(idx + 1) % len(module_list)] transfer_buffer.copy_(next_layer_cpu, non_blocking=True) self.key2flat_gpu_buffer[key] = [transfer_buffer, compute_buffer] return pre_forward_hook def create_pre_backward_hook(idx: int): def pre_backward_hook(module, grad_output): transfer_buffer, compute_buffer = self.key2flat_gpu_buffer[key] self._view_into_flat_param(module, compute_buffer) current_stream = torch.cuda.current_stream() current_stream.wait_stream(self.stream) self.stream.wait_stream(current_stream) with torch.cuda.stream(self.stream): next_layer_cpu = self.key2flat_cpu_params[key][(idx - 1) % len(module_list)] transfer_buffer.copy_(next_layer_cpu, non_blocking=True) self.key2flat_gpu_buffer[key] = [compute_buffer, transfer_buffer] return pre_backward_hook # NOTE: apparently when nn.Module.register_full_backward_hook() fires, param.grad # is not guaranteed to be computed https://github.com/pytorch/pytorch/issues/86051 # hence, we have to use Tensor.register_post_accumulate_grad_hook() to offload grads. def post_grad_hook(p: Tensor): # make sure p.grad finished being computed self.stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self.stream): self.param2cpu_view[p].grad.copy_(p.grad, non_blocking=True) # we will execute optim step in this order self.param_queue.append((p, self.stream.record_event())) # free grad memory p.grad.record_stream(self.stream) p.grad = None desc = f"Copying params to pinned memory {key}" for i, curr_layer in enumerate(tqdm(module_list, desc=desc, dynamic_ncols=True)): flat_param = self._get_flat_param(curr_layer).cpu().pin_memory() self.key2flat_cpu_params[key].append(flat_param) offset = 0 for p in curr_layer.parameters(): cpu_param = flat_param[offset : offset + p.numel()].view(p.shape) offset += p.numel() self.param2cpu_view[p] = cpu_param # pre-allocate pinned memory for gradients, and install hooks to offload grads if p.requires_grad: cpu_param.grad = torch.empty(p.shape, dtype=p.dtype, device="cpu", pin_memory=True) self.param2optim[p] = self.optim_cls([cpu_param], **(self.optim_kwargs or dict())) p.register_post_accumulate_grad_hook(post_grad_hook) curr_layer.register_forward_pre_hook(create_pre_forward_hook(i)) curr_layer.register_full_backward_pre_hook(create_pre_backward_hook(i)) @torch.no_grad() def optim_step(self): after_bwd_event = torch.cuda.current_stream().record_event() self.manual_optim.step() for p, sync_event in self.param_queue: sync_event.synchronize() # wait for grad offload to finish self.param2optim[p].step() # manually prefetch 1st layer, since it won't be prefetched in pre-forward hook # make sure backward finishes self.stream.wait_event(after_bwd_event) with torch.cuda.stream(self.stream): for key in self.key2flat_cpu_params.keys(): self.key2flat_gpu_buffer[key][0].copy_(self.key2flat_cpu_params[key][0], non_blocking=True) def optim_zero_grad(self): self.manual_optim.zero_grad() self.param_queue = []