# Early Stopping Experiment with MNIST # http://fouryears.eu/2017/12/05/the-mystery-of-early-stopping/ # # Code adapted from: https://github.com/fchollet/keras/blob/master/examples/mnist_cnn.py # By: Konstantin Tretyakov # License: MIT import keras from keras.datasets import mnist from keras.models import Sequential from keras.layers import Dense, Dropout, Flatten from keras.layers import Conv2D, MaxPooling2D from keras.regularizers import l2 from keras import backend as K from sklearn.model_selection import train_test_split import numpy as np import pickle import os from sklearn.metrics import log_loss img_rows, img_cols = 28, 28 num_classes = 10 # the data, shuffled and split between train and test sets (x_train, y_train), (x_test, y_test) = mnist.load_data() if K.image_data_format() == 'channels_first': x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) input_shape = (1, img_rows, img_cols) else: x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) input_shape = (img_rows, img_cols, 1) x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 print('x_train shape:', x_train.shape) print(x_train.shape[0], 'train samples') print(x_test.shape[0], 'test samples') # convert class vectors to binary class matrices y_train = keras.utils.to_categorical(y_train, num_classes) y_test = keras.utils.to_categorical(y_test, num_classes) def model_bare(): m = Sequential([Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape), Conv2D(64, (3, 3), activation='relu'), MaxPooling2D(pool_size=(2, 2)), Flatten(), Dense(128, activation='relu'), Dense(num_classes, activation='softmax')]) m.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.SGD(lr=0.05), metrics=['accuracy']) return m def model_l2(): m = Sequential([Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape), Conv2D(64, (3, 3), activation='relu'), MaxPooling2D(pool_size=(2, 2)), Flatten(), Dense(128, activation='relu', kernel_regularizer=l2(0.001)), Dense(num_classes, activation='softmax')]) m.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.SGD(lr=0.05), metrics=['accuracy']) return m def model_dropout(): m = Sequential([Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape), Conv2D(64, (3, 3), activation='relu'), MaxPooling2D(pool_size=(2, 2)), Dropout(0.25), Flatten(), Dense(128, activation='relu'), Dropout(0.5), Dense(num_classes, activation='softmax')]) m.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.SGD(lr=0.05), metrics=['accuracy']) return m def fit_partial(model, test_size, batch_size=512, epochs_max=100): m = model() if os.path.exists('best.weights'): os.unlink('best.weights') x_fit, x_stop, y_fit, y_stop = train_test_split(x_train, y_train, test_size=test_size) save_best = keras.callbacks.ModelCheckpoint('best.weights', monitor='val_loss', verbose=1, save_best_only=True) early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1) m.fit(x_fit, y_fit, batch_size=batch_size, epochs=epochs_max, verbose=1, validation_data=(x_stop, y_stop), callbacks=[early_stop, save_best]) m.load_weights('best.weights') p = m.predict(x_test, batch_size=batch_size, verbose=0) return log_loss(y_test, p) def fit_full(model, batch_size=512, epochs_max=100): m = model() m.fit(x_train, y_train, batch_size=batch_size, epochs=epochs_max, verbose=1) #return m.evaluate(x_test, y_test, batch_size=batch_size, verbose=0) p = m.predict(x_test, batch_size=batch_size, verbose=0) return log_loss(y_test, p) all_results = {} for m, title in [(model_bare, "Bare"), (model_l2, "L2"), (model_dropout, "Dropout")]: res = [fit_full(m, epochs_max=200)] stops = np.arange(0.05, 1.0, 0.05) for s in stops: res.append(fit_partial(m, s, epochs_max=200)) print(res) all_results[title] = res print(all_results) with open('results.pkl', 'wb') as f: pickle.dump(all_results, f)