package org.deeplearning4j.examples.feedforward.xor; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; //import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; //import org.deeplearning4j.eval.Evaluation; //import org.deeplearning4j.nn.api.Model; //import org.deeplearning4j.nn.api.OptimizationAlgorithm; //import org.deeplearning4j.nn.conf.Updater; //import org.deeplearning4j.nn.conf.MultiLayerConfiguration; //import org.deeplearning4j.nn.conf.NeuralNetConfiguration; //import org.deeplearning4j.nn.conf.layers.DenseLayer; //import org.deeplearning4j.nn.conf.layers.OutputLayer; //import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; //import org.deeplearning4j.nn.weights.WeightInit; //import org.deeplearning4j.optimize.listeners.ScoreIterationListener; //import org.nd4j.linalg.api.ndarray.INDArray; //import org.nd4j.linalg.dataset.DataSet; //import org.nd4j.linalg.factory.Nd4j; //import org.nd4j.linalg.lossfunctions.LossFunctions; // ////import org.nd4j.jita.conf.CudaEnvironment; import java.util.Arrays; import java.util.Random; /** * Created by osipov on 6/28/16. */ public class FizzBuzz { public static int decodeBinary(INDArray arr) { int i = 0; for (int j = 0; j < arr.length(); j++) { i += Math.pow(2, j) * arr.getInt(j); } return i; } public static int decodeBinary(float[] b) { int i = 0; for (int j = 0; j < b.length; j++) { i += Math.pow(2, j)*b[j]; } return i; } // public static float[] encodeBinary(int val, int numDigits) { // float[] result = new float[numDigits]; // for (int i = 0; i < numDigits; i++) { // result[i] = (val >> i) & 1; // } // return result; // } public static INDArray encodeBinary(int val, int numDigits) { INDArray encoded = Nd4j.zeros(numDigits); for (int i = 0; i < numDigits; i++) encoded.putScalar(i, (val >> i) & 1); return encoded; } // public static float[] encodeFizzBuzz(int i) { // if (i % 15 == 0) return new float[]{0.0f, 0.0f, 0.0f, 1.0f}; // else // if (i % 5 == 0) return new float[]{0.0f, 0.0f, 1.0f, 0.0f}; // else // if (i % 3 == 0) return new float[]{0.0f, 1.0f, 0.0f, 0.0f}; // // else return new float[]{1.0f, 0.0f, 0.0f, 0.0f}; // } public static INDArray encodeFizzBuzz(int i) { INDArray encoded = Nd4j.zeros(4); if (i % 15 == 0) return encoded.putScalar(3, 1); else if (i % 5 == 0) return encoded.putScalar(2, 1); else if (i % 3 == 0) return encoded.putScalar(1, 1); else return encoded.putScalar(0, 1); } // // public static int[] encodeFizzBuzz(int i) { // if (i % 15 == 0) return new int[]{0, 0, 0, 1}; // else // if (i % 5 == 0) return new int[]{0, 0, 1, 0}; // else // if (i % 3 == 0) return new int[]{0, 1, 0, 0}; // // else return new int[]{1, 0, 0, 0}; // } public static void main(String[] args) { // org.nd4j.jita.conf.CudaEnvironment.getInstance().getConfiguration().allowMultiGPU(true); Nd4j.ENFORCE_NUMERICAL_STABILITY = true; // final int NUM_UPPER = 32768; final int NUM_UPPER = 8192; final int NUM_DIGITS = 10; int rngSeed = 12345; int numEpochs = 5000; int batchSize = 128; double learningRate = 0.3; double regularizationRate = learningRate * 0.0005; double nesterovsMomentum = 0.9; Random rnd = new Random(rngSeed); // int numEpochs = 1000; INDArray trainFeaturesTmp = Nd4j.zeros(NUM_UPPER - 101, NUM_DIGITS); INDArray trainLabelsTmp = Nd4j.zeros(NUM_UPPER - 101, 4); int trainCount = 0; for (int i = 101; i < NUM_UPPER; i++) { INDArray features = encodeBinary(i, NUM_DIGITS); INDArray labels = encodeFizzBuzz(i); boolean lucky = false; if (labels.getInt(0) == 1) lucky = rnd.nextInt(8) == 0; else if (labels.getInt(1) == 1) lucky = rnd.nextInt(4) == 0; else if (labels.getInt(2) == 1) lucky = rnd.nextInt(2) == 0; else if (labels.getInt(3) == 1) lucky = true; if (lucky) { trainFeaturesTmp.putRow(trainCount, features); trainLabelsTmp.putRow(trainCount, labels); trainCount++; } } int[] counts = new int[4]; for (int i = 0; i < trainCount; i++) { if (trainLabelsTmp.getRow(i).getInt(0) == 1) counts[0] += 1; else if (trainLabelsTmp.getRow(i).getInt(1) == 1) counts[1] += 1; else if (trainLabelsTmp.getRow(i).getInt(2) == 1) counts[2] += 1; else if (trainLabelsTmp.getRow(i).getInt(3) == 1) counts[3] += 1; } System.out.println("Train count: " + Arrays.toString(counts)); INDArray trainFeatures = Nd4j.zeros(trainCount, NUM_DIGITS); INDArray trainLabels = Nd4j.zeros(trainCount, 4); for (int i = 0; i < trainCount; i++) { trainFeatures.putRow(i, trainFeaturesTmp.getRow(i)); trainLabels.putRow(i, trainLabelsTmp.getRow(i)); } INDArray testFeatures = Nd4j.zeros(100, NUM_DIGITS); for (int i = 1; i < 101; i++) testFeatures.putRow(i - 1, encodeBinary(i, NUM_DIGITS)); INDArray testLabels = Nd4j.zeros(100, 4); for (int i = 1; i < 101; i++) testLabels.putRow(i - 1, encodeFizzBuzz(i)); final DataSet trainDataset = new DataSet(trainFeatures, trainLabels); final DataSet testDataset = new DataSet(testFeatures, testLabels); // for (int i = 0; i < 100; i++) { // System.out.println(testFeatures.getRow(i).toString() + " " + testLabels.getRow(i).toString()); // } // if (true) return; trainDataset.shuffle(rngSeed); DataSetIterator trainDatasetBatches = new ListDataSetIterator(trainDataset.asList(), batchSize); // DataSetIterator testDatasetBatches = new ListDataSetIterator(testDataset.asList(), batchSize); System.out.println("Build model...."); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(rngSeed) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) // .biasInit(0) .iterations(1) .learningRate(learningRate) .activation("relu") .weightInit(WeightInit.XAVIER) .miniBatch(true) .useDropConnect(false) .updater(Updater.NESTEROVS).momentum(nesterovsMomentum) // .regularization(true).l2(regularizationRate) .list() .layer(0, new DenseLayer.Builder() .nIn(10) .nOut(100) // .weightInit(WeightInit.DISTRIBUTION) // .dist(new NormalDistribution(0.0, 0.01)) // .activation("relu") .build()) .layer(1, new DenseLayer.Builder() .nIn(100) .nOut(100) // .weightInit(WeightInit.DISTRIBUTION) // .dist(new NormalDistribution(0.0, 0.01)) // .activation("relu") .build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .nIn(100) .nOut(4) // .weightInit(WeightInit.DISTRIBUTION) // .dist(new UniformDistribution(0.1, 1)) .activation("softmax") .build()) .pretrain(false).backprop(true) .build(); final MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); // add an listener which outputs the error every 100 parameter updates // model.setListeners(new ScoreIterationListener(100)); model.setListeners(new ScoreIterationListener[]{ new ScoreIterationListener(100), // new ScoreIterationListener(200) { // private int myCount = 0; // // @Override // public void iterationDone(Model m, int iter) { //// super.iterationDone(m, iter); // try { // if (myCount % 200 == 0 && myCount > 0) { // org.deeplearning4j.nn.multilayer.MultiLayerNetwork mod = (org.deeplearning4j.nn.multilayer.MultiLayerNetwork) m; // Evaluation eval = new Evaluation(4); // INDArray output = mod.output(testDataset.getFeatures()); // eval.eval(testDataset.getLabels(), output); // System.out.println("Test Iteration " + myCount); // System.out.println(eval.stats(true)); // } // myCount++; // } catch (Throwable t) { // System.out.println("caught throwable " + t); // } // } // } }); System.out.println("Train model...."); for( int i=0; i