|
|
@@ -0,0 +1,283 @@ |
|
|
/* |
|
|
* Licensed to the Apache Software Foundation (ASF) under one or more |
|
|
* contributor license agreements. See the NOTICE file distributed with |
|
|
* this work for additional information regarding copyright ownership. |
|
|
* The ASF licenses this file to You under the Apache License, Version 2.0 |
|
|
* (the "License"); you may not use this file except in compliance with |
|
|
* the License. You may obtain a copy of the License at |
|
|
* |
|
|
* http://www.apache.org/licenses/LICENSE-2.0 |
|
|
* |
|
|
* Unless required by applicable law or agreed to in writing, software |
|
|
* distributed under the License is distributed on an "AS IS" BASIS, |
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
* See the License for the specific language governing permissions and |
|
|
* limitations under the License. |
|
|
*/ |
|
|
|
|
|
package se.sics.quickstart; |
|
|
|
|
|
import org.apache.flink.api.common.functions.MapFunction; |
|
|
import org.apache.flink.api.common.state.ValueState; |
|
|
import org.apache.flink.api.common.state.ValueStateDescriptor; |
|
|
import org.apache.flink.api.common.typeinfo.TypeHint; |
|
|
import org.apache.flink.api.common.typeinfo.TypeInformation; |
|
|
import org.apache.flink.api.java.utils.ParameterTool; |
|
|
import org.apache.flink.configuration.Configuration; |
|
|
import org.apache.flink.streaming.api.TimeCharacteristic; |
|
|
import org.apache.flink.streaming.api.datastream.DataStream; |
|
|
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; |
|
|
import org.apache.flink.streaming.api.functions.co.CoFlatMapFunction; |
|
|
import org.apache.flink.streaming.api.functions.co.CoMapFunction; |
|
|
import org.apache.flink.streaming.api.functions.source.SourceFunction; |
|
|
import org.apache.flink.streaming.api.functions.windowing.RichAllWindowFunction; |
|
|
import org.apache.flink.streaming.api.windowing.windows.GlobalWindow; |
|
|
import org.apache.flink.util.Collector; |
|
|
|
|
|
import java.util.ArrayList; |
|
|
import java.util.Collections; |
|
|
|
|
|
/** |
|
|
* Skeleton for incremental machine learning algorithm consisting of a |
|
|
* pre-computed model, which gets updated for the new inputs and new input data |
|
|
* for which the job provides predictions. |
|
|
* |
|
|
* <p> |
|
|
* This may serve as a base of a number of algorithms, e.g. updating an |
|
|
* incremental Alternating Least Squares model while also providing the |
|
|
* predictions. |
|
|
* |
|
|
* <p> |
|
|
* This example shows how to use: |
|
|
* <ul> |
|
|
* <li>Connected streams |
|
|
* <li>CoFunctions |
|
|
* <li>Tuple data types |
|
|
* </ul> |
|
|
*/ |
|
|
public class IncrementalLearning { |
|
|
|
|
|
|
|
|
// ************************************************************************* |
|
|
// PROGRAM |
|
|
// ************************************************************************* |
|
|
|
|
|
public static void main(String[] args) throws Exception { |
|
|
|
|
|
// Checking input parameters |
|
|
final ParameterTool params = ParameterTool.fromArgs(args); |
|
|
|
|
|
Double learningRate = params.has("learningRate") ? new Double(params.get("learningRate")) : 0.001; |
|
|
|
|
|
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); |
|
|
env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime); |
|
|
// env.setParallelism(1); |
|
|
|
|
|
// To simplify we make the assumption that the last element in each line is the dependent variable |
|
|
DataStream<ArrayList<Double>> trainingData = env.readTextFile(params.get("training")) |
|
|
.map(new VectorExtractor()); |
|
|
DataStream<ArrayList<Double>> newData = env.readTextFile(params.get("test")) |
|
|
.map(new VectorExtractor()); |
|
|
|
|
|
// build new model on every second of new data |
|
|
DataStream<ArrayList<Double>> model = trainingData |
|
|
.countWindowAll(Integer.parseInt(params.get("batchsize"))) |
|
|
.apply(new PartialModelBuilder(learningRate, Integer.parseInt(params.get("dimensions")))); |
|
|
|
|
|
// model.print(); |
|
|
|
|
|
// use partial model for newData |
|
|
DataStream<Double> errors = newData.connect(model).flatMap(new Evaluator()); |
|
|
|
|
|
errors.print(); |
|
|
|
|
|
// emit result |
|
|
// if (params.has("output")) { |
|
|
// prediction.writeAsText(params.get("output")); |
|
|
// } else { |
|
|
// System.out.println("Printing result to stdout. Use --output to specify output path."); |
|
|
// prediction.print(); |
|
|
// } |
|
|
|
|
|
// execute program |
|
|
env.execute("Streaming Incremental Learning"); |
|
|
} |
|
|
|
|
|
// ************************************************************************* |
|
|
// USER FUNCTIONS |
|
|
// ************************************************************************* |
|
|
|
|
|
/** |
|
|
* Feeds new data for newData. By default it is implemented as constantly |
|
|
* emitting the Integer 1 in a loop. |
|
|
*/ |
|
|
public static class FiniteNewDataSource implements SourceFunction<Integer> { |
|
|
private static final long serialVersionUID = 1L; |
|
|
private int counter; |
|
|
|
|
|
private String filepath; |
|
|
|
|
|
public FiniteNewDataSource(int counter, String filepath) { |
|
|
this.counter = counter; |
|
|
this.filepath = filepath; |
|
|
} |
|
|
|
|
|
@Override |
|
|
public void run(SourceContext<Integer> ctx) throws Exception { |
|
|
Thread.sleep(15); |
|
|
while (counter < 50) { |
|
|
ctx.collect(getNewData()); |
|
|
} |
|
|
} |
|
|
|
|
|
@Override |
|
|
public void cancel() { |
|
|
// No cleanup needed |
|
|
} |
|
|
|
|
|
private Integer getNewData() throws InterruptedException { |
|
|
Thread.sleep(5); |
|
|
counter++; |
|
|
return 1; |
|
|
} |
|
|
} |
|
|
|
|
|
private static Double predict(ArrayList<Double> model, ArrayList<Double> example) { |
|
|
Double prediction = 0.0; |
|
|
for (int i = 0; i < model.size(); i++) { |
|
|
prediction += model.get(i) * example.get(i); |
|
|
} |
|
|
return prediction; |
|
|
} |
|
|
|
|
|
/** |
|
|
* Builds up-to-date partial models on new training data. |
|
|
*/ |
|
|
public static class PartialModelBuilder extends RichAllWindowFunction<ArrayList<Double>, ArrayList<Double>, GlobalWindow> { |
|
|
|
|
|
public PartialModelBuilder(Double learningRate, int dimensions) { |
|
|
this.learningRate = learningRate; |
|
|
this.dimensions = dimensions; |
|
|
} |
|
|
|
|
|
private Double learningRate; |
|
|
private int dimensions; |
|
|
private int applyCount = 0; |
|
|
|
|
|
private static final long serialVersionUID = 1L; |
|
|
|
|
|
private transient ValueState<ArrayList<Double>> modelState; |
|
|
|
|
|
@Override |
|
|
public void open(Configuration config) { |
|
|
ArrayList<Double> allZeroes = new ArrayList<>(Collections.nCopies(dimensions, 0.0)); |
|
|
// obtain key-value state for prediction model |
|
|
// TODO: Do random assignment of weights instead of all zeros? |
|
|
ValueStateDescriptor<ArrayList<Double>> descriptor = |
|
|
new ValueStateDescriptor<>( |
|
|
// state name |
|
|
"modelState", |
|
|
// type information of state |
|
|
TypeInformation.of(new TypeHint<ArrayList<Double>>() {}), |
|
|
// default value of state |
|
|
allZeroes); |
|
|
modelState = getRuntimeContext().getState(descriptor); |
|
|
} |
|
|
|
|
|
private Double squaredError(Double truth, Double prediction) { |
|
|
return 0.5 * (truth - prediction) * (truth - prediction); |
|
|
} |
|
|
|
|
|
private ArrayList<Double> buildPartialModel(Iterable<ArrayList<Double>> trainingBatch) throws Exception{ |
|
|
int batchSize = 0; |
|
|
ArrayList<Double> regressionModel = modelState.value(); |
|
|
ArrayList<Double> gradientSum = new ArrayList<>(Collections.nCopies(dimensions, 0.0)); |
|
|
for (ArrayList<Double> sample : trainingBatch) { |
|
|
batchSize++; |
|
|
Double truth = sample.get(sample.size() - 1); |
|
|
Double prediction = predict(regressionModel, sample); |
|
|
Double error = squaredError(truth, prediction); |
|
|
Double derivative = prediction - truth; |
|
|
for (int i = 0; i < regressionModel.size(); i++) { |
|
|
Double weightGradient = derivative * sample.get(i); |
|
|
Double currentSum = gradientSum.get(i); |
|
|
gradientSum.set(i, currentSum + weightGradient); |
|
|
} |
|
|
} |
|
|
for (int i = 0; i < regressionModel.size(); i++) { |
|
|
Double oldWeight = regressionModel.get(i); |
|
|
Double currentLR = learningRate / Math.sqrt(applyCount); |
|
|
Double change = currentLR * (gradientSum.get(i) / batchSize); |
|
|
regressionModel.set(i, oldWeight - change); |
|
|
} |
|
|
return regressionModel; |
|
|
} |
|
|
|
|
|
@Override |
|
|
public void apply(GlobalWindow window, Iterable<ArrayList<Double>> values, Collector<ArrayList<Double>> out) throws Exception { |
|
|
this.applyCount++; |
|
|
|
|
|
ArrayList<Double> updatedModel = buildPartialModel(values); |
|
|
modelState.update(updatedModel); |
|
|
out.collect(updatedModel); |
|
|
} |
|
|
} |
|
|
|
|
|
public static class Evaluator implements CoFlatMapFunction<ArrayList<Double>, ArrayList<Double>, Double> { |
|
|
|
|
|
ArrayList<Double> partialModel = null; |
|
|
|
|
|
@Override |
|
|
public void flatMap1(ArrayList<Double> example, Collector<Double> out) throws Exception { |
|
|
// System.out.format("Example: %f\n", example.get(0)); |
|
|
if (partialModel != null) { |
|
|
System.out.println("Model was not null!"); |
|
|
Double prediction = predict(partialModel, example); |
|
|
Double error = example.get(example.size() - 1) - prediction; |
|
|
out.collect(error); |
|
|
} else { |
|
|
System.out.println("Model was null!"); |
|
|
} |
|
|
} |
|
|
|
|
|
@Override |
|
|
public void flatMap2(ArrayList<Double> curModel, Collector<Double> out) throws Exception { |
|
|
partialModel = curModel; |
|
|
out.collect(partialModel.get(0)); |
|
|
} |
|
|
} |
|
|
|
|
|
public static class Predictor implements CoMapFunction<ArrayList<Double>, ArrayList<Double>, Double> { |
|
|
ArrayList<Double> partialModel = null; |
|
|
|
|
|
@Override |
|
|
public Double map1(ArrayList<Double> example) throws Exception { |
|
|
if (partialModel != null) { |
|
|
System.out.println("Partial model ready!"); |
|
|
return predict(partialModel, example); |
|
|
} else { |
|
|
return Double.NaN; |
|
|
} |
|
|
} |
|
|
|
|
|
@Override |
|
|
public Double map2(ArrayList<Double> curModel) throws Exception { |
|
|
partialModel = curModel; |
|
|
return -1.0; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
public static class VectorExtractor implements MapFunction<String, ArrayList<Double>> { |
|
|
@Override |
|
|
public ArrayList<Double> map(String s) throws Exception { |
|
|
String[] elements = s.split(","); |
|
|
ArrayList<Double> doubleElements = new ArrayList<>(elements.length); |
|
|
for (int i = 0; i < elements.length; i++) { |
|
|
doubleElements.add(new Double(elements[i])); |
|
|
} |
|
|
return doubleElements; |
|
|
} |
|
|
} |
|
|
|
|
|
} |