/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package se.sics.quickstart; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.state.ValueState; import org.apache.flink.api.common.state.ValueStateDescriptor; import org.apache.flink.api.common.typeinfo.TypeHint; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.utils.ParameterTool; import org.apache.flink.configuration.Configuration; import org.apache.flink.streaming.api.TimeCharacteristic; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; import org.apache.flink.streaming.api.functions.co.CoFlatMapFunction; import org.apache.flink.streaming.api.functions.co.CoMapFunction; import org.apache.flink.streaming.api.functions.source.SourceFunction; import org.apache.flink.streaming.api.functions.windowing.RichAllWindowFunction; import org.apache.flink.streaming.api.windowing.windows.GlobalWindow; import org.apache.flink.util.Collector; import java.util.ArrayList; import java.util.Collections; /** * Skeleton for incremental machine learning algorithm consisting of a * pre-computed model, which gets updated for the new inputs and new input data * for which the job provides predictions. * *

* This may serve as a base of a number of algorithms, e.g. updating an * incremental Alternating Least Squares model while also providing the * predictions. * *

* This example shows how to use: *

*/ public class IncrementalLearning { // ************************************************************************* // PROGRAM // ************************************************************************* public static void main(String[] args) throws Exception { // Checking input parameters final ParameterTool params = ParameterTool.fromArgs(args); Double learningRate = params.has("learningRate") ? new Double(params.get("learningRate")) : 0.001; StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); env.setStreamTimeCharacteristic(TimeCharacteristic.ProcessingTime); // env.setParallelism(1); // To simplify we make the assumption that the last element in each line is the dependent variable DataStream> trainingData = env.readTextFile(params.get("training")) .map(new VectorExtractor()); DataStream> newData = env.readTextFile(params.get("test")) .map(new VectorExtractor()); // build new model on every second of new data DataStream> model = trainingData .countWindowAll(Integer.parseInt(params.get("batchsize"))) .apply(new PartialModelBuilder(learningRate, Integer.parseInt(params.get("dimensions")))); // model.print(); // use partial model for newData DataStream errors = newData.connect(model).flatMap(new Evaluator()); errors.print(); // emit result // if (params.has("output")) { // prediction.writeAsText(params.get("output")); // } else { // System.out.println("Printing result to stdout. Use --output to specify output path."); // prediction.print(); // } // execute program env.execute("Streaming Incremental Learning"); } // ************************************************************************* // USER FUNCTIONS // ************************************************************************* /** * Feeds new data for newData. By default it is implemented as constantly * emitting the Integer 1 in a loop. */ public static class FiniteNewDataSource implements SourceFunction { private static final long serialVersionUID = 1L; private int counter; private String filepath; public FiniteNewDataSource(int counter, String filepath) { this.counter = counter; this.filepath = filepath; } @Override public void run(SourceContext ctx) throws Exception { Thread.sleep(15); while (counter < 50) { ctx.collect(getNewData()); } } @Override public void cancel() { // No cleanup needed } private Integer getNewData() throws InterruptedException { Thread.sleep(5); counter++; return 1; } } private static Double predict(ArrayList model, ArrayList example) { Double prediction = 0.0; for (int i = 0; i < model.size(); i++) { prediction += model.get(i) * example.get(i); } return prediction; } /** * Builds up-to-date partial models on new training data. */ public static class PartialModelBuilder extends RichAllWindowFunction, ArrayList, GlobalWindow> { public PartialModelBuilder(Double learningRate, int dimensions) { this.learningRate = learningRate; this.dimensions = dimensions; } private Double learningRate; private int dimensions; private int applyCount = 0; private static final long serialVersionUID = 1L; private transient ValueState> modelState; @Override public void open(Configuration config) { ArrayList allZeroes = new ArrayList<>(Collections.nCopies(dimensions, 0.0)); // obtain key-value state for prediction model // TODO: Do random assignment of weights instead of all zeros? ValueStateDescriptor> descriptor = new ValueStateDescriptor<>( // state name "modelState", // type information of state TypeInformation.of(new TypeHint>() {}), // default value of state allZeroes); modelState = getRuntimeContext().getState(descriptor); } private Double squaredError(Double truth, Double prediction) { return 0.5 * (truth - prediction) * (truth - prediction); } private ArrayList buildPartialModel(Iterable> trainingBatch) throws Exception{ int batchSize = 0; ArrayList regressionModel = modelState.value(); ArrayList gradientSum = new ArrayList<>(Collections.nCopies(dimensions, 0.0)); for (ArrayList sample : trainingBatch) { batchSize++; Double truth = sample.get(sample.size() - 1); Double prediction = predict(regressionModel, sample); Double error = squaredError(truth, prediction); Double derivative = prediction - truth; for (int i = 0; i < regressionModel.size(); i++) { Double weightGradient = derivative * sample.get(i); Double currentSum = gradientSum.get(i); gradientSum.set(i, currentSum + weightGradient); } } for (int i = 0; i < regressionModel.size(); i++) { Double oldWeight = regressionModel.get(i); Double currentLR = learningRate / Math.sqrt(applyCount); Double change = currentLR * (gradientSum.get(i) / batchSize); regressionModel.set(i, oldWeight - change); } return regressionModel; } @Override public void apply(GlobalWindow window, Iterable> values, Collector> out) throws Exception { this.applyCount++; ArrayList updatedModel = buildPartialModel(values); modelState.update(updatedModel); out.collect(updatedModel); } } public static class Evaluator implements CoFlatMapFunction, ArrayList, Double> { ArrayList partialModel = null; @Override public void flatMap1(ArrayList example, Collector out) throws Exception { // System.out.format("Example: %f\n", example.get(0)); if (partialModel != null) { System.out.println("Model was not null!"); Double prediction = predict(partialModel, example); Double error = example.get(example.size() - 1) - prediction; out.collect(error); } else { System.out.println("Model was null!"); } } @Override public void flatMap2(ArrayList curModel, Collector out) throws Exception { partialModel = curModel; out.collect(partialModel.get(0)); } } public static class Predictor implements CoMapFunction, ArrayList, Double> { ArrayList partialModel = null; @Override public Double map1(ArrayList example) throws Exception { if (partialModel != null) { System.out.println("Partial model ready!"); return predict(partialModel, example); } else { return Double.NaN; } } @Override public Double map2(ArrayList curModel) throws Exception { partialModel = curModel; return -1.0; } } public static class VectorExtractor implements MapFunction> { @Override public ArrayList map(String s) throws Exception { String[] elements = s.split(","); ArrayList doubleElements = new ArrayList<>(elements.length); for (int i = 0; i < elements.length; i++) { doubleElements.add(new Double(elements[i])); } return doubleElements; } } }