Skip to content

Instantly share code, notes, and snippets.

@YangWang92
Last active October 29, 2024 13:11
Show Gist options
  • Select an option

  • Save YangWang92/ec98a86c3a33c573b601cf4348d0a0e7 to your computer and use it in GitHub Desktop.

Select an option

Save YangWang92/ec98a86c3a33c573b601cf4348d0a0e7 to your computer and use it in GitHub Desktop.
generate inv hessian
import torch
from argparse import ArgumentParser
from vptq.utils.hessian import load_hessian
import os
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--load_hessian_dir', type=str, default=None,
help='Directory containing Hessian .pt files')
parser.add_argument('--store_inv_hessian_dir', type=str, default=None,
help='Directory to save inverted Hessian .pt files')
args = parser.parse_args()
# create folder
os.makedirs(args.store_inv_hessian_dir, exist_ok=True)
percdamp = 0.01
hessian_files = [f for f in os.listdir(
args.load_hessian_dir) if f.endswith('.pt')]
for hessian_file in hessian_files:
hessian_path = os.path.join(args.load_hessian_dir, hessian_file)
hessian, mu = load_hessian(hessian_path)
dev = 'cuda'
hessian = hessian.to(dev)
zero_idx = torch.diag(hessian) == 0
hessian[zero_idx, zero_idx] = 1
# get permutation
perm = torch.argsort(torch.diag(hessian), descending=True).to(dev)
hessian = hessian[perm][:, perm]
# add damping
damp = percdamp * torch.mean(torch.diag(hessian))
diag = torch.arange(hessian.shape[0], device=dev)
hessian[diag, diag] += damp
# inverse Hessian
hessian = torch.linalg.cholesky(hessian)
hessian = torch.cholesky_inverse(hessian)
hessian = torch.linalg.cholesky(hessian, upper=True)
inv_hessian = hessian
# Saving the inverted Hessian to the specified directory with the same file name
save_path = os.path.join(args.store_inv_hessian_dir, hessian_file)
torch.save({'invH': inv_hessian.to('cpu'),
'perm': perm.to('cpu'),
'zero_idx': zero_idx.to('cpu')}, save_path)
print(f'Saved inverted Hessian to {save_path}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment