Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save MervinPraison/78c4c727ad77986f03d6b115cd1d68d4 to your computer and use it in GitHub Desktop.

Select an option

Save MervinPraison/78c4c727ad77986f03d6b115cd1d68d4 to your computer and use it in GitHub Desktop.

Revisions

  1. @yufengg yufengg revised this gist Sep 14, 2017. 1 changed file with 154 additions and 45 deletions.
    199 changes: 154 additions & 45 deletions AIA003_tf_estimators.ipynb
    Original file line number Diff line number Diff line change
    @@ -23,11 +23,18 @@
    "# limitations under the License."
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "## Overview\n",
    "All the code we'll look at is in the next cell. We will step through each step after."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "collapsed": true,
    "scrolled": true
    },
    "outputs": [],
    @@ -76,7 +83,21 @@
    "accuracy_score = classifier.evaluate(input_fn=input_fn(test_set), \n",
    " steps=100)[\"accuracy\"]\n",
    "print('\\nAccuracy: {0:f}'.format(accuracy_score))\n",
    "\n"
    "\n",
    "# Export the model for serving\n",
    "feature_spec = {'flower_features': tf.FixedLenFeature(shape=[4], dtype=np.float32)}\n",
    "\n",
    "serving_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)\n",
    "\n",
    "classifier.export_savedmodel(export_dir_base='/tmp/iris_model' + '/export', \n",
    " serving_input_receiver_fn=serving_fn)\n"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "## Imports"
    ]
    },
    {
    @@ -132,6 +153,13 @@
    "<img src=\"https://storage.googleapis.com/image-uploader/AIA_images/data_table.png\" style=\"width: 450px\"/>"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "## Load data in"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 2,
    @@ -290,6 +318,13 @@
    "print(training_set.target)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "## Feature columns and model creation"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 3,
    @@ -316,6 +351,13 @@
    " model_dir=\"/tmp/iris_model\")\n"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "## Input function"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 4,
    @@ -342,6 +384,13 @@
    "# raw data -> input function -> feature columns -> model"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "## Training"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 5,
    @@ -354,28 +403,29 @@
    "output_type": "stream",
    "text": [
    "INFO:tensorflow:Create CheckpointSaverHook.\n",
    "INFO:tensorflow:Saving checkpoints for 1 into /tmp/iris_model/model.ckpt.\n",
    "INFO:tensorflow:loss = 131.833, step = 1\n",
    "INFO:tensorflow:global_step/sec: 832.73\n",
    "INFO:tensorflow:loss = 37.1391, step = 101 (0.121 sec)\n",
    "INFO:tensorflow:global_step/sec: 656.81\n",
    "INFO:tensorflow:loss = 27.8594, step = 201 (0.154 sec)\n",
    "INFO:tensorflow:global_step/sec: 837.738\n",
    "INFO:tensorflow:loss = 23.0449, step = 301 (0.118 sec)\n",
    "INFO:tensorflow:global_step/sec: 832.487\n",
    "INFO:tensorflow:loss = 20.058, step = 401 (0.121 sec)\n",
    "INFO:tensorflow:global_step/sec: 705.118\n",
    "INFO:tensorflow:loss = 18.0083, step = 501 (0.142 sec)\n",
    "INFO:tensorflow:global_step/sec: 650.441\n",
    "INFO:tensorflow:loss = 16.505, step = 601 (0.153 sec)\n",
    "INFO:tensorflow:global_step/sec: 619.556\n",
    "INFO:tensorflow:loss = 15.3496, step = 701 (0.162 sec)\n",
    "INFO:tensorflow:global_step/sec: 657.613\n",
    "INFO:tensorflow:loss = 14.43, step = 801 (0.152 sec)\n",
    "INFO:tensorflow:global_step/sec: 763.202\n",
    "INFO:tensorflow:loss = 13.6782, step = 901 (0.131 sec)\n",
    "INFO:tensorflow:Saving checkpoints for 1000 into /tmp/iris_model/model.ckpt.\n",
    "INFO:tensorflow:Loss for final step: 13.0562.\n",
    "INFO:tensorflow:Restoring parameters from /tmp/iris_model/model.ckpt-3000\n",
    "INFO:tensorflow:Saving checkpoints for 3001 into /tmp/iris_model/model.ckpt.\n",
    "INFO:tensorflow:loss = 8.50027, step = 3001\n",
    "INFO:tensorflow:global_step/sec: 827.439\n",
    "INFO:tensorflow:loss = 8.40806, step = 3101 (0.123 sec)\n",
    "INFO:tensorflow:global_step/sec: 868.825\n",
    "INFO:tensorflow:loss = 8.32063, step = 3201 (0.116 sec)\n",
    "INFO:tensorflow:global_step/sec: 959.112\n",
    "INFO:tensorflow:loss = 8.23757, step = 3301 (0.104 sec)\n",
    "INFO:tensorflow:global_step/sec: 844.444\n",
    "INFO:tensorflow:loss = 8.15855, step = 3401 (0.118 sec)\n",
    "INFO:tensorflow:global_step/sec: 847.278\n",
    "INFO:tensorflow:loss = 8.08324, step = 3501 (0.118 sec)\n",
    "INFO:tensorflow:global_step/sec: 825.594\n",
    "INFO:tensorflow:loss = 8.01139, step = 3601 (0.120 sec)\n",
    "INFO:tensorflow:global_step/sec: 882.98\n",
    "INFO:tensorflow:loss = 7.94273, step = 3701 (0.114 sec)\n",
    "INFO:tensorflow:global_step/sec: 941.876\n",
    "INFO:tensorflow:loss = 7.87704, step = 3801 (0.106 sec)\n",
    "INFO:tensorflow:global_step/sec: 889.862\n",
    "INFO:tensorflow:loss = 7.81412, step = 3901 (0.112 sec)\n",
    "INFO:tensorflow:Saving checkpoints for 4000 into /tmp/iris_model/model.ckpt.\n",
    "INFO:tensorflow:Loss for final step: 7.75437.\n",
    "fit done\n"
    ]
    }
    @@ -390,14 +440,25 @@
    {
    "cell_type": "code",
    "execution_count": 6,
    "metadata": {
    "collapsed": true
    },
    "outputs": [],
    "source": [
    "## Evaluation"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 7,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "INFO:tensorflow:Starting evaluation at 2017-09-05-04:32:30\n",
    "INFO:tensorflow:Restoring parameters from /tmp/iris_model/model.ckpt-1000\n",
    "INFO:tensorflow:Starting evaluation at 2017-09-14-15:16:40\n",
    "INFO:tensorflow:Restoring parameters from /tmp/iris_model/model.ckpt-4000\n",
    "INFO:tensorflow:Evaluation [1/100]\n",
    "INFO:tensorflow:Evaluation [2/100]\n",
    "INFO:tensorflow:Evaluation [3/100]\n",
    @@ -498,8 +559,8 @@
    "INFO:tensorflow:Evaluation [98/100]\n",
    "INFO:tensorflow:Evaluation [99/100]\n",
    "INFO:tensorflow:Evaluation [100/100]\n",
    "INFO:tensorflow:Finished evaluation at 2017-09-05-04:32:31\n",
    "INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.966667, average_loss = 0.120964, global_step = 1000, loss = 3.62893\n",
    "INFO:tensorflow:Finished evaluation at 2017-09-14-15:16:41\n",
    "INFO:tensorflow:Saving dict for global step 4000: accuracy = 0.966667, average_loss = 0.0706094, global_step = 4000, loss = 2.11828\n",
    "\n",
    "Accuracy: 0.966667\n"
    ]
    @@ -513,33 +574,81 @@
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "cell_type": "markdown",
    "metadata": {
    "collapsed": true
    },
    "outputs": [],
    "source": [
    "# Estimators review\n",
    "\n",
    "# Load datasets.\n",
    "training_data = load_csv_with_header()\n",
    "### Load datasets.\n",
    "\n",
    "# define input functions\n",
    "def input_fn(dataset)\n",
    " \n",
    "# Define feature columns\n",
    "feature_columns = [tf.feature_column.numeric_column(feature_name, \n",
    " shape=[4])]\n",
    " training_data = load_csv_with_header()\n",
    "\n",
    "### define input functions\n",
    "\n",
    " def input_fn(dataset)\n",
    " \n",
    "### Define feature columns\n",
    "\n",
    " feature_columns = [tf.feature_column.numeric_column(feature_name, shape=[4])]\n",
    "\n",
    "### Create model\n",
    "\n",
    " classifier = tf.estimator.LinearClassifier()\n",
    "\n",
    "### Train\n",
    "\n",
    " classifier.train()\n",
    "\n",
    "### Evaluate\n",
    "\n",
    " classifier.evaluate()"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "## Exporting a model for serving predictions\n"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 8,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "INFO:tensorflow:Restoring parameters from /tmp/iris_model/model.ckpt-4000\n",
    "INFO:tensorflow:Assets added to graph.\n",
    "INFO:tensorflow:No assets to write.\n",
    "INFO:tensorflow:SavedModel written to: /tmp/iris_model/export/1505402201/saved_model.pb\n"
    ]
    },
    {
    "data": {
    "text/plain": [
    "'/tmp/iris_model/export/1505402201'"
    ]
    },
    "execution_count": 8,
    "metadata": {},
    "output_type": "execute_result"
    }
    ],
    "source": [
    "feature_spec = {'flower_features': tf.FixedLenFeature(shape=[4], dtype=np.float32)}\n",
    "\n",
    "# Create model\n",
    "classifier = tf.estimator.LinearClassifier()\n",
    "serving_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)\n",
    "\n",
    "# Train\n",
    "classifier.train()\n",
    "classifier.export_savedmodel(export_dir_base='/tmp/iris_model' + '/export', \n",
    " serving_input_receiver_fn=serving_fn)\n",
    "\n",
    "# Evaluate\n",
    "classifier.evaluate()"
    "\n",
    "\n"
    ]
    }
    ],
  2. @yufengg yufengg revised this gist Sep 7, 2017. 1 changed file with 24 additions and 0 deletions.
    24 changes: 24 additions & 0 deletions AIA003_tf_estimators.ipynb
    Original file line number Diff line number Diff line change
    @@ -4,6 +4,30 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "collapsed": true
    },
    "outputs": [],
    "source": [
    "# Copyright 2017 Google Inc.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "collapsed": true,
    "scrolled": true
    },
    "outputs": [],
  3. @yufengg yufengg revised this gist Sep 6, 2017. 1 changed file with 4 additions and 4 deletions.
    8 changes: 4 additions & 4 deletions AIA003_tf_estimators.ipynb
    Original file line number Diff line number Diff line change
    @@ -84,9 +84,9 @@
    "\n",
    "3 types of Iris Flowers: \n",
    "\n",
    "<img src=\"Iris_setosa.jpg\" style=\"width: 100px; display:inline\"/>\n",
    "<img src=\"Iris_versicolor.jpg\" style=\"width: 150px;display:inline\"/>\n",
    "<img src=\"Iris_virginica.jpg\" style=\"width: 150px;display:inline\"/>\n",
    "<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/5/56/Kosaciec_szczecinkowaty_Iris_setosa.jpg/450px-Kosaciec_szczecinkowaty_Iris_setosa.jpg\" style=\"width: 100px; display:inline\"/>\n",
    "<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/4/41/Iris_versicolor_3.jpg/800px-Iris_versicolor_3.jpg\" style=\"width: 150px;display:inline\"/>\n",
    "<img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/9/9f/Iris_virginica.jpg/736px-Iris_virginica.jpg\" style=\"width: 150px;display:inline\"/>\n",
    "* Iris Setosa\n",
    "* Iris Versicolour\n",
    "* Iris Virginica\n",
    @@ -105,7 +105,7 @@
    " 4. petal width in cm\n",
    "\n",
    "<img src=\"petal_sepal.png\" style=\"width: 200px;\"/>\n",
    "<img src=\"data_table.png\"/>"
    "<img src=\"https://storage.googleapis.com/image-uploader/AIA_images/data_table.png\" style=\"width: 450px\"/>"
    ]
    },
    {
  4. @yufengg yufengg created this gist Sep 6, 2017.
    543 changes: 543 additions & 0 deletions AIA003_tf_estimators.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,543 @@
    {
    "cells": [
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "scrolled": true
    },
    "outputs": [],
    "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "print(tf.__version__)\n",
    "\n",
    "from tensorflow.contrib.learn.python.learn.datasets import base\n",
    "\n",
    "# Data files\n",
    "IRIS_TRAINING = \"iris_training.csv\"\n",
    "IRIS_TEST = \"iris_test.csv\"\n",
    "\n",
    "# Load datasets.\n",
    "training_set = base.load_csv_with_header(filename=IRIS_TRAINING,\n",
    " features_dtype=np.float32,\n",
    " target_dtype=np.int)\n",
    "test_set = base.load_csv_with_header(filename=IRIS_TEST,\n",
    " features_dtype=np.float32,\n",
    " target_dtype=np.int)\n",
    "\n",
    "# Specify that all features have real-value data\n",
    "feature_name = \"flower_features\"\n",
    "feature_columns = [tf.feature_column.numeric_column(feature_name, \n",
    " shape=[4])]\n",
    "classifier = tf.estimator.LinearClassifier(\n",
    " feature_columns=feature_columns,\n",
    " n_classes=3,\n",
    " model_dir=\"/tmp/iris_model\")\n",
    "\n",
    "def input_fn(dataset):\n",
    " def _fn():\n",
    " features = {feature_name: tf.constant(dataset.data)}\n",
    " label = tf.constant(dataset.target)\n",
    " return features, label\n",
    " return _fn\n",
    "\n",
    "# Fit model.\n",
    "classifier.train(input_fn=input_fn(training_set),\n",
    " steps=1000)\n",
    "print('fit done')\n",
    "\n",
    "# Evaluate accuracy.\n",
    "accuracy_score = classifier.evaluate(input_fn=input_fn(test_set), \n",
    " steps=100)[\"accuracy\"]\n",
    "print('\\nAccuracy: {0:f}'.format(accuracy_score))\n",
    "\n"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 1,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "1.3.0\n"
    ]
    }
    ],
    "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "\n",
    "print(tf.__version__)"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "## Data set\n",
    "From https://en.wikipedia.org/wiki/Iris_flower_data_set\n",
    "\n",
    "3 types of Iris Flowers: \n",
    "\n",
    "<img src=\"Iris_setosa.jpg\" style=\"width: 100px; display:inline\"/>\n",
    "<img src=\"Iris_versicolor.jpg\" style=\"width: 150px;display:inline\"/>\n",
    "<img src=\"Iris_virginica.jpg\" style=\"width: 150px;display:inline\"/>\n",
    "* Iris Setosa\n",
    "* Iris Versicolour\n",
    "* Iris Virginica\n",
    "\n",
    "\n"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
    "## Data Columns:\n",
    " 1. sepal length in cm \n",
    " 2. sepal width in cm \n",
    " 3. petal length in cm \n",
    " 4. petal width in cm\n",
    "\n",
    "<img src=\"petal_sepal.png\" style=\"width: 200px;\"/>\n",
    "<img src=\"data_table.png\"/>"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 2,
    "metadata": {
    "scrolled": true
    },
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "[[ 6.4000001 2.79999995 5.5999999 2.20000005]\n",
    " [ 5. 2.29999995 3.29999995 1. ]\n",
    " [ 4.9000001 2.5 4.5 1.70000005]\n",
    " [ 4.9000001 3.0999999 1.5 0.1 ]\n",
    " [ 5.69999981 3.79999995 1.70000005 0.30000001]\n",
    " [ 4.4000001 3.20000005 1.29999995 0.2 ]\n",
    " [ 5.4000001 3.4000001 1.5 0.40000001]\n",
    " [ 6.9000001 3.0999999 5.0999999 2.29999995]\n",
    " [ 6.69999981 3.0999999 4.4000001 1.39999998]\n",
    " [ 5.0999999 3.70000005 1.5 0.40000001]\n",
    " [ 5.19999981 2.70000005 3.9000001 1.39999998]\n",
    " [ 6.9000001 3.0999999 4.9000001 1.5 ]\n",
    " [ 5.80000019 4. 1.20000005 0.2 ]\n",
    " [ 5.4000001 3.9000001 1.70000005 0.40000001]\n",
    " [ 7.69999981 3.79999995 6.69999981 2.20000005]\n",
    " [ 6.30000019 3.29999995 4.69999981 1.60000002]\n",
    " [ 6.80000019 3.20000005 5.9000001 2.29999995]\n",
    " [ 7.5999999 3. 6.5999999 2.0999999 ]\n",
    " [ 6.4000001 3.20000005 5.30000019 2.29999995]\n",
    " [ 5.69999981 4.4000001 1.5 0.40000001]\n",
    " [ 6.69999981 3.29999995 5.69999981 2.0999999 ]\n",
    " [ 6.4000001 2.79999995 5.5999999 2.0999999 ]\n",
    " [ 5.4000001 3.9000001 1.29999995 0.40000001]\n",
    " [ 6.0999999 2.5999999 5.5999999 1.39999998]\n",
    " [ 7.19999981 3. 5.80000019 1.60000002]\n",
    " [ 5.19999981 3.5 1.5 0.2 ]\n",
    " [ 5.80000019 2.5999999 4. 1.20000005]\n",
    " [ 5.9000001 3. 5.0999999 1.79999995]\n",
    " [ 5.4000001 3. 4.5 1.5 ]\n",
    " [ 6.69999981 3. 5. 1.70000005]\n",
    " [ 6.30000019 2.29999995 4.4000001 1.29999995]\n",
    " [ 5.0999999 2.5 3. 1.10000002]\n",
    " [ 6.4000001 3.20000005 4.5 1.5 ]\n",
    " [ 6.80000019 3. 5.5 2.0999999 ]\n",
    " [ 6.19999981 2.79999995 4.80000019 1.79999995]\n",
    " [ 6.9000001 3.20000005 5.69999981 2.29999995]\n",
    " [ 6.5 3.20000005 5.0999999 2. ]\n",
    " [ 5.80000019 2.79999995 5.0999999 2.4000001 ]\n",
    " [ 5.0999999 3.79999995 1.5 0.30000001]\n",
    " [ 4.80000019 3. 1.39999998 0.30000001]\n",
    " [ 7.9000001 3.79999995 6.4000001 2. ]\n",
    " [ 5.80000019 2.70000005 5.0999999 1.89999998]\n",
    " [ 6.69999981 3. 5.19999981 2.29999995]\n",
    " [ 5.0999999 3.79999995 1.89999998 0.40000001]\n",
    " [ 4.69999981 3.20000005 1.60000002 0.2 ]\n",
    " [ 6. 2.20000005 5. 1.5 ]\n",
    " [ 4.80000019 3.4000001 1.60000002 0.2 ]\n",
    " [ 7.69999981 2.5999999 6.9000001 2.29999995]\n",
    " [ 4.5999999 3.5999999 1. 0.2 ]\n",
    " [ 7.19999981 3.20000005 6. 1.79999995]\n",
    " [ 5. 3.29999995 1.39999998 0.2 ]\n",
    " [ 6.5999999 3. 4.4000001 1.39999998]\n",
    " [ 6.0999999 2.79999995 4. 1.29999995]\n",
    " [ 5. 3.20000005 1.20000005 0.2 ]\n",
    " [ 7. 3.20000005 4.69999981 1.39999998]\n",
    " [ 6. 3. 4.80000019 1.79999995]\n",
    " [ 7.4000001 2.79999995 6.0999999 1.89999998]\n",
    " [ 5.80000019 2.70000005 5.0999999 1.89999998]\n",
    " [ 6.19999981 3.4000001 5.4000001 2.29999995]\n",
    " [ 5. 2. 3.5 1. ]\n",
    " [ 5.5999999 2.5 3.9000001 1.10000002]\n",
    " [ 6.69999981 3.0999999 5.5999999 2.4000001 ]\n",
    " [ 6.30000019 2.5 5. 1.89999998]\n",
    " [ 6.4000001 3.0999999 5.5 1.79999995]\n",
    " [ 6.19999981 2.20000005 4.5 1.5 ]\n",
    " [ 7.30000019 2.9000001 6.30000019 1.79999995]\n",
    " [ 4.4000001 3. 1.29999995 0.2 ]\n",
    " [ 7.19999981 3.5999999 6.0999999 2.5 ]\n",
    " [ 6.5 3. 5.5 1.79999995]\n",
    " [ 5. 3.4000001 1.5 0.2 ]\n",
    " [ 4.69999981 3.20000005 1.29999995 0.2 ]\n",
    " [ 6.5999999 2.9000001 4.5999999 1.29999995]\n",
    " [ 5.5 3.5 1.29999995 0.2 ]\n",
    " [ 7.69999981 3. 6.0999999 2.29999995]\n",
    " [ 6.0999999 3. 4.9000001 1.79999995]\n",
    " [ 4.9000001 3.0999999 1.5 0.1 ]\n",
    " [ 5.5 2.4000001 3.79999995 1.10000002]\n",
    " [ 5.69999981 2.9000001 4.19999981 1.29999995]\n",
    " [ 6. 2.9000001 4.5 1.5 ]\n",
    " [ 6.4000001 2.70000005 5.30000019 1.89999998]\n",
    " [ 5.4000001 3.70000005 1.5 0.2 ]\n",
    " [ 6.0999999 2.9000001 4.69999981 1.39999998]\n",
    " [ 6.5 2.79999995 4.5999999 1.5 ]\n",
    " [ 5.5999999 2.70000005 4.19999981 1.29999995]\n",
    " [ 6.30000019 3.4000001 5.5999999 2.4000001 ]\n",
    " [ 4.9000001 3.0999999 1.5 0.1 ]\n",
    " [ 6.80000019 2.79999995 4.80000019 1.39999998]\n",
    " [ 5.69999981 2.79999995 4.5 1.29999995]\n",
    " [ 6. 2.70000005 5.0999999 1.60000002]\n",
    " [ 5. 3.5 1.29999995 0.30000001]\n",
    " [ 6.5 3. 5.19999981 2. ]\n",
    " [ 6.0999999 2.79999995 4.69999981 1.20000005]\n",
    " [ 5.0999999 3.5 1.39999998 0.30000001]\n",
    " [ 4.5999999 3.0999999 1.5 0.2 ]\n",
    " [ 6.5 3. 5.80000019 2.20000005]\n",
    " [ 4.5999999 3.4000001 1.39999998 0.30000001]\n",
    " [ 4.5999999 3.20000005 1.39999998 0.2 ]\n",
    " [ 7.69999981 2.79999995 6.69999981 2. ]\n",
    " [ 5.9000001 3.20000005 4.80000019 1.79999995]\n",
    " [ 5.0999999 3.79999995 1.60000002 0.2 ]\n",
    " [ 4.9000001 3. 1.39999998 0.2 ]\n",
    " [ 4.9000001 2.4000001 3.29999995 1. ]\n",
    " [ 4.5 2.29999995 1.29999995 0.30000001]\n",
    " [ 5.80000019 2.70000005 4.0999999 1. ]\n",
    " [ 5. 3.4000001 1.60000002 0.40000001]\n",
    " [ 5.19999981 3.4000001 1.39999998 0.2 ]\n",
    " [ 5.30000019 3.70000005 1.5 0.2 ]\n",
    " [ 5. 3.5999999 1.39999998 0.2 ]\n",
    " [ 5.5999999 2.9000001 3.5999999 1.29999995]\n",
    " [ 4.80000019 3.0999999 1.60000002 0.2 ]\n",
    " [ 6.30000019 2.70000005 4.9000001 1.79999995]\n",
    " [ 5.69999981 2.79999995 4.0999999 1.29999995]\n",
    " [ 5. 3. 1.60000002 0.2 ]\n",
    " [ 6.30000019 3.29999995 6. 2.5 ]\n",
    " [ 5. 3.5 1.60000002 0.60000002]\n",
    " [ 5.5 2.5999999 4.4000001 1.20000005]\n",
    " [ 5.69999981 3. 4.19999981 1.20000005]\n",
    " [ 4.4000001 2.9000001 1.39999998 0.2 ]\n",
    " [ 4.80000019 3. 1.39999998 0.1 ]\n",
    " [ 5.5 2.4000001 3.70000005 1. ]]\n",
    "[2 1 2 0 0 0 0 2 1 0 1 1 0 0 2 1 2 2 2 0 2 2 0 2 2 0 1 2 1 1 1 1 1 2 2 2 2\n",
    " 2 0 0 2 2 2 0 0 2 0 2 0 2 0 1 1 0 1 2 2 2 2 1 1 2 2 2 1 2 0 2 2 0 0 1 0 2\n",
    " 2 0 1 1 1 2 0 1 1 1 2 0 1 1 1 0 2 1 0 0 2 0 0 2 1 0 0 1 0 1 0 0 0 0 1 0 2\n",
    " 1 0 2 0 1 1 0 0 1]\n"
    ]
    }
    ],
    "source": [
    "from tensorflow.contrib.learn.python.learn.datasets import base\n",
    "\n",
    "# Data files\n",
    "IRIS_TRAINING = \"iris_training.csv\"\n",
    "IRIS_TEST = \"iris_test.csv\"\n",
    "\n",
    "# Load datasets.\n",
    "training_set = base.load_csv_with_header(filename=IRIS_TRAINING,\n",
    " features_dtype=np.float32,\n",
    " target_dtype=np.int)\n",
    "test_set = base.load_csv_with_header(filename=IRIS_TEST,\n",
    " features_dtype=np.float32,\n",
    " target_dtype=np.int)\n",
    "\n",
    "print(training_set.data)\n",
    "\n",
    "print(training_set.target)"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 3,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "INFO:tensorflow:Using default config.\n",
    "INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_tf_random_seed': 1, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_save_checkpoints_steps': None, '_model_dir': '/tmp/iris_model', '_save_summary_steps': 100}\n"
    ]
    }
    ],
    "source": [
    "# Specify that all features have real-value data\n",
    "feature_name = \"flower_features\"\n",
    "feature_columns = [tf.feature_column.numeric_column(feature_name, \n",
    " shape=[4])]\n",
    "\n",
    "classifier = tf.estimator.LinearClassifier(\n",
    " feature_columns=feature_columns,\n",
    " n_classes=3,\n",
    " model_dir=\"/tmp/iris_model\")\n"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 4,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "({'flower_features': <tf.Tensor 'Const:0' shape=(120, 4) dtype=float32>}, <tf.Tensor 'Const_1:0' shape=(120,) dtype=int64>)\n"
    ]
    }
    ],
    "source": [
    "def input_fn(dataset):\n",
    " def _fn():\n",
    " features = {feature_name: tf.constant(dataset.data)}\n",
    " label = tf.constant(dataset.target)\n",
    " return features, label\n",
    " return _fn\n",
    "\n",
    "print(input_fn(training_set)())\n",
    "\n",
    "# raw data -> input function -> feature columns -> model"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 5,
    "metadata": {
    "scrolled": true
    },
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "INFO:tensorflow:Create CheckpointSaverHook.\n",
    "INFO:tensorflow:Saving checkpoints for 1 into /tmp/iris_model/model.ckpt.\n",
    "INFO:tensorflow:loss = 131.833, step = 1\n",
    "INFO:tensorflow:global_step/sec: 832.73\n",
    "INFO:tensorflow:loss = 37.1391, step = 101 (0.121 sec)\n",
    "INFO:tensorflow:global_step/sec: 656.81\n",
    "INFO:tensorflow:loss = 27.8594, step = 201 (0.154 sec)\n",
    "INFO:tensorflow:global_step/sec: 837.738\n",
    "INFO:tensorflow:loss = 23.0449, step = 301 (0.118 sec)\n",
    "INFO:tensorflow:global_step/sec: 832.487\n",
    "INFO:tensorflow:loss = 20.058, step = 401 (0.121 sec)\n",
    "INFO:tensorflow:global_step/sec: 705.118\n",
    "INFO:tensorflow:loss = 18.0083, step = 501 (0.142 sec)\n",
    "INFO:tensorflow:global_step/sec: 650.441\n",
    "INFO:tensorflow:loss = 16.505, step = 601 (0.153 sec)\n",
    "INFO:tensorflow:global_step/sec: 619.556\n",
    "INFO:tensorflow:loss = 15.3496, step = 701 (0.162 sec)\n",
    "INFO:tensorflow:global_step/sec: 657.613\n",
    "INFO:tensorflow:loss = 14.43, step = 801 (0.152 sec)\n",
    "INFO:tensorflow:global_step/sec: 763.202\n",
    "INFO:tensorflow:loss = 13.6782, step = 901 (0.131 sec)\n",
    "INFO:tensorflow:Saving checkpoints for 1000 into /tmp/iris_model/model.ckpt.\n",
    "INFO:tensorflow:Loss for final step: 13.0562.\n",
    "fit done\n"
    ]
    }
    ],
    "source": [
    "# Fit model.\n",
    "classifier.train(input_fn=input_fn(training_set),\n",
    " steps=1000)\n",
    "print('fit done')\n"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": 6,
    "metadata": {},
    "outputs": [
    {
    "name": "stdout",
    "output_type": "stream",
    "text": [
    "INFO:tensorflow:Starting evaluation at 2017-09-05-04:32:30\n",
    "INFO:tensorflow:Restoring parameters from /tmp/iris_model/model.ckpt-1000\n",
    "INFO:tensorflow:Evaluation [1/100]\n",
    "INFO:tensorflow:Evaluation [2/100]\n",
    "INFO:tensorflow:Evaluation [3/100]\n",
    "INFO:tensorflow:Evaluation [4/100]\n",
    "INFO:tensorflow:Evaluation [5/100]\n",
    "INFO:tensorflow:Evaluation [6/100]\n",
    "INFO:tensorflow:Evaluation [7/100]\n",
    "INFO:tensorflow:Evaluation [8/100]\n",
    "INFO:tensorflow:Evaluation [9/100]\n",
    "INFO:tensorflow:Evaluation [10/100]\n",
    "INFO:tensorflow:Evaluation [11/100]\n",
    "INFO:tensorflow:Evaluation [12/100]\n",
    "INFO:tensorflow:Evaluation [13/100]\n",
    "INFO:tensorflow:Evaluation [14/100]\n",
    "INFO:tensorflow:Evaluation [15/100]\n",
    "INFO:tensorflow:Evaluation [16/100]\n",
    "INFO:tensorflow:Evaluation [17/100]\n",
    "INFO:tensorflow:Evaluation [18/100]\n",
    "INFO:tensorflow:Evaluation [19/100]\n",
    "INFO:tensorflow:Evaluation [20/100]\n",
    "INFO:tensorflow:Evaluation [21/100]\n",
    "INFO:tensorflow:Evaluation [22/100]\n",
    "INFO:tensorflow:Evaluation [23/100]\n",
    "INFO:tensorflow:Evaluation [24/100]\n",
    "INFO:tensorflow:Evaluation [25/100]\n",
    "INFO:tensorflow:Evaluation [26/100]\n",
    "INFO:tensorflow:Evaluation [27/100]\n",
    "INFO:tensorflow:Evaluation [28/100]\n",
    "INFO:tensorflow:Evaluation [29/100]\n",
    "INFO:tensorflow:Evaluation [30/100]\n",
    "INFO:tensorflow:Evaluation [31/100]\n",
    "INFO:tensorflow:Evaluation [32/100]\n",
    "INFO:tensorflow:Evaluation [33/100]\n",
    "INFO:tensorflow:Evaluation [34/100]\n",
    "INFO:tensorflow:Evaluation [35/100]\n",
    "INFO:tensorflow:Evaluation [36/100]\n",
    "INFO:tensorflow:Evaluation [37/100]\n",
    "INFO:tensorflow:Evaluation [38/100]\n",
    "INFO:tensorflow:Evaluation [39/100]\n",
    "INFO:tensorflow:Evaluation [40/100]\n",
    "INFO:tensorflow:Evaluation [41/100]\n",
    "INFO:tensorflow:Evaluation [42/100]\n",
    "INFO:tensorflow:Evaluation [43/100]\n",
    "INFO:tensorflow:Evaluation [44/100]\n",
    "INFO:tensorflow:Evaluation [45/100]\n",
    "INFO:tensorflow:Evaluation [46/100]\n",
    "INFO:tensorflow:Evaluation [47/100]\n",
    "INFO:tensorflow:Evaluation [48/100]\n",
    "INFO:tensorflow:Evaluation [49/100]\n",
    "INFO:tensorflow:Evaluation [50/100]\n",
    "INFO:tensorflow:Evaluation [51/100]\n",
    "INFO:tensorflow:Evaluation [52/100]\n",
    "INFO:tensorflow:Evaluation [53/100]\n",
    "INFO:tensorflow:Evaluation [54/100]\n",
    "INFO:tensorflow:Evaluation [55/100]\n",
    "INFO:tensorflow:Evaluation [56/100]\n",
    "INFO:tensorflow:Evaluation [57/100]\n",
    "INFO:tensorflow:Evaluation [58/100]\n",
    "INFO:tensorflow:Evaluation [59/100]\n",
    "INFO:tensorflow:Evaluation [60/100]\n",
    "INFO:tensorflow:Evaluation [61/100]\n",
    "INFO:tensorflow:Evaluation [62/100]\n",
    "INFO:tensorflow:Evaluation [63/100]\n",
    "INFO:tensorflow:Evaluation [64/100]\n",
    "INFO:tensorflow:Evaluation [65/100]\n",
    "INFO:tensorflow:Evaluation [66/100]\n",
    "INFO:tensorflow:Evaluation [67/100]\n",
    "INFO:tensorflow:Evaluation [68/100]\n",
    "INFO:tensorflow:Evaluation [69/100]\n",
    "INFO:tensorflow:Evaluation [70/100]\n",
    "INFO:tensorflow:Evaluation [71/100]\n",
    "INFO:tensorflow:Evaluation [72/100]\n",
    "INFO:tensorflow:Evaluation [73/100]\n",
    "INFO:tensorflow:Evaluation [74/100]\n",
    "INFO:tensorflow:Evaluation [75/100]\n",
    "INFO:tensorflow:Evaluation [76/100]\n",
    "INFO:tensorflow:Evaluation [77/100]\n",
    "INFO:tensorflow:Evaluation [78/100]\n",
    "INFO:tensorflow:Evaluation [79/100]\n",
    "INFO:tensorflow:Evaluation [80/100]\n",
    "INFO:tensorflow:Evaluation [81/100]\n",
    "INFO:tensorflow:Evaluation [82/100]\n",
    "INFO:tensorflow:Evaluation [83/100]\n",
    "INFO:tensorflow:Evaluation [84/100]\n",
    "INFO:tensorflow:Evaluation [85/100]\n",
    "INFO:tensorflow:Evaluation [86/100]\n",
    "INFO:tensorflow:Evaluation [87/100]\n",
    "INFO:tensorflow:Evaluation [88/100]\n",
    "INFO:tensorflow:Evaluation [89/100]\n",
    "INFO:tensorflow:Evaluation [90/100]\n",
    "INFO:tensorflow:Evaluation [91/100]\n",
    "INFO:tensorflow:Evaluation [92/100]\n",
    "INFO:tensorflow:Evaluation [93/100]\n",
    "INFO:tensorflow:Evaluation [94/100]\n",
    "INFO:tensorflow:Evaluation [95/100]\n",
    "INFO:tensorflow:Evaluation [96/100]\n",
    "INFO:tensorflow:Evaluation [97/100]\n",
    "INFO:tensorflow:Evaluation [98/100]\n",
    "INFO:tensorflow:Evaluation [99/100]\n",
    "INFO:tensorflow:Evaluation [100/100]\n",
    "INFO:tensorflow:Finished evaluation at 2017-09-05-04:32:31\n",
    "INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.966667, average_loss = 0.120964, global_step = 1000, loss = 3.62893\n",
    "\n",
    "Accuracy: 0.966667\n"
    ]
    }
    ],
    "source": [
    "# Evaluate accuracy.\n",
    "accuracy_score = classifier.evaluate(input_fn=input_fn(test_set), \n",
    " steps=100)[\"accuracy\"]\n",
    "print('\\nAccuracy: {0:f}'.format(accuracy_score))"
    ]
    },
    {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
    "collapsed": true
    },
    "outputs": [],
    "source": [
    "# Estimators review\n",
    "\n",
    "# Load datasets.\n",
    "training_data = load_csv_with_header()\n",
    "\n",
    "# define input functions\n",
    "def input_fn(dataset)\n",
    " \n",
    "# Define feature columns\n",
    "feature_columns = [tf.feature_column.numeric_column(feature_name, \n",
    " shape=[4])]\n",
    "\n",
    "# Create model\n",
    "classifier = tf.estimator.LinearClassifier()\n",
    "\n",
    "# Train\n",
    "classifier.train()\n",
    "\n",
    "# Evaluate\n",
    "classifier.evaluate()"
    ]
    }
    ],
    "metadata": {
    "kernelspec": {
    "display_name": "Python 2",
    "language": "python",
    "name": "python2"
    },
    "language_info": {
    "codemirror_mode": {
    "name": "ipython",
    "version": 2
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython2",
    "version": "2.7.13"
    }
    },
    "nbformat": 4,
    "nbformat_minor": 2
    }