Last active
October 29, 2024 13:11
-
-
Save YangWang92/ec98a86c3a33c573b601cf4348d0a0e7 to your computer and use it in GitHub Desktop.
Revisions
-
YangWang92 revised this gist
Oct 29, 2024 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,6 +1,6 @@ import torch from argparse import ArgumentParser from vptq.utils.hessian import load_hessian import os if __name__ == "__main__": -
YangWang92 revised this gist
Oct 29, 2024 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,6 +1,6 @@ import torch from argparse import ArgumentParser from .utils.hessian import load_hessian import os if __name__ == "__main__": -
YangWang92 created this gist
Oct 29, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,51 @@ import torch from argparse import ArgumentParser from LUTobq.utils.hessian import load_hessian import os if __name__ == "__main__": parser = ArgumentParser() parser.add_argument('--load_hessian_dir', type=str, default=None, help='Directory containing Hessian .pt files') parser.add_argument('--store_inv_hessian_dir', type=str, default=None, help='Directory to save inverted Hessian .pt files') args = parser.parse_args() # create folder os.makedirs(args.store_inv_hessian_dir, exist_ok=True) percdamp = 0.01 hessian_files = [f for f in os.listdir( args.load_hessian_dir) if f.endswith('.pt')] for hessian_file in hessian_files: hessian_path = os.path.join(args.load_hessian_dir, hessian_file) hessian, mu = load_hessian(hessian_path) dev = 'cuda' hessian = hessian.to(dev) zero_idx = torch.diag(hessian) == 0 hessian[zero_idx, zero_idx] = 1 # get permutation perm = torch.argsort(torch.diag(hessian), descending=True).to(dev) hessian = hessian[perm][:, perm] # add damping damp = percdamp * torch.mean(torch.diag(hessian)) diag = torch.arange(hessian.shape[0], device=dev) hessian[diag, diag] += damp # inverse Hessian hessian = torch.linalg.cholesky(hessian) hessian = torch.cholesky_inverse(hessian) hessian = torch.linalg.cholesky(hessian, upper=True) inv_hessian = hessian # Saving the inverted Hessian to the specified directory with the same file name save_path = os.path.join(args.store_inv_hessian_dir, hessian_file) torch.save({'invH': inv_hessian.to('cpu'), 'perm': perm.to('cpu'), 'zero_idx': zero_idx.to('cpu')}, save_path) print(f'Saved inverted Hessian to {save_path}')