import akka.actor._ import akka.stream.scaladsl.Flow import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.receiver.ActorHelper import akka.actor.{ ExtensionKey, Extension, ExtendedActorSystem } import scala.reflect.ClassTag object AkkaStreamSparkIntegration { /** * Returns an InputDStream of type FlowElementType along with a Flow map element that you can use to attach to * your flow. * * Requires your system to have proper Akka remote configurations set up. * * @param flowBufferSize In the event that the InputStream is not yet ready, how many elements from the Akka stream * should be buffered before dropping oldest entries * @param actorSystem * @param streamingContext * @tparam FlowElementType * @example * * Format: OFF * {{{ * import akka.actor.ActorSystem * import akka.stream.{ActorMaterializer, ClosedShape} * import akka.stream.javadsl.{Sink, RunnableGraph} * import akka.stream.scaladsl._ * import com.beachape.sparkka._ * implicit val actSys = ActorSystem() * implicit val materializer = ActorMaterializer() * * val g = RunnableGraph.fromGraph(GraphDSL.create() { implicit builder => * * import GraphDSL.Implicits._ * * val source = Source(1 to 10) * * val sink = builder.add(Sink.ignore) * val bCast = builder.add(Broadcast[Int](2)) * val merge = builder.add(Merge[Int](2)) * * // InputDStream can then be used to build elements of the graph that require integration with Spark * val (inputDStream, feedDInput) = SparkIntegration.streamConnection[Int]() * * val add1 = Flow[Int].map(_ + 1) * val times3 = Flow[Int].map(_ * 3) * source ~> bCast ~> add1 ~> merge ~> sink * bCast ~> times3 ~> feedDInput ~> merge * * ClosedShape * }) * }}} * Format: ON */ def streamConnection[FlowElementType: ClassTag](actorName: String = uuid(), flowBufferSize: Int = 5000)(implicit actorSystem: ActorSystem, streamingContext: StreamingContext): (ReceiverInputDStream[FlowElementType], Flow[FlowElementType, FlowElementType, Unit]) = { val feederActor = actorSystem.actorOf(Props(new FlowShimFeeder[FlowElementType](flowBufferSize))) val remoteAddress = RemoteAddressExtension(actorSystem).address val feederActorPath = feederActor.path.toStringWithAddress(remoteAddress) val inputDStreamFromActor = streamingContext.actorStream[FlowElementType](Props(new FlowShimPublisher(feederActorPath)), actorName) val flow = Flow[FlowElementType].map { p => feederActor ! p p } (inputDStreamFromActor, flow) } // Seems rather daft to need 2 actors to do this, but otherwise we run into serialisation problems with the Akka Stream private class FlowShimFeeder[FlowElementType: ClassTag](flowBufferSize: Int) extends Actor with ActorLogging { import context.become def receive = awaitingSubscriber(Nil) def awaitingSubscriber(toSend: Seq[FlowElementType]): Receive = { case d: FlowElementType => become(awaitingSubscriber(toSend.takeRight(flowBufferSize) :+ d)) case Subscribe(ref) => { toSend.foreach(ref ! _) become(subscribed(Seq(ref))) } case other => log.error(s"Received a random message: $other") } def subscribed(subscribers: Seq[ActorRef]): Receive = { case p: FlowElementType => subscribers.foreach(_ ! p) case Subscribe(ref) => become(subscribed(subscribers :+ ref)) case UnSubscribe(ref) => become(subscribed(subscribers.filterNot(_ == ref))) case other => log.error(s"Received a random message: $other") } } private class FlowShimPublisher[FlowElementType: ClassTag](feederAbsoluteAddress: String) extends Actor with ActorHelper { private lazy val feederActor = context.system.actorSelection(feederAbsoluteAddress) override def preStart(): Unit = feederActor ! Subscribe(self) override def postStop(): Unit = feederActor ! UnSubscribe(self) def receive = { case p: FlowElementType => store(p) case other => log.error(s"Received a random message: $other") } } private sealed case class Subscribe(ref: ActorRef) private sealed case class UnSubscribe(ref: ActorRef) private def uuid() = java.util.UUID.randomUUID.toString } // Helper classes for resolving absolute actor address class RemoteAddressExtensionImpl(system: ExtendedActorSystem) extends Extension { def address = system.provider.getDefaultAddress } object RemoteAddressExtension extends ExtensionKey[RemoteAddressExtensionImpl]