Last active
          May 7, 2024 04:28 
        
      - 
      
- 
        Save ahoho/ba41c42984faf64bf4302b2b1cd7e0ce to your computer and use it in GitHub Desktop. 
    Create a huggingface pipeline with a lora-trained alpaca
  
        
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | from typing import Optional, Any | |
| import torch | |
| from transformers.utils import is_accelerate_available, is_bitsandbytes_available | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| GenerationConfig, | |
| pipeline, | |
| ) | |
| from peft import PeftModel | |
| ALPACA_TEMPLATE = ( | |
| "Below is an instruction that describes a task, paired with an input that provides " | |
| "further context. Write a response that appropriately completes the request.\n\n" | |
| "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" | |
| ) | |
| def load_adapted_hf_generation_pipeline( | |
| base_model_name, | |
| lora_model_name, | |
| temperature: float = 0, | |
| top_p: float = 1., | |
| max_tokens: int = 50, | |
| batch_size: int = 16, | |
| device: str = "cpu", | |
| load_in_8bit: bool = True, | |
| generation_kwargs: Optional[dict] = None, | |
| ): | |
| """ | |
| Load a huggingface model & adapt with PEFT. | |
| Borrowed from https://github.com/tloen/alpaca-lora/blob/main/generate.py | |
| """ | |
| if device == "cuda": | |
| if not is_accelerate_available(): | |
| raise ValueError("Install `accelerate`") | |
| if load_in_8bit and not is_bitsandbytes_available(): | |
| raise ValueError("Install `bitsandbytes`") | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name) | |
| task = "text-generation" | |
| if device == "cuda": | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| load_in_8bit=load_in_8bit, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| ) | |
| model = PeftModel.from_pretrained( | |
| model, | |
| lora_model_name, | |
| torch_dtype=torch.float16, | |
| ) | |
| elif device == "mps": | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, | |
| device_map={"": device}, | |
| torch_dtype=torch.float16, | |
| ) | |
| model = PeftModel.from_pretrained( | |
| model, | |
| lora_model_name, | |
| device_map={"": device}, | |
| torch_dtype=torch.float16, | |
| ) | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| base_model_name, device_map={"": device}, low_cpu_mem_usage=True | |
| ) | |
| model = PeftModel.from_pretrained( | |
| model, | |
| lora_model_name, | |
| device_map={"": device}, | |
| ) | |
| # unwind broken decapoda-research config | |
| model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk | |
| model.config.bos_token_id = 1 | |
| model.config.eos_token_id = 2 | |
| if not load_in_8bit: | |
| model.half() # seems to fix bugs for some users. | |
| model.eval() | |
| generation_kwargs = generation_kwargs if generation_kwargs is not None else {} | |
| config = GenerationConfig( | |
| do_sample=True, | |
| temperature=temperature, | |
| max_new_tokens=max_tokens, | |
| top_p=top_p, | |
| **generation_kwargs, | |
| ) | |
| pipe = pipeline( | |
| task, | |
| model=model, | |
| tokenizer=tokenizer, | |
| batch_size=16, # TODO: make a parameter | |
| generation_config=config, | |
| framework="pt", | |
| ) | |
| return pipe | |
| if __name__ == "__main__": | |
| pipe = load_adapted_hf_generation_pipeline( | |
| base_model_name="decapoda-research/llama-7b-hf", | |
| lora_model_name="tloen/alpaca-lora-7b", | |
| ) | |
| prompt = ALPACA_TEMPLATE.format( | |
| instruction="Paraphrase the sentence.", | |
| input="The quick brown fox jumped over the lazy dog." | |
| ) | |
| print(pipe(prompt)) | 
Thank you so much! I'm gonna test it and let you know.
One more time thank you so much! It works well :)
Thank you so much! I
Heads up I believe you're missing a comma at the end of line 116
Nice catch, thanks!
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment
  
            
I believe that is a just warning that you can safely ignore. For the versions of transformers & PEFT I was using (4.28.1 and 0.3.0.dev0, respectively),
PeftModelForCausalLMhad not been added to thetext-generationpipelines list of supported models (but, as you can see, the underlyingLlamaForCausalLMupon which the Peft model is added is supported--i.e., the warning is spurious)It's possible I'm wrong here? But I did get this to work