Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save pavp-git/ee56c35f4c32cfec74feedfbb99d8cd0 to your computer and use it in GitHub Desktop.

Select an option

Save pavp-git/ee56c35f4c32cfec74feedfbb99d8cd0 to your computer and use it in GitHub Desktop.
Merging QLoRA weights with quantized model
import torch
import peft
import json
import shutil
from peft.utils import _get_submodules
import os
import bitsandbytes as bnb
from bitsandbytes.functional import dequantize_4bit
from peft import PeftModel
from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, BitsAndBytesConfig, CodeLlamaTokenizer
import gc
def dequantize_model(model, tokenizer, to='./dequantized_model', dtype=torch.float16):
"""
'model': the peftmodel you loaded with qlora.
'tokenizer': the model's corresponding hf's tokenizer.
"""
# Delete the model object if it exists
if os.path.exists(to):
shutil.rmtree(to)
os.makedirs(to, exist_ok=True)
cls = bnb.nn.Linear4bit
base_model = model.model
with torch.no_grad():
for name, module in base_model.named_modules():
if isinstance(module, cls):
print(f"Dequantizing `{name}`...")
weights = dequantize_4bit(module.weight.data, quant_state=module.weight.quant_state, quant_type="nf4")
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias)
new_module.weight = torch.nn.Parameter(weights)
new_module.bias = torch.nn.Parameter(module.bias)
parent, target, target_name = _get_submodules(base_model, name)
setattr(parent, target_name, new_module)
# a hack, setting this to avoid hf's saving error because hf
# itself does not support saving a model that is registered to be loaded in 4bit.
base_model.is_loaded_in_4bit = False
print("Saving dequantized model...")
base_model.save_pretrained(to)
tokenizer.save_pretrained(to)
config_data = json.loads(open(os.path.join(to, 'config.json'), 'r').read())
config_data.pop("quantization_config", None)
config_data.pop("pretraining_tp", None)
with open(os.path.join(to, 'config.json'), 'w') as config:
config.write(json.dumps(config_data, indent=2))
return base_model
model_path = 'NousResearch/Llama-2-13b-hf'
adapter_path = 'Example/Adapter-Path'
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
try:
print(f"Starting to load the model {model_path} into memory")
model = LlamaForCausalLM.from_pretrained(
model_path,
load_in_4bit=True,
torch_dtype=torch.bfloat16,
quantization_config=quantization_config,
device_map={"": 0}
)
print(model.model)
tok = LlamaTokenizer.from_pretrained(model_path)
model = dequantize_model(model, tok)
model = PeftModel.from_pretrained(model = model, model_id = adapter_path)
model = model.merge_and_unload()
print(f"Successfully loaded the model {model_path} into memory")
except Exception as e:
print(f"An error occurred: {e}")
# Delete the model object if it exists
if 'model' in locals():
del model
# Clear the GPU cache
torch.cuda.empty_cache()
# Run the garbage collection
gc.collect()
print("Model, GPU cache, and garbage have been cleared.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment