Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save kevinyang007/128d4a5978356fd0f8bde7452cf15216 to your computer and use it in GitHub Desktop.

Select an option

Save kevinyang007/128d4a5978356fd0f8bde7452cf15216 to your computer and use it in GitHub Desktop.

Revisions

  1. @parulnith parulnith revised this gist Dec 14, 2018. No changes.
  2. @parulnith parulnith renamed this gist Dec 13, 2018. 1 changed file with 0 additions and 0 deletions.
    File renamed without changes.
  3. @parulnith parulnith revised this gist Dec 13, 2018. 1 changed file with 12 additions and 1 deletion.
    13 changes: 12 additions & 1 deletion untitled9.ipynb
    Original file line number Diff line number Diff line change
    @@ -5,7 +5,8 @@
    "colab": {
    "name": "Untitled9.ipynb",
    "version": "0.3.2",
    "provenance": []
    "provenance": [],
    "include_colab_link": true
    },
    "kernelspec": {
    "name": "python3",
    @@ -14,6 +15,16 @@
    "accelerator": "TPU"
    },
    "cells": [
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "view-in-github",
    "colab_type": "text"
    },
    "source": [
    "<a href=\"https://colab.research.google.com/gist/parulnith/7f8c174e6ac099e86f0495d3d9a4c01e/untitled9.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
    ]
    },
    {
    "metadata": {
    "id": "cNnM2w-HCeb1",
  4. @parulnith parulnith created this gist Dec 13, 2018.
    1,200 changes: 1,200 additions & 0 deletions untitled9.ipynb
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,1200 @@
    {
    "nbformat": 4,
    "nbformat_minor": 0,
    "metadata": {
    "colab": {
    "name": "Untitled9.ipynb",
    "version": "0.3.2",
    "provenance": []
    },
    "kernelspec": {
    "name": "python3",
    "display_name": "Python 3"
    },
    "accelerator": "TPU"
    },
    "cells": [
    {
    "metadata": {
    "id": "cNnM2w-HCeb1",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    "# Music genre classification notebook"
    ]
    },
    {
    "metadata": {
    "id": "2l3sppZMCydR",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    "## Importing Libraries"
    ]
    },
    {
    "metadata": {
    "id": "Gt3fyg6dCNvX",
    "colab_type": "code",
    "colab": {}
    },
    "cell_type": "code",
    "source": [
    "# feature extractoring and preprocessing data\n",
    "import librosa\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import os\n",
    "from PIL import Image\n",
    "import pathlib\n",
    "import csv\n",
    "\n",
    "# Preprocessing\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import LabelEncoder, StandardScaler\n",
    "\n",
    "#Keras\n",
    "import keras\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')"
    ],
    "execution_count": 0,
    "outputs": []
    },
    {
    "metadata": {
    "id": "DPe_ebYuDqr5",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    "## Extracting music and features\n",
    "\n",
    "### Dataset\n",
    "\n",
    "We use [GTZAN genre collection](http://marsyasweb.appspot.com/download/data_sets/) dataset for classification. \n",
    "<br>\n",
    "<br>\n",
    "The dataset consists of 10 genres i.e\n",
    " * Blues\n",
    " * Classical\n",
    " * Country\n",
    " * Disco\n",
    " * Hiphop\n",
    " * Jazz\n",
    " * Metal\n",
    " * Pop\n",
    " * Reggae\n",
    " * Rock\n",
    " \n",
    "Each genre contains 100 songs. Total dataset: 1000 songs"
    ]
    },
    {
    "metadata": {
    "id": "neqMS0VoDpN5",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    ""
    ]
    },
    {
    "metadata": {
    "id": "AfBSVfRCD3PE",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    "## Extracting the Spectrogram for every Audio"
    ]
    },
    {
    "metadata": {
    "id": "BHh3pTEVDdrT",
    "colab_type": "code",
    "colab": {}
    },
    "cell_type": "code",
    "source": [
    "cmap = plt.get_cmap('inferno')\n",
    "\n",
    "plt.figure(figsize=(10,10))\n",
    "genres = 'blues classical country disco hiphop jazz metal pop reggae rock'.split()\n",
    "for g in genres:\n",
    " pathlib.Path(f'img_data/{g}').mkdir(parents=True, exist_ok=True) \n",
    " for filename in os.listdir(f'./MIR/genres/{g}'):\n",
    " songname = f'./MIR/genres/{g}/{filename}'\n",
    " y, sr = librosa.load(songname, mono=True, duration=5)\n",
    " plt.specgram(y, NFFT=2048, Fs=2, Fc=0, noverlap=128, cmap=cmap, sides='default', mode='default', scale='dB');\n",
    " plt.axis('off');\n",
    " plt.savefig(f'img_data/{g}/{filename[:-3].replace(\".\", \"\")}.png')\n",
    " plt.clf()\n",
    " "
    ],
    "execution_count": 0,
    "outputs": []
    },
    {
    "metadata": {
    "id": "SszVgjYnFNX9",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    "All the audio files get converted into their respective spectrograms .WE can noe easily extract features from them."
    ]
    },
    {
    "metadata": {
    "id": "3Nw9HpSdFRsW",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    ""
    ]
    },
    {
    "metadata": {
    "id": "piwUwgP5Eef9",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    "## Extracting features from Spectrogram\n",
    "\n",
    "\n",
    "We will extract\n",
    "\n",
    "* Mel-frequency cepstral coefficients (MFCC)(20 in number)\n",
    "* Spectral Centroid,\n",
    "* Zero Crossing Rate\n",
    "* Chroma Frequencies\n",
    "* Spectral Roll-off."
    ]
    },
    {
    "metadata": {
    "id": "__g8tX8pDeIL",
    "colab_type": "code",
    "colab": {}
    },
    "cell_type": "code",
    "source": [
    "header = 'filename chroma_stft rmse spectral_centroid spectral_bandwidth rolloff zero_crossing_rate'\n",
    "for i in range(1, 21):\n",
    " header += f' mfcc{i}'\n",
    "header += ' label'\n",
    "header = header.split()"
    ],
    "execution_count": 0,
    "outputs": []
    },
    {
    "metadata": {
    "id": "TBlT448pEqR9",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    "## Writing data to csv file\n",
    "\n",
    "We write the data to a csv file "
    ]
    },
    {
    "metadata": {
    "id": "ZsSQmB0PE3Iu",
    "colab_type": "code",
    "colab": {}
    },
    "cell_type": "code",
    "source": [
    "file = open('data.csv', 'w', newline='')\n",
    "with file:\n",
    " writer = csv.writer(file)\n",
    " writer.writerow(header)\n",
    "genres = 'blues classical country disco hiphop jazz metal pop reggae rock'.split()\n",
    "for g in genres:\n",
    " for filename in os.listdir(f'./MIR/genres/{g}'):\n",
    " songname = f'./MIR/genres/{g}/{filename}'\n",
    " y, sr = librosa.load(songname, mono=True, duration=30)\n",
    " chroma_stft = librosa.feature.chroma_stft(y=y, sr=sr)\n",
    " spec_cent = librosa.feature.spectral_centroid(y=y, sr=sr)\n",
    " spec_bw = librosa.feature.spectral_bandwidth(y=y, sr=sr)\n",
    " rolloff = librosa.feature.spectral_rolloff(y=y, sr=sr)\n",
    " zcr = librosa.feature.zero_crossing_rate(y)\n",
    " mfcc = librosa.feature.mfcc(y=y, sr=sr)\n",
    " to_append = f'{filename} {np.mean(chroma_stft)} {np.mean(rmse)} {np.mean(spec_cent)} {np.mean(spec_bw)} {np.mean(rolloff)} {np.mean(zcr)}' \n",
    " for e in mfcc:\n",
    " to_append += f' {np.mean(e)}'\n",
    " to_append += f' {g}'\n",
    " file = open('data.csv', 'a', newline='')\n",
    " with file:\n",
    " writer = csv.writer(file)\n",
    " writer.writerow(to_append.split())"
    ],
    "execution_count": 0,
    "outputs": []
    },
    {
    "metadata": {
    "id": "0yfdo1cj6V7d",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    "The data has been extracted into a [data.csv](https://github.com/parulnith/Music-Genre-Classification-with-Python/blob/master/data.csv) file."
    ]
    },
    {
    "metadata": {
    "id": "fgeCZSKQEp1A",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    "# Analysing the Data in Pandas"
    ]
    },
    {
    "metadata": {
    "id": "Kr5_EdpD9dyh",
    "colab_type": "code",
    "colab": {
    "base_uri": "https://localhost:8080/",
    "height": 253
    },
    "outputId": "81fd4a29-93fa-44f8-bf90-2f99981f761a"
    },
    "cell_type": "code",
    "source": [
    "data = pd.read_csv('data.csv')\n",
    "data.head()"
    ],
    "execution_count": 6,
    "outputs": [
    {
    "output_type": "execute_result",
    "data": {
    "text/html": [
    "<div>\n",
    "<style scoped>\n",
    " .dataframe tbody tr th:only-of-type {\n",
    " vertical-align: middle;\n",
    " }\n",
    "\n",
    " .dataframe tbody tr th {\n",
    " vertical-align: top;\n",
    " }\n",
    "\n",
    " .dataframe thead th {\n",
    " text-align: right;\n",
    " }\n",
    "</style>\n",
    "<table border=\"1\" class=\"dataframe\">\n",
    " <thead>\n",
    " <tr style=\"text-align: right;\">\n",
    " <th></th>\n",
    " <th>filename</th>\n",
    " <th>chroma_stft</th>\n",
    " <th>rmse</th>\n",
    " <th>spectral_centroid</th>\n",
    " <th>spectral_bandwidth</th>\n",
    " <th>rolloff</th>\n",
    " <th>zero_crossing_rate</th>\n",
    " <th>mfcc1</th>\n",
    " <th>mfcc2</th>\n",
    " <th>mfcc3</th>\n",
    " <th>...</th>\n",
    " <th>mfcc12</th>\n",
    " <th>mfcc13</th>\n",
    " <th>mfcc14</th>\n",
    " <th>mfcc15</th>\n",
    " <th>mfcc16</th>\n",
    " <th>mfcc17</th>\n",
    " <th>mfcc18</th>\n",
    " <th>mfcc19</th>\n",
    " <th>mfcc20</th>\n",
    " <th>label</th>\n",
    " </tr>\n",
    " </thead>\n",
    " <tbody>\n",
    " <tr>\n",
    " <th>0</th>\n",
    " <td>blues.00081.au</td>\n",
    " <td>0.380260</td>\n",
    " <td>0.248262</td>\n",
    " <td>2116.942959</td>\n",
    " <td>1956.611056</td>\n",
    " <td>4196.107960</td>\n",
    " <td>0.127272</td>\n",
    " <td>-26.929785</td>\n",
    " <td>107.334008</td>\n",
    " <td>-46.809993</td>\n",
    " <td>...</td>\n",
    " <td>14.336612</td>\n",
    " <td>-13.821769</td>\n",
    " <td>7.562789</td>\n",
    " <td>-6.181372</td>\n",
    " <td>0.330165</td>\n",
    " <td>-6.829571</td>\n",
    " <td>0.965922</td>\n",
    " <td>-7.570825</td>\n",
    " <td>2.918987</td>\n",
    " <td>blues</td>\n",
    " </tr>\n",
    " <tr>\n",
    " <th>1</th>\n",
    " <td>blues.00022.au</td>\n",
    " <td>0.306451</td>\n",
    " <td>0.113475</td>\n",
    " <td>1156.070496</td>\n",
    " <td>1497.668176</td>\n",
    " <td>2170.053545</td>\n",
    " <td>0.058613</td>\n",
    " <td>-233.860772</td>\n",
    " <td>136.170239</td>\n",
    " <td>3.289490</td>\n",
    " <td>...</td>\n",
    " <td>-2.250578</td>\n",
    " <td>3.959198</td>\n",
    " <td>5.322555</td>\n",
    " <td>0.812028</td>\n",
    " <td>-1.107202</td>\n",
    " <td>-4.556555</td>\n",
    " <td>-2.436490</td>\n",
    " <td>3.316913</td>\n",
    " <td>-0.608485</td>\n",
    " <td>blues</td>\n",
    " </tr>\n",
    " <tr>\n",
    " <th>2</th>\n",
    " <td>blues.00031.au</td>\n",
    " <td>0.253487</td>\n",
    " <td>0.151571</td>\n",
    " <td>1331.073970</td>\n",
    " <td>1973.643437</td>\n",
    " <td>2900.174130</td>\n",
    " <td>0.042967</td>\n",
    " <td>-221.802549</td>\n",
    " <td>110.843071</td>\n",
    " <td>18.620984</td>\n",
    " <td>...</td>\n",
    " <td>-13.037723</td>\n",
    " <td>-12.652228</td>\n",
    " <td>-1.821905</td>\n",
    " <td>-7.260097</td>\n",
    " <td>-6.660252</td>\n",
    " <td>-14.682694</td>\n",
    " <td>-11.719264</td>\n",
    " <td>-11.025216</td>\n",
    " <td>-13.387260</td>\n",
    " <td>blues</td>\n",
    " </tr>\n",
    " <tr>\n",
    " <th>3</th>\n",
    " <td>blues.00012.au</td>\n",
    " <td>0.269320</td>\n",
    " <td>0.119072</td>\n",
    " <td>1361.045467</td>\n",
    " <td>1567.804596</td>\n",
    " <td>2739.625101</td>\n",
    " <td>0.069124</td>\n",
    " <td>-207.208080</td>\n",
    " <td>132.799175</td>\n",
    " <td>-15.438986</td>\n",
    " <td>...</td>\n",
    " <td>-0.613248</td>\n",
    " <td>0.384877</td>\n",
    " <td>2.605128</td>\n",
    " <td>-5.188924</td>\n",
    " <td>-9.527455</td>\n",
    " <td>-9.244394</td>\n",
    " <td>-2.848274</td>\n",
    " <td>-1.418707</td>\n",
    " <td>-5.932607</td>\n",
    " <td>blues</td>\n",
    " </tr>\n",
    " <tr>\n",
    " <th>4</th>\n",
    " <td>blues.00056.au</td>\n",
    " <td>0.391059</td>\n",
    " <td>0.137728</td>\n",
    " <td>1811.076084</td>\n",
    " <td>2052.332563</td>\n",
    " <td>3927.809582</td>\n",
    " <td>0.075480</td>\n",
    " <td>-145.434568</td>\n",
    " <td>102.829023</td>\n",
    " <td>-12.517677</td>\n",
    " <td>...</td>\n",
    " <td>7.457218</td>\n",
    " <td>-10.470444</td>\n",
    " <td>-2.360483</td>\n",
    " <td>-6.783624</td>\n",
    " <td>2.671134</td>\n",
    " <td>-4.760879</td>\n",
    " <td>-0.949005</td>\n",
    " <td>0.024832</td>\n",
    " <td>-2.005315</td>\n",
    " <td>blues</td>\n",
    " </tr>\n",
    " </tbody>\n",
    "</table>\n",
    "<p>5 rows × 28 columns</p>\n",
    "</div>"
    ],
    "text/plain": [
    " filename chroma_stft rmse spectral_centroid \\\n",
    "0 blues.00081.au 0.380260 0.248262 2116.942959 \n",
    "1 blues.00022.au 0.306451 0.113475 1156.070496 \n",
    "2 blues.00031.au 0.253487 0.151571 1331.073970 \n",
    "3 blues.00012.au 0.269320 0.119072 1361.045467 \n",
    "4 blues.00056.au 0.391059 0.137728 1811.076084 \n",
    "\n",
    " spectral_bandwidth rolloff zero_crossing_rate mfcc1 \\\n",
    "0 1956.611056 4196.107960 0.127272 -26.929785 \n",
    "1 1497.668176 2170.053545 0.058613 -233.860772 \n",
    "2 1973.643437 2900.174130 0.042967 -221.802549 \n",
    "3 1567.804596 2739.625101 0.069124 -207.208080 \n",
    "4 2052.332563 3927.809582 0.075480 -145.434568 \n",
    "\n",
    " mfcc2 mfcc3 ... mfcc12 mfcc13 mfcc14 mfcc15 \\\n",
    "0 107.334008 -46.809993 ... 14.336612 -13.821769 7.562789 -6.181372 \n",
    "1 136.170239 3.289490 ... -2.250578 3.959198 5.322555 0.812028 \n",
    "2 110.843071 18.620984 ... -13.037723 -12.652228 -1.821905 -7.260097 \n",
    "3 132.799175 -15.438986 ... -0.613248 0.384877 2.605128 -5.188924 \n",
    "4 102.829023 -12.517677 ... 7.457218 -10.470444 -2.360483 -6.783624 \n",
    "\n",
    " mfcc16 mfcc17 mfcc18 mfcc19 mfcc20 label \n",
    "0 0.330165 -6.829571 0.965922 -7.570825 2.918987 blues \n",
    "1 -1.107202 -4.556555 -2.436490 3.316913 -0.608485 blues \n",
    "2 -6.660252 -14.682694 -11.719264 -11.025216 -13.387260 blues \n",
    "3 -9.527455 -9.244394 -2.848274 -1.418707 -5.932607 blues \n",
    "4 2.671134 -4.760879 -0.949005 0.024832 -2.005315 blues \n",
    "\n",
    "[5 rows x 28 columns]"
    ]
    },
    "metadata": {
    "tags": []
    },
    "execution_count": 6
    }
    ]
    },
    {
    "metadata": {
    "id": "iHrDHCaR9gKR",
    "colab_type": "code",
    "colab": {
    "base_uri": "https://localhost:8080/",
    "height": 34
    },
    "outputId": "7d32943a-1ad5-4a59-c13a-beebeb36e4c2"
    },
    "cell_type": "code",
    "source": [
    "data.shape"
    ],
    "execution_count": 7,
    "outputs": [
    {
    "output_type": "execute_result",
    "data": {
    "text/plain": [
    "(1000, 28)"
    ]
    },
    "metadata": {
    "tags": []
    },
    "execution_count": 7
    }
    ]
    },
    {
    "metadata": {
    "id": "veD5BgX49hZa",
    "colab_type": "code",
    "colab": {}
    },
    "cell_type": "code",
    "source": [
    "# Dropping unneccesary columns\n",
    "data = data.drop(['filename'],axis=1)"
    ],
    "execution_count": 0,
    "outputs": []
    },
    {
    "metadata": {
    "id": "Nyr0aAAsGXjZ",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    "## Encoding the Labels"
    ]
    },
    {
    "metadata": {
    "id": "frI5HH4q-1HS",
    "colab_type": "code",
    "colab": {}
    },
    "cell_type": "code",
    "source": [
    "genre_list = data.iloc[:, -1]\n",
    "encoder = LabelEncoder()\n",
    "y = encoder.fit_transform(genre_list)"
    ],
    "execution_count": 0,
    "outputs": []
    },
    {
    "metadata": {
    "id": "Slm8W0-iGVhI",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    ""
    ]
    },
    {
    "metadata": {
    "id": "_2n8a02zGfvP",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    "## Scaling the Feature columns"
    ]
    },
    {
    "metadata": {
    "id": "uqcqn-nyAofk",
    "colab_type": "code",
    "colab": {}
    },
    "cell_type": "code",
    "source": [
    "scaler = StandardScaler()\n",
    "X = scaler.fit_transform(np.array(data.iloc[:, :-1], dtype = float))"
    ],
    "execution_count": 0,
    "outputs": []
    },
    {
    "metadata": {
    "id": "e3VZvbwpGo9R",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    "## Dividing data into training and Testing set"
    ]
    },
    {
    "metadata": {
    "id": "F1GW3VvQA7Rj",
    "colab_type": "code",
    "colab": {}
    },
    "cell_type": "code",
    "source": [
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)"
    ],
    "execution_count": 0,
    "outputs": []
    },
    {
    "metadata": {
    "id": "upuczQ-KBHJ5",
    "colab_type": "code",
    "colab": {
    "base_uri": "https://localhost:8080/",
    "height": 34
    },
    "outputId": "1431a28b-e8b6-4db2-e505-7e149e37c0d7"
    },
    "cell_type": "code",
    "source": [
    "len(y_train)"
    ],
    "execution_count": 12,
    "outputs": [
    {
    "output_type": "execute_result",
    "data": {
    "text/plain": [
    "800"
    ]
    },
    "metadata": {
    "tags": []
    },
    "execution_count": 12
    }
    ]
    },
    {
    "metadata": {
    "id": "LtoE_FqqBzM8",
    "colab_type": "code",
    "colab": {
    "base_uri": "https://localhost:8080/",
    "height": 34
    },
    "outputId": "76555a2b-2030-48e1-b52d-d71b4ebae38e"
    },
    "cell_type": "code",
    "source": [
    "len(y_test)"
    ],
    "execution_count": 13,
    "outputs": [
    {
    "output_type": "execute_result",
    "data": {
    "text/plain": [
    "200"
    ]
    },
    "metadata": {
    "tags": []
    },
    "execution_count": 13
    }
    ]
    },
    {
    "metadata": {
    "id": "ir9XaWgQB0lq",
    "colab_type": "code",
    "colab": {
    "base_uri": "https://localhost:8080/",
    "height": 119
    },
    "outputId": "2ec90814-19d8-4f27-934a-1ce54406d4ea"
    },
    "cell_type": "code",
    "source": [
    "X_train[10]"
    ],
    "execution_count": 14,
    "outputs": [
    {
    "output_type": "execute_result",
    "data": {
    "text/plain": [
    "array([-0.9149113 , 0.18294103, -1.10587131, -1.3875197 , -1.14640873,\n",
    " -0.97232926, -0.29174214, 1.20078936, -0.68458101, -0.55849017,\n",
    " -1.27056582, -0.88176926, -0.74844069, -0.40970382, 0.49685952,\n",
    " -1.12666045, 0.59501437, -0.39783853, 0.29327275, -0.72916871,\n",
    " 0.63015786, -0.91149976, 0.7743942 , -0.64790051, 0.42229852,\n",
    " -1.01449461])"
    ]
    },
    "metadata": {
    "tags": []
    },
    "execution_count": 14
    }
    ]
    },
    {
    "metadata": {
    "id": "Vp2yc5FWG04e",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    "# Classification with Keras\n",
    "\n",
    "## Building our Network"
    ]
    },
    {
    "metadata": {
    "id": "Qj3sc2uFEUMt",
    "colab_type": "code",
    "colab": {}
    },
    "cell_type": "code",
    "source": [
    "from keras import models\n",
    "from keras import layers\n",
    "\n",
    "model = models.Sequential()\n",
    "model.add(layers.Dense(256, activation='relu', input_shape=(X_train.shape[1],)))\n",
    "\n",
    "model.add(layers.Dense(128, activation='relu'))\n",
    "\n",
    "model.add(layers.Dense(64, activation='relu'))\n",
    "\n",
    "model.add(layers.Dense(10, activation='softmax'))"
    ],
    "execution_count": 0,
    "outputs": []
    },
    {
    "metadata": {
    "id": "7yrsmpI6EjJ2",
    "colab_type": "code",
    "colab": {}
    },
    "cell_type": "code",
    "source": [
    "model.compile(optimizer='adam',\n",
    " loss='sparse_categorical_crossentropy',\n",
    " metrics=['accuracy'])"
    ],
    "execution_count": 0,
    "outputs": []
    },
    {
    "metadata": {
    "id": "bP0hVm4aElS7",
    "colab_type": "code",
    "colab": {
    "base_uri": "https://localhost:8080/",
    "height": 697
    },
    "outputId": "aacf234d-d0a9-4de4-91be-5fd45a33b279"
    },
    "cell_type": "code",
    "source": [
    "history = model.fit(X_train,\n",
    " y_train,\n",
    " epochs=20,\n",
    " batch_size=128)\n",
    " "
    ],
    "execution_count": 19,
    "outputs": [
    {
    "output_type": "stream",
    "text": [
    "Epoch 1/20\n",
    "800/800 [==============================] - 1s 811us/step - loss: 2.1289 - acc: 0.2400\n",
    "Epoch 2/20\n",
    "800/800 [==============================] - 0s 39us/step - loss: 1.7940 - acc: 0.4088\n",
    "Epoch 3/20\n",
    "800/800 [==============================] - 0s 37us/step - loss: 1.5437 - acc: 0.4450\n",
    "Epoch 4/20\n",
    "800/800 [==============================] - 0s 38us/step - loss: 1.3584 - acc: 0.5413\n",
    "Epoch 5/20\n",
    "800/800 [==============================] - 0s 38us/step - loss: 1.2220 - acc: 0.5750\n",
    "Epoch 6/20\n",
    "800/800 [==============================] - 0s 41us/step - loss: 1.1187 - acc: 0.6288\n",
    "Epoch 7/20\n",
    "800/800 [==============================] - 0s 37us/step - loss: 1.0326 - acc: 0.6550\n",
    "Epoch 8/20\n",
    "800/800 [==============================] - 0s 44us/step - loss: 0.9631 - acc: 0.6713\n",
    "Epoch 9/20\n",
    "800/800 [==============================] - 0s 47us/step - loss: 0.9143 - acc: 0.6913\n",
    "Epoch 10/20\n",
    "800/800 [==============================] - 0s 37us/step - loss: 0.8630 - acc: 0.7125\n",
    "Epoch 11/20\n",
    "800/800 [==============================] - 0s 36us/step - loss: 0.8095 - acc: 0.7263\n",
    "Epoch 12/20\n",
    "800/800 [==============================] - 0s 37us/step - loss: 0.7728 - acc: 0.7700\n",
    "Epoch 13/20\n",
    "800/800 [==============================] - 0s 36us/step - loss: 0.7433 - acc: 0.7563\n",
    "Epoch 14/20\n",
    "800/800 [==============================] - 0s 45us/step - loss: 0.7066 - acc: 0.7825\n",
    "Epoch 15/20\n",
    "800/800 [==============================] - 0s 43us/step - loss: 0.6718 - acc: 0.7787\n",
    "Epoch 16/20\n",
    "800/800 [==============================] - 0s 36us/step - loss: 0.6601 - acc: 0.7913\n",
    "Epoch 17/20\n",
    "800/800 [==============================] - 0s 36us/step - loss: 0.6242 - acc: 0.7963\n",
    "Epoch 18/20\n",
    "800/800 [==============================] - 0s 44us/step - loss: 0.5994 - acc: 0.8038\n",
    "Epoch 19/20\n",
    "800/800 [==============================] - 0s 42us/step - loss: 0.5715 - acc: 0.8125\n",
    "Epoch 20/20\n",
    "800/800 [==============================] - 0s 39us/step - loss: 0.5437 - acc: 0.8250\n"
    ],
    "name": "stdout"
    }
    ]
    },
    {
    "metadata": {
    "id": "0m1J0_wUFK4C",
    "colab_type": "code",
    "colab": {
    "base_uri": "https://localhost:8080/",
    "height": 34
    },
    "outputId": "ffd3bf36-29ea-437a-987c-9aa600b9dae6"
    },
    "cell_type": "code",
    "source": [
    "test_loss, test_acc = model.evaluate(X_test,y_test)"
    ],
    "execution_count": 20,
    "outputs": [
    {
    "output_type": "stream",
    "text": [
    "200/200 [==============================] - 0s 244us/step\n"
    ],
    "name": "stdout"
    }
    ]
    },
    {
    "metadata": {
    "id": "f6HrjXeUF0Ko",
    "colab_type": "code",
    "colab": {
    "base_uri": "https://localhost:8080/",
    "height": 34
    },
    "outputId": "ea282dbd-6f9e-48c7-de2d-dc9afde8949e"
    },
    "cell_type": "code",
    "source": [
    "print('test_acc: ',test_acc)"
    ],
    "execution_count": 21,
    "outputs": [
    {
    "output_type": "stream",
    "text": [
    "test_acc: 0.68\n"
    ],
    "name": "stdout"
    }
    ]
    },
    {
    "metadata": {
    "id": "3yQmP_f5Kq0w",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    "Tes accuracy is less than training dataa accuracy. This hints at Overfitting"
    ]
    },
    {
    "metadata": {
    "id": "-U2qzRJoHV9O",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    "## Validating our approach\n",
    "Let's set apart 200 samples in our training data to use as a validation set:"
    ]
    },
    {
    "metadata": {
    "id": "xJNbvYZoF7ZT",
    "colab_type": "code",
    "colab": {}
    },
    "cell_type": "code",
    "source": [
    "x_val = X_train[:200]\n",
    "partial_x_train = X_train[200:]\n",
    "\n",
    "y_val = y_train[:200]\n",
    "partial_y_train = y_train[200:]"
    ],
    "execution_count": 0,
    "outputs": []
    },
    {
    "metadata": {
    "id": "L1EkG59EHeEV",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    "Now let's train our network for 20 epochs:"
    ]
    },
    {
    "metadata": {
    "id": "Dp3G4P3aP4k2",
    "colab_type": "code",
    "colab": {
    "base_uri": "https://localhost:8080/",
    "height": 1071
    },
    "outputId": "25e1a389-1ac2-425b-bd5f-05736b6e9b96"
    },
    "cell_type": "code",
    "source": [
    "model = models.Sequential()\n",
    "model.add(layers.Dense(512, activation='relu', input_shape=(X_train.shape[1],)))\n",
    "model.add(layers.Dense(256, activation='relu'))\n",
    "model.add(layers.Dense(128, activation='relu'))\n",
    "model.add(layers.Dense(64, activation='relu'))\n",
    "model.add(layers.Dense(10, activation='softmax'))\n",
    "\n",
    "model.compile(optimizer='adam',\n",
    " loss='sparse_categorical_crossentropy',\n",
    " metrics=['accuracy'])\n",
    "\n",
    "model.fit(partial_x_train,\n",
    " partial_y_train,\n",
    " epochs=30,\n",
    " batch_size=512,\n",
    " validation_data=(x_val, y_val))\n",
    "results = model.evaluate(X_test, y_test)"
    ],
    "execution_count": 37,
    "outputs": [
    {
    "output_type": "stream",
    "text": [
    "Train on 600 samples, validate on 200 samples\n",
    "Epoch 1/30\n",
    "600/600 [==============================] - 1s 1ms/step - loss: 2.3074 - acc: 0.0950 - val_loss: 2.1857 - val_acc: 0.2850\n",
    "Epoch 2/30\n",
    "600/600 [==============================] - 0s 65us/step - loss: 2.1126 - acc: 0.3783 - val_loss: 2.0936 - val_acc: 0.2400\n",
    "Epoch 3/30\n",
    "600/600 [==============================] - 0s 59us/step - loss: 1.9535 - acc: 0.3633 - val_loss: 1.9966 - val_acc: 0.2600\n",
    "Epoch 4/30\n",
    "600/600 [==============================] - 0s 58us/step - loss: 1.8082 - acc: 0.3833 - val_loss: 1.8713 - val_acc: 0.3250\n",
    "Epoch 5/30\n",
    "600/600 [==============================] - 0s 59us/step - loss: 1.6663 - acc: 0.4083 - val_loss: 1.7302 - val_acc: 0.3450\n",
    "Epoch 6/30\n",
    "600/600 [==============================] - 0s 52us/step - loss: 1.5329 - acc: 0.4550 - val_loss: 1.6233 - val_acc: 0.3700\n",
    "Epoch 7/30\n",
    "600/600 [==============================] - 0s 62us/step - loss: 1.4236 - acc: 0.4850 - val_loss: 1.5402 - val_acc: 0.3950\n",
    "Epoch 8/30\n",
    "600/600 [==============================] - 0s 57us/step - loss: 1.3250 - acc: 0.5117 - val_loss: 1.4655 - val_acc: 0.3800\n",
    "Epoch 9/30\n",
    "600/600 [==============================] - 0s 52us/step - loss: 1.2338 - acc: 0.5633 - val_loss: 1.3927 - val_acc: 0.4650\n",
    "Epoch 10/30\n",
    "600/600 [==============================] - 0s 61us/step - loss: 1.1577 - acc: 0.5983 - val_loss: 1.3338 - val_acc: 0.5500\n",
    "Epoch 11/30\n",
    "600/600 [==============================] - 0s 64us/step - loss: 1.0981 - acc: 0.6317 - val_loss: 1.3111 - val_acc: 0.5550\n",
    "Epoch 12/30\n",
    "600/600 [==============================] - 0s 52us/step - loss: 1.0529 - acc: 0.6517 - val_loss: 1.2696 - val_acc: 0.5400\n",
    "Epoch 13/30\n",
    "600/600 [==============================] - 0s 52us/step - loss: 0.9994 - acc: 0.6567 - val_loss: 1.2480 - val_acc: 0.5400\n",
    "Epoch 14/30\n",
    "600/600 [==============================] - 0s 65us/step - loss: 0.9673 - acc: 0.6633 - val_loss: 1.2384 - val_acc: 0.5700\n",
    "Epoch 15/30\n",
    "600/600 [==============================] - 0s 58us/step - loss: 0.9286 - acc: 0.6633 - val_loss: 1.1953 - val_acc: 0.5800\n",
    "Epoch 16/30\n",
    "600/600 [==============================] - 0s 59us/step - loss: 0.8849 - acc: 0.6783 - val_loss: 1.2000 - val_acc: 0.5550\n",
    "Epoch 17/30\n",
    "600/600 [==============================] - 0s 61us/step - loss: 0.8621 - acc: 0.6850 - val_loss: 1.1743 - val_acc: 0.5850\n",
    "Epoch 18/30\n",
    "600/600 [==============================] - 0s 61us/step - loss: 0.8195 - acc: 0.7150 - val_loss: 1.1609 - val_acc: 0.5750\n",
    "Epoch 19/30\n",
    "600/600 [==============================] - 0s 62us/step - loss: 0.7976 - acc: 0.7283 - val_loss: 1.1238 - val_acc: 0.6150\n",
    "Epoch 20/30\n",
    "600/600 [==============================] - 0s 63us/step - loss: 0.7660 - acc: 0.7650 - val_loss: 1.1604 - val_acc: 0.5850\n",
    "Epoch 21/30\n",
    "600/600 [==============================] - 0s 65us/step - loss: 0.7465 - acc: 0.7650 - val_loss: 1.1888 - val_acc: 0.5700\n",
    "Epoch 22/30\n",
    "600/600 [==============================] - 0s 65us/step - loss: 0.7099 - acc: 0.7517 - val_loss: 1.1563 - val_acc: 0.6050\n",
    "Epoch 23/30\n",
    "600/600 [==============================] - 0s 68us/step - loss: 0.6857 - acc: 0.7683 - val_loss: 1.0900 - val_acc: 0.6200\n",
    "Epoch 24/30\n",
    "600/600 [==============================] - 0s 67us/step - loss: 0.6597 - acc: 0.7850 - val_loss: 1.0872 - val_acc: 0.6300\n",
    "Epoch 25/30\n",
    "600/600 [==============================] - 0s 67us/step - loss: 0.6377 - acc: 0.7967 - val_loss: 1.1148 - val_acc: 0.6200\n",
    "Epoch 26/30\n",
    "600/600 [==============================] - 0s 64us/step - loss: 0.6070 - acc: 0.8200 - val_loss: 1.1397 - val_acc: 0.6150\n",
    "Epoch 27/30\n",
    "600/600 [==============================] - 0s 66us/step - loss: 0.5991 - acc: 0.8167 - val_loss: 1.1255 - val_acc: 0.6300\n",
    "Epoch 28/30\n",
    "600/600 [==============================] - 0s 62us/step - loss: 0.5656 - acc: 0.8333 - val_loss: 1.0955 - val_acc: 0.6350\n",
    "Epoch 29/30\n",
    "600/600 [==============================] - 0s 66us/step - loss: 0.5513 - acc: 0.8300 - val_loss: 1.1030 - val_acc: 0.6050\n",
    "Epoch 30/30\n",
    "600/600 [==============================] - 0s 56us/step - loss: 0.5498 - acc: 0.8233 - val_loss: 1.0869 - val_acc: 0.6250\n",
    "200/200 [==============================] - 0s 65us/step\n"
    ],
    "name": "stdout"
    }
    ]
    },
    {
    "metadata": {
    "id": "dljqHfDPI6lH",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    ""
    ]
    },
    {
    "metadata": {
    "id": "Mvi9it1SI4aR",
    "colab_type": "code",
    "colab": {
    "base_uri": "https://localhost:8080/",
    "height": 34
    },
    "outputId": "98b01ef2-3935-442b-82d6-45f56e036d39"
    },
    "cell_type": "code",
    "source": [
    "results"
    ],
    "execution_count": 38,
    "outputs": [
    {
    "output_type": "execute_result",
    "data": {
    "text/plain": [
    "[1.2261371064186095, 0.65]"
    ]
    },
    "metadata": {
    "tags": []
    },
    "execution_count": 38
    }
    ]
    },
    {
    "metadata": {
    "id": "r3hb8s1l4rBA",
    "colab_type": "text"
    },
    "cell_type": "markdown",
    "source": [
    "## Predictions on Test Data"
    ]
    },
    {
    "metadata": {
    "id": "gudBAhIXJIi2",
    "colab_type": "code",
    "colab": {}
    },
    "cell_type": "code",
    "source": [
    "predictions = model.predict(X_test)"
    ],
    "execution_count": 0,
    "outputs": []
    },
    {
    "metadata": {
    "id": "Xb7bVPSwJQF0",
    "colab_type": "code",
    "colab": {
    "base_uri": "https://localhost:8080/",
    "height": 34
    },
    "outputId": "aca09c75-1d21-4847-bdd9-a0521dc8d948"
    },
    "cell_type": "code",
    "source": [
    "predictions[0].shape"
    ],
    "execution_count": 26,
    "outputs": [
    {
    "output_type": "execute_result",
    "data": {
    "text/plain": [
    "(10,)"
    ]
    },
    "metadata": {
    "tags": []
    },
    "execution_count": 26
    }
    ]
    },
    {
    "metadata": {
    "id": "llusRQV0JRy9",
    "colab_type": "code",
    "colab": {
    "base_uri": "https://localhost:8080/",
    "height": 34
    },
    "outputId": "a856289d-883a-47cb-c0fb-ec148330a60a"
    },
    "cell_type": "code",
    "source": [
    "np.sum(predictions[0])"
    ],
    "execution_count": 27,
    "outputs": [
    {
    "output_type": "execute_result",
    "data": {
    "text/plain": [
    "1.0"
    ]
    },
    "metadata": {
    "tags": []
    },
    "execution_count": 27
    }
    ]
    },
    {
    "metadata": {
    "id": "0eoEuSZqJTdU",
    "colab_type": "code",
    "colab": {
    "base_uri": "https://localhost:8080/",
    "height": 34
    },
    "outputId": "94c17d00-dd7f-40a1-84d2-78d1ebde6103"
    },
    "cell_type": "code",
    "source": [
    "np.argmax(predictions[0])"
    ],
    "execution_count": 28,
    "outputs": [
    {
    "output_type": "execute_result",
    "data": {
    "text/plain": [
    "8"
    ]
    },
    "metadata": {
    "tags": []
    },
    "execution_count": 28
    }
    ]
    },
    {
    "metadata": {
    "id": "Utgt1bXfJVRN",
    "colab_type": "code",
    "colab": {}
    },
    "cell_type": "code",
    "source": [
    ""
    ],
    "execution_count": 0,
    "outputs": []
    }
    ]
    }