Created
November 26, 2024 12:57
-
-
Save JIElite/111d9d2aae5ec2f4800dfd43752a697d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "24a6fa13-36d2-46a9-a4f0-30387ab5fe04", | |
| "metadata": {}, | |
| "source": [ | |
| "# Explain the arithmetic intensity of conv1 in Whisper\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "14f1b230-1db9-4da5-b5a5-b20718ac7518", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import torch\n", | |
| "from transformers import AutoModelForSpeechSeq2Seq" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "e8a0650d-6396-4e52-9233-345d77376f86", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "model = AutoModelForSpeechSeq2Seq.from_pretrained(\"openai/whisper-large-v3\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "85d8d04f-734e-4d97-aa08-c65347a1eb21", | |
| "metadata": {}, | |
| "source": [ | |
| "## Conv1d in Whisper Encoder\n", | |
| "\n", | |
| "The first Conv1d\n", | |
| "- in_channels = 128\n", | |
| "- out_channels = 1280\n", | |
| "- kernel_size = 3\n", | |
| "- stride = 1\n", | |
| "- padding = 1\n", | |
| "\n", | |
| "The second Conv1d\n", | |
| "- in_channels = 1280\n", | |
| "- out_channels = 1280\n", | |
| "- kernel_size = 3\n", | |
| "- stride = 2\n", | |
| "- padding = 1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 3, | |
| "id": "c4228110-c757-4239-bf55-2e30490884d6", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "WhisperEncoder(\n", | |
| " (conv1): Conv1d(128, 1280, kernel_size=(3,), stride=(1,), padding=(1,))\n", | |
| " (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))\n", | |
| " (embed_positions): Embedding(1500, 1280)\n", | |
| " (layers): ModuleList(\n", | |
| " (0-31): 32 x WhisperEncoderLayer(\n", | |
| " (self_attn): WhisperSdpaAttention(\n", | |
| " (k_proj): Linear(in_features=1280, out_features=1280, bias=False)\n", | |
| " (v_proj): Linear(in_features=1280, out_features=1280, bias=True)\n", | |
| " (q_proj): Linear(in_features=1280, out_features=1280, bias=True)\n", | |
| " (out_proj): Linear(in_features=1280, out_features=1280, bias=True)\n", | |
| " )\n", | |
| " (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", | |
| " (activation_fn): GELUActivation()\n", | |
| " (fc1): Linear(in_features=1280, out_features=5120, bias=True)\n", | |
| " (fc2): Linear(in_features=5120, out_features=1280, bias=True)\n", | |
| " (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", | |
| " )\n", | |
| " )\n", | |
| " (layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)\n", | |
| ")" | |
| ] | |
| }, | |
| "execution_count": 3, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "model.model.encoder" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 4, | |
| "id": "3a559ee0-de9f-4ffb-a752-baf6a3359466", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Conv1d(128, 1280, kernel_size=(3,), stride=(1,), padding=(1,))\n", | |
| "The shape of bias: torch.Size([1280])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# The first conv1 layer in whisper\n", | |
| "print(model.model.encoder.conv1)\n", | |
| "\n", | |
| "# Exist bias\n", | |
| "model.model.encoder.conv1.bias\n", | |
| "print(\"The shape of bias:\", model.model.encoder.conv1.bias.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 5, | |
| "id": "9a79d19b-b776-4764-b812-16e413956919", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))\n", | |
| "The shape of bias: torch.Size([1280])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# The second conv1 layer in whisper\n", | |
| "print(model.model.encoder.conv2)\n", | |
| "model.model.encoder.conv2.bias\n", | |
| "print(\"The shape of bias:\", model.model.encoder.conv2.bias.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 6, | |
| "id": "c9cff408-3f90-4df2-82cc-4f158b8ba150", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Create an example input of Whisper, the shape of input tensor is\n", | |
| "# equivalent to Mel-Spectrogram\n", | |
| "x = torch.randn(1, 128, 3000)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 7, | |
| "id": "089d3bcd-59d6-4b2c-a1fb-5fffa09c1100", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "torch.Size([1, 1280, 3000])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# The output shape of first conv1d = [1, 1280, 3000]\n", | |
| "output = model.model.encoder.conv1(x)\n", | |
| "print(output.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "6baf95db-0db3-4008-b378-78b492600905", | |
| "metadata": {}, | |
| "source": [ | |
| "## The Mops of `model.model.encoder.conv1`\n", | |
| "\n", | |
| "mops = input + weights + output" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 8, | |
| "id": "a0235787-4c1b-423f-b798-626259b15a68", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "4_715_520\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "mops = 1 * 128 * 3000 + 3 * 128 * 1280 + 1 * 1280 * 3000\n", | |
| "print(f\"{mops:_}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "fb6815cd-b06d-4a8b-9af4-2374dc94f9e4", | |
| "metadata": {}, | |
| "source": [ | |
| "## The FLOPs of `model.model.encoder.conv1`\n", | |
| "\n", | |
| "each kernel size = 3 * 128, so there are 3 * 128 multiplications\n", | |
| "\n", | |
| "after convolution, we need to add up the value from dot product, there are 3 * 128 - 1 add operations\n", | |
| "\n", | |
| "because there are 1280 kernels, " | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "id": "1c4e52de-110b-4439-89d1-91b3419d275f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "2_949_120_000\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "flops = ((3 + 2) * 128 + 127) * 1280 * 3000 + 1280 * 3000\n", | |
| "print(f\"{flops:_}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "b344fa5b-714c-40d4-b726-0844a3fef4b0", | |
| "metadata": {}, | |
| "source": [ | |
| "## Arithmetic Intensity of Conv1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "0e8f3002-f7a2-4874-b9fa-1bc8bc85ed4b", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "625.4071661237786\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "arithmetic_intensity = flops / mops\n", | |
| "print(arithmetic_intensity)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "1194b6f5-161e-4e80-a88a-da2647d6b3e3", | |
| "metadata": {}, | |
| "source": [ | |
| "## Verify flops with pytorch built-in flop counter" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "9a239c3b-8043-4b1b-b09f-191b08196571", | |
| "metadata": {}, | |
| "source": [ | |
| "**The FLOPs in first conv1d**" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "355a9ed5-3378-478d-a102-e42dfdcf3717", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "FLOPs: 2_949_120_000\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "from torch.utils.flop_counter import FlopCounterMode\n", | |
| "\n", | |
| "flop_counter = FlopCounterMode(display=False, depth=None)\n", | |
| "conv_layer = model.model.encoder.conv1\n", | |
| "x = torch.randn(1, 128, 3000)\n", | |
| "\n", | |
| "with flop_counter:\n", | |
| " conv_layer(x)\n", | |
| "\n", | |
| "total_flops = flop_counter.get_total_flops()\n", | |
| "print(f\"FLOPs: {total_flops:_}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "366e480d-f7e6-492c-9343-a5622b66ddb5", | |
| "metadata": {}, | |
| "source": [ | |
| "**The FLOPs in the second conv1d**" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "6efc785a-5422-4e2c-8907-8e23fb1428e4", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "FLOPs: 14_745_600_000\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "flop_counter = FlopCounterMode(display=False, depth=None)\n", | |
| "conv_layer = model.model.encoder.conv2\n", | |
| "x = torch.randn(1, 1280, 3000)\n", | |
| "\n", | |
| "with flop_counter:\n", | |
| " out = conv_layer(x)\n", | |
| "\n", | |
| "total_flops = flop_counter.get_total_flops()\n", | |
| "print(f\"FLOPs: {total_flops:_}\")" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "f5b4e894-9ecb-43eb-b8a9-958a28e8db9f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "torch.Size([1, 1280, 1500])\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "print(out.shape)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "9c02b580-9f0d-45f6-bb25-5f08cd3c39b6", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "1381.294964028777\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "mops = 1 * 1280 * 3000 + 3 * 1280 * 1280 + 1 * 1280 * 1500\n", | |
| "arithmetic_intensity = total_flops / mops\n", | |
| "print(arithmetic_intensity)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "8d681a8f-94ff-4ee3-80c7-edaf8126243d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.12.7" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment