Skip to content

Instantly share code, notes, and snippets.

@yaravind
Last active July 3, 2020 23:16
Show Gist options
  • Select an option

  • Save yaravind/3847afbd87d9f25ef9ae4ccf79b5f39d to your computer and use it in GitHub Desktop.

Select an option

Save yaravind/3847afbd87d9f25ef9ae4ccf79b5f39d to your computer and use it in GitHub Desktop.

Revisions

  1. yaravind revised this gist Jul 3, 2020. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion KMeansSparkMLToMLLib.scala
    Original file line number Diff line number Diff line change
    @@ -1,4 +1,4 @@
    mport org.apache.spark.mllib.clustering.BisectingKMeans
    import org.apache.spark.mllib.clustering.BisectingKMeans
    import org.apache.spark.mllib.linalg.Vectors
    import org.apache.spark.mllib.linalg.Vector

  2. yaravind revised this gist Jul 3, 2020. No changes.
  3. yaravind created this gist Jul 3, 2020.
    39 changes: 39 additions & 0 deletions KMeansSparkMLToMLLib.scala
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,39 @@
    mport org.apache.spark.mllib.clustering.BisectingKMeans
    import org.apache.spark.mllib.linalg.Vectors
    import org.apache.spark.mllib.linalg.Vector

    //std_features col is of type vector
    scaledFeatures.select($"std_features").printSchema()

    val tempFeatureRdd = scaledFeatures.select($"std_features").rdd

    import scala.reflect.runtime.universe._
    def getType[T: TypeTag](value: T) = typeOf[T]
    println("-------BEFORE")
    println("Type of RDD: "+getType(tempFeatureRdd))
    println("Type of column: "+getType(tempFeatureRdd.first()))

    /**
    create a new df of type RDD[org.apache.spark.mllib.linalg.Vector] by mapping
    RDD[org.apache.spark.sql.Row] to RDD[org.apache.spark.mllib.linalg.Vector]
    as BisectingKMeans works only with Vector type
    **/
    val input = scaledFeatures
    .select($"std_features")
    .rdd
    .map(v => Vectors.fromML(v.getAs[org.apache.spark.ml.linalg.Vector](0)))
    .cache() //important for ML algos to run faster
    println("-------AFTER")
    println("Type of RDD: "+getType(input))
    println("Type of column: "+getType(input.first()))

    println("Total rows: "+input.count())

    // Clustering the data into 9 clusters by BisectingKMeans.
    val bkm = new BisectingKMeans().setK(9)
    val model = bkm.run(input)

    println(s"Compute Cost: ${model.computeCost(input)}")
    model.clusterCenters.zipWithIndex.foreach { case (center, idx) =>
    println(s"Cluster Center ${idx}: ${center}")
    }