Skip to content

Instantly share code, notes, and snippets.

@osipov
Created October 12, 2016 19:08
Show Gist options
  • Select an option

  • Save osipov/11bcc59c14b1a140d4f67ca865d56648 to your computer and use it in GitHub Desktop.

Select an option

Save osipov/11bcc59c14b1a140d4f67ca865d56648 to your computer and use it in GitHub Desktop.

Revisions

  1. osipov created this gist Oct 12, 2016.
    195 changes: 195 additions & 0 deletions fizzbuzz.dl4j.java
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,195 @@
    package org.deeplearning4j.examples.feedforward.xor;

    import org.deeplearning4j.eval.Evaluation;
    import org.deeplearning4j.nn.api.OptimizationAlgorithm;
    import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
    import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
    import org.deeplearning4j.nn.conf.Updater;
    import org.deeplearning4j.nn.conf.layers.DenseLayer;
    import org.deeplearning4j.nn.conf.layers.OutputLayer;
    import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
    import org.deeplearning4j.nn.weights.WeightInit;
    import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
    import org.nd4j.linalg.api.ndarray.INDArray;
    import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
    import org.nd4j.linalg.dataset.DataSet;
    import org.nd4j.linalg.factory.Nd4j;
    import org.nd4j.linalg.lossfunctions.LossFunctions;

    import java.util.Random;

    /**
    * Created by @osipov on 10/12/16.
    */
    public class FizzBuzz {

    /**
    * Encode a positive integer in binary using little endian style with a numDigits width
    * @param val positive integer
    * @param numDigits width of the binary number
    * @return INDArray array with the binary encoding
    */
    public static INDArray encodeBinary(int val, int numDigits) {
    INDArray encoded = Nd4j.zeros(numDigits);
    for (int i = 0; i < numDigits; i++)
    encoded.putScalar(i, (val >> i) & 1);
    return encoded;
    }

    /**
    * Decode a binary number into a positive integer
    * @param arr binary number stored little endian style in INDArray
    * @return decoding of the binary number back into an integer
    */
    public static int decodeBinary(INDArray arr) {
    int i = 0;
    for (int j = 0; j < arr.length(); j++) {
    i += Math.pow(2, j) * arr.getInt(j);
    }
    return i;
    }

    /**
    * Hot one encode a positive integer to one of the 4 "fizzbuzz" classes as follows
    * [0.0 1.0 0.0 0.0] if the number is divisible by 3
    * [0.0 0.0 1.0 0.0] if the number is divisible by 5
    * [0.0 0.0 0.0 1.0] if the number is divisible by 3 and 5
    * [1.0 0.0 0.0 0.0] otherwise
    * @param i
    * @return INDArray containing the encoding
    */
    public static INDArray encodeFizzBuzz(int i) {
    INDArray encoded = Nd4j.zeros(4);
    if (i % 15 == 0 && i != 1) return encoded.putScalar(3, 1);
    else if (i % 5 == 0) return encoded.putScalar(2, 1);
    else if (i % 3 == 0) return encoded.putScalar(1, 1);
    else return encoded.putScalar(0, 1);
    }


    /**
    * Decode a hotone encoded binary number using the following rules
    * if the number is
    * - divisible by 3 return "fizz"
    * - divisible by 5 return "buzz"
    * - divisible by 3 and 5, return "fizzbuzz"
    * otherwise, return null
    * @param arr INDArray specified hotone encoding per @see encodeFizzBuzz.
    * @return String which can be null, "fizz", "buzz", or "fizzbuzz"
    */
    public static String decodeFizzBuzz(INDArray arr) {
    int idx = Nd4j.getExecutioner().execAndReturn(new IAMax(arr)).getFinalResult();
    if (idx == 0)
    return null;
    else
    if (idx == 1)
    return "fizz";
    else
    if (idx == 2)
    return "buzz";
    else
    return "fizzbuzz";
    }


    public static void main(String[] args) {
    Nd4j.ENFORCE_NUMERICAL_STABILITY = true;

    //random number generator seed
    int rngSeed = 12345;
    Random rnd = new Random(rngSeed);

    //width of the binary number to store hotone encoding of the input
    final int NUM_DIGITS = 10;

    //figure out the largest number we can represent using a NUM_DIGITS wide binary number
    final int NUM_UPPER = (int)Math.pow(2.0, NUM_DIGITS);

    int numEpochs = 10000;

    double learningRate = 0.5;
    double regularizationRate = 0.75;

    double nesterovsMomentum = 0.95;

    //populate the train set with the numbers in the range [101,923]
    INDArray trainFeatures = Nd4j.zeros(NUM_UPPER - 101, NUM_DIGITS);
    INDArray trainLabels = Nd4j.zeros(NUM_UPPER - 101, 4);
    for (int i = 0; i < NUM_UPPER - 101; i++) {
    INDArray features = encodeBinary(i + 101, NUM_DIGITS);
    INDArray labels = encodeFizzBuzz(i + 101);

    trainFeatures.putRow(i, features);
    trainLabels.putRow(i, labels);
    }

    //populate the test set with the numbers in the range [1, 100]
    INDArray testFeatures = Nd4j.zeros(100, NUM_DIGITS);
    INDArray testLabels = Nd4j.zeros(100, 4);
    for (int i = 1; i < 101; i++) {
    testFeatures.putRow(i - 1, encodeBinary(i, NUM_DIGITS));
    testLabels.putRow(i - 1, encodeFizzBuzz(i));
    }

    final DataSet trainDataset = new DataSet(trainFeatures, trainLabels);
    final DataSet testDataset = new DataSet(testFeatures, testLabels);

    trainDataset.shuffle(rngSeed);

    System.out.println("Build model....");
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
    .seed(rngSeed)
    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
    .iterations(1)
    .learningRate(learningRate)
    .activation("relu")
    .weightInit(WeightInit.XAVIER)
    .updater(Updater.NESTEROVS).momentum(nesterovsMomentum)
    .regularization(regularizationRate > 0.0).l2(regularizationRate)
    .list()
    .layer(0, new DenseLayer.Builder()
    .nIn(NUM_DIGITS)
    .nOut(100)
    .build())
    .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
    .nIn(100)
    .nOut(4)
    .activation("softmax")
    .build())
    .pretrain(false).backprop(true)
    .build();

    final MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();

    model.setListeners(new ScoreIterationListener(100));

    System.out.println("Train model....");
    for( int i=0; i<numEpochs; i++ ) {
    model.fit(trainDataset);
    }

    System.out.println("Evaluate model....");
    {
    System.out.println("****************Train eval********************");
    Evaluation eval = new Evaluation(4);
    eval.eval(trainDataset.getLabels(), model.output(trainDataset.getFeatures()));
    System.out.println(eval.stats());
    System.out.println("****************Train eval********************");
    }
    {
    System.out.println("****************Test eval********************");
    Evaluation eval = new Evaluation(4);
    eval.eval(testDataset.getLabels(), model.output(testDataset.getFeatures()));
    System.out.println(eval.stats());
    System.out.println("****************Test eval********************");
    }

    System.out.println("****************Example finished********************");

    for (int i = 0; i < 100; i++) {
    String decoded = decodeFizzBuzz(model.output(testFeatures.getRow(i)));
    System.out.println((i + 1) + " " + testFeatures.getRow(i).toString() + " " + encodeFizzBuzz(i + 1) + " " + model.output(testFeatures.getRow(i)).toString() + " " + (decoded == null ? i + 1 : decoded));
    }
    }
    }