Skip to content

Instantly share code, notes, and snippets.

@albrzykowski
Created August 24, 2018 00:55
Show Gist options
  • Save albrzykowski/c44834f2e3fc8049bbf26e710656a6d8 to your computer and use it in GitHub Desktop.
Save albrzykowski/c44834f2e3fc8049bbf26e710656a6d8 to your computer and use it in GitHub Desktop.

Revisions

  1. albrzykowski created this gist Aug 24, 2018.
    149 changes: 149 additions & 0 deletions TextClassification.java
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,149 @@
    import java.util.Arrays;
    import java.util.List;

    import org.apache.hadoop.yarn.webapp.hamlet.HamletSpec.P;
    import org.apache.spark.SparkConf;
    import org.apache.spark.api.java.JavaSparkContext;
    import org.apache.spark.api.java.function.MapFunction;
    import org.apache.spark.ml.Pipeline;
    import org.apache.spark.ml.PipelineModel;
    import org.apache.spark.ml.PipelineStage;
    import org.apache.spark.ml.classification.LogisticRegression;
    import org.apache.spark.ml.classification.LogisticRegressionModel;
    import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
    import org.apache.spark.ml.feature.Tokenizer;
    import org.apache.spark.ml.linalg.Matrix;
    import org.apache.spark.ml.param.ParamMap;
    import org.apache.spark.ml.tuning.CrossValidator;
    import org.apache.spark.ml.tuning.CrossValidatorModel;
    import org.apache.spark.ml.tuning.ParamGridBuilder;
    import org.apache.spark.sql.Column;
    import org.apache.spark.sql.Dataset;
    import org.apache.spark.sql.RelationalGroupedDataset;
    import org.apache.spark.sql.Row;
    import org.apache.spark.sql.SQLContext;
    import org.apache.spark.sql.SparkSession;
    import org.apache.spark.sql.api.java.UDF1;
    import org.apache.spark.sql.catalyst.expressions.Randn;
    import org.apache.spark.sql.expressions.Window;
    import org.apache.spark.sql.expressions.WindowSpec;
    import org.apache.spark.sql.types.DataTypes;
    import org.apache.spark.sql.types.StructType;
    import org.netlib.util.doubleW;

    import breeze.linalg.randn;
    import scala.Tuple2;

    import org.apache.spark.api.java.JavaPairRDD;
    import org.apache.spark.api.java.JavaRDD;
    import org.apache.spark.mllib.evaluation.MulticlassMetrics;
    import org.apache.spark.ml.feature.HashingTF;
    import org.apache.spark.ml.feature.IDF;
    import org.apache.spark.ml.feature.LabeledPoint;
    import org.apache.spark.ml.feature.StopWordsRemover;
    import org.apache.spark.ml.feature.StringIndexer;
    import static org.apache.spark.sql.functions.*;

    public class App {
    public static void main( String[] args ) {

    SparkSession spark = SparkSession
    .builder()
    .appName("Java Spark SQL Example")
    .getOrCreate();
    StructType schema = new StructType()
    .add("word", "string")
    .add("polarity", "double")
    .add("category", "string");

    Dataset<Row> df = spark.read()
    .option("mode", "DROPMALFORMED")
    .option("delimiter", "\t")
    .option("header", "true")
    .schema(schema)
    .csv("src/main/resources/SEL-utf-8.txt");

    df.show(20);

    Dataset<Row>[] split = df.orderBy(rand()).randomSplit(new double[] {0.7, 0.3});

    Dataset<Row> training = split[0];
    Dataset<Row> test = split[1];

    StringIndexer indexer = new StringIndexer()
    .setInputCol("label")
    .setOutputCol("labelIndexed");

    Tokenizer tokenizer = new Tokenizer()
    .setInputCol("text")
    .setOutputCol("tokens");

    StopWordsRemover stopWordsRemover = new StopWordsRemover()
    .setInputCol("tokens")
    .setOutputCol("cleardFromSopwords")
    .setStopWords(StopWordsRemover.loadDefaultStopWords("english"));

    HashingTF hashingTF = new HashingTF()
    .setInputCol("cleardFromSopwords")
    .setOutputCol("rawFeatures")
    .setNumFeatures(50000);

    IDF idf = new IDF()
    .setInputCol("rawFeatures")
    .setOutputCol("features");

    LogisticRegression lr = new LogisticRegression()
    .setMaxIter(10)
    .setRegParam(0.3)
    .setFamily("multinomial")
    .setLabelCol("labelIndexed");

    Pipeline pipeline = new Pipeline()
    .setStages(new PipelineStage[] {indexer, tokenizer, stopWordsRemover, hashingTF, idf, lr});

    ParamMap[] paramGrid = new ParamGridBuilder()
    .addGrid(lr.maxIter(), new int[] { 10, 20 })
    .addGrid(lr.regParam(), new double[] { 0.1, 1.0 })
    .addGrid(lr.elasticNetParam(), new double[] { 0.7 })
    .addGrid(hashingTF.numFeatures(), new int[] {50000})
    .build();

    MulticlassClassificationEvaluator mce = new MulticlassClassificationEvaluator()
    .setLabelCol("labelIndexed")
    .setPredictionCol("prediction")
    .setMetricName("weightedPrecision");

    CrossValidator validator = new CrossValidator()
    .setNumFolds(2)
    .setEstimator(pipeline)
    .setEvaluator(mce)
    .setEstimatorParamMaps(paramGrid);

    PipelineModel model = (PipelineModel) validator.fit(training).bestModel();

    try {
    model.save("src/main/resources/model");
    } catch(Exception e) {}

    Dataset<Row> predictions = model.transform(test);

    MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
    .setLabelCol("labelIndexed")
    .setPredictionCol("prediction")
    .setMetricName("weightedPrecision");

    double accuracy = evaluator.evaluate(predictions);

    predictions
    .withColumn("label", new Column("label"))
    .withColumn("labelIndexed", new Column("labelIndexed"))
    .withColumn("prediction", new Column("prediction"))
    .withColumn("text", new Column("text"))
    .select("label", "prediction", "labelIndexed", "text")
    .show(500);

    System.out.println("Weighted precision: " + accuracy);

    }

    }