Created
October 14, 2015 19:58
-
-
Save mgilham/8fab0a798778bf8e799c to your computer and use it in GitHub Desktop.
Spark SQL sumVector UDAF
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import org.apache.spark.sql.expressions.MutableAggregationBuffer | |
| import org.apache.spark.sql.expressions.UserDefinedAggregateFunction | |
| import org.apache.spark.sql.Row | |
| import org.apache.spark.sql.types._ | |
| import org.apache.spark.mllib.linalg.SparseVector | |
| import org.apache.spark.mllib.linalg.Vector | |
| import org.apache.spark.mllib.linalg.Vectors | |
| import org.apache.spark.mllib.linalg.VectorUDT | |
| // MLG: this is highly unoptimized, but likely good enough for now | |
| object sumVector extends UserDefinedAggregateFunction { | |
| private val vectorUDT = new VectorUDT | |
| private def addArray(agg: Array[Double], arr: Array[Double]) { | |
| var i = 0 | |
| while(i < arr.length) { | |
| agg(i) = agg(i) + arr(i) | |
| i += 1 | |
| } | |
| } | |
| private def ensureArraySize(agg: Array[Double], size: Int): Array[Double] = { | |
| if(size > agg.length) { | |
| val newAgg = new Array[Double](size) | |
| Array.copy(agg, 0, newAgg, 0, agg.length) | |
| newAgg | |
| } else { | |
| agg | |
| } | |
| } | |
| // Schema you get as an input | |
| def inputSchema = new StructType().add("vec", vectorUDT) | |
| // Schema of the row which is used for aggregation | |
| def bufferSchema = new StructType().add("arr", ArrayType(DoubleType, false)) | |
| // Returned type | |
| def dataType = vectorUDT | |
| // Self-explaining | |
| def deterministic = true | |
| // zero value | |
| def initialize(buffer: MutableAggregationBuffer) = buffer.update(0, Array[Double]()) | |
| // Similar to seqOp in aggregate | |
| def update(buffer: MutableAggregationBuffer, input: Row) = { | |
| if(!input.isNullAt(0)) { | |
| val vec = input.getAs[Vector](0) | |
| val arr: Array[Double] = vec.toArray | |
| val agg: Array[Double] = ensureArraySize(buffer.getSeq[Double](0).toArray, arr.length) | |
| addArray(agg, arr) | |
| buffer.update(0, agg.toSeq) | |
| } | |
| } | |
| // Similar to combOp in aggregate | |
| def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = { | |
| val agg2: Array[Double] = buffer2.getSeq[Double](0).toArray | |
| val agg1: Array[Double] = ensureArraySize(buffer1.getSeq[Double](0).toArray, agg2.length) | |
| addArray(agg1, agg2) | |
| buffer1.update(0, agg1.toSeq) | |
| } | |
| // Called on exit to get return value | |
| def evaluate(buffer: Row) = Vectors.dense(buffer.getSeq[Double](0).toArray).compressed | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment