Skip to content

Instantly share code, notes, and snippets.

@dursunkoc
Forked from thvasilo/IncrementalSGD.java
Created July 24, 2017 20:36
Show Gist options
  • Save dursunkoc/ec97b9f8f83db2d8bb802af1b2ca48b9 to your computer and use it in GitHub Desktop.
Save dursunkoc/ec97b9f8f83db2d8bb802af1b2ca48b9 to your computer and use it in GitHub Desktop.

Revisions

  1. Theodore Vasiloudis created this gist Nov 21, 2016.
    283 changes: 283 additions & 0 deletions IncrementalSGD.java
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,283 @@
    /*
    * Licensed to the Apache Software Foundation (ASF) under one or more
    * contributor license agreements. See the NOTICE file distributed with
    * this work for additional information regarding copyright ownership.
    * The ASF licenses this file to You under the Apache License, Version 2.0
    * (the "License"); you may not use this file except in compliance with
    * the License. You may obtain a copy of the License at
    *
    * http://www.apache.org/licenses/LICENSE-2.0
    *
    * Unless required by applicable law or agreed to in writing, software
    * distributed under the License is distributed on an "AS IS" BASIS,
    * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    * See the License for the specific language governing permissions and
    * limitations under the License.
    */

    package se.sics.quickstart;

    import org.apache.flink.api.common.functions.MapFunction;
    import org.apache.flink.api.common.state.ValueState;
    import org.apache.flink.api.common.state.ValueStateDescriptor;
    import org.apache.flink.api.common.typeinfo.TypeHint;
    import org.apache.flink.api.common.typeinfo.TypeInformation;
    import org.apache.flink.api.java.utils.ParameterTool;
    import org.apache.flink.configuration.Configuration;
    import org.apache.flink.streaming.api.TimeCharacteristic;
    import org.apache.flink.streaming.api.datastream.DataStream;
    import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
    import org.apache.flink.streaming.api.functions.co.CoFlatMapFunction;
    import org.apache.flink.streaming.api.functions.co.CoMapFunction;
    import org.apache.flink.streaming.api.functions.source.SourceFunction;
    import org.apache.flink.streaming.api.functions.windowing.RichAllWindowFunction;
    import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
    import org.apache.flink.util.Collector;

    import java.util.ArrayList;
    import java.util.Collections;

    /**
    * Skeleton for incremental machine learning algorithm consisting of a
    * pre-computed model, which gets updated for the new inputs and new input data
    * for which the job provides predictions.
    *
    * <p>
    * This may serve as a base of a number of algorithms, e.g. updating an
    * incremental Alternating Least Squares model while also providing the
    * predictions.
    *
    * <p>
    * This example shows how to use:
    * <ul>
    * <li>Connected streams
    * <li>CoFunctions
    * <li>Tuple data types
    * </ul>
    */
    public class IncrementalLearning {


    // *************************************************************************
    // PROGRAM
    // *************************************************************************

    public static void main(String[] args) throws Exception {

    // Checking input parameters
    final ParameterTool params = ParameterTool.fromArgs(args);

    Double learningRate = params.has("learningRate") ? new Double(params.get("learningRate")) : 0.001;

    StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
    env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime);
    // env.setParallelism(1);

    // To simplify we make the assumption that the last element in each line is the dependent variable
    DataStream<ArrayList<Double>> trainingData = env.readTextFile(params.get("training"))
    .map(new VectorExtractor());
    DataStream<ArrayList<Double>> newData = env.readTextFile(params.get("test"))
    .map(new VectorExtractor());

    // build new model on every second of new data
    DataStream<ArrayList<Double>> model = trainingData
    .countWindowAll(Integer.parseInt(params.get("batchsize")))
    .apply(new PartialModelBuilder(learningRate, Integer.parseInt(params.get("dimensions"))));

    // model.print();

    // use partial model for newData
    DataStream<Double> errors = newData.connect(model).flatMap(new Evaluator());

    errors.print();

    // emit result
    // if (params.has("output")) {
    // prediction.writeAsText(params.get("output"));
    // } else {
    // System.out.println("Printing result to stdout. Use --output to specify output path.");
    // prediction.print();
    // }

    // execute program
    env.execute("Streaming Incremental Learning");
    }

    // *************************************************************************
    // USER FUNCTIONS
    // *************************************************************************

    /**
    * Feeds new data for newData. By default it is implemented as constantly
    * emitting the Integer 1 in a loop.
    */
    public static class FiniteNewDataSource implements SourceFunction<Integer> {
    private static final long serialVersionUID = 1L;
    private int counter;

    private String filepath;

    public FiniteNewDataSource(int counter, String filepath) {
    this.counter = counter;
    this.filepath = filepath;
    }

    @Override
    public void run(SourceContext<Integer> ctx) throws Exception {
    Thread.sleep(15);
    while (counter < 50) {
    ctx.collect(getNewData());
    }
    }

    @Override
    public void cancel() {
    // No cleanup needed
    }

    private Integer getNewData() throws InterruptedException {
    Thread.sleep(5);
    counter++;
    return 1;
    }
    }

    private static Double predict(ArrayList<Double> model, ArrayList<Double> example) {
    Double prediction = 0.0;
    for (int i = 0; i < model.size(); i++) {
    prediction += model.get(i) * example.get(i);
    }
    return prediction;
    }

    /**
    * Builds up-to-date partial models on new training data.
    */
    public static class PartialModelBuilder extends RichAllWindowFunction<ArrayList<Double>, ArrayList<Double>, GlobalWindow> {

    public PartialModelBuilder(Double learningRate, int dimensions) {
    this.learningRate = learningRate;
    this.dimensions = dimensions;
    }

    private Double learningRate;
    private int dimensions;
    private int applyCount = 0;

    private static final long serialVersionUID = 1L;

    private transient ValueState<ArrayList<Double>> modelState;

    @Override
    public void open(Configuration config) {
    ArrayList<Double> allZeroes = new ArrayList<>(Collections.nCopies(dimensions, 0.0));
    // obtain key-value state for prediction model
    // TODO: Do random assignment of weights instead of all zeros?
    ValueStateDescriptor<ArrayList<Double>> descriptor =
    new ValueStateDescriptor<>(
    // state name
    "modelState",
    // type information of state
    TypeInformation.of(new TypeHint<ArrayList<Double>>() {}),
    // default value of state
    allZeroes);
    modelState = getRuntimeContext().getState(descriptor);
    }

    private Double squaredError(Double truth, Double prediction) {
    return 0.5 * (truth - prediction) * (truth - prediction);
    }

    private ArrayList<Double> buildPartialModel(Iterable<ArrayList<Double>> trainingBatch) throws Exception{
    int batchSize = 0;
    ArrayList<Double> regressionModel = modelState.value();
    ArrayList<Double> gradientSum = new ArrayList<>(Collections.nCopies(dimensions, 0.0));
    for (ArrayList<Double> sample : trainingBatch) {
    batchSize++;
    Double truth = sample.get(sample.size() - 1);
    Double prediction = predict(regressionModel, sample);
    Double error = squaredError(truth, prediction);
    Double derivative = prediction - truth;
    for (int i = 0; i < regressionModel.size(); i++) {
    Double weightGradient = derivative * sample.get(i);
    Double currentSum = gradientSum.get(i);
    gradientSum.set(i, currentSum + weightGradient);
    }
    }
    for (int i = 0; i < regressionModel.size(); i++) {
    Double oldWeight = regressionModel.get(i);
    Double currentLR = learningRate / Math.sqrt(applyCount);
    Double change = currentLR * (gradientSum.get(i) / batchSize);
    regressionModel.set(i, oldWeight - change);
    }
    return regressionModel;
    }

    @Override
    public void apply(GlobalWindow window, Iterable<ArrayList<Double>> values, Collector<ArrayList<Double>> out) throws Exception {
    this.applyCount++;

    ArrayList<Double> updatedModel = buildPartialModel(values);
    modelState.update(updatedModel);
    out.collect(updatedModel);
    }
    }

    public static class Evaluator implements CoFlatMapFunction<ArrayList<Double>, ArrayList<Double>, Double> {

    ArrayList<Double> partialModel = null;

    @Override
    public void flatMap1(ArrayList<Double> example, Collector<Double> out) throws Exception {
    // System.out.format("Example: %f\n", example.get(0));
    if (partialModel != null) {
    System.out.println("Model was not null!");
    Double prediction = predict(partialModel, example);
    Double error = example.get(example.size() - 1) - prediction;
    out.collect(error);
    } else {
    System.out.println("Model was null!");
    }
    }

    @Override
    public void flatMap2(ArrayList<Double> curModel, Collector<Double> out) throws Exception {
    partialModel = curModel;
    out.collect(partialModel.get(0));
    }
    }

    public static class Predictor implements CoMapFunction<ArrayList<Double>, ArrayList<Double>, Double> {
    ArrayList<Double> partialModel = null;

    @Override
    public Double map1(ArrayList<Double> example) throws Exception {
    if (partialModel != null) {
    System.out.println("Partial model ready!");
    return predict(partialModel, example);
    } else {
    return Double.NaN;
    }
    }

    @Override
    public Double map2(ArrayList<Double> curModel) throws Exception {
    partialModel = curModel;
    return -1.0;
    }
    }


    public static class VectorExtractor implements MapFunction<String, ArrayList<Double>> {
    @Override
    public ArrayList<Double> map(String s) throws Exception {
    String[] elements = s.split(",");
    ArrayList<Double> doubleElements = new ArrayList<>(elements.length);
    for (int i = 0; i < elements.length; i++) {
    doubleElements.add(new Double(elements[i]));
    }
    return doubleElements;
    }
    }

    }