{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Weight initialization example: ReLU\n", "A sample example demonstrate [Xavier initialization](http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf) vs. [MSR initialization](https://arxiv.org/abs/1502.01852) with ReLU activation function.\n", "\n", "### Step 1: Generate dummy dataset and build a plain model" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": false }, "outputs": [], "source": [ "require 'torch'\n", "require 'nn'\n", "Plot = require 'itorch.Plot'\n", "\n", "-- Create fake dataset\n", "x = torch.Tensor(10, 50):uniform(-1, 1)\n", "\n", "-- Build model\n", "model = nn.Sequential()\n", " :add(nn.Linear(50, 500))\n", " :add(nn.ReLU(true))\n", "\n", "for i = 2,20 do\n", " model:add(nn.Linear(500, 500))\n", " :add(nn.ReLU(true))\n", "end" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Step 2: Use Xavier method to initialize the weights. Plot the activation of each ReLU layer." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "collapsed": false, "scrolled": false }, "outputs": [ { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "-- Xavier initialization\n", "for k,v in pairs(model:findModules('nn.Linear')) do\n", " local n = v.weight:size(1)\n", " v.weight:uniform(-math.sqrt(3/n),math.sqrt(3/n))\n", "end\n", "\n", "model:forward(x)\n", "\n", "-- Plot histogram of activations\n", "for i = 2, 20, 2 do\n", " out = model.modules[i].output\n", " plot = Plot():histogram(out, 100, 0, 1):draw()\n", " plot:title(('Activation histogram: Layer %d'):format(i/2)):redraw()\n", " plot:xaxis('Value'):yaxis('Number'):redraw()\n", "end" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "### Step 3: Use MSR method to initialize the weights. Plot the activation of each ReLU layer." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": false, "scrolled": false }, "outputs": [ { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "
\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "-- MSR initialization\n", "for k,v in pairs(model:findModules('nn.Linear')) do\n", " local n = v.weight:size(1)\n", " v.weight:normal(0,1/math.sqrt(n/2))\n", "end\n", "\n", "x = torch.Tensor(10, 50)\n", "model:forward(x)\n", "\n", "-- Plot histogram of activations\n", "for i = 2, 20, 2 do\n", " out = model.modules[i].output\n", " plot = Plot():histogram(out, 100, 0, 1):draw()\n", " plot:title(('Activation histogram: Layer %d'):format(i/2)):redraw()\n", " plot:xaxis('Value'):yaxis('Number'):redraw()\n", "end" ] } ], "metadata": { "kernelspec": { "display_name": "iTorch", "language": "lua", "name": "itorch" }, "language_info": { "name": "lua", "version": "5.1" } }, "nbformat": 4, "nbformat_minor": 2 }