Skip to content

Instantly share code, notes, and snippets.

@mgilham
Created October 14, 2015 19:58
Show Gist options
  • Save mgilham/8fab0a798778bf8e799c to your computer and use it in GitHub Desktop.
Save mgilham/8fab0a798778bf8e799c to your computer and use it in GitHub Desktop.
Spark SQL sumVector UDAF
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.mllib.linalg.SparseVector
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.VectorUDT
// MLG: this is highly unoptimized, but likely good enough for now
object sumVector extends UserDefinedAggregateFunction {
private val vectorUDT = new VectorUDT
private def addArray(agg: Array[Double], arr: Array[Double]) {
var i = 0
while(i < arr.length) {
agg(i) = agg(i) + arr(i)
i += 1
}
}
private def ensureArraySize(agg: Array[Double], size: Int): Array[Double] = {
if(size > agg.length) {
val newAgg = new Array[Double](size)
Array.copy(agg, 0, newAgg, 0, agg.length)
newAgg
} else {
agg
}
}
// Schema you get as an input
def inputSchema = new StructType().add("vec", vectorUDT)
// Schema of the row which is used for aggregation
def bufferSchema = new StructType().add("arr", ArrayType(DoubleType, false))
// Returned type
def dataType = vectorUDT
// Self-explaining
def deterministic = true
// zero value
def initialize(buffer: MutableAggregationBuffer) = buffer.update(0, Array[Double]())
// Similar to seqOp in aggregate
def update(buffer: MutableAggregationBuffer, input: Row) = {
if(!input.isNullAt(0)) {
val vec = input.getAs[Vector](0)
val arr: Array[Double] = vec.toArray
val agg: Array[Double] = ensureArraySize(buffer.getSeq[Double](0).toArray, arr.length)
addArray(agg, arr)
buffer.update(0, agg.toSeq)
}
}
// Similar to combOp in aggregate
def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
val agg2: Array[Double] = buffer2.getSeq[Double](0).toArray
val agg1: Array[Double] = ensureArraySize(buffer1.getSeq[Double](0).toArray, agg2.length)
addArray(agg1, agg2)
buffer1.update(0, agg1.toSeq)
}
// Called on exit to get return value
def evaluate(buffer: Row) = Vectors.dense(buffer.getSeq[Double](0).toArray).compressed
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment