from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets import numpy as np import timeit import tensorflow as tf from pprint import pformat mnist = read_data_sets("data", one_hot=False) NUM_CLASS = 10 STEP = 200 # hypers C = 1. BATCH_SIZE = 128 LEARNING_RATE = 1e-2 def encode_label(labels, target): batch_ys = labels == target batch_ys = batch_ys.astype(int) for i in range(batch_ys.shape[0]): batch_ys[i] = 1 if batch_ys[i] else -1 return batch_ys.reshape([-1, 1]) def svm(x_test, mnist=None): print('Enter run...') reg_term = tf.constant(0.05) X = tf.placeholder(tf.float32, [None, 784], name='x') # W = tf.Variable(tf.truncated_normal([784, 1], name='weight')) W = tf.Variable(tf.zeros([784, 1], name='weight')) b = tf.Variable(tf.zeros([1]), name='b') Y = tf.placeholder(tf.float32, [None, 1], name='y') y_predict = tf.add(tf.matmul(X, W), b) reg_loss = reg_term * tf.reduce_sum(tf.square(W)) hinge_loss = tf.reduce_sum(tf.maximum(0., 1 - Y * y_predict)) svm_loss = reg_loss + C * hinge_loss optimizer = tf.train.GradientDescentOptimizer(LEARNING_RATE) goal = optimizer.minimize(svm_loss) predicted_class = tf.sign(y_predict) correct_prediction = tf.equal(Y, predicted_class) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) for i in range(NUM_CLASS): # print('Enter {}...'.format(i)) x_val, y_val = mnist.validation.next_batch(1000) y_val = encode_label(y_val, i) with tf.Session() as sess: print('Enter {} session...'.format(i)) for stp in range(STEP): tf.global_variables_initializer().run() batch_xs, batch_ys = mnist.train.next_batch(BATCH_SIZE) batch_ys = encode_label(batch_ys, i) sess.run(goal, feed_dict={X: batch_xs, Y: batch_ys}) if stp % 10 == 0: print('loss: ', sess.run(svm_loss, feed_dict={X: batch_xs, Y: batch_ys})) print("Class", i, "Accuracy on validation:", accuracy.eval(feed_dict={X: x_val, Y: y_val})) def run(algorithm, x_test, y_test, mnist, algorithm_name='Algorithm'): print('Running {}...'.format(algorithm_name)) start = timeit.default_timer() np.random.seed(0) algorithm(x_test, mnist=mnist) for algorithm in [svm]: x_valid, y_valid = mnist.validation._images, mnist.validation.labels # correct_predict, accuracy, run_time = run(algorithm, x_valid, y_valid, mnist, algorithm_name=algorithm.__name__) run(algorithm, x_valid, y_valid, mnist, algorithm_name=algorithm.__name__)