package flows import akka.actor.ActorRef import akka.pattern.ask import akka.stream.ActorMaterializer import akka.stream.scaladsl.RunnableGraph import akka.stream.stage._ import akka.util.Timeout import flows.DetourStage.AwaitCompletion import scala.concurrent.ExecutionContext import scala.util.{Failure, Success, Try} object DetourStage { case object AwaitCompletion } class DetourStage[In, Out](g: RunnableGraph[(ActorRef, ActorRef)], timeout: Timeout) (implicit materializer: ActorMaterializer, ec: ExecutionContext) extends AsyncStage[In, Out, Try[Out]] { private var inFlight: Option[Out] = None override def onPush(elem: In, ctx: AsyncContext[Out, Try[Out]]): UpstreamDirective = { val (source, sink) = g.run() val future = ask(sink, AwaitCompletion)(timeout).map(_.asInstanceOf[Out]) val callback = ctx.getAsyncCallback() future.onComplete(callback.invoke) source ! elem ctx.holdUpstream() } override def onPull(ctx: AsyncContext[Out, Try[Out]]): DownstreamDirective = inFlight match { case Some(elem) => inFlight = None push(elem, ctx) case None => ctx.holdDownstream() } override def onAsyncInput(event: Try[Out], ctx: AsyncContext[Out, Try[Out]]): Directive = event match { case Failure(ex) => ctx.fail(ex) case Success(elem) if ctx.isHoldingDownstream => push(elem, ctx) case Success(elem) => inFlight = Some(elem) ctx.ignore() } override def onUpstreamFinish(ctx: AsyncContext[Out, Try[Out]]): TerminationDirective = { if (ctx.isHoldingUpstream) ctx.absorbTermination() else ctx.finish() } private def push(elem: Out, ctx: AsyncContext[Out, Try[Out]]): DownstreamDirective = { if (ctx.isFinishing) ctx.pushAndFinish(elem) else ctx.pushAndPull(elem) } }