Skip to content

Instantly share code, notes, and snippets.

@dnlcrl
Created April 30, 2018 23:20
Show Gist options
  • Select an option

  • Save dnlcrl/9e9ee03105a8dc60a5ae7dced0837ee4 to your computer and use it in GitHub Desktop.

Select an option

Save dnlcrl/9e9ee03105a8dc60a5ae7dced0837ee4 to your computer and use it in GitHub Desktop.

Revisions

  1. dnlcrl created this gist Apr 30, 2018.
    1,066 changes: 1,066 additions & 0 deletions einsum_vs_torch.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,1066 @@
    {
    "cells": [
    {
    "cell_type": "code",
    "execution_count": 1,
    "metadata": {},
    "outputs": [
    {
    "data": {
    "text/plain": [
    "tensor([[ 0., 1., 2.],\n",
    " [ 3., 4., 5.]])"
    ]
    },
    "execution_count": 1,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "import torch\n",
    "a = torch.arange(6).reshape(2, 3)\n",
    "a"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "TRANSPOSE"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 2,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "2.73 µs ± 75.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([[ 0., 3.],\n",
    " [ 1., 4.],\n",
    " [ 2., 5.]])"
    ]
    },
    "execution_count": 2,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.einsum('ij->ji', [a])\n",
    "torch.einsum('ij->ji', [a])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 3,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "1.22 µs ± 29.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([[ 0., 3.],\n",
    " [ 1., 4.],\n",
    " [ 2., 5.]])"
    ]
    },
    "execution_count": 3,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit a.t()\n",
    "a.t()"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "SUM"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 4,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "7.31 µs ± 208 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor(15.)"
    ]
    },
    "execution_count": 4,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.einsum('ij->', [a])\n",
    "torch.einsum('ij->', [a])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 5,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "1.58 µs ± 44.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor(15.)"
    ]
    },
    "execution_count": 5,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit a.sum()\n",
    "a.sum()"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "COLUMN SUM"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 6,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "5.42 µs ± 49.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([ 3., 5., 7.])"
    ]
    },
    "execution_count": 6,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.einsum('ij->j', [a])\n",
    "torch.einsum('ij->j', [a])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 7,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "2.42 µs ± 64.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([ 3., 5., 7.])"
    ]
    },
    "execution_count": 7,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.sum(a, 0)\n",
    "torch.sum(a, 0)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "ROW SUM"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 8,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "4.81 µs ± 27.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([ 3., 12.])"
    ]
    },
    "execution_count": 8,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.einsum('ij->i', [a])\n",
    "torch.einsum('ij->i', [a])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 9,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "1.92 µs ± 103 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([ 3., 12.])"
    ]
    },
    "execution_count": 9,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.sum(a, 1)\n",
    "torch.sum(a, 1)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "MATRIX-MATRIX MULTIPLICATION"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 10,
    "metadata": {},
    "outputs": [],
    "source": [
    "a = torch.arange(6).reshape(2, 3)\n",
    "b = torch.arange(15).reshape(3, 5)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 11,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "23.5 µs ± 715 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([[ 25., 28., 31., 34., 37.],\n",
    " [ 70., 82., 94., 106., 118.]])"
    ]
    },
    "execution_count": 11,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.einsum('ik,kj->ij', [a, b])\n",
    "torch.einsum('ik,kj->ij', [a, b])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 12,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "1.77 µs ± 12.9 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([[ 25., 28., 31., 34., 37.],\n",
    " [ 70., 82., 94., 106., 118.]])"
    ]
    },
    "execution_count": 12,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.mm(a, b)\n",
    "torch.mm(a, b)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "MATRIX-VECTOR MULTIPLICATION"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 13,
    "metadata": {},
    "outputs": [],
    "source": [
    "a = torch.arange(6).reshape(2, 3)\n",
    "b = torch.arange(3)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 14,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "19.3 µs ± 799 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([ 5., 14.])"
    ]
    },
    "execution_count": 14,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.einsum('ik,k->i', [a, b])\n",
    "torch.einsum('ik,k->i', [a, b])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 15,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "1.51 µs ± 20.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([ 5., 14.])"
    ]
    },
    "execution_count": 15,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.mv(a,b)\n",
    "torch.mv(a,b)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "DOT PRODUCT"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "vector"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 16,
    "metadata": {},
    "outputs": [],
    "source": [
    "a = torch.arange(3)\n",
    "b = torch.arange(3,6) # -- a vector of length 3 containing [3, 4, 5]"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 17,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "17.1 µs ± 913 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor(14.)"
    ]
    },
    "execution_count": 17,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.einsum('i,i->', [a, b])\n",
    "torch.einsum('i,i->', [a, b])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 18,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "1.92 µs ± 102 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor(14.)"
    ]
    },
    "execution_count": 18,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.dot(a,b)\n",
    "torch.dot(a,b)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "matrix"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 19,
    "metadata": {},
    "outputs": [],
    "source": [
    "a = torch.arange(6).reshape(2, 3)\n",
    "b = torch.arange(6,12).reshape(2, 3)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 20,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "18.6 µs ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor(145.)"
    ]
    },
    "execution_count": 20,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.einsum('ij,ij->', [a, b])\n",
    "torch.einsum('ij,ij->', [a, b])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 21,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "6.13 µs ± 85.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor(145.)"
    ]
    },
    "execution_count": 21,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.dot(a.view(-1),b.view(-1))\n",
    "torch.dot(a.view(-1),b.view(-1))"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "HADAMARD PRODUCT"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 22,
    "metadata": {},
    "outputs": [],
    "source": [
    "a = torch.arange(6).reshape(2, 3)\n",
    "b = torch.arange(6,12).reshape(2, 3)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 23,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "5.56 µs ± 72.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([[ 0., 7., 16.],\n",
    " [ 27., 40., 55.]])"
    ]
    },
    "execution_count": 23,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.einsum('ij,ij->ij', [a, b])\n",
    "torch.einsum('ij,ij->ij', [a, b])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 24,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "1.31 µs ± 14.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([[ 0., 7., 16.],\n",
    " [ 27., 40., 55.]])"
    ]
    },
    "execution_count": 24,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit a*b\n",
    "a*b"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "OUTER PRODUCT"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 25,
    "metadata": {},
    "outputs": [],
    "source": [
    "a = torch.arange(3)\n",
    "b = torch.arange(3,7) # -- a vector of length 4 containing [3, 4, 5, 6]"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 26,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "10.8 µs ± 622 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([[ 0., 0., 0., 0.],\n",
    " [ 3., 4., 5., 6.],\n",
    " [ 6., 8., 10., 12.]])"
    ]
    },
    "execution_count": 26,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.einsum('i,j->ij', [a, b])\n",
    "torch.einsum('i,j->ij', [a, b])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 27,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "1.64 µs ± 51 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([[ 0., 0., 0., 0.],\n",
    " [ 3., 4., 5., 6.],\n",
    " [ 6., 8., 10., 12.]])"
    ]
    },
    "execution_count": 27,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.ger(a, b)\n",
    "torch.ger(a, b)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "BATCH MATRIX MULTIPLICATION"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 28,
    "metadata": {},
    "outputs": [],
    "source": [
    "a = torch.randn(3,2,5)\n",
    "b = torch.randn(3,5,3)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 29,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "24 µs ± 422 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([[[-3.6068, 3.6341, 3.4859],\n",
    " [ 2.3148, 2.5504, 3.8194]],\n",
    "\n",
    " [[ 2.3448, 2.5390, -0.1359],\n",
    " [ 3.4580, 3.4026, 0.0316]],\n",
    "\n",
    " [[-2.1875, -3.7540, 4.1446],\n",
    " [ 1.5737, -0.2249, -0.2547]]])"
    ]
    },
    "execution_count": 29,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.einsum('ijk,ikl->ijl', [a, b])\n",
    "torch.einsum('ijk,ikl->ijl', [a, b])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 30,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "4.81 µs ± 150 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([[[-3.6068, 3.6341, 3.4859],\n",
    " [ 2.3148, 2.5504, 3.8194]],\n",
    "\n",
    " [[ 2.3448, 2.5390, -0.1359],\n",
    " [ 3.4580, 3.4026, 0.0316]],\n",
    "\n",
    " [[-2.1875, -3.7540, 4.1446],\n",
    " [ 1.5737, -0.2249, -0.2547]]])"
    ]
    },
    "execution_count": 30,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit a.bmm(b)\n",
    "a.bmm(b)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "TENSOR MULTIPLICATION"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 31,
    "metadata": {},
    "outputs": [],
    "source": [
    "a = torch.randn(2,3,5,7)\n",
    "b = torch.randn(11,13,3,17,5)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 32,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "210 µs ± 27.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "torch.Size([2, 7, 11, 13, 17])"
    ]
    },
    "execution_count": 32,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape\n",
    "torch.einsum('pqrs,tuqvr->pstuv', [a, b]).shape"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 33,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "83.7 µs ± 5.52 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "torch.Size([2, 7, 11, 13, 17])"
    ]
    },
    "execution_count": 33,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.mm(a.transpose(1,3).transpose(2,3).reshape(2*7, 3*5), \\\n",
    "b.transpose(4, 3).transpose(1, 3).transpose(0, 2).reshape(3*5, 11*13*17)).reshape(2, 7, 11, 13, 17).shape\n",
    "\n",
    "torch.mm(a.transpose(1,3).transpose(2,3).reshape(2*7, 3*5), \\\n",
    "b.transpose(4, 3).transpose(1, 3).transpose(0, 2).reshape(3*5, 11*13*17)).reshape(2, 7, 11, 13, 17).shape"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 34,
    "metadata": {},
    "outputs": [
    {
    "data": {
    "text/plain": [
    "tensor(1, dtype=torch.uint8)"
    ]
    },
    "execution_count": 34,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "(torch.einsum('pqrs,tuqvr->pstuv', [a, b]) == torch.mm(a.transpose(1,3).transpose(2,3).reshape(2*7, 3*5), \\\n",
    "b.transpose(4, 3).transpose(1, 3).transpose(0, 2).reshape(3*5, 11*13*17)).reshape(2, 7, 11, 13, 17)).all()"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "BILINEAR TRANSFORMATION"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 35,
    "metadata": {},
    "outputs": [],
    "source": [
    "a = torch.randn(2,3)\n",
    "b = torch.randn(5,3,7)\n",
    "c = torch.randn(2,7)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 36,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "51.3 µs ± 2.25 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([[-7.8614, 1.6084, 1.8052, 2.3681, 1.1696],\n",
    " [ 5.7942, -1.5822, -4.0773, 1.1712, 0.1531]])"
    ]
    },
    "execution_count": 36,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.einsum('ik,jkl,il->ij', [a, b, c])\n",
    "torch.einsum('ik,jkl,il->ij', [a, b, c])"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 37,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "37 µs ± 1.26 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([[-7.8614, 1.6084, 1.8052, 2.3681, 1.1696],\n",
    " [ 5.7942, -1.5822, -4.0773, 1.1712, 0.1531]])"
    ]
    },
    "execution_count": 37,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.mm(torch.mm(a, b.transpose(0, 1).reshape(3, -1)).reshape(-1, 7), c.transpose(1, 0))\\\n",
    ".view(-1).gather(0, torch.stack([torch.range(0, 9, 2), torch.range(11, 19, 2)]).view(-1).long()).reshape(2, 5)\n",
    "\n",
    "torch.mm(torch.mm(a, b.transpose(0, 1).reshape(3, -1)).reshape(-1, 7), c.transpose(1, 0))\\\n",
    ".view(-1).gather(0, torch.stack([torch.range(0, 9, 2), torch.range(11, 19, 2)]).view(-1).long()).reshape(2, 5)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "using .arange"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 44,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "34.8 µs ± 929 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "tensor([[-7.8614, 1.6084, 1.8052, 2.3681, 1.1696],\n",
    " [ 5.7942, -1.5822, -4.0773, 1.1712, 0.1531]])"
    ]
    },
    "execution_count": 44,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "%timeit torch.mm(torch.mm(a, b.transpose(0, 1).reshape(3, -1)).reshape(-1, 7), c.transpose(1, 0))\\\n",
    ".view(-1).gather(0, torch.stack([torch.arange(0, 10, 2), torch.arange(11, 20, 2)]).view(-1).long()).reshape(2, 5)\n",
    "\n",
    "torch.mm(torch.mm(a, b.transpose(0, 1).reshape(3, -1)).reshape(-1, 7), c.transpose(1, 0))\\\n",
    ".view(-1).gather(0, torch.stack([torch.arange(0, 10, 2), torch.arange(11, 20, 2)]).view(-1).long()).reshape(2, 5)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": []
    }
    ],
    "metadata": {
    "kernelspec": {
    "display_name": "Python 3",
    "language": "python",
    "name": "python3"
    },
    "language_info": {
    "codemirror_mode": {
    "name": "ipython",
    "version": 3
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "version": "3.6.3"
    }
    },
    "nbformat": 4,
    "nbformat_minor": 2
    }