Skip to content

Instantly share code, notes, and snippets.

@jimexist
Last active August 16, 2021 14:40
Show Gist options
  • Select an option

  • Save jimexist/9b5adca9cc7ad4b8e7c0b02b9a341af1 to your computer and use it in GitHub Desktop.

Select an option

Save jimexist/9b5adca9cc7ad4b8e7c0b02b9a341af1 to your computer and use it in GitHub Desktop.

Revisions

  1. jimexist revised this gist Aug 16, 2021. 1 changed file with 13 additions and 34 deletions.
    47 changes: 13 additions & 34 deletions lstm_output.ipynb
    Original file line number Diff line number Diff line change
    @@ -2,7 +2,7 @@
    "cells": [
    {
    "cell_type": "code",
    "execution_count": 2,
    "execution_count": 1,
    "id": "c1333286",
    "metadata": {},
    "outputs": [],
    @@ -13,81 +13,60 @@
    },
    {
    "cell_type": "code",
    "execution_count": 7,
    "execution_count": 2,
    "id": "4bd8e564",
    "metadata": {},
    "outputs": [],
    "source": [
    "batch_size = 4\n",
    "input_size = 5\n",
    "hidden_size = 8\n",
    "seq_length = 11\n",
    "seq_length = 31\n",
    "rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 8,
    "execution_count": 3,
    "id": "e4c9ff01",
    "metadata": {},
    "outputs": [
    {
    "data": {
    "text/plain": [
    "torch.Size([3, 11, 5])"
    "torch.Size([4, 31, 5])"
    ]
    },
    "execution_count": 8,
    "execution_count": 3,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "input_tensor = torch.randn(3, seq_length, input_size)\n",
    "input_tensor = torch.randn(batch_size, seq_length, input_size)\n",
    "input_tensor.shape"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 11,
    "execution_count": 4,
    "id": "d3ad7340",
    "metadata": {},
    "outputs": [
    {
    "ename": "NameError",
    "evalue": "name 'states_state' is not defined",
    "output_type": "error",
    "traceback": [
    "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
    "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
    "\u001b[0;32m/var/folders/mn/3932bp3n1hs4cs2nqxj89x_h0000gn/T/ipykernel_99465/1253711468.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0moutput_tensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mhidden_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcell_state\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrnn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0moutput_tensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstates_state\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcell_state\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
    "\u001b[0;31mNameError\u001b[0m: name 'states_state' is not defined"
    ]
    }
    ],
    "source": [
    "output_tensor, (hidden_state, cell_state) = rnn(input_tensor)\n",
    "output_tensor.shape, states_state.shape, cell_state.shape"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 5,
    "id": "b7b5e77f",
    "metadata": {},
    "outputs": [
    {
    "data": {
    "text/plain": [
    "(torch.Size([3, 8]), torch.Size([3, 8]))"
    "(torch.Size([4, 31, 8]), torch.Size([1, 4, 8]), torch.Size([1, 4, 8]))"
    ]
    },
    "execution_count": 5,
    "execution_count": 4,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "soutput_tensor.shape, state.shape"
    "output_tensor, (hidden_state, cell_state) = rnn(input_tensor)\n",
    "output_tensor.shape, hidden_state.shape, cell_state.shape"
    ]
    },
    {
  2. jimexist created this gist Aug 16, 2021.
    123 changes: 123 additions & 0 deletions lstm_output.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,123 @@
    {
    "cells": [
    {
    "cell_type": "code",
    "execution_count": 2,
    "id": "c1333286",
    "metadata": {},
    "outputs": [],
    "source": [
    "import torch\n",
    "from torch import nn"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 7,
    "id": "4bd8e564",
    "metadata": {},
    "outputs": [],
    "source": [
    "input_size = 5\n",
    "hidden_size = 8\n",
    "seq_length = 11\n",
    "rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size, batch_first=True)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 8,
    "id": "e4c9ff01",
    "metadata": {},
    "outputs": [
    {
    "data": {
    "text/plain": [
    "torch.Size([3, 11, 5])"
    ]
    },
    "execution_count": 8,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "input_tensor = torch.randn(3, seq_length, input_size)\n",
    "input_tensor.shape"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 11,
    "id": "d3ad7340",
    "metadata": {},
    "outputs": [
    {
    "ename": "NameError",
    "evalue": "name 'states_state' is not defined",
    "output_type": "error",
    "traceback": [
    "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
    "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
    "\u001b[0;32m/var/folders/mn/3932bp3n1hs4cs2nqxj89x_h0000gn/T/ipykernel_99465/1253711468.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0moutput_tensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mhidden_state\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcell_state\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mrnn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_tensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0moutput_tensor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstates_state\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcell_state\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
    "\u001b[0;31mNameError\u001b[0m: name 'states_state' is not defined"
    ]
    }
    ],
    "source": [
    "output_tensor, (hidden_state, cell_state) = rnn(input_tensor)\n",
    "output_tensor.shape, states_state.shape, cell_state.shape"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 5,
    "id": "b7b5e77f",
    "metadata": {},
    "outputs": [
    {
    "data": {
    "text/plain": [
    "(torch.Size([3, 8]), torch.Size([3, 8]))"
    ]
    },
    "execution_count": 5,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "soutput_tensor.shape, state.shape"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "id": "7bd296a8",
    "metadata": {},
    "outputs": [],
    "source": []
    }
    ],
    "metadata": {
    "kernelspec": {
    "display_name": "Python 3 (ipykernel)",
    "language": "python",
    "name": "python3"
    },
    "language_info": {
    "codemirror_mode": {
    "name": "ipython",
    "version": 3
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "version": "3.8.11"
    }
    },
    "nbformat": 4,
    "nbformat_minor": 5
    }