Skip to content

Instantly share code, notes, and snippets.

@tekjar
Created October 29, 2020 16:09
Show Gist options
  • Save tekjar/a5c0f30cb925d21aea370886ddf2b9e9 to your computer and use it in GitHub Desktop.
Save tekjar/a5c0f30cb925d21aea370886ddf2b9e9 to your computer and use it in GitHub Desktop.
rumqttd tls
#[macro_use]
extern crate log;
use futures_util::future::join_all;
use futures_util::sink::Sink;
use futures_util::stream::Stream;
use rumq_core::mqtt4::{codec, Packet};
use tokio::net::TcpListener;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::task;
use tokio::time::{self, Elapsed};
use tokio_rustls::rustls::internal::pemfile::{certs, rsa_private_keys};
use tokio_rustls::rustls::TLSError;
use tokio_rustls::rustls::{AllowAnyAuthenticatedClient, NoClientAuth, RootCertStore, ServerConfig};
use tokio_rustls::TlsAcceptor;
use tokio_util::codec::Framed;
use serde::Deserialize;
use std::fs::File;
use std::io::{self, BufReader};
use std::path::Path;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
mod connection;
mod router;
mod state;
pub use router::{Connection, RouterMessage};
pub use rumq_core as core;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("I/O")]
Io(#[from] io::Error),
#[error("MQTT protocol error")]
Mqtt(#[from] rumq_core::Error),
#[error("Timeout")]
Timeout(#[from] Elapsed),
#[error("Broker State")]
State(#[from] state::Error),
#[error("TLS")]
Tls(#[from] TLSError),
#[error("No server cert")]
NoServerCert,
#[error("No server private key")]
NoServerPrivateKey,
#[error("No ca file")]
NoCAFile,
#[error("No server cert file")]
NoServerCertFile,
#[error("No server key file")]
NoServerKeyFile,
#[error("Disconnected")]
Disconnected,
}
#[derive(Debug, Deserialize, Clone)]
pub struct Config {
servers: Vec<ServerSettings>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct ServerSettings {
pub port: u16,
pub connection_timeout_ms: u16,
pub next_connection_delay_ms: u64,
pub max_client_id_len: usize,
pub max_connections: usize,
pub disk_persistence: bool,
pub throttle_delay_ms: u64,
pub disk_retention_size: usize,
pub disk_retention_time_sec: usize,
pub auto_save_interval_sec: u16,
pub max_payload_size: usize,
pub max_inflight_topic_size: usize,
pub ca_path: Option<String>,
pub cert_path: Option<String>,
pub key_path: Option<String>,
pub username: Option<String>,
pub password: Option<String>,
}
async fn tls_connection<P: AsRef<Path>>(ca_path: Option<P>, cert_path: P, key_path: P) -> Result<TlsAcceptor, Error> {
// client authentication with a CA. CA isn't required otherwise
let mut server_config = if let Some(ca_path) = ca_path {
let mut root_cert_store = RootCertStore::empty();
root_cert_store.add_pem_file(&mut BufReader::new(File::open(ca_path)?)).map_err(|_| Error::NoCAFile)?;
ServerConfig::new(AllowAnyAuthenticatedClient::new(root_cert_store))
} else {
ServerConfig::new(NoClientAuth::new())
};
let certs = certs(&mut BufReader::new(File::open(cert_path)?)).map_err(|_| Error::NoServerCertFile)?;
let mut keys = rsa_private_keys(&mut BufReader::new(File::open(key_path)?)).map_err(|_| Error::NoServerKeyFile)?;
server_config.set_single_cert(certs, keys.remove(0))?;
let acceptor = TlsAcceptor::from(Arc::new(server_config));
Ok(acceptor)
}
async fn accept_loop(config: Arc<ServerSettings>, router_tx: Sender<(String, router::RouterMessage)>) -> Result<(), Error> {
let addr = format!("0.0.0.0:{}", config.port);
let connection_config = config.clone();
let acceptor = if let Some(cert_path) = config.cert_path.clone() {
let key_path = config.key_path.clone().ok_or(Error::NoServerPrivateKey)?;
Some(tls_connection(config.ca_path.clone(), cert_path, key_path).await?)
} else {
None
};
info!("Waiting for connections on {}", addr);
// eventloop which accepts connections
let mut listener = TcpListener::bind(addr).await?;
let accept_loop_delay = Duration::from_millis(config.next_connection_delay_ms);
loop {
let (stream, addr) = match listener.accept().await {
Ok(s) => s,
Err(e) => {
error!("Tcp connection error = {:?}", e);
continue;
}
};
info!("Accepting from: {}", addr);
let config = connection_config.clone();
let router_tx = router_tx.clone();
if let Some(acceptor) = &acceptor {
let stream = match acceptor.accept(stream).await {
Ok(s) => s,
Err(e) => {
error!("Tls connection error = {:?}", e);
continue;
}
};
let framed = Framed::new(stream, codec::MqttCodec::new(config.max_payload_size));
task::spawn(async {
match connection::eventloop(config, framed, router_tx).await {
Ok(id) => info!("Connection eventloop done!!. Id = {:?}", id),
Err(e) => error!("Connection eventloop error = {:?}", e),
}
});
} else {
let framed = Framed::new(stream, codec::MqttCodec::new(config.max_payload_size));
task::spawn(async {
match connection::eventloop(config, framed, router_tx).await {
Ok(id) => info!("Connection eventloop done!!. Id = {:?}", id),
Err(e) => error!("Connection eventloop error = {:?}", e),
}
});
};
time::delay_for(accept_loop_delay).await;
}
}
pub trait Network: Stream<Item = Result<Packet, rumq_core::Error>> + Sink<Packet, Error = io::Error> + Unpin + Send {}
impl<T> Network for T where T: Stream<Item = Result<Packet, rumq_core::Error>> + Sink<Packet, Error = io::Error> + Unpin + Send {}
#[tokio::main(core_threads = 1)]
async fn router(rx: Receiver<(String, router::RouterMessage)>) {
let mut router = router::Router::new(rx);
if let Err(e) = router.start().await {
error!("Router stopped. Error = {:?}", e);
}
}
pub struct Broker {
config: Config,
router_handle: Sender<(String, router::RouterMessage)>,
}
pub fn new(config: Config) -> Broker {
let (router_tx, router_rx) = channel(100);
thread::spawn(move || router(router_rx));
Broker { config, router_handle: router_tx }
}
impl Broker {
pub fn new_router_handle(&self) -> Sender<(String, router::RouterMessage)> {
self.router_handle.clone()
}
pub async fn start(&mut self) -> Vec<Result<(), task::JoinError>> {
let mut servers = Vec::new();
let server_configs = self.config.servers.split_off(0);
for server in server_configs.into_iter() {
let config = Arc::new(server);
let fut = accept_loop(config, self.router_handle.clone());
let o = task::spawn(async {
error!("Accept loop returned = {:?}", fut.await);
});
servers.push(o);
}
join_all(servers).await
}
}
#[cfg(test)]
mod test {
#[test]
fn accept_loop_rate_limits_incoming_connections() {}
#[test]
fn accept_loop_should_not_allow_more_than_maximum_connections() {}
#[test]
fn accept_loop_should_accept_new_connection_when_a_client_disconnects_after_max_connections() {}
#[test]
fn client_loop_should_error_if_connect_packet_is_not_received_in_time() {}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment