Skip to content

Instantly share code, notes, and snippets.

@cli99
Created October 16, 2024 21:51
Show Gist options
  • Save cli99/8b583eb7180d83e152aa8fd98e4f6706 to your computer and use it in GitHub Desktop.
Save cli99/8b583eb7180d83e152aa8fd98e4f6706 to your computer and use it in GitHub Desktop.

Revisions

  1. cli99 created this gist Oct 16, 2024.
    247 changes: 247 additions & 0 deletions test_autoquant.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,247 @@
    import torch
    from transformers import AutoModelForCausalLM, AutoTokenizer

    import torchao
    from torchao.quantization.autoquant import (
    DEFAULT_AUTOQUANT_CLASS_LIST,
    DEFAULT_INT4_AUTOQUANT_CLASS_LIST,
    OTHER_AUTOQUANT_CLASS_LIST,
    )
    from torchao.quantization.quant_api import (
    int4_weight_only,
    int8_dynamic_activation_int8_weight,
    int8_weight_only,
    quantize_,
    )

    model_id = "/mnt/workdisk/chengli/models/meta-llama/Llama-3.2-1B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    prompt = "The president of the United States"
    inputs = tokenizer.encode(prompt, return_tensors="pt").to("cuda")

    model = AutoModelForCausalLM.from_pretrained(model_id)
    model = model.to("cuda")
    model = torch.compile(model, mode="max-autotune")
    print("compiled", model)

    model = torchao.autoquant(
    model, qtensor_class_list=DEFAULT_AUTOQUANT_CLASS_LIST + OTHER_AUTOQUANT_CLASS_LIST
    )

    print("autoquantized", model)

    generate_ids = model.generate(inputs, max_length=30)
    model.finalize_autoquant()

    response = tokenizer.batch_decode(
    generate_ids, kip_special_tokens=True, clean_up_tokenization_spaces=False
    )[0]
    print(response)


    # ---------------------------------------------------------------------------
    # OptimizedModule(
    # (_orig_mod): LlamaForCausalLM(
    # (model): LlamaModel(
    # (embed_tokens): Embedding(128256, 2048)
    # (layers): ModuleList(
    # (0-15): 16 x LlamaDecoderLayer(
    # (self_attn): LlamaSdpaAttention(
    # (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
    # (k_proj): Linear(in_features=2048, out_features=512, bias=False)
    # (v_proj): Linear(in_features=2048, out_features=512, bias=False)
    # (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
    # (rotary_emb): LlamaRotaryEmbedding()
    # )
    # (mlp): LlamaMLP(
    # (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
    # (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
    # (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
    # (act_fn): SiLU()
    # )
    # (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
    # (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
    # )
    # )
    # (norm): LlamaRMSNorm((2048,), eps=1e-05)
    # (rotary_emb): LlamaRotaryEmbedding()
    # )
    # (lm_head): Linear(in_features=2048, out_features=128256, bias=False)
    # )
    # )

    # The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
    # Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
    # The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
    # Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
    # /mnt/workdisk/chengli/miniconda3/envs/vllm/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:663: UserWarning: Graph break due to unsupported builtin None.TensorBase._make_wrapper_subclass. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.
    # torch._dynamo.utils.warn_once(msg)
    # /mnt/workdisk/chengli/miniconda3/envs/vllm/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:663: UserWarning: Graph break due to unsupported builtin None.TensorBase._make_wrapper_subclass. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.
    # torch._dynamo.utils.warn_once(msg)
    # /mnt/workdisk/chengli/miniconda3/envs/vllm/lib/python3.11/site-packages/torch/_dynamo/variables/functions.py:663: UserWarning: Graph break due to unsupported builtin None.TensorBase._make_wrapper_subclass. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_graph.
    # torch._dynamo.utils.warn_once(msg)
    # compiled OptimizedModule(
    # (_orig_mod): LlamaForCausalLM(
    # (model): LlamaModel(
    # (embed_tokens): Embedding(128256, 2048)
    # (layers): ModuleList(
    # (0-15): 16 x LlamaDecoderLayer(
    # (self_attn): LlamaSdpaAttention(
    # (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
    # (k_proj): Linear(in_features=2048, out_features=512, bias=False)
    # (v_proj): Linear(in_features=2048, out_features=512, bias=False)
    # (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
    # (rotary_emb): LlamaRotaryEmbedding()
    # )
    # (mlp): LlamaMLP(
    # (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
    # (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
    # (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
    # (act_fn): SiLU()
    # )
    # (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
    # (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
    # )
    # )
    # (norm): LlamaRMSNorm((2048,), eps=1e-05)
    # (rotary_emb): LlamaRotaryEmbedding()
    # )
    # (lm_head): Linear(in_features=2048, out_features=128256, bias=False)
    # )
    # )
    # autoquantized OptimizedModule(
    # (_orig_mod): LlamaForCausalLM(
    # (model): LlamaModel(
    # (embed_tokens): Embedding(128256, 2048)
    # (layers): ModuleList(
    # (0-15): 16 x LlamaDecoderLayer(
    # (self_attn): LlamaSdpaAttention(
    # (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
    # (k_proj): Linear(in_features=2048, out_features=512, bias=False)
    # (v_proj): Linear(in_features=2048, out_features=512, bias=False)
    # (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
    # (rotary_emb): LlamaRotaryEmbedding()
    # )
    # (mlp): LlamaMLP(
    # (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
    # (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
    # (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
    # (act_fn): SiLU()
    # )
    # (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
    # (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
    # )
    # )
    # (norm): LlamaRMSNorm((2048,), eps=1e-05)
    # (rotary_emb): LlamaRotaryEmbedding()
    # )
    # (lm_head): Linear(in_features=2048, out_features=128256, bias=False)
    # )
    # )
    # activation_shapes: torch.Size([7, 2048]), times_seen: 1
    # activation_shapes: torch.Size([1, 2048]), times_seen: 22
    # weight_shape: torch.Size([2048, 2048]), dtype: torch.float32, bias_shape: None
    # >>time: 0.020ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms
    # >>time: 0.018ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms
    # >time (all shapes): 0.0181ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, prev_best: infms
    # >>time: 0.028ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.020ms
    # >>time: 0.014ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.018ms
    # >time (all shapes): 0.0142ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, prev_best: 0.0181ms
    # >>time: 0.020ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.028ms
    # >>time: 0.013ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.014ms
    # >time (all shapes): 0.0137ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, prev_best: 0.0142ms
    # >>time: 0.019ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.020ms
    # >>time: 0.027ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, to_beat: 0.024ms
    # >>time: 0.025ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> interpolated, breakeven constant: 0.09
    # >>time: 0.021ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.013ms
    # >time (all shapes): 0.0210ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, prev_best: 0.0137ms
    # >>time: 0.028ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.020ms
    # >>time: 0.024ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.013ms
    # >time (all shapes): 0.0239ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, prev_best: 0.0137ms
    # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([7, 2048]), torch.Size([2048, 2048]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation
    # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([1, 2048]), torch.Size([2048, 2048]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation
    # >time (all shapes): infms for <class 'torchao.quantization.autoquant.AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight'>, prev_best: 0.0137ms
    # best_cls=<class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>

    # activation_shapes: torch.Size([7, 2048]), times_seen: 1
    # activation_shapes: torch.Size([1, 2048]), times_seen: 22
    # weight_shape: torch.Size([512, 2048]), dtype: torch.float32, bias_shape: None
    # >>time: 0.017ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms
    # >>time: 0.013ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms
    # >time (all shapes): 0.0134ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, prev_best: infms
    # >>time: 0.022ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.017ms
    # >>time: 0.012ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.013ms
    # >time (all shapes): 0.0125ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, prev_best: 0.0134ms
    # >>time: 0.014ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.022ms
    # >>time: 0.012ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.012ms
    # >time (all shapes): 0.0122ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, prev_best: 0.0125ms
    # >>time: 0.018ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.014ms
    # >>time: 0.020ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.012ms
    # >time (all shapes): 0.0196ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, prev_best: 0.0122ms
    # >>time: 0.020ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.014ms
    # >>time: 0.016ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.012ms
    # >time (all shapes): 0.0159ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, prev_best: 0.0122ms
    # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([7, 2048]), torch.Size([512, 2048]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation
    # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([1, 2048]), torch.Size([512, 2048]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation
    # >time (all shapes): infms for <class 'torchao.quantization.autoquant.AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight'>, prev_best: 0.0122ms
    # best_cls=<class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>

    # activation_shapes: torch.Size([7, 2048]), times_seen: 1
    # activation_shapes: torch.Size([1, 2048]), times_seen: 22
    # weight_shape: torch.Size([8192, 2048]), dtype: torch.float32, bias_shape: None
    # >>time: 0.041ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms
    # >>time: 0.034ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms
    # >time (all shapes): 0.0339ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, prev_best: infms
    # >>time: 0.071ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.041ms
    # >>time: 0.018ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.034ms
    # >time (all shapes): 0.0204ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, prev_best: 0.0339ms
    # >>time: 0.044ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.071ms
    # >>time: 0.018ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.018ms
    # >time (all shapes): 0.0192ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, prev_best: 0.0204ms
    # >>time: 0.023ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.044ms
    # >>time: 0.030ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, to_beat: 0.168ms
    # >>time: 0.029ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> interpolated, breakeven constant: 2.93
    # >>time: 0.024ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.018ms
    # >time (all shapes): 0.0244ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, prev_best: 0.0192ms
    # >>time: 0.075ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.044ms
    # >>time: 0.069ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.018ms
    # >time (all shapes): 0.0691ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, prev_best: 0.0192ms
    # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([7, 2048]), torch.Size([8192, 2048]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation
    # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([1, 2048]), torch.Size([8192, 2048]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation
    # >time (all shapes): infms for <class 'torchao.quantization.autoquant.AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight'>, prev_best: 0.0192ms
    # best_cls=<class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>

    # activation_shapes: torch.Size([7, 8192]), times_seen: 1
    # activation_shapes: torch.Size([1, 8192]), times_seen: 22
    # weight_shape: torch.Size([2048, 8192]), dtype: torch.float32, bias_shape: None
    # >>time: 0.046ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms
    # >>time: 0.035ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms
    # >time (all shapes): 0.0358ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, prev_best: infms
    # >>time: 0.078ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.046ms
    # >>time: 0.020ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.035ms
    # >time (all shapes): 0.0222ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, prev_best: 0.0358ms
    # >>time: 0.045ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.078ms
    # >>time: 0.020ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.020ms
    # >time (all shapes): 0.0210ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, prev_best: 0.0222ms
    # >>time: 0.043ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.045ms
    # >>time: 0.051ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, to_beat: 0.058ms
    # >>time: 0.050ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> interpolated, breakeven constant: 0.27
    # >>time: 0.045ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.020ms
    # >time (all shapes): 0.0450ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'>, prev_best: 0.0210ms
    # >>time: 0.072ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.045ms
    # >>time: 0.061ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.020ms
    # >time (all shapes): 0.0616ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, prev_best: 0.0210ms
    # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([7, 8192]), torch.Size([2048, 8192]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation
    # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([1, 8192]), torch.Size([2048, 8192]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation
    # >time (all shapes): infms for <class 'torchao.quantization.autoquant.AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight'>, prev_best: 0.0210ms
    # best_cls=<class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>

    # activation_shapes: torch.Size([1, 2048]), times_seen: 23
    # weight_shape: torch.Size([128256, 2048]), dtype: torch.float32, bias_shape: None
    # >>time: 0.345ms for <class 'torchao.quantization.autoquant.AQFloatLinearWeight'>, to_beat: infms
    # >>time: 0.105ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.345ms
    # >>time: 0.105ms for <class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>, to_beat: 0.105ms
    # >>time: 0.110ms for <class 'torchao.quantization.autoquant.AQInt8DynamicallyQuantizedLinearWeight'> matmul, to_beat: 0.105ms
    # >>time: 0.806ms for <class 'torchao.quantization.autoquant.AQFloat8WeightOnlyQuantizedLinearWeight'>, to_beat: 0.105ms
    # warning: failed to autoquant AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight for shape: (torch.Size([1, 2048]), torch.Size([128256, 2048]), None, torch.float32) due to PerRow quantization only works for bfloat16 precision input activation
    # best_cls=<class 'torchao.quantization.autoquant.AQInt8WeightOnlyQuantizedLinearWeight2'>