Skip to content

Instantly share code, notes, and snippets.

@JIElite
Created November 26, 2024 12:57
Show Gist options
  • Select an option

  • Save JIElite/111d9d2aae5ec2f4800dfd43752a697d to your computer and use it in GitHub Desktop.

Select an option

Save JIElite/111d9d2aae5ec2f4800dfd43752a697d to your computer and use it in GitHub Desktop.

Revisions

  1. JIElite created this gist Nov 26, 2024.
    408 changes: 408 additions & 0 deletions test.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,408 @@
    {
    "cells": [
    {
    "cell_type": "markdown",
    "id": "24a6fa13-36d2-46a9-a4f0-30387ab5fe04",
    "metadata": {},
    "source": [
    "# Explain the arithmetic intensity of conv1 in Whisper\n"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 1,
    "id": "14f1b230-1db9-4da5-b5a5-b20718ac7518",
    "metadata": {},
    "outputs": [],
    "source": [
    "import torch\n",
    "from transformers import AutoModelForSpeechSeq2Seq"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 2,
    "id": "e8a0650d-6396-4e52-9233-345d77376f86",
    "metadata": {},
    "outputs": [],
    "source": [
    "model = AutoModelForSpeechSeq2Seq.from_pretrained(\"openai/whisper-large-v3\")"
    ]
    },
    {
    "cell_type": "markdown",
    "id": "85d8d04f-734e-4d97-aa08-c65347a1eb21",
    "metadata": {},
    "source": [
    "## Conv1d in Whisper Encoder\n",
    "\n",
    "The first Conv1d\n",
    "- in_channels = 128\n",
    "- out_channels = 1280\n",
    "- kernel_size = 3\n",
    "- stride = 1\n",
    "- padding = 1\n",
    "\n",
    "The second Conv1d\n",
    "- in_channels = 1280\n",
    "- out_channels = 1280\n",
    "- kernel_size = 3\n",
    "- stride = 2\n",
    "- padding = 1"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 3,
    "id": "c4228110-c757-4239-bf55-2e30490884d6",
    "metadata": {},
    "outputs": [
    {
    "data": {
    "text/plain": [
    "WhisperEncoder(\n",
    " (conv1): Conv1d(128, 1280, kernel_size=(3,), stride=(1,), padding=(1,))\n",
    " (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))\n",
    " (embed_positions): Embedding(1500, 1280)\n",
    " (layers): ModuleList(\n",
    " (0-31): 32 x WhisperEncoderLayer(\n",
    " (self_attn): WhisperSdpaAttention(\n",
    " (k_proj): Linear(in_features=1280, out_features=1280, bias=False)\n",
    " (v_proj): Linear(in_features=1280, out_features=1280, bias=True)\n",
    " (q_proj): Linear(in_features=1280, out_features=1280, bias=True)\n",
    " (out_proj): Linear(in_features=1280, out_features=1280, bias=True)\n",
    " )\n",
    " (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
    " (activation_fn): GELUActivation()\n",
    " (fc1): Linear(in_features=1280, out_features=5120, bias=True)\n",
    " (fc2): Linear(in_features=5120, out_features=1280, bias=True)\n",
    " (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
    " )\n",
    " )\n",
    " (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n",
    ")"
    ]
    },
    "execution_count": 3,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "model.model.encoder"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 4,
    "id": "3a559ee0-de9f-4ffb-a752-baf6a3359466",
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "Conv1d(128, 1280, kernel_size=(3,), stride=(1,), padding=(1,))\n",
    "The shape of bias: torch.Size([1280])\n"
    ]
    }
    ],
    "source": [
    "# The first conv1 layer in whisper\n",
    "print(model.model.encoder.conv1)\n",
    "\n",
    "# Exist bias\n",
    "model.model.encoder.conv1.bias\n",
    "print(\"The shape of bias:\", model.model.encoder.conv1.bias.shape)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 5,
    "id": "9a79d19b-b776-4764-b812-16e413956919",
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))\n",
    "The shape of bias: torch.Size([1280])\n"
    ]
    }
    ],
    "source": [
    "# The second conv1 layer in whisper\n",
    "print(model.model.encoder.conv2)\n",
    "model.model.encoder.conv2.bias\n",
    "print(\"The shape of bias:\", model.model.encoder.conv2.bias.shape)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 6,
    "id": "c9cff408-3f90-4df2-82cc-4f158b8ba150",
    "metadata": {},
    "outputs": [],
    "source": [
    "# Create an example input of Whisper, the shape of input tensor is\n",
    "# equivalent to Mel-Spectrogram\n",
    "x = torch.randn(1, 128, 3000)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 7,
    "id": "089d3bcd-59d6-4b2c-a1fb-5fffa09c1100",
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "torch.Size([1, 1280, 3000])\n"
    ]
    }
    ],
    "source": [
    "# The output shape of first conv1d = [1, 1280, 3000]\n",
    "output = model.model.encoder.conv1(x)\n",
    "print(output.shape)"
    ]
    },
    {
    "cell_type": "markdown",
    "id": "6baf95db-0db3-4008-b378-78b492600905",
    "metadata": {},
    "source": [
    "## The Mops of `model.model.encoder.conv1`\n",
    "\n",
    "mops = input + weights + output"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 8,
    "id": "a0235787-4c1b-423f-b798-626259b15a68",
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "4_715_520\n"
    ]
    }
    ],
    "source": [
    "mops = 1 * 128 * 3000 + 3 * 128 * 1280 + 1 * 1280 * 3000\n",
    "print(f\"{mops:_}\")"
    ]
    },
    {
    "cell_type": "markdown",
    "id": "fb6815cd-b06d-4a8b-9af4-2374dc94f9e4",
    "metadata": {},
    "source": [
    "## The FLOPs of `model.model.encoder.conv1`\n",
    "\n",
    "each kernel size = 3 * 128, so there are 3 * 128 multiplications\n",
    "\n",
    "after convolution, we need to add up the value from dot product, there are 3 * 128 - 1 add operations\n",
    "\n",
    "because there are 1280 kernels, "
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 9,
    "id": "1c4e52de-110b-4439-89d1-91b3419d275f",
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "2_949_120_000\n"
    ]
    }
    ],
    "source": [
    "flops = ((3 + 2) * 128 + 127) * 1280 * 3000 + 1280 * 3000\n",
    "print(f\"{flops:_}\")"
    ]
    },
    {
    "cell_type": "markdown",
    "id": "b344fa5b-714c-40d4-b726-0844a3fef4b0",
    "metadata": {},
    "source": [
    "## Arithmetic Intensity of Conv1"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 10,
    "id": "0e8f3002-f7a2-4874-b9fa-1bc8bc85ed4b",
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "625.4071661237786\n"
    ]
    }
    ],
    "source": [
    "arithmetic_intensity = flops / mops\n",
    "print(arithmetic_intensity)"
    ]
    },
    {
    "cell_type": "markdown",
    "id": "1194b6f5-161e-4e80-a88a-da2647d6b3e3",
    "metadata": {},
    "source": [
    "## Verify flops with pytorch built-in flop counter"
    ]
    },
    {
    "cell_type": "markdown",
    "id": "9a239c3b-8043-4b1b-b09f-191b08196571",
    "metadata": {},
    "source": [
    "**The FLOPs in first conv1d**"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 11,
    "id": "355a9ed5-3378-478d-a102-e42dfdcf3717",
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "FLOPs: 2_949_120_000\n"
    ]
    }
    ],
    "source": [
    "from torch.utils.flop_counter import FlopCounterMode\n",
    "\n",
    "flop_counter = FlopCounterMode(display=False, depth=None)\n",
    "conv_layer = model.model.encoder.conv1\n",
    "x = torch.randn(1, 128, 3000)\n",
    "\n",
    "with flop_counter:\n",
    " conv_layer(x)\n",
    "\n",
    "total_flops = flop_counter.get_total_flops()\n",
    "print(f\"FLOPs: {total_flops:_}\")"
    ]
    },
    {
    "cell_type": "markdown",
    "id": "366e480d-f7e6-492c-9343-a5622b66ddb5",
    "metadata": {},
    "source": [
    "**The FLOPs in the second conv1d**"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 12,
    "id": "6efc785a-5422-4e2c-8907-8e23fb1428e4",
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "FLOPs: 14_745_600_000\n"
    ]
    }
    ],
    "source": [
    "flop_counter = FlopCounterMode(display=False, depth=None)\n",
    "conv_layer = model.model.encoder.conv2\n",
    "x = torch.randn(1, 1280, 3000)\n",
    "\n",
    "with flop_counter:\n",
    " out = conv_layer(x)\n",
    "\n",
    "total_flops = flop_counter.get_total_flops()\n",
    "print(f\"FLOPs: {total_flops:_}\")"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 13,
    "id": "f5b4e894-9ecb-43eb-b8a9-958a28e8db9f",
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "torch.Size([1, 1280, 1500])\n"
    ]
    }
    ],
    "source": [
    "print(out.shape)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 14,
    "id": "9c02b580-9f0d-45f6-bb25-5f08cd3c39b6",
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "1381.294964028777\n"
    ]
    }
    ],
    "source": [
    "mops = 1 * 1280 * 3000 + 3 * 1280 * 1280 + 1 * 1280 * 1500\n",
    "arithmetic_intensity = total_flops / mops\n",
    "print(arithmetic_intensity)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "8d681a8f-94ff-4ee3-80c7-edaf8126243d",
    "metadata": {},
    "outputs": [],
    "source": []
    }
    ],
    "metadata": {
    "kernelspec": {
    "display_name": "Python 3 (ipykernel)",
    "language": "python",
    "name": "python3"
    },
    "language_info": {
    "codemirror_mode": {
    "name": "ipython",
    "version": 3
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "version": "3.12.7"
    }
    },
    "nbformat": 4,
    "nbformat_minor": 5
    }