import java.util.Arrays; import java.util.List; import org.apache.hadoop.yarn.webapp.hamlet.HamletSpec.P; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.classification.LogisticRegressionModel; import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; import org.apache.spark.ml.feature.Tokenizer; import org.apache.spark.ml.linalg.Matrix; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.ml.tuning.CrossValidator; import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.RelationalGroupedDataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.api.java.UDF1; import org.apache.spark.sql.catalyst.expressions.Randn; import org.apache.spark.sql.expressions.Window; import org.apache.spark.sql.expressions.WindowSpec; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructType; import org.netlib.util.doubleW; import breeze.linalg.randn; import scala.Tuple2; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.mllib.evaluation.MulticlassMetrics; import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.IDF; import org.apache.spark.ml.feature.LabeledPoint; import org.apache.spark.ml.feature.StopWordsRemover; import org.apache.spark.ml.feature.StringIndexer; import static org.apache.spark.sql.functions.*; public class App { public static void main( String[] args ) { SparkSession spark = SparkSession .builder() .appName("Java Spark SQL Example") .getOrCreate(); StructType schema = new StructType() .add("word", "string") .add("polarity", "double") .add("category", "string"); Dataset df = spark.read() .option("mode", "DROPMALFORMED") .option("delimiter", "\t") .option("header", "true") .schema(schema) .csv("src/main/resources/SEL-utf-8.txt"); df.show(20); Dataset[] split = df.orderBy(rand()).randomSplit(new double[] {0.7, 0.3}); Dataset training = split[0]; Dataset test = split[1]; StringIndexer indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndexed"); Tokenizer tokenizer = new Tokenizer() .setInputCol("text") .setOutputCol("tokens"); StopWordsRemover stopWordsRemover = new StopWordsRemover() .setInputCol("tokens") .setOutputCol("cleardFromSopwords") .setStopWords(StopWordsRemover.loadDefaultStopWords("english")); HashingTF hashingTF = new HashingTF() .setInputCol("cleardFromSopwords") .setOutputCol("rawFeatures") .setNumFeatures(50000); IDF idf = new IDF() .setInputCol("rawFeatures") .setOutputCol("features"); LogisticRegression lr = new LogisticRegression() .setMaxIter(10) .setRegParam(0.3) .setFamily("multinomial") .setLabelCol("labelIndexed"); Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[] {indexer, tokenizer, stopWordsRemover, hashingTF, idf, lr}); ParamMap[] paramGrid = new ParamGridBuilder() .addGrid(lr.maxIter(), new int[] { 10, 20 }) .addGrid(lr.regParam(), new double[] { 0.1, 1.0 }) .addGrid(lr.elasticNetParam(), new double[] { 0.7 }) .addGrid(hashingTF.numFeatures(), new int[] {50000}) .build(); MulticlassClassificationEvaluator mce = new MulticlassClassificationEvaluator() .setLabelCol("labelIndexed") .setPredictionCol("prediction") .setMetricName("weightedPrecision"); CrossValidator validator = new CrossValidator() .setNumFolds(2) .setEstimator(pipeline) .setEvaluator(mce) .setEstimatorParamMaps(paramGrid); PipelineModel model = (PipelineModel) validator.fit(training).bestModel(); try { model.save("src/main/resources/model"); } catch(Exception e) {} Dataset predictions = model.transform(test); MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() .setLabelCol("labelIndexed") .setPredictionCol("prediction") .setMetricName("weightedPrecision"); double accuracy = evaluator.evaluate(predictions); predictions .withColumn("label", new Column("label")) .withColumn("labelIndexed", new Column("labelIndexed")) .withColumn("prediction", new Column("prediction")) .withColumn("text", new Column("text")) .select("label", "prediction", "labelIndexed", "text") .show(500); System.out.println("Weighted precision: " + accuracy); } }