{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Weight initialization example: Tanh\n", "A sample example demonstrate Naive initialization vs. [Xavier initialization](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf) with Tanh activation function.\n", "\n", "### Step 1: Generate dummy dataset and build a plain model" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "require 'torch'\n", "require 'nn'\n", "Plot = require 'itorch.Plot'\n", "\n", "model = nn.Sequential()\n", " :add(nn.Linear(50, 500))\n", " :add(nn.Tanh())\n", "\n", "for i = 2,20 do\n", " model:add(nn.Linear(500, 500))\n", " :add(nn.Tanh())\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 2: Use Naive method to initialize the weights. Plot the activation of each Tanh layer." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "scrolled": false }, "outputs": [ { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "-- Gaussian initialization\n", "for k,v in pairs(model:findModules('nn.Linear')) do\n", " v.weight:normal(0,0.01)\n", "end\n", "\n", "x = torch.Tensor(10, 50)\n", "model:forward(x)\n", "\n", "-- Plot histogram of activations\n", "std = torch.Tensor(10)\n", "mean = torch.Tensor(10)\n", "for i = 2, 20, 2 do\n", " out = model.modules[i].output\n", " plot = Plot():histogram(out, 100, -1, 1):draw()\n", " plot:title(('Activation histogram: Layer %d'):format(i/2)):redraw()\n", " plot:xaxis('Value'):yaxis('Number'):redraw()\n", " std[i/2] = out:std()\n", " mean[i/2] = out:mean()\n", "end\n", "\n", "-- plot variance\n", "plot = Plot():line(torch.range(1, 10), std):draw()\n", "plot:title('Variance'):redraw()\n", "plot:xaxis('Layer'):redraw()\n", "\n", "-- plot mean\n", "plot = Plot():line(torch.range(1, 10), mean):draw()\n", "plot:title('Mean'):redraw()\n", "plot:xaxis('Layer'):redraw()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "### Step 3: Use Xavier method to initialize the weights. Plot the activation of each Tanh layer." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": false, "scrolled": true }, "outputs": [ { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "-- Xavier initialization\n", "for k,v in pairs(model:findModules('nn.Linear')) do\n", " local n = v.weight:size(1)\n", " v.weight:normal(0,1/math.sqrt(n))\n", "end\n", "\n", "x = torch.Tensor(10, 50)\n", "model:forward(x)\n", "\n", "\n", "-- Plot histogram of activations\n", "std = torch.Tensor(10)\n", "mean = torch.Tensor(10)\n", "for i = 2, 20, 2 do\n", " out = model.modules[i].output\n", " plot = Plot():histogram(out, 100, -1, 1):draw()\n", " plot:title(('Activation histogram: Layer %d'):format(i/2)):redraw()\n", " plot:xaxis('Value'):yaxis('Number'):redraw()\n", " std[i/2] = out:std()\n", " mean[i/2] = out:mean()\n", "end\n", "\n", "-- plot variance\n", "plot = Plot():line(torch.range(1, 10), std):draw()\n", "plot:title('Variance'):redraw()\n", "plot:xaxis('Layer'):redraw()\n", "\n", "-- plot mean\n", "plot = Plot():line(torch.range(1, 10), mean):draw()\n", "plot:title('Mean'):redraw()\n", "plot:xaxis('Layer'):redraw()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "iTorch", "language": "lua", "name": "itorch" }, "language_info": { "name": "lua", "version": "5.1" } }, "nbformat": 4, "nbformat_minor": 2 }