/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package se.sics.quickstart;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeHint;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.TimeCharacteristic;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.co.CoFlatMapFunction;
import org.apache.flink.streaming.api.functions.co.CoMapFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.streaming.api.functions.windowing.RichAllWindowFunction;
import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
import org.apache.flink.util.Collector;
import java.util.ArrayList;
import java.util.Collections;
/**
* Skeleton for incremental machine learning algorithm consisting of a
* pre-computed model, which gets updated for the new inputs and new input data
* for which the job provides predictions.
*
*
* This may serve as a base of a number of algorithms, e.g. updating an
* incremental Alternating Least Squares model while also providing the
* predictions.
*
*
* This example shows how to use:
*
* - Connected streams
*
- CoFunctions
*
- Tuple data types
*
*/
public class IncrementalLearning {
// *************************************************************************
// PROGRAM
// *************************************************************************
public static void main(String[] args) throws Exception {
// Checking input parameters
final ParameterTool params = ParameterTool.fromArgs(args);
Double learningRate = params.has("learningRate") ? new Double(params.get("learningRate")) : 0.001;
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime);
// env.setParallelism(1);
// To simplify we make the assumption that the last element in each line is the dependent variable
DataStream> trainingData = env.readTextFile(params.get("training"))
.map(new VectorExtractor());
DataStream> newData = env.readTextFile(params.get("test"))
.map(new VectorExtractor());
// build new model on every second of new data
DataStream> model = trainingData
.countWindowAll(Integer.parseInt(params.get("batchsize")))
.apply(new PartialModelBuilder(learningRate, Integer.parseInt(params.get("dimensions"))));
// model.print();
// use partial model for newData
DataStream errors = newData.connect(model).flatMap(new Evaluator());
errors.print();
// emit result
// if (params.has("output")) {
// prediction.writeAsText(params.get("output"));
// } else {
// System.out.println("Printing result to stdout. Use --output to specify output path.");
// prediction.print();
// }
// execute program
env.execute("Streaming Incremental Learning");
}
// *************************************************************************
// USER FUNCTIONS
// *************************************************************************
/**
* Feeds new data for newData. By default it is implemented as constantly
* emitting the Integer 1 in a loop.
*/
public static class FiniteNewDataSource implements SourceFunction {
private static final long serialVersionUID = 1L;
private int counter;
private String filepath;
public FiniteNewDataSource(int counter, String filepath) {
this.counter = counter;
this.filepath = filepath;
}
@Override
public void run(SourceContext ctx) throws Exception {
Thread.sleep(15);
while (counter < 50) {
ctx.collect(getNewData());
}
}
@Override
public void cancel() {
// No cleanup needed
}
private Integer getNewData() throws InterruptedException {
Thread.sleep(5);
counter++;
return 1;
}
}
private static Double predict(ArrayList model, ArrayList example) {
Double prediction = 0.0;
for (int i = 0; i < model.size(); i++) {
prediction += model.get(i) * example.get(i);
}
return prediction;
}
/**
* Builds up-to-date partial models on new training data.
*/
public static class PartialModelBuilder extends RichAllWindowFunction, ArrayList, GlobalWindow> {
public PartialModelBuilder(Double learningRate, int dimensions) {
this.learningRate = learningRate;
this.dimensions = dimensions;
}
private Double learningRate;
private int dimensions;
private int applyCount = 0;
private static final long serialVersionUID = 1L;
private transient ValueState> modelState;
@Override
public void open(Configuration config) {
ArrayList allZeroes = new ArrayList<>(Collections.nCopies(dimensions, 0.0));
// obtain key-value state for prediction model
// TODO: Do random assignment of weights instead of all zeros?
ValueStateDescriptor> descriptor =
new ValueStateDescriptor<>(
// state name
"modelState",
// type information of state
TypeInformation.of(new TypeHint>() {}),
// default value of state
allZeroes);
modelState = getRuntimeContext().getState(descriptor);
}
private Double squaredError(Double truth, Double prediction) {
return 0.5 * (truth - prediction) * (truth - prediction);
}
private ArrayList buildPartialModel(Iterable> trainingBatch) throws Exception{
int batchSize = 0;
ArrayList regressionModel = modelState.value();
ArrayList gradientSum = new ArrayList<>(Collections.nCopies(dimensions, 0.0));
for (ArrayList sample : trainingBatch) {
batchSize++;
Double truth = sample.get(sample.size() - 1);
Double prediction = predict(regressionModel, sample);
Double error = squaredError(truth, prediction);
Double derivative = prediction - truth;
for (int i = 0; i < regressionModel.size(); i++) {
Double weightGradient = derivative * sample.get(i);
Double currentSum = gradientSum.get(i);
gradientSum.set(i, currentSum + weightGradient);
}
}
for (int i = 0; i < regressionModel.size(); i++) {
Double oldWeight = regressionModel.get(i);
Double currentLR = learningRate / Math.sqrt(applyCount);
Double change = currentLR * (gradientSum.get(i) / batchSize);
regressionModel.set(i, oldWeight - change);
}
return regressionModel;
}
@Override
public void apply(GlobalWindow window, Iterable> values, Collector> out) throws Exception {
this.applyCount++;
ArrayList updatedModel = buildPartialModel(values);
modelState.update(updatedModel);
out.collect(updatedModel);
}
}
public static class Evaluator implements CoFlatMapFunction, ArrayList, Double> {
ArrayList partialModel = null;
@Override
public void flatMap1(ArrayList example, Collector out) throws Exception {
// System.out.format("Example: %f\n", example.get(0));
if (partialModel != null) {
System.out.println("Model was not null!");
Double prediction = predict(partialModel, example);
Double error = example.get(example.size() - 1) - prediction;
out.collect(error);
} else {
System.out.println("Model was null!");
}
}
@Override
public void flatMap2(ArrayList curModel, Collector out) throws Exception {
partialModel = curModel;
out.collect(partialModel.get(0));
}
}
public static class Predictor implements CoMapFunction, ArrayList, Double> {
ArrayList partialModel = null;
@Override
public Double map1(ArrayList example) throws Exception {
if (partialModel != null) {
System.out.println("Partial model ready!");
return predict(partialModel, example);
} else {
return Double.NaN;
}
}
@Override
public Double map2(ArrayList curModel) throws Exception {
partialModel = curModel;
return -1.0;
}
}
public static class VectorExtractor implements MapFunction> {
@Override
public ArrayList map(String s) throws Exception {
String[] elements = s.split(",");
ArrayList doubleElements = new ArrayList<>(elements.length);
for (int i = 0; i < elements.length; i++) {
doubleElements.add(new Double(elements[i]));
}
return doubleElements;
}
}
}