Forked from ChrisHayduk/merge_qlora_with_quantized_model.py
Created
October 25, 2023 05:59
-
-
Save pavp-git/ee56c35f4c32cfec74feedfbb99d8cd0 to your computer and use it in GitHub Desktop.
Merging QLoRA weights with quantized model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| import peft | |
| import json | |
| import shutil | |
| from peft.utils import _get_submodules | |
| import os | |
| import bitsandbytes as bnb | |
| from bitsandbytes.functional import dequantize_4bit | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig, CodeLlamaTokenizer | |
| import gc | |
| def dequantize_model(model, tokenizer, to='./dequantized_model', dtype=torch.float16): | |
| """ | |
| 'model': the peftmodel you loaded with qlora. | |
| 'tokenizer': the model's corresponding hf's tokenizer. | |
| """ | |
| # Delete the model object if it exists | |
| if os.path.exists(to): | |
| shutil.rmtree(to) | |
| os.makedirs(to, exist_ok=True) | |
| cls = bnb.nn.Linear4bit | |
| base_model = model.model | |
| with torch.no_grad(): | |
| for name, module in base_model.named_modules(): | |
| if isinstance(module, cls): | |
| print(f"Dequantizing `{name}`...") | |
| weights = dequantize_4bit(module.weight.data, quant_state=module.weight.quant_state, quant_type="nf4") | |
| new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias) | |
| new_module.weight = torch.nn.Parameter(weights) | |
| new_module.bias = torch.nn.Parameter(module.bias) | |
| parent, target, target_name = _get_submodules(base_model, name) | |
| setattr(parent, target_name, new_module) | |
| # a hack, setting this to avoid hf's saving error because hf | |
| # itself does not support saving a model that is registered to be loaded in 4bit. | |
| base_model.is_loaded_in_4bit = False | |
| print("Saving dequantized model...") | |
| base_model.save_pretrained(to) | |
| tokenizer.save_pretrained(to) | |
| config_data = json.loads(open(os.path.join(to, 'config.json'), 'r').read()) | |
| config_data.pop("quantization_config", None) | |
| config_data.pop("pretraining_tp", None) | |
| with open(os.path.join(to, 'config.json'), 'w') as config: | |
| config.write(json.dumps(config_data, indent=2)) | |
| return base_model | |
| model_path = 'NousResearch/Llama-2-13b-hf' | |
| adapter_path = 'Example/Adapter-Path' | |
| quantization_config=BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| ) | |
| try: | |
| print(f"Starting to load the model {model_path} into memory") | |
| model = LlamaForCausalLM.from_pretrained( | |
| model_path, | |
| load_in_4bit=True, | |
| torch_dtype=torch.bfloat16, | |
| quantization_config=quantization_config, | |
| device_map={"": 0} | |
| ) | |
| print(model.model) | |
| tok = LlamaTokenizer.from_pretrained(model_path) | |
| model = dequantize_model(model, tok) | |
| model = PeftModel.from_pretrained(model = model, model_id = adapter_path) | |
| model = model.merge_and_unload() | |
| print(f"Successfully loaded the model {model_path} into memory") | |
| except Exception as e: | |
| print(f"An error occurred: {e}") | |
| # Delete the model object if it exists | |
| if 'model' in locals(): | |
| del model | |
| # Clear the GPU cache | |
| torch.cuda.empty_cache() | |
| # Run the garbage collection | |
| gc.collect() | |
| print("Model, GPU cache, and garbage have been cleared.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment