Skip to content

Instantly share code, notes, and snippets.

@dursunkoc
Forked from thvasilo/IncrementalSGD.java
Created July 24, 2017 20:36
Show Gist options
  • Save dursunkoc/ec97b9f8f83db2d8bb802af1b2ca48b9 to your computer and use it in GitHub Desktop.
Save dursunkoc/ec97b9f8f83db2d8bb802af1b2ca48b9 to your computer and use it in GitHub Desktop.
A basic online SGD using the Flink stream API.
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package se.sics.quickstart;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.TimeCharacteristic;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.CoFlatMapFunction;
import org.apache.flink.streaming.api.functions.co.CoMapFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.functions.windowing.RichAllWindowFunction;
import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
import org.apache.flink.util.Collector;
import java.util.ArrayList;
import java.util.Collections;
/**
* Skeleton for incremental machine learning algorithm consisting of a
* pre-computed model, which gets updated for the new inputs and new input data
* for which the job provides predictions.
*
* <p>
* This may serve as a base of a number of algorithms, e.g. updating an
* incremental Alternating Least Squares model while also providing the
* predictions.
*
* <p>
* This example shows how to use:
* <ul>
* <li>Connected streams
* <li>CoFunctions
* <li>Tuple data types
* </ul>
*/
public class IncrementalLearning {
// *************************************************************************
// PROGRAM
// *************************************************************************
public static void main(String[] args) throws Exception {
// Checking input parameters
final ParameterTool params = ParameterTool.fromArgs(args);
Double learningRate = params.has("learningRate") ? new Double(params.get("learningRate")) : 0.001;
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime);
// env.setParallelism(1);
// To simplify we make the assumption that the last element in each line is the dependent variable
DataStream<ArrayList<Double>> trainingData = env.readTextFile(params.get("training"))
.map(new VectorExtractor());
DataStream<ArrayList<Double>> newData = env.readTextFile(params.get("test"))
.map(new VectorExtractor());
// build new model on every second of new data
DataStream<ArrayList<Double>> model = trainingData
.countWindowAll(Integer.parseInt(params.get("batchsize")))
.apply(new PartialModelBuilder(learningRate, Integer.parseInt(params.get("dimensions"))));
// model.print();
// use partial model for newData
DataStream<Double> errors = newData.connect(model).flatMap(new Evaluator());
errors.print();
// emit result
// if (params.has("output")) {
// prediction.writeAsText(params.get("output"));
// } else {
// System.out.println("Printing result to stdout. Use --output to specify output path.");
// prediction.print();
// }
// execute program
env.execute("Streaming Incremental Learning");
}
// *************************************************************************
// USER FUNCTIONS
// *************************************************************************
/**
* Feeds new data for newData. By default it is implemented as constantly
* emitting the Integer 1 in a loop.
*/
public static class FiniteNewDataSource implements SourceFunction<Integer> {
private static final long serialVersionUID = 1L;
private int counter;
private String filepath;
public FiniteNewDataSource(int counter, String filepath) {
this.counter = counter;
this.filepath = filepath;
}
@Override
public void run(SourceContext<Integer> ctx) throws Exception {
Thread.sleep(15);
while (counter < 50) {
ctx.collect(getNewData());
}
}
@Override
public void cancel() {
// No cleanup needed
}
private Integer getNewData() throws InterruptedException {
Thread.sleep(5);
counter++;
return 1;
}
}
private static Double predict(ArrayList<Double> model, ArrayList<Double> example) {
Double prediction = 0.0;
for (int i = 0; i < model.size(); i++) {
prediction += model.get(i) * example.get(i);
}
return prediction;
}
/**
* Builds up-to-date partial models on new training data.
*/
public static class PartialModelBuilder extends RichAllWindowFunction<ArrayList<Double>, ArrayList<Double>, GlobalWindow> {
public PartialModelBuilder(Double learningRate, int dimensions) {
this.learningRate = learningRate;
this.dimensions = dimensions;
}
private Double learningRate;
private int dimensions;
private int applyCount = 0;
private static final long serialVersionUID = 1L;
private transient ValueState<ArrayList<Double>> modelState;
@Override
public void open(Configuration config) {
ArrayList<Double> allZeroes = new ArrayList<>(Collections.nCopies(dimensions, 0.0));
// obtain key-value state for prediction model
// TODO: Do random assignment of weights instead of all zeros?
ValueStateDescriptor<ArrayList<Double>> descriptor =
new ValueStateDescriptor<>(
// state name
"modelState",
// type information of state
TypeInformation.of(new TypeHint<ArrayList<Double>>() {}),
// default value of state
allZeroes);
modelState = getRuntimeContext().getState(descriptor);
}
private Double squaredError(Double truth, Double prediction) {
return 0.5 * (truth - prediction) * (truth - prediction);
}
private ArrayList<Double> buildPartialModel(Iterable<ArrayList<Double>> trainingBatch) throws Exception{
int batchSize = 0;
ArrayList<Double> regressionModel = modelState.value();
ArrayList<Double> gradientSum = new ArrayList<>(Collections.nCopies(dimensions, 0.0));
for (ArrayList<Double> sample : trainingBatch) {
batchSize++;
Double truth = sample.get(sample.size() - 1);
Double prediction = predict(regressionModel, sample);
Double error = squaredError(truth, prediction);
Double derivative = prediction - truth;
for (int i = 0; i < regressionModel.size(); i++) {
Double weightGradient = derivative * sample.get(i);
Double currentSum = gradientSum.get(i);
gradientSum.set(i, currentSum + weightGradient);
}
}
for (int i = 0; i < regressionModel.size(); i++) {
Double oldWeight = regressionModel.get(i);
Double currentLR = learningRate / Math.sqrt(applyCount);
Double change = currentLR * (gradientSum.get(i) / batchSize);
regressionModel.set(i, oldWeight - change);
}
return regressionModel;
}
@Override
public void apply(GlobalWindow window, Iterable<ArrayList<Double>> values, Collector<ArrayList<Double>> out) throws Exception {
this.applyCount++;
ArrayList<Double> updatedModel = buildPartialModel(values);
modelState.update(updatedModel);
out.collect(updatedModel);
}
}
public static class Evaluator implements CoFlatMapFunction<ArrayList<Double>, ArrayList<Double>, Double> {
ArrayList<Double> partialModel = null;
@Override
public void flatMap1(ArrayList<Double> example, Collector<Double> out) throws Exception {
// System.out.format("Example: %f\n", example.get(0));
if (partialModel != null) {
System.out.println("Model was not null!");
Double prediction = predict(partialModel, example);
Double error = example.get(example.size() - 1) - prediction;
out.collect(error);
} else {
System.out.println("Model was null!");
}
}
@Override
public void flatMap2(ArrayList<Double> curModel, Collector<Double> out) throws Exception {
partialModel = curModel;
out.collect(partialModel.get(0));
}
}
public static class Predictor implements CoMapFunction<ArrayList<Double>, ArrayList<Double>, Double> {
ArrayList<Double> partialModel = null;
@Override
public Double map1(ArrayList<Double> example) throws Exception {
if (partialModel != null) {
System.out.println("Partial model ready!");
return predict(partialModel, example);
} else {
return Double.NaN;
}
}
@Override
public Double map2(ArrayList<Double> curModel) throws Exception {
partialModel = curModel;
return -1.0;
}
}
public static class VectorExtractor implements MapFunction<String, ArrayList<Double>> {
@Override
public ArrayList<Double> map(String s) throws Exception {
String[] elements = s.split(",");
ArrayList<Double> doubleElements = new ArrayList<>(elements.length);
for (int i = 0; i < elements.length; i++) {
doubleElements.add(new Double(elements[i]));
}
return doubleElements;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment