Skip to content

Instantly share code, notes, and snippets.

@rkuhn
Created May 22, 2022 10:12
Show Gist options
  • Select an option

  • Save rkuhn/413aa0cb4f7415bbb10c3cddd1fa0615 to your computer and use it in GitHub Desktop.

Select an option

Save rkuhn/413aa0cb4f7415bbb10c3cddd1fa0615 to your computer and use it in GitHub Desktop.

Revisions

  1. rkuhn created this gist May 22, 2022.
    88 changes: 88 additions & 0 deletions behaviour.rs
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,88 @@
    use super::{
    handler::{self, IntoHandler, Request, Response},
    RequestReceived, StreamingResponseConfig,
    };
    use crate::Codec;
    use futures::channel::mpsc;
    use libp2p::{
    core::connection::ConnectionId,
    swarm::{NetworkBehaviour, NetworkBehaviourAction, NotifyHandler, PollParameters},
    PeerId,
    };
    use std::{
    collections::VecDeque,
    marker::PhantomData,
    task::{Context, Poll},
    };

    pub struct StreamingResponse<T: Codec + Send + 'static> {
    config: StreamingResponseConfig,
    events: VecDeque<RequestReceived<T>>,
    requests: VecDeque<NetworkBehaviourAction<RequestReceived<T>, IntoHandler<T>>>,
    _ph: PhantomData<T>,
    }

    impl<T: Codec + Send + 'static> StreamingResponse<T> {
    pub fn new(config: StreamingResponseConfig) -> Self {
    Self {
    config,
    events: VecDeque::default(),
    requests: VecDeque::default(),
    _ph: PhantomData,
    }
    }

    pub fn request(&mut self, peer_id: PeerId, request: T::Request, channel: mpsc::Sender<Response<T::Response>>) {
    self.requests.push_back(NetworkBehaviourAction::NotifyHandler {
    peer_id,
    handler: NotifyHandler::Any,
    event: Request::new(request, channel),
    })
    }
    }

    impl<T: Codec + Send + 'static> NetworkBehaviour for StreamingResponse<T> {
    type ProtocolsHandler = IntoHandler<T>;
    type OutEvent = RequestReceived<T>;

    fn new_handler(&mut self) -> Self::ProtocolsHandler {
    IntoHandler::new(
    self.config.spawner.clone(),
    self.config.max_message_size,
    self.config.request_timeout,
    self.config.response_send_buffer_size,
    self.config.keep_alive,
    )
    }

    fn inject_event(
    &mut self,
    peer_id: PeerId,
    connection: ConnectionId,
    event: <<Self::ProtocolsHandler as libp2p::swarm::IntoProtocolsHandler>::Handler as libp2p::swarm::ProtocolsHandler>::OutEvent,
    ) {
    let handler::RequestReceived { request, channel } = event;
    log::trace!("request received by behaviour: {:?}", request);
    self.events.push_back(RequestReceived {
    peer_id,
    connection,
    request,
    channel,
    });
    }

    fn poll(
    &mut self,
    _cx: &mut Context<'_>,
    _params: &mut impl PollParameters,
    ) -> Poll<NetworkBehaviourAction<Self::OutEvent, Self::ProtocolsHandler>> {
    if let Some(action) = self.requests.pop_front() {
    log::trace!("triggering request action");
    return Poll::Ready(action);
    }
    match self.events.pop_front() {
    Some(e) => Poll::Ready(NetworkBehaviourAction::GenerateEvent(e)),
    None => Poll::Pending,
    }
    }
    }
    317 changes: 317 additions & 0 deletions handler.rs
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,317 @@
    use super::{
    protocol::{self, Requester, Responder},
    ProtocolError,
    };
    use crate::Codec;
    use futures::{
    channel::mpsc, future::BoxFuture, stream::FuturesUnordered, AsyncWriteExt, FutureExt, SinkExt, StreamExt,
    };
    use libp2p::{
    core::{ConnectedPoint, UpgradeError},
    swarm::{
    protocols_handler::{InboundUpgradeSend, OutboundUpgradeSend},
    IntoProtocolsHandler, KeepAlive, ProtocolsHandler, ProtocolsHandlerEvent, ProtocolsHandlerUpgrErr,
    SubstreamProtocol,
    },
    PeerId,
    };
    use std::{
    any::Any,
    collections::VecDeque,
    fmt::Debug,
    io::ErrorKind,
    marker::PhantomData,
    sync::Arc,
    task::{Context, Poll},
    time::Duration,
    };

    #[derive(Debug, PartialEq)]
    pub enum Response<T> {
    Msg(T),
    Error(ProtocolError),
    Finished,
    }

    impl<T> Response<T> {
    pub fn into_msg(self) -> Result<T, ProtocolError> {
    match self {
    Response::Msg(msg) => Ok(msg),
    Response::Error(e) => Err(e),
    Response::Finished => Err(ProtocolError::Io(ErrorKind::UnexpectedEof.into())),
    }
    }
    }

    pub struct Request<T: Codec> {
    request: T::Request,
    channel: mpsc::Sender<Response<T::Response>>,
    }

    impl<T: Codec> Request<T> {
    pub fn new(request: T::Request, channel: mpsc::Sender<Response<T::Response>>) -> Self {
    Self { request, channel }
    }
    }

    impl<T: Codec> Debug for Request<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    f.debug_struct("Request").field("request", &self.request).finish()
    }
    }

    pub struct RequestReceived<T: Codec> {
    pub(crate) request: T::Request,
    pub(crate) channel: mpsc::Sender<T::Response>,
    }

    impl<T: Codec> Debug for RequestReceived<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    f.debug_struct("RequestReceived")
    .field("request", &self.request)
    .finish()
    }
    }

    pub struct IntoHandler<T> {
    spawner: Spawner,
    max_message_size: u32,
    request_timeout: Duration,
    response_send_buffer_size: usize,
    keep_alive: bool,
    _ph: PhantomData<T>,
    }

    impl<T> IntoHandler<T> {
    pub fn new(
    spawner: Spawner,
    max_message_size: u32,
    request_timeout: Duration,
    response_send_buffer_size: usize,
    keep_alive: bool,
    ) -> Self {
    Self {
    spawner,
    max_message_size,
    request_timeout,
    response_send_buffer_size,
    keep_alive,
    _ph: PhantomData,
    }
    }
    }

    impl<T: Codec + Send + 'static> IntoProtocolsHandler for IntoHandler<T> {
    type Handler = Handler<T>;

    fn into_handler(self, _remote_peer_id: &PeerId, _connected_point: &ConnectedPoint) -> Self::Handler {
    Handler::new(
    self.spawner,
    self.max_message_size,
    self.request_timeout,
    self.response_send_buffer_size,
    self.keep_alive,
    )
    }

    fn inbound_protocol(&self) -> <Self::Handler as ProtocolsHandler>::InboundProtocol {
    Responder::new(self.max_message_size)
    }
    }

    type ProtocolEvent<T> = ProtocolsHandlerEvent<
    Requester<T>,
    mpsc::Sender<Response<<T as Codec>::Response>>,
    RequestReceived<T>,
    ProtocolError,
    >;
    pub type ResponseFuture = BoxFuture<'static, Box<dyn Any + Send + 'static>>;
    pub type Spawner = Arc<dyn Fn(ResponseFuture) -> ResponseFuture + Send + Sync + 'static>;

    pub struct Handler<T: Codec> {
    events: VecDeque<ProtocolEvent<T>>,
    streams: FuturesUnordered<ResponseFuture>,
    spawner: Spawner,
    max_message_size: u32,
    request_timeout: Duration,
    response_send_buffer_size: usize,
    keep_alive: bool,
    }

    impl<T: Codec> Debug for Handler<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    f.debug_struct("Handler")
    .field("events", &self.events.len())
    .field("streams", &self.streams.len())
    .finish()
    }
    }

    impl<T: Codec> Handler<T> {
    pub fn new(
    spawner: Spawner,
    max_message_size: u32,
    request_timeout: Duration,
    response_send_buffer_size: usize,
    keep_alive: bool,
    ) -> Self {
    Self {
    events: VecDeque::default(),
    streams: FuturesUnordered::default(),
    spawner,
    max_message_size,
    request_timeout,
    response_send_buffer_size,
    keep_alive,
    }
    }
    }

    impl<T: Codec + Send + 'static> ProtocolsHandler for Handler<T> {
    type InEvent = Request<T>;
    type OutEvent = RequestReceived<T>;
    type Error = ProtocolError;
    type InboundProtocol = Responder<T>;
    type OutboundProtocol = Requester<T>;
    type InboundOpenInfo = ();
    type OutboundOpenInfo = mpsc::Sender<Response<T::Response>>;

    fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
    SubstreamProtocol::new(Responder::new(self.max_message_size), ()).with_timeout(self.request_timeout)
    }

    fn inject_fully_negotiated_inbound(
    &mut self,
    protocol: <Self::InboundProtocol as InboundUpgradeSend>::Output,
    _info: Self::InboundOpenInfo,
    ) {
    let (request, mut stream) = protocol;
    let (channel, mut rx) = mpsc::channel(self.response_send_buffer_size);
    let max_message_size = self.max_message_size;
    log::trace!("handler received request");
    let task = (self.spawner)(
    async move {
    log::trace!("starting send loop");
    let mut buffer = Vec::new();
    loop {
    // only flush once we’re going to sleep
    let response = match rx.try_next() {
    Ok(Some(r)) => r,
    Ok(None) => break,
    Err(_) => {
    log::trace!("flushing stream");
    stream.flush().await?;
    match rx.next().await {
    Some(r) => r,
    None => break,
    }
    }
    };
    protocol::write_msg(&mut stream, response, max_message_size, &mut buffer).await?;
    }
    log::trace!("flushing and closing substream");
    protocol::write_finish(&mut stream).await?;
    Result::<_, ProtocolError>::Ok(())
    }
    .map(|res| -> Box<dyn Any + Send + 'static> { Box::new(res) })
    .boxed(),
    );
    self.streams.push(task);
    self.events
    .push_back(ProtocolsHandlerEvent::Custom(RequestReceived { request, channel }));
    }

    fn inject_fully_negotiated_outbound(
    &mut self,
    mut stream: <Self::OutboundProtocol as OutboundUpgradeSend>::Output,
    mut tx: Self::OutboundOpenInfo,
    ) {
    let max_message_size = self.max_message_size;
    let task = (self.spawner)(
    async move {
    log::trace!("starting receive loop");
    let mut buffer = Vec::new();
    loop {
    match protocol::read_msg(&mut stream, max_message_size, &mut buffer)
    .await
    .unwrap_or_else(Response::Error)
    {
    Response::Msg(msg) => {
    tx.feed(Response::Msg(msg)).await?;
    log::trace!("response sent to client code");
    }
    Response::Error(e) => {
    log::debug!("sending substream error {}", e);
    tx.feed(Response::Error(e)).await?;
    return Result::<_, ProtocolError>::Ok(());
    }
    Response::Finished => {
    log::trace!("finishing substream");
    tx.feed(Response::Finished).await?;
    return Ok(());
    }
    }
    }
    }
    .map(|res| -> Box<dyn Any + Send + 'static> { Box::new(res) })
    .boxed(),
    );
    self.streams.push(task);
    }

    fn inject_event(&mut self, command: Self::InEvent) {
    let Request { request, channel } = command;
    log::trace!("requesting {:?}", request);
    self.events.push_back(ProtocolsHandlerEvent::OutboundSubstreamRequest {
    protocol: SubstreamProtocol::new(Requester::new(self.max_message_size, request), channel)
    .with_timeout(self.request_timeout),
    })
    }

    fn inject_dial_upgrade_error(
    &mut self,
    mut tx: Self::OutboundOpenInfo,
    error: ProtocolsHandlerUpgrErr<<Self::OutboundProtocol as OutboundUpgradeSend>::Error>,
    ) {
    let error = match error {
    ProtocolsHandlerUpgrErr::Timeout => ProtocolError::Timeout,
    ProtocolsHandlerUpgrErr::Timer => ProtocolError::Timeout,
    ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Apply(e)) => e,
    ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(e)) => e.into(),
    };
    log::debug!("dial upgrade error: {}", error);
    if let Err(Response::Error(e)) = tx.try_send(Response::Error(error)).map_err(|e| e.into_inner()) {
    log::warn!("cannot send upgrade error to requester: {}", e);
    }
    }

    fn connection_keep_alive(&self) -> KeepAlive {
    if !self.keep_alive && self.streams.is_empty() {
    KeepAlive::No
    } else {
    KeepAlive::Yes
    }
    }

    fn poll(&mut self, cx: &mut Context<'_>) -> Poll<ProtocolEvent<T>> {
    loop {
    if self.streams.is_empty() {
    break;
    }
    if let Poll::Ready(result) = self.streams.poll_next_unpin(cx) {
    // since the set was not empty, this must be a Some()
    if let Some(Err(e)) = result.and_then(|e| e.downcast::<Result<(), ProtocolError>>().ok().map(|b| *b)) {
    // no need to tear down the connection, substream is already closed
    log::warn!("error in substream task: {}", e);
    }
    } else {
    break;
    }
    }

    match self.events.pop_front() {
    Some(e) => Poll::Ready(e),
    None => Poll::Pending,
    }
    }
    }
    102 changes: 102 additions & 0 deletions mod.rs
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,102 @@
    mod behaviour;
    mod handler;
    mod protocol;
    #[cfg(test)]
    mod tests;

    pub use behaviour::StreamingResponse;
    pub use handler::{Response, ResponseFuture, Spawner};
    pub use protocol::ProtocolError;

    use crate::Codec;
    use futures::channel::mpsc;
    use libp2p::{core::connection::ConnectionId, PeerId};
    use std::{fmt::Debug, sync::Arc, time::Duration};

    pub struct RequestReceived<T: Codec> {
    pub peer_id: PeerId,
    pub connection: ConnectionId,
    pub request: T::Request,
    pub channel: mpsc::Sender<T::Response>,
    }

    impl<T: Codec> Debug for RequestReceived<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    f.debug_struct("RequestReceived")
    .field("peer_id", &self.peer_id)
    .field("connection", &self.connection)
    .field("request", &self.request)
    .finish()
    }
    }

    pub struct StreamingResponseConfig {
    pub spawner: Spawner,
    pub request_timeout: Duration,
    pub max_message_size: u32,
    pub response_send_buffer_size: usize,
    pub keep_alive: bool,
    }

    impl StreamingResponseConfig {
    /// Spawn response stream handling tasks using the given function
    ///
    /// This function may be called from an arbitrary context, you cannot assume that because
    /// you’re using Tokio this will happen on a Tokio thread. Hence it is necessary to point
    /// to the target thread pool directly, e.g. by using a runtime handle.
    ///
    /// If this method is not used, tasks will be polled via the Swarm, which may be an I/O
    /// bottleneck.
    pub fn with_spawner(self, spawner: impl Fn(ResponseFuture) -> ResponseFuture + Send + Sync + 'static) -> Self {
    Self {
    spawner: Arc::new(spawner),
    ..self
    }
    }
    /// Timeout for the transmission of the request to the peer, default is 10sec
    pub fn with_request_timeout(self, request_timeout: Duration) -> Self {
    Self {
    request_timeout,
    ..self
    }
    }
    /// Maximum message size permitted for requests and responses
    ///
    /// The maximum is 4GiB, the default 1MB. Sending huge messages requires corresponding
    /// buffers and may not be desirable.
    pub fn with_max_message_size(self, max_message_size: u32) -> Self {
    Self {
    max_message_size,
    ..self
    }
    }
    /// Set the queue size in messages for the channel created for incoming requests
    ///
    /// All channels are bounded in size and use back-pressure. This channel size allows some
    /// decoupling between response generation and network transmission. Default is 128.
    pub fn with_response_send_buffer_size(self, response_send_buffer_size: usize) -> Self {
    Self {
    response_send_buffer_size,
    ..self
    }
    }
    /// If this is set to true, then this behaviour will keep the connection alive
    ///
    /// Otherwise the connection is released (i.e. closed if no other behaviour keeps it alive)
    /// when there are no active requests ongoing. Default is `false`.
    pub fn with_keep_alive(self, keep_alive: bool) -> Self {
    Self { keep_alive, ..self }
    }
    }

    impl Default for StreamingResponseConfig {
    fn default() -> Self {
    Self {
    spawner: Arc::new(|f| f),
    request_timeout: Duration::from_secs(10),
    max_message_size: 1_000_000,
    response_send_buffer_size: 128,
    keep_alive: false,
    }
    }
    }
    277 changes: 277 additions & 0 deletions protocol.rs
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,277 @@
    use super::handler::Response;
    use crate::Codec;
    use derive_more::{Display, Error, From};
    use futures::{channel::mpsc, future::BoxFuture, AsyncReadExt, AsyncWriteExt, FutureExt};
    use libp2p::{
    core::{upgrade::NegotiationError, UpgradeInfo},
    swarm::NegotiatedSubstream,
    InboundUpgrade, OutboundUpgrade,
    };
    use serde::de::DeserializeOwned;
    use std::{
    fmt::{Display, Write},
    io::ErrorKind,
    iter::{once, Once},
    marker::PhantomData,
    };

    #[derive(Error, Display, Debug, From)]
    pub enum ProtocolError {
    #[display(fmt = "timeout while waiting for request receive confirmation")]
    Timeout,
    #[display(fmt = "message too large received: {}", _0)]
    #[from(ignore)]
    MessageTooLargeRecv(#[error(ignore)] usize),
    #[display(fmt = "message too large sent: {}", _0)]
    #[from(ignore)]
    MessageTooLargeSent(#[error(ignore)] usize),
    #[display(fmt = "substream protocol negotiation error: {}", _0)]
    Negotiation(NegotiationError),
    #[display(fmt = "I/O error: {}", _0)]
    Io(std::io::Error),
    #[display(fmt = "(de)serialisation error: {}", _0)]
    Serde(serde_cbor::Error),
    #[display(fmt = "internal channel error")]
    Channel(mpsc::SendError),
    /// This variant is useful for implementing the function to pass to
    /// [`with_spawner`](crate::v2::StreamingResponseConfig)
    #[display(fmt = "spawned task failed (cancelled={})", _0)]
    JoinError(#[error(ignore)] bool),
    }

    impl PartialEq for ProtocolError {
    fn eq(&self, other: &Self) -> bool {
    match (self, other) {
    (Self::MessageTooLargeRecv(l0), Self::MessageTooLargeRecv(r0)) => l0 == r0,
    (Self::MessageTooLargeSent(l0), Self::MessageTooLargeSent(r0)) => l0 == r0,
    (Self::Negotiation(l0), Self::Negotiation(r0)) => l0.to_string() == r0.to_string(),
    (Self::Io(l0), Self::Io(r0)) => l0.to_string() == r0.to_string(),
    (Self::Serde(l0), Self::Serde(r0)) => l0.to_string() == r0.to_string(),
    (Self::Channel(l0), Self::Channel(r0)) => l0 == r0,
    (Self::JoinError(l0), Self::JoinError(r0)) => l0 == r0,
    _ => core::mem::discriminant(self) == core::mem::discriminant(other),
    }
    }
    }

    impl ProtocolError {
    pub fn as_code(&self) -> u8 {
    match self {
    ProtocolError::Timeout => 1,
    ProtocolError::MessageTooLargeRecv(_) => 2,
    ProtocolError::MessageTooLargeSent(_) => 3,
    ProtocolError::Negotiation(_) => 4,
    ProtocolError::Io(_) => 5,
    ProtocolError::Serde(_) => 6,
    ProtocolError::Channel(_) => 7,
    ProtocolError::JoinError(_) => 8,
    }
    }
    pub fn from_code(code: u8) -> Self {
    match code {
    1 => ProtocolError::Timeout,
    2 => ProtocolError::MessageTooLargeRecv(0),
    3 => ProtocolError::MessageTooLargeSent(0),
    4 => ProtocolError::Negotiation(NegotiationError::Failed),
    5 => ProtocolError::Io(std::io::Error::new(ErrorKind::Other, "some error on peer")),
    6 => ProtocolError::Serde(std::io::Error::new(ErrorKind::Other, "serde error on peer").into()),
    7 => {
    let (mut tx, _) = mpsc::channel(1);
    let err = tx.try_send(0).unwrap_err().into_send_error();
    ProtocolError::Channel(err)
    }
    8 => ProtocolError::JoinError(false),
    n => ProtocolError::Io(std::io::Error::new(
    ErrorKind::Other,
    format!("unknown error code {}", n),
    )),
    }
    }
    }

    pub async fn write_msg(
    io: &mut NegotiatedSubstream,
    msg: impl serde::Serialize,
    max_size: u32,
    buffer: &mut Vec<u8>,
    ) -> Result<(), ProtocolError> {
    buffer.resize(4, 0);
    let res = serde_cbor::to_writer(&mut *buffer, &msg);
    if let Err(e) = res {
    let err = ProtocolError::Serde(e);
    write_err(io, &err).await?;
    return Err(err);
    }
    let size = buffer.len() - 4;
    if size > (max_size as usize) {
    log::debug!("message size {} too large (max = {})", size, max_size);
    let err = ProtocolError::MessageTooLargeSent(size);
    write_err(io, &err).await?;
    return Err(err);
    }
    log::trace!("sending message of size {}", size);
    buffer.as_mut_slice()[..4].copy_from_slice(&(size as u32).to_be_bytes());
    io.write_all(buffer.as_slice()).await?;
    Ok(())
    }

    pub async fn write_err(io: &mut NegotiatedSubstream, err: &ProtocolError) -> Result<(), std::io::Error> {
    let buf = [255, err.as_code()];
    io.write_all(&buf).await?;
    io.flush().await?;
    io.close().await?;
    Ok(())
    }

    pub async fn write_finish(io: &mut NegotiatedSubstream) -> Result<(), std::io::Error> {
    let buf = [255, 0];
    io.write_all(&buf).await?;
    io.flush().await?;
    io.close().await?;
    Ok(())
    }

    pub async fn read_msg<T: DeserializeOwned>(
    io: &mut NegotiatedSubstream,
    max_size: u32,
    buffer: &mut Vec<u8>,
    ) -> Result<Response<T>, ProtocolError> {
    let mut size_bytes = [0u8; 4];
    let mut to_read = &mut size_bytes[..];
    while !to_read.is_empty() {
    let read = io.read(to_read).await?;
    log::trace!("read {} header bytes", read);
    if read == 0 {
    let len = to_read.len();
    let read = &size_bytes[..4 - len];
    if read.len() != 2 || read[0] != 255 {
    return Err(ProtocolError::Io(ErrorKind::UnexpectedEof.into()));
    } else {
    return match read[1] {
    0 => Ok(Response::Finished),
    n => Err(ProtocolError::from_code(n)),
    };
    }
    }
    to_read = to_read.split_at_mut(read).1;
    }
    let size = u32::from_be_bytes(size_bytes);

    if size > max_size {
    log::debug!("message size {} too large (max = {})", size, max_size);
    let mut bytes = [0u8; 4096];
    bytes[..4].copy_from_slice(&size_bytes);
    let n = io.read(&mut bytes[4..]).await?;
    log::debug!("{:?}", &bytes[..n + 4]);
    return Err(ProtocolError::MessageTooLargeRecv(size as usize));
    }
    log::trace!("received header: msg is {} bytes", size);

    buffer.resize(size as usize, 0);
    io.read_exact(buffer.as_mut_slice()).await?;
    log::trace!("all bytes read");
    Ok(Response::Msg(serde_cbor::from_slice(buffer.as_slice())?))
    }

    #[derive(Debug)]
    pub struct Responder<T> {
    max_message_size: u32,
    _ph: PhantomData<T>,
    }

    impl<T> Responder<T> {
    pub fn new(max_message_size: u32) -> Self {
    Self {
    max_message_size,
    _ph: PhantomData,
    }
    }
    }

    impl<T: Codec> UpgradeInfo for Responder<T> {
    type Info = &'static [u8];
    type InfoIter = Once<&'static [u8]>;

    fn protocol_info(&self) -> Self::InfoIter {
    once(T::protocol_info())
    }
    }

    struct ProtoNameDisplay(&'static [u8]);

    impl Display for ProtoNameDisplay {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
    for byte in self.0 {
    if *byte > 31 && *byte < 128 {
    f.write_char((*byte).into())?;
    } else {
    f.write_char('\u{fffd}')?;
    }
    }
    Ok(())
    }
    }

    impl<T: Codec> InboundUpgrade<NegotiatedSubstream> for Responder<T> {
    type Output = (T::Request, NegotiatedSubstream);
    type Error = ProtocolError;
    type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;

    fn upgrade_inbound(self, mut socket: NegotiatedSubstream, info: Self::Info) -> Self::Future {
    let max_message_size = self.max_message_size;
    async move {
    log::trace!("starting inbound upgrade `{}`", ProtoNameDisplay(info));
    let msg = read_msg(&mut socket, max_message_size, &mut Vec::new())
    .await?
    .into_msg()?;
    log::trace!("request received: {:?}", msg);
    Ok((msg, socket))
    }
    .boxed()
    }
    }

    #[derive(Debug)]
    pub struct Requester<T: Codec> {
    max_message_size: u32,
    request: T::Request,
    }

    impl<T: Codec> Requester<T> {
    pub fn new(max_message_size: u32, request: T::Request) -> Self {
    Self {
    max_message_size,
    request,
    }
    }
    }

    impl<T: Codec> UpgradeInfo for Requester<T> {
    type Info = &'static [u8];
    type InfoIter = Once<&'static [u8]>;

    fn protocol_info(&self) -> Self::InfoIter {
    once(T::protocol_info())
    }
    }

    impl<T: Codec> OutboundUpgrade<NegotiatedSubstream> for Requester<T> {
    type Output = NegotiatedSubstream;
    type Error = ProtocolError;
    type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;

    fn upgrade_outbound(self, mut socket: NegotiatedSubstream, info: Self::Info) -> Self::Future {
    let Self {
    max_message_size,
    request,
    } = self;
    async move {
    log::trace!("starting output upgrade `{}`", ProtoNameDisplay(info));
    write_msg(&mut socket, request, max_message_size, &mut Vec::new()).await?;
    socket.flush().await?;
    log::trace!("all bytes sent");
    Ok(socket)
    }
    .boxed()
    }
    }
    340 changes: 340 additions & 0 deletions tests.rs
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,340 @@
    use super::{ProtocolError, StreamingResponse, StreamingResponseConfig};
    use crate::{
    v2::{handler::Response, RequestReceived},
    Codec,
    };
    use futures::{
    channel::mpsc::{self, Receiver, Sender},
    Future, FutureExt, SinkExt, StreamExt,
    };
    use libp2p::{
    core::{transport::MemoryTransport, upgrade::Version},
    identity::Keypair,
    mplex::MplexConfig,
    multiaddr::Protocol,
    plaintext::PlainText2Config,
    swarm::{SwarmBuilder, SwarmEvent},
    Multiaddr, PeerId, Swarm, Transport,
    };
    use tokio::runtime::{Handle, Runtime};
    use tracing_subscriber::{fmt::format::FmtSpan, util::SubscriberInitExt, EnvFilter};

    mod proto;

    const PROTO: &[u8] = b"/my/test";

    fn test_swarm(use_spawner: Option<Handle>) -> Swarm<StreamingResponse<Proto>> {
    let local_key = Keypair::generate_ed25519();
    let local_public_key = local_key.public();
    let local_peer_id = local_public_key.clone().into();
    let transport = MemoryTransport::default()
    .upgrade(Version::V1)
    .authenticate(PlainText2Config { local_public_key })
    .multiplex(MplexConfig::new())
    .boxed();
    let mut config = StreamingResponseConfig::default()
    .with_keep_alive(true)
    .with_max_message_size(100);
    #[allow(clippy::redundant_closure)]
    if let Some(rt) = use_spawner {
    config = config.with_spawner(move |f| rt.spawn(f).map(|r| r.unwrap_or_else(|e| Box::new(e))).boxed());
    }
    let behaviour = StreamingResponse::new(config);
    SwarmBuilder::new(transport, behaviour, local_peer_id).build()
    }

    fn fake_swarm(rt: &Runtime, bytes: &[u8]) -> Swarm<proto::TestBehaviour> {
    let local_key = Keypair::generate_ed25519();
    let local_public_key = local_key.public();
    let local_peer_id = local_public_key.clone().into();
    let transport = MemoryTransport::default()
    .upgrade(Version::V1)
    .authenticate(PlainText2Config { local_public_key })
    .multiplex(MplexConfig::new())
    .boxed();
    let behaviour = proto::TestBehaviour(rt.handle().clone(), bytes.to_owned());
    SwarmBuilder::new(transport, behaviour, local_peer_id).build()
    }

    struct Proto;
    impl Codec for Proto {
    type Request = String;
    type Response = String;

    fn protocol_info() -> &'static [u8] {
    PROTO
    }
    }

    macro_rules! wait4 {
    ($s:ident, $p:pat => $e:expr) => {
    loop {
    let ev = $s.next().await;
    if ev.is_none() {
    panic!("{} STOPPED", stringify!($s))
    }
    let ev = ev.unwrap();
    log::info!("{} got {:?}", stringify!($s), ev);
    if let $p = ev {
    break $e;
    }
    }
    };
    }

    macro_rules! task {
    ($s:ident $(, $p:pat => $e:expr)*) => {
    tokio::spawn(async move {
    while let Some(ev) = $s.next().await {
    log::info!("{} got {:?}", stringify!($s), ev);
    match ev {
    $($p => ($e),)*
    _ => {}
    }
    }
    log::info!("{} STOPPED", stringify!($s));
    })
    };
    }

    fn dbg<T: std::fmt::Debug>(x: T) -> String {
    format!("{:?}", x)
    }

    fn setup_logger() {
    tracing_subscriber::fmt()
    .with_env_filter(EnvFilter::from_default_env())
    .with_span_events(FmtSpan::ENTER | FmtSpan::CLOSE)
    .finish()
    .try_init()
    .ok();
    }

    #[test]
    fn smoke() {
    setup_logger();
    let rt = Runtime::new().unwrap();
    let mut asker = test_swarm(None);
    let asker_id = *asker.local_peer_id();
    let mut responder = test_swarm(None);
    let responder_id = *responder.local_peer_id();

    asker.listen_on(Multiaddr::empty().with(Protocol::Memory(0))).unwrap();

    rt.block_on(async move {
    let addr = wait4!(asker, SwarmEvent::NewListenAddr { address, .. } => address);

    responder.dial(addr).unwrap();
    task!(responder,
    SwarmEvent::Behaviour(RequestReceived { request, peer_id, mut channel, .. }) => {
    tokio::spawn(async move {
    channel.feed(request).await.unwrap();
    channel.feed(peer_id.to_string()).await.unwrap();
    channel.close().await.unwrap();
    });
    }
    );

    let peer_id = wait4!(asker, SwarmEvent::ConnectionEstablished { peer_id, .. } => peer_id);
    assert_eq!(peer_id, responder_id);

    let (tx, rx) = mpsc::channel(10);
    asker.behaviour_mut().request(peer_id, "request".to_owned(), tx);

    task!(asker);

    let response = rx
    .map(|r| match r {
    Response::Msg(m) => Some(m),
    Response::Error(e) => panic!("got error: {:#}", e),
    Response::Finished => None,
    })
    .collect::<Vec<_>>()
    .await;
    assert_eq!(
    response,
    vec![Some("request".to_owned()), Some(asker_id.to_string()), None]
    );
    });
    }

    #[test]
    fn smoke_executor() {
    setup_logger();
    let rt = Runtime::new().unwrap();
    let mut asker = test_swarm(Some(rt.handle().clone()));
    let asker_id = *asker.local_peer_id();
    let mut responder = test_swarm(Some(rt.handle().clone()));
    let responder_id = *responder.local_peer_id();

    asker.listen_on(Multiaddr::empty().with(Protocol::Memory(0))).unwrap();

    rt.block_on(async move {
    let addr = wait4!(asker, SwarmEvent::NewListenAddr { address, .. } => address);

    responder.dial(addr).unwrap();
    task!(responder,
    SwarmEvent::Behaviour(RequestReceived { request, peer_id, mut channel, .. }) => {
    tokio::spawn(async move {
    channel.feed(request).await.unwrap();
    channel.feed(peer_id.to_string()).await.unwrap();
    channel.close().await.unwrap();
    });
    }
    );

    let peer_id = wait4!(asker, SwarmEvent::ConnectionEstablished { peer_id, .. } => peer_id);
    assert_eq!(peer_id, responder_id);

    let (tx, rx) = mpsc::channel(10);
    asker.behaviour_mut().request(peer_id, "request".to_owned(), tx);

    task!(asker);

    let response = rx
    .map(|r| match r {
    Response::Msg(m) => Some(m),
    Response::Error(e) => panic!("got error: {:#}", e),
    Response::Finished => None,
    })
    .collect::<Vec<_>>()
    .await;
    assert_eq!(
    response,
    vec![Some("request".to_owned()), Some(asker_id.to_string()), None]
    );
    });
    }

    fn test_setup<F, Fut, L>(request: String, logic: L, f: F)
    where
    F: FnOnce(Receiver<Response<String>>) -> Fut + Send + 'static,
    Fut: Future,
    L: Fn(String, PeerId, Sender<String>) + Send + 'static,
    {
    setup_logger();
    let rt = Runtime::new().unwrap();
    let mut asker = test_swarm(None);
    let mut responder = test_swarm(None);

    rt.block_on(async move {
    responder
    .listen_on(Multiaddr::empty().with(Protocol::Memory(0)))
    .unwrap();
    let addr = wait4!(responder, SwarmEvent::NewListenAddr{ address, .. } => address);
    task!(responder, SwarmEvent::Behaviour(RequestReceived { request, peer_id, channel, .. }) => logic(request, peer_id, channel));
    asker.dial(addr).unwrap();
    let peer_id = wait4!(asker, SwarmEvent::ConnectionEstablished { peer_id, .. } => peer_id);
    let (tx, rx) = mpsc::channel(10);
    asker.behaviour_mut().request(peer_id, request, tx);
    task!(asker);
    f(rx).await;
    });
    }

    fn fake_setup<F, Fut>(bytes: &[u8], f: F)
    where
    F: FnOnce(Receiver<Response<String>>) -> Fut + Send + 'static,
    Fut: Future,
    {
    setup_logger();
    let rt = Runtime::new().unwrap();
    let mut asker = test_swarm(None);
    let mut responder = fake_swarm(&rt, bytes);

    rt.block_on(async move {
    responder
    .listen_on(Multiaddr::empty().with(Protocol::Memory(0)))
    .unwrap();
    let addr = wait4!(responder, SwarmEvent::NewListenAddr{ address, .. } => address);
    task!(responder);
    asker.dial(addr).unwrap();
    let peer_id = wait4!(asker, SwarmEvent::ConnectionEstablished { peer_id, .. } => peer_id);
    let (tx, rx) = mpsc::channel(10);
    asker.behaviour_mut().request(peer_id, "request".to_owned(), tx);
    task!(asker);
    f(rx).await;
    });
    }

    #[test]
    fn err_size() {
    fake_setup(b"zzzz", |mut rx| async move {
    assert_eq!(
    rx.next().await,
    Some(Response::Error(ProtocolError::MessageTooLargeRecv(2054847098)))
    );
    });
    }

    #[test]
    fn err_nothing() {
    fake_setup(b"", |mut rx| async move {
    assert_eq!(dbg(rx.next().await.unwrap()), "Error(Io(Kind(UnexpectedEof)))");
    });
    }

    #[test]
    fn err_incomplete() {
    fake_setup(b"\0\0\0\x05dabcd\0\0\0\x10abcd", |mut rx| async move {
    assert_eq!(rx.next().await, Some(Response::Msg("abcd".to_owned())));
    assert_eq!(dbg(rx.next().await.unwrap()), "Error(Io(Kind(UnexpectedEof)))");
    });
    }

    #[test]
    fn err_no_finish() {
    fake_setup(b"\0\0\0\x05dabcd", |mut rx| async move {
    assert_eq!(rx.next().await, Some(Response::Msg("abcd".to_owned())));
    assert_eq!(dbg(rx.next().await.unwrap()), "Error(Io(Kind(UnexpectedEof)))");
    });
    }

    #[test]
    fn err_deser() {
    fake_setup(b"\0\0\0\x04abcd", |mut rx| async move {
    assert_eq!(
    dbg(rx.next().await),
    "Some(Error(Serde(ErrorImpl { code: TrailingData, offset: 3 })))"
    );
    });
    }

    #[test]
    fn err_response_size() {
    test_setup(
    "123456789012345678901234567890123456789012345678901234567890".to_owned(),
    |mut request, peer_id, mut channel| {
    tokio::spawn(async move {
    request.push_str(&*peer_id.to_string());
    channel.feed(request).await.unwrap();
    });
    },
    |mut rx| async move {
    assert_eq!(
    rx.next().await,
    Some(Response::Error(ProtocolError::MessageTooLargeSent(0)))
    );
    },
    );
    }

    #[test]
    fn err_request_size() {
    test_setup(
    "1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890"
    .to_owned(),
    |mut request, peer_id, mut channel| {
    tokio::spawn(async move {
    request.push_str(&*peer_id.to_string());
    channel.feed(request).await.unwrap();
    });
    },
    |mut rx| async move {
    assert_eq!(
    rx.next().await,
    Some(Response::Error(ProtocolError::MessageTooLargeSent(102)))
    );
    },
    );
    }