Created
October 16, 2024 21:51
-
-
Save cli99/8b583eb7180d83e152aa8fd98e4f6706 to your computer and use it in GitHub Desktop.
Revisions
-
cli99 created this gist
Oct 16, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,247 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer import torchao from torchao.quantization.autoquant import ( DEFAULT_AUTOQUANT_CLASS_LIST, DEFAULT_INT4_AUTOQUANT_CLASS_LIST, OTHER_AUTOQUANT_CLASS_LIST, ) from torchao.quantization.quant_api import ( int4_weight_only, int8_dynamic_activation_int8_weight, int8_weight_only, quantize_, ) model_id = "/mnt/workdisk/chengli/models/meta-llama/Llama-3.2-1B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_id) prompt = "The president of the United States" inputs = tokenizer.encode(prompt, return_tensors="pt").to("cuda") model = AutoModelForCausalLM.from_pretrained(model_id) model = model.to("cuda") model = torch.compile(model, mode="max-autotune") print("compiled", model) model = torchao.autoquant( model, qtensor_class_list=DEFAULT_AUTOQUANT_CLASS_LIST + OTHER_AUTOQUANT_CLASS_LIST ) print("autoquantized", model) generate_ids = model.generate(inputs, max_length=30) model.finalize_autoquant() response = tokenizer.batch_decode( generate_ids, kip_special_tokens=True, clean_up_tokenization_spaces=False )[0] print(response) # --------------------------------------------------------------------------- # OptimizedModule( # (_orig_mod): LlamaForCausalLM( # (model): LlamaModel( # (embed_tokens): Embedding(128256, 2048) # (layers): ModuleList( # (0-15): 16 x LlamaDecoderLayer( # (self_attn): LlamaSdpaAttention( # (q_proj): Linear(in_features=2048, out_features=2048, bias=False) # (k_proj): Linear(in_features=2048, out_features=512, bias=False) # (v_proj): Linear(in_features=2048, out_features=512, bias=False) # (o_proj): Linear(in_features=2048, out_features=2048, bias=False) # (rotary_emb): LlamaRotaryEmbedding() # ) # (mlp): LlamaMLP( # (gate_proj): Linear(in_features=2048, out_features=8192, bias=False) # (up_proj): Linear(in_features=2048, out_features=8192, bias=False) # (down_proj): Linear(in_features=8192, out_features=2048, bias=False) # (act_fn): SiLU() # ) # (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05) # (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05) # ) # ) # (norm): LlamaRMSNorm((2048,), eps=1e-05) # (rotary_emb): LlamaRotaryEmbedding() # ) # (lm_head): Linear(in_features=2048, out_features=128256, bias=False) # ) # ) # The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results. # Setting `pad_token_id` to `eos_token_id`:None for open-end generation. # The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results. # Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32) # /mnt/workdisk/chengli/miniconda3/envs/vllm/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:663: UserWarning: Graph break due to unsupported builtin None.TensorBase._make_wrapper_subclass. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph. # torch._dynamo.utils.warn_once(msg) # /mnt/workdisk/chengli/miniconda3/envs/vllm/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:663: UserWarning: Graph break due to unsupported builtin None.TensorBase._make_wrapper_subclass. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph. # torch._dynamo.utils.warn_once(msg) # /mnt/workdisk/chengli/miniconda3/envs/vllm/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:663: UserWarning: Graph break due to unsupported builtin None.TensorBase._make_wrapper_subclass. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph. # torch._dynamo.utils.warn_once(msg) # compiled OptimizedModule( # (_orig_mod): LlamaForCausalLM( # (model): LlamaModel( # (embed_tokens): Embedding(128256, 2048) # (layers): ModuleList( # (0-15): 16 x LlamaDecoderLayer( # (self_attn): LlamaSdpaAttention( # (q_proj): Linear(in_features=2048, out_features=2048, bias=False) # (k_proj): Linear(in_features=2048, out_features=512, bias=False) # (v_proj): Linear(in_features=2048, out_features=512, bias=False) # (o_proj): Linear(in_features=2048, out_features=2048, bias=False) # (rotary_emb): LlamaRotaryEmbedding() # ) # (mlp): LlamaMLP( # (gate_proj): Linear(in_features=2048, out_features=8192, bias=False) # (up_proj): Linear(in_features=2048, out_features=8192, bias=False) # (down_proj): Linear(in_features=8192, out_features=2048, bias=False) # (act_fn): SiLU() # ) # (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05) # (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05) # ) # ) # (norm): LlamaRMSNorm((2048,), eps=1e-05) # (rotary_emb): LlamaRotaryEmbedding() # ) # (lm_head): Linear(in_features=2048, out_features=128256, bias=False) # ) # ) # autoquantized OptimizedModule( # (_orig_mod): LlamaForCausalLM( # (model): LlamaModel( # (embed_tokens): Embedding(128256, 2048) # (layers): ModuleList( # (0-15): 16 x LlamaDecoderLayer( # (self_attn): LlamaSdpaAttention( # (q_proj): Linear(in_features=2048, out_features=2048, bias=False) # (k_proj): Linear(in_features=2048, out_features=512, bias=False) # (v_proj): Linear(in_features=2048, out_features=512, bias=False) # (o_proj): Linear(in_features=2048, out_features=2048, bias=False) # (rotary_emb): LlamaRotaryEmbedding() # ) # (mlp): LlamaMLP( # (gate_proj): Linear(in_features=2048, out_features=8192, bias=False) # (up_proj): Linear(in_features=2048, out_features=8192, bias=False) # (down_proj): Linear(in_features=8192, out_features=2048, bias=False) # (act_fn): SiLU() # ) # (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05) # (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05) # ) # ) # (norm): LlamaRMSNorm((2048,), eps=1e-05) # (rotary_emb): LlamaRotaryEmbedding() # ) # (lm_head): Linear(in_features=2048, out_features=128256, bias=False) # ) # ) # activation_shapes: torch.Size([7, 2048]), times_seen: 1 # activation_shapes: torch.Size([1, 2048]), times_seen: 22 # weight_shape: torch.Size([2048, 2048]), dtype: torch.float32, bias_shape: None # >>time: 0.020ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms # >>time: 0.018ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms # >time (all shapes): 0.0181ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, prev_best: infms # >>time: 0.028ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.020ms # >>time: 0.014ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.018ms # >time (all shapes): 0.0142ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, prev_best: 0.0181ms # >>time: 0.020ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.028ms # >>time: 0.013ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.014ms # >time (all shapes): 0.0137ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, prev_best: 0.0142ms # >>time: 0.019ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.020ms # >>time: 0.027ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, to_beat: 0.024ms # >>time: 0.025ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> interpolated, breakeven constant: 0.09 # >>time: 0.021ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.013ms # >time (all shapes): 0.0210ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, prev_best: 0.0137ms # >>time: 0.028ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.020ms # >>time: 0.024ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.013ms # >time (all shapes): 0.0239ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, prev_best: 0.0137ms # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([7, 2048]), torch.Size([2048, 2048]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([1, 2048]), torch.Size([2048, 2048]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation # >time (all shapes): infms for <class 'torchao.quantization.autoquant.AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight'>, prev_best: 0.0137ms # best_cls=<class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'> # activation_shapes: torch.Size([7, 2048]), times_seen: 1 # activation_shapes: torch.Size([1, 2048]), times_seen: 22 # weight_shape: torch.Size([512, 2048]), dtype: torch.float32, bias_shape: None # >>time: 0.017ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms # >>time: 0.013ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms # >time (all shapes): 0.0134ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, prev_best: infms # >>time: 0.022ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.017ms # >>time: 0.012ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.013ms # >time (all shapes): 0.0125ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, prev_best: 0.0134ms # >>time: 0.014ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.022ms # >>time: 0.012ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.012ms # >time (all shapes): 0.0122ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, prev_best: 0.0125ms # >>time: 0.018ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.014ms # >>time: 0.020ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.012ms # >time (all shapes): 0.0196ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, prev_best: 0.0122ms # >>time: 0.020ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.014ms # >>time: 0.016ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.012ms # >time (all shapes): 0.0159ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, prev_best: 0.0122ms # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([7, 2048]), torch.Size([512, 2048]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([1, 2048]), torch.Size([512, 2048]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation # >time (all shapes): infms for <class 'torchao.quantization.autoquant.AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight'>, prev_best: 0.0122ms # best_cls=<class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'> # activation_shapes: torch.Size([7, 2048]), times_seen: 1 # activation_shapes: torch.Size([1, 2048]), times_seen: 22 # weight_shape: torch.Size([8192, 2048]), dtype: torch.float32, bias_shape: None # >>time: 0.041ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms # >>time: 0.034ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms # >time (all shapes): 0.0339ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, prev_best: infms # >>time: 0.071ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.041ms # >>time: 0.018ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.034ms # >time (all shapes): 0.0204ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, prev_best: 0.0339ms # >>time: 0.044ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.071ms # >>time: 0.018ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.018ms # >time (all shapes): 0.0192ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, prev_best: 0.0204ms # >>time: 0.023ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.044ms # >>time: 0.030ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, to_beat: 0.168ms # >>time: 0.029ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> interpolated, breakeven constant: 2.93 # >>time: 0.024ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.018ms # >time (all shapes): 0.0244ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, prev_best: 0.0192ms # >>time: 0.075ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.044ms # >>time: 0.069ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.018ms # >time (all shapes): 0.0691ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, prev_best: 0.0192ms # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([7, 2048]), torch.Size([8192, 2048]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([1, 2048]), torch.Size([8192, 2048]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation # >time (all shapes): infms for <class 'torchao.quantization.autoquant.AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight'>, prev_best: 0.0192ms # best_cls=<class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'> # activation_shapes: torch.Size([7, 8192]), times_seen: 1 # activation_shapes: torch.Size([1, 8192]), times_seen: 22 # weight_shape: torch.Size([2048, 8192]), dtype: torch.float32, bias_shape: None # >>time: 0.046ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms # >>time: 0.035ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms # >time (all shapes): 0.0358ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, prev_best: infms # >>time: 0.078ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.046ms # >>time: 0.020ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.035ms # >time (all shapes): 0.0222ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, prev_best: 0.0358ms # >>time: 0.045ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.078ms # >>time: 0.020ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.020ms # >time (all shapes): 0.0210ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, prev_best: 0.0222ms # >>time: 0.043ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.045ms # >>time: 0.051ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, to_beat: 0.058ms # >>time: 0.050ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> interpolated, breakeven constant: 0.27 # >>time: 0.045ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.020ms # >time (all shapes): 0.0450ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, prev_best: 0.0210ms # >>time: 0.072ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.045ms # >>time: 0.061ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.020ms # >time (all shapes): 0.0616ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, prev_best: 0.0210ms # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([7, 8192]), torch.Size([2048, 8192]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([1, 8192]), torch.Size([2048, 8192]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation # >time (all shapes): infms for <class 'torchao.quantization.autoquant.AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight'>, prev_best: 0.0210ms # best_cls=<class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'> # activation_shapes: torch.Size([1, 2048]), times_seen: 23 # weight_shape: torch.Size([128256, 2048]), dtype: torch.float32, bias_shape: None # >>time: 0.345ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms # >>time: 0.105ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.345ms # >>time: 0.105ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.105ms # >>time: 0.110ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.105ms # >>time: 0.806ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.105ms # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([1, 2048]), torch.Size([128256, 2048]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation # best_cls=<class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>