Skip to content

Instantly share code, notes, and snippets.

@YangWang92
Last active October 29, 2024 13:11
Show Gist options
  • Select an option

  • Save YangWang92/ec98a86c3a33c573b601cf4348d0a0e7 to your computer and use it in GitHub Desktop.

Select an option

Save YangWang92/ec98a86c3a33c573b601cf4348d0a0e7 to your computer and use it in GitHub Desktop.

Revisions

  1. YangWang92 revised this gist Oct 29, 2024. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion invhessian.py
    Original file line number Diff line number Diff line change
    @@ -1,6 +1,6 @@
    import torch
    from argparse import ArgumentParser
    from .utils.hessian import load_hessian
    from vptq.utils.hessian import load_hessian
    import os

    if __name__ == "__main__":
  2. YangWang92 revised this gist Oct 29, 2024. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion invhessian.py
    Original file line number Diff line number Diff line change
    @@ -1,6 +1,6 @@
    import torch
    from argparse import ArgumentParser
    from LUTobq.utils.hessian import load_hessian
    from .utils.hessian import load_hessian
    import os

    if __name__ == "__main__":
  3. YangWang92 created this gist Oct 29, 2024.
    51 changes: 51 additions & 0 deletions invhessian.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,51 @@
    import torch
    from argparse import ArgumentParser
    from LUTobq.utils.hessian import load_hessian
    import os

    if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--load_hessian_dir', type=str, default=None,
    help='Directory containing Hessian .pt files')
    parser.add_argument('--store_inv_hessian_dir', type=str, default=None,
    help='Directory to save inverted Hessian .pt files')

    args = parser.parse_args()

    # create folder
    os.makedirs(args.store_inv_hessian_dir, exist_ok=True)

    percdamp = 0.01
    hessian_files = [f for f in os.listdir(
    args.load_hessian_dir) if f.endswith('.pt')]

    for hessian_file in hessian_files:
    hessian_path = os.path.join(args.load_hessian_dir, hessian_file)
    hessian, mu = load_hessian(hessian_path)
    dev = 'cuda'
    hessian = hessian.to(dev)

    zero_idx = torch.diag(hessian) == 0
    hessian[zero_idx, zero_idx] = 1

    # get permutation
    perm = torch.argsort(torch.diag(hessian), descending=True).to(dev)
    hessian = hessian[perm][:, perm]

    # add damping
    damp = percdamp * torch.mean(torch.diag(hessian))
    diag = torch.arange(hessian.shape[0], device=dev)
    hessian[diag, diag] += damp

    # inverse Hessian
    hessian = torch.linalg.cholesky(hessian)
    hessian = torch.cholesky_inverse(hessian)
    hessian = torch.linalg.cholesky(hessian, upper=True)
    inv_hessian = hessian

    # Saving the inverted Hessian to the specified directory with the same file name
    save_path = os.path.join(args.store_inv_hessian_dir, hessian_file)
    torch.save({'invH': inv_hessian.to('cpu'),
    'perm': perm.to('cpu'),
    'zero_idx': zero_idx.to('cpu')}, save_path)
    print(f'Saved inverted Hessian to {save_path}')