Skip to content

Instantly share code, notes, and snippets.

@osipov
Created October 11, 2016 14:54
Show Gist options
  • Select an option

  • Save osipov/2da9af5273dd2d169b9f04be503aebd1 to your computer and use it in GitHub Desktop.

Select an option

Save osipov/2da9af5273dd2d169b9f04be503aebd1 to your computer and use it in GitHub Desktop.

Revisions

  1. osipov created this gist Oct 11, 2016.
    300 changes: 300 additions & 0 deletions gistfile1.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,300 @@
    package org.deeplearning4j.examples.feedforward.xor;

    import org.deeplearning4j.eval.Evaluation;
    import org.deeplearning4j.nn.api.Model;
    import org.deeplearning4j.nn.api.OptimizationAlgorithm;
    import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
    import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
    import org.deeplearning4j.nn.conf.Updater;
    import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
    import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
    import org.deeplearning4j.nn.conf.layers.DenseLayer;
    import org.deeplearning4j.nn.conf.layers.OutputLayer;
    import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
    import org.deeplearning4j.nn.weights.WeightInit;
    import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
    import org.nd4j.linalg.api.ndarray.INDArray;
    import org.nd4j.linalg.dataset.DataSet;
    import org.nd4j.linalg.factory.Nd4j;
    import org.nd4j.linalg.lossfunctions.LossFunctions;


    //import org.deeplearning4j.nn.conf.distribution.UniformDistribution;
    import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
    import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
    //import org.deeplearning4j.eval.Evaluation;
    //import org.deeplearning4j.nn.api.Model;
    //import org.deeplearning4j.nn.api.OptimizationAlgorithm;
    //import org.deeplearning4j.nn.conf.Updater;
    //import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
    //import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
    //import org.deeplearning4j.nn.conf.layers.DenseLayer;
    //import org.deeplearning4j.nn.conf.layers.OutputLayer;
    //import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
    //import org.deeplearning4j.nn.weights.WeightInit;
    //import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
    //import org.nd4j.linalg.api.ndarray.INDArray;
    //import org.nd4j.linalg.dataset.DataSet;
    //import org.nd4j.linalg.factory.Nd4j;
    //import org.nd4j.linalg.lossfunctions.LossFunctions;
    //
    ////import org.nd4j.jita.conf.CudaEnvironment;

    import java.util.Arrays;
    import java.util.Random;

    /**
    * Created by osipov on 6/28/16.
    */
    public class FizzBuzz {

    public static int decodeBinary(INDArray arr) {
    int i = 0;
    for (int j = 0; j < arr.length(); j++) {
    i += Math.pow(2, j) * arr.getInt(j);
    }
    return i;
    }
    public static int decodeBinary(float[] b) {
    int i = 0;
    for (int j = 0; j < b.length; j++) {
    i += Math.pow(2, j)*b[j];
    }
    return i;
    }

    // public static float[] encodeBinary(int val, int numDigits) {
    // float[] result = new float[numDigits];
    // for (int i = 0; i < numDigits; i++) {
    // result[i] = (val >> i) & 1;
    // }
    // return result;
    // }

    public static INDArray encodeBinary(int val, int numDigits) {
    INDArray encoded = Nd4j.zeros(numDigits);
    for (int i = 0; i < numDigits; i++)
    encoded.putScalar(i, (val >> i) & 1);
    return encoded;
    }

    // public static float[] encodeFizzBuzz(int i) {
    // if (i % 15 == 0) return new float[]{0.0f, 0.0f, 0.0f, 1.0f};
    // else
    // if (i % 5 == 0) return new float[]{0.0f, 0.0f, 1.0f, 0.0f};
    // else
    // if (i % 3 == 0) return new float[]{0.0f, 1.0f, 0.0f, 0.0f};
    //
    // else return new float[]{1.0f, 0.0f, 0.0f, 0.0f};
    // }

    public static INDArray encodeFizzBuzz(int i) {
    INDArray encoded = Nd4j.zeros(4);
    if (i % 15 == 0) return encoded.putScalar(3, 1);
    else if (i % 5 == 0) return encoded.putScalar(2, 1);
    else if (i % 3 == 0) return encoded.putScalar(1, 1);
    else return encoded.putScalar(0, 1);
    }
    //
    // public static int[] encodeFizzBuzz(int i) {
    // if (i % 15 == 0) return new int[]{0, 0, 0, 1};
    // else
    // if (i % 5 == 0) return new int[]{0, 0, 1, 0};
    // else
    // if (i % 3 == 0) return new int[]{0, 1, 0, 0};
    //
    // else return new int[]{1, 0, 0, 0};

    // }
    public static void main(String[] args) {
    // org.nd4j.jita.conf.CudaEnvironment.getInstance().getConfiguration().allowMultiGPU(true);
    Nd4j.ENFORCE_NUMERICAL_STABILITY = true;

    // final int NUM_UPPER = 32768;
    final int NUM_UPPER = 8192;
    final int NUM_DIGITS = 10;
    int rngSeed = 12345;
    int numEpochs = 5000;
    int batchSize = 128;
    double learningRate = 0.3;
    double regularizationRate = learningRate * 0.0005;
    double nesterovsMomentum = 0.9;


    Random rnd = new Random(rngSeed);
    // int numEpochs = 1000;

    INDArray trainFeaturesTmp = Nd4j.zeros(NUM_UPPER - 101, NUM_DIGITS);
    INDArray trainLabelsTmp = Nd4j.zeros(NUM_UPPER - 101, 4);

    int trainCount = 0;
    for (int i = 101; i < NUM_UPPER; i++) {
    INDArray features = encodeBinary(i, NUM_DIGITS);
    INDArray labels = encodeFizzBuzz(i);

    boolean lucky = false;
    if (labels.getInt(0) == 1) lucky = rnd.nextInt(8) == 0;
    else
    if (labels.getInt(1) == 1) lucky = rnd.nextInt(4) == 0;
    else
    if (labels.getInt(2) == 1) lucky = rnd.nextInt(2) == 0;
    else
    if (labels.getInt(3) == 1) lucky = true;

    if (lucky) {
    trainFeaturesTmp.putRow(trainCount, features);
    trainLabelsTmp.putRow(trainCount, labels);
    trainCount++;
    }
    }

    int[] counts = new int[4];
    for (int i = 0; i < trainCount; i++) {
    if (trainLabelsTmp.getRow(i).getInt(0) == 1) counts[0] += 1;
    else
    if (trainLabelsTmp.getRow(i).getInt(1) == 1) counts[1] += 1;
    else
    if (trainLabelsTmp.getRow(i).getInt(2) == 1) counts[2] += 1;
    else
    if (trainLabelsTmp.getRow(i).getInt(3) == 1) counts[3] += 1;
    }


    System.out.println("Train count: " + Arrays.toString(counts));

    INDArray trainFeatures = Nd4j.zeros(trainCount, NUM_DIGITS);
    INDArray trainLabels = Nd4j.zeros(trainCount, 4);

    for (int i = 0; i < trainCount; i++) {
    trainFeatures.putRow(i, trainFeaturesTmp.getRow(i));
    trainLabels.putRow(i, trainLabelsTmp.getRow(i));
    }

    INDArray testFeatures = Nd4j.zeros(100, NUM_DIGITS);
    for (int i = 1; i < 101; i++) testFeatures.putRow(i - 1, encodeBinary(i, NUM_DIGITS));

    INDArray testLabels = Nd4j.zeros(100, 4);
    for (int i = 1; i < 101; i++) testLabels.putRow(i - 1, encodeFizzBuzz(i));

    final DataSet trainDataset = new DataSet(trainFeatures, trainLabels);
    final DataSet testDataset = new DataSet(testFeatures, testLabels);

    // for (int i = 0; i < 100; i++) {
    // System.out.println(testFeatures.getRow(i).toString() + " " + testLabels.getRow(i).toString());
    // }
    // if (true) return;

    trainDataset.shuffle(rngSeed);

    DataSetIterator trainDatasetBatches = new ListDataSetIterator(trainDataset.asList(), batchSize);
    // DataSetIterator testDatasetBatches = new ListDataSetIterator(testDataset.asList(), batchSize);

    System.out.println("Build model....");
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
    .seed(rngSeed)
    .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT)
    // .biasInit(0)
    .iterations(1)
    .learningRate(learningRate)
    .activation("relu")
    .weightInit(WeightInit.XAVIER)
    .miniBatch(true)
    .useDropConnect(false)
    .updater(Updater.NESTEROVS).momentum(nesterovsMomentum)
    // .regularization(true).l2(regularizationRate)
    .list()
    .layer(0, new DenseLayer.Builder()
    .nIn(10)
    .nOut(100)
    // .weightInit(WeightInit.DISTRIBUTION)
    // .dist(new NormalDistribution(0.0, 0.01))
    // .activation("relu")
    .build())
    .layer(1, new DenseLayer.Builder()
    .nIn(100)
    .nOut(100)
    // .weightInit(WeightInit.DISTRIBUTION)
    // .dist(new NormalDistribution(0.0, 0.01))
    // .activation("relu")
    .build())
    .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
    .nIn(100)
    .nOut(4)
    // .weightInit(WeightInit.DISTRIBUTION)
    // .dist(new UniformDistribution(0.1, 1))
    .activation("softmax")
    .build())
    .pretrain(false).backprop(true)
    .build();

    final MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();

    // add an listener which outputs the error every 100 parameter updates
    // model.setListeners(new ScoreIterationListener(100));


    model.setListeners(new ScoreIterationListener[]{
    new ScoreIterationListener(100),
    // new ScoreIterationListener(200) {
    // private int myCount = 0;
    //
    // @Override
    // public void iterationDone(Model m, int iter) {
    //// super.iterationDone(m, iter);
    // try {
    // if (myCount % 200 == 0 && myCount > 0) {
    // org.deeplearning4j.nn.multilayer.MultiLayerNetwork mod = (org.deeplearning4j.nn.multilayer.MultiLayerNetwork) m;
    // Evaluation eval = new Evaluation(4);
    // INDArray output = mod.output(testDataset.getFeatures());
    // eval.eval(testDataset.getLabels(), output);
    // System.out.println("Test Iteration " + myCount);
    // System.out.println(eval.stats(true));
    // }
    // myCount++;
    // } catch (Throwable t) {
    // System.out.println("caught throwable " + t);
    // }
    // }
    // }
    });

    System.out.println("Train model....");
    for( int i=0; i<numEpochs; i++ ) {
    model.fit(trainDatasetBatches);
    // model.fit(trainDataset);
    }

    System.out.println("Evaluate model....");
    {
    System.out.println("****************Train eval********************");
    Evaluation eval = new Evaluation(4);
    eval.eval(trainDataset.getLabels(), model.output(trainDataset.getFeatures()));
    System.out.println(eval.stats());
    System.out.println("****************Train eval********************");
    }
    {
    System.out.println("****************Test eval********************");
    Evaluation eval = new Evaluation(4);
    eval.eval(testDataset.getLabels(), model.output(testDataset.getFeatures()));
    System.out.println(eval.stats());
    System.out.println("****************Test eval********************");
    }

    System.out.println("****************Example finished********************");


    for (int i = 0; i < 16; i++) {
    System.out.println((i + 1) + " " + testFeatures.getRow(i).toString() + " " + model.output(testFeatures.getRow(i)).toString());
    }


    // System.out.println(model.output(testDataset.getFeatures()));

    // for (int i = 1; i < 101; i++) {
    // INDArray o = model.output(encodeBinary(i, NUM_DIGITS));
    //
    // System.out.println(i + " " + o.toString() + " " + o.maxNumber() + " " + o.eps(o.maxNumber()).toString());
    // }
    }
    }