use super::{ protocol::{self, Requester, Responder}, ProtocolError, }; use crate::Codec; use futures::{ channel::mpsc, future::BoxFuture, stream::FuturesUnordered, AsyncWriteExt, FutureExt, SinkExt, StreamExt, }; use libp2p::{ core::{ConnectedPoint, UpgradeError}, swarm::{ protocols_handler::{InboundUpgradeSend, OutboundUpgradeSend}, IntoProtocolsHandler, KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr, SubstreamProtocol, }, PeerId, }; use std::{ any::Any, collections::VecDeque, fmt::Debug, io::ErrorKind, marker::PhantomData, sync::Arc, task::{Context, Poll}, time::Duration, }; #[derive(Debug, PartialEq)] pub enum Response { Msg(T), Error(ProtocolError), Finished, } impl Response { pub fn into_msg(self) -> Result { match self { Response::Msg(msg) => Ok(msg), Response::Error(e) => Err(e), Response::Finished => Err(ProtocolError::Io(ErrorKind::UnexpectedEof.into())), } } } pub struct Request { request: T::Request, channel: mpsc::Sender>, } impl Request { pub fn new(request: T::Request, channel: mpsc::Sender>) -> Self { Self { request, channel } } } impl Debug for Request { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Request").field("request", &self.request).finish() } } pub struct RequestReceived { pub(crate) request: T::Request, pub(crate) channel: mpsc::Sender, } impl Debug for RequestReceived { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("RequestReceived") .field("request", &self.request) .finish() } } pub struct IntoHandler { spawner: Spawner, max_message_size: u32, request_timeout: Duration, response_send_buffer_size: usize, keep_alive: bool, _ph: PhantomData, } impl IntoHandler { pub fn new( spawner: Spawner, max_message_size: u32, request_timeout: Duration, response_send_buffer_size: usize, keep_alive: bool, ) -> Self { Self { spawner, max_message_size, request_timeout, response_send_buffer_size, keep_alive, _ph: PhantomData, } } } impl IntoProtocolsHandler for IntoHandler { type Handler = Handler; fn into_handler(self, _remote_peer_id: &PeerId, _connected_point: &ConnectedPoint) -> Self::Handler { Handler::new( self.spawner, self.max_message_size, self.request_timeout, self.response_send_buffer_size, self.keep_alive, ) } fn inbound_protocol(&self) -> ::InboundProtocol { Responder::new(self.max_message_size) } } type ProtocolEvent = ProtocolsHandlerEvent< Requester, mpsc::Sender::Response>>, RequestReceived, ProtocolError, >; pub type ResponseFuture = BoxFuture<'static, Box>; pub type Spawner = Arc ResponseFuture + Send + Sync + 'static>; pub struct Handler { events: VecDeque>, streams: FuturesUnordered, spawner: Spawner, max_message_size: u32, request_timeout: Duration, response_send_buffer_size: usize, keep_alive: bool, } impl Debug for Handler { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Handler") .field("events", &self.events.len()) .field("streams", &self.streams.len()) .finish() } } impl Handler { pub fn new( spawner: Spawner, max_message_size: u32, request_timeout: Duration, response_send_buffer_size: usize, keep_alive: bool, ) -> Self { Self { events: VecDeque::default(), streams: FuturesUnordered::default(), spawner, max_message_size, request_timeout, response_send_buffer_size, keep_alive, } } } impl ProtocolsHandler for Handler { type InEvent = Request; type OutEvent = RequestReceived; type Error = ProtocolError; type InboundProtocol = Responder; type OutboundProtocol = Requester; type InboundOpenInfo = (); type OutboundOpenInfo = mpsc::Sender>; fn listen_protocol(&self) -> SubstreamProtocol { SubstreamProtocol::new(Responder::new(self.max_message_size), ()).with_timeout(self.request_timeout) } fn inject_fully_negotiated_inbound( &mut self, protocol: ::Output, _info: Self::InboundOpenInfo, ) { let (request, mut stream) = protocol; let (channel, mut rx) = mpsc::channel(self.response_send_buffer_size); let max_message_size = self.max_message_size; log::trace!("handler received request"); let task = (self.spawner)( async move { log::trace!("starting send loop"); let mut buffer = Vec::new(); loop { // only flush once we’re going to sleep let response = match rx.try_next() { Ok(Some(r)) => r, Ok(None) => break, Err(_) => { log::trace!("flushing stream"); stream.flush().await?; match rx.next().await { Some(r) => r, None => break, } } }; protocol::write_msg(&mut stream, response, max_message_size, &mut buffer).await?; } log::trace!("flushing and closing substream"); protocol::write_finish(&mut stream).await?; Result::<_, ProtocolError>::Ok(()) } .map(|res| -> Box { Box::new(res) }) .boxed(), ); self.streams.push(task); self.events .push_back(ProtocolsHandlerEvent::Custom(RequestReceived { request, channel })); } fn inject_fully_negotiated_outbound( &mut self, mut stream: ::Output, mut tx: Self::OutboundOpenInfo, ) { let max_message_size = self.max_message_size; let task = (self.spawner)( async move { log::trace!("starting receive loop"); let mut buffer = Vec::new(); loop { match protocol::read_msg(&mut stream, max_message_size, &mut buffer) .await .unwrap_or_else(Response::Error) { Response::Msg(msg) => { tx.feed(Response::Msg(msg)).await?; log::trace!("response sent to client code"); } Response::Error(e) => { log::debug!("sending substream error {}", e); tx.feed(Response::Error(e)).await?; return Result::<_, ProtocolError>::Ok(()); } Response::Finished => { log::trace!("finishing substream"); tx.feed(Response::Finished).await?; return Ok(()); } } } } .map(|res| -> Box { Box::new(res) }) .boxed(), ); self.streams.push(task); } fn inject_event(&mut self, command: Self::InEvent) { let Request { request, channel } = command; log::trace!("requesting {:?}", request); self.events.push_back(ProtocolsHandlerEvent::OutboundSubstreamRequest { protocol: SubstreamProtocol::new(Requester::new(self.max_message_size, request), channel) .with_timeout(self.request_timeout), }) } fn inject_dial_upgrade_error( &mut self, mut tx: Self::OutboundOpenInfo, error: ProtocolsHandlerUpgrErr<::Error>, ) { let error = match error { ProtocolsHandlerUpgrErr::Timeout => ProtocolError::Timeout, ProtocolsHandlerUpgrErr::Timer => ProtocolError::Timeout, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)) => e, ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)) => e.into(), }; log::debug!("dial upgrade error: {}", error); if let Err(Response::Error(e)) = tx.try_send(Response::Error(error)).map_err(|e| e.into_inner()) { log::warn!("cannot send upgrade error to requester: {}", e); } } fn connection_keep_alive(&self) -> KeepAlive { if !self.keep_alive && self.streams.is_empty() { KeepAlive::No } else { KeepAlive::Yes } } fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { loop { if self.streams.is_empty() { break; } if let Poll::Ready(result) = self.streams.poll_next_unpin(cx) { // since the set was not empty, this must be a Some() if let Some(Err(e)) = result.and_then(|e| e.downcast::>().ok().map(|b| *b)) { // no need to tear down the connection, substream is already closed log::warn!("error in substream task: {}", e); } } else { break; } } match self.events.pop_front() { Some(e) => Poll::Ready(e), None => Poll::Pending, } } }