/** Hive/Pig/Cascading/Scalding-style inner join which will perform a map-side/replicated/broadcast * join if the "small" relation has fewer than maxNumRows, and a reduce-side join otherwise. * @param big the large relation * @param small the small relation * @maxNumRows the maximum number of rows that the small relation can have to be a * candidate for a map-side/replicated/broadcast join * @return a joined RDD with a common key and a tuple of values from the two * relations (the big relation value first, followed by the small one) */ private def optimizedInnerJoin[A : ClassTag, B : ClassTag, C : ClassTag] (big: RDD[(A, B)], small: RDD[(A, C)], maxNumRows: Long): RDD[(A, (B, C))] = { /* This is needed for efficiency's sake, since the choice between * map- and reduce-side joins is based on the row count of the * smaller relation. The count will materialize the small relation. * If it's too big for a map-side join, it will be already cached * for the reduce-side join. Caching is idempotent, so nothing * will happen if the dataset is already cached. */ small.cache() val joined = if (small.count() <= maxNumRows) { /* There was another solution to this, i.e. "small.collectAsMap()" * (http://ampcamp.berkeley.edu/wp-content/uploads/2012/06/matei-zaharia-amp-camp-2012-advanced-spark.pdf), * but that gives incorrect results since the map deduplicates entries with identical keys, * but that's a normal occurrence in MapReduce frameworks (that's the rationale for grouping * entries by key in the reduce stage). The simpler solution gives incorrect results * in these cases, which constitute the vast majority of key-value RDD use cases. */ val grouped: Map[A, Array[C]] = small. collect(). groupBy { case (key, _) => key }. map { case (key, kv: Array[(A, C)]) => (key, kv.map { case (_, v) => v }) } /* Broadcast the map representing the small relation to all nodes. * Joining against the big dataset will be done locally on each node * for all partitions at the map stage. This is called a map-side join * in Hadoop-land, or a replicated join in distributed relational * databases. In the Spark context, we can also call it a broadcast join. */ val smallBc = sc.broadcast(grouped) big.flatMap { case (a: A, b: B) if smallBc.value.contains(a) => smallBc.value(a).flatMap { case c => Some((a, (b, c))) } case _ => None } } else { // "Small" dataset is too big - do a regular reduce-side join using the RDD API big.join(small) } small.unpersist(blocking = false) joined }