diff --git a/bare_cnr/Cargo.toml b/bare_cnr/Cargo.toml new file mode 100644 index 0000000..063b170 --- /dev/null +++ b/bare_cnr/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "bare_cnr" +version = "0.1.0-alpha.1" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +thiserror = "1.0.31" +serde = "1.0.137" +serde_bare = "0.5.0" +dashmap = "5.3.4" +log = "0.4.17" + +[dependencies.tokio] +# TODO optional = true +version = "1.18.2" +features = ["full"] # TODO restrict this + + +[features] +# TODO default = ["tokio"] \ No newline at end of file diff --git a/bare_cnr/src/lib.rs b/bare_cnr/src/lib.rs new file mode 100644 index 0000000..13debef --- /dev/null +++ b/bare_cnr/src/lib.rs @@ -0,0 +1,19 @@ +use thiserror::Error; + +mod multiplexer; +mod transport; + +#[derive(Error, Debug)] +pub enum CnrError { + #[error("Version mismatch.")] + VersionMismatch, + + #[error("Ser/deserialisation error.")] + Serde(#[from] serde_bare::error::Error), + + #[error("Input/output error.")] + Io(#[from] std::io::Error), + + #[error("Channel closed.")] + Closed, +} diff --git a/bare_cnr/src/multiplexer.rs b/bare_cnr/src/multiplexer.rs new file mode 100644 index 0000000..a3fa8e9 --- /dev/null +++ b/bare_cnr/src/multiplexer.rs @@ -0,0 +1,207 @@ +use crate::CnrError; +use crate::CnrError::Closed; +use dashmap::mapref::entry::Entry; +use dashmap::DashMap; +use log::error; +use serde::de::DeserializeOwned; +use serde::Serialize; +use std::marker::PhantomData; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::{Mutex, MutexGuard}; +use tokio::task::JoinHandle; + +pub struct TransportMultiplexer { + task_tx: JoinHandle<()>, + task_rx: JoinHandle<()>, + + /// Whether this side is the initiator. + /// The initiator can open even-numbered channels. + initiator: bool, + + /// Queue to send new bytes out on the wire + tx_queue: Sender<(u16, Vec)>, + + /// Senders for sending messages from the wire to channels + channels: Arc>>>, + + /// Channel receivers for channels that have received messages but haven't yet been claimed + unclaimed_channels: Arc>>>, + // TODO channels: +} + +impl TransportMultiplexer { + pub fn new( + rx: R, + tx: W, + initiator: bool, + ) -> Result { + let (txqueue_tx, txqueue_rx) = tokio::sync::mpsc::channel(8); + + let channels = Arc::new(Default::default()); + let unclaimed_channels = Arc::new(Default::default()); + + let task_tx = tokio::spawn(async move { + if let Err(err) = TransportMultiplexer::handle_tx(tx, txqueue_rx).await { + error!("TX handler failed: {:?}", err); + } + }); + let channels2 = Arc::clone(&channels); + let unclaimed_channels2 = Arc::clone(&unclaimed_channels); + let task_rx = tokio::spawn(async move { + if let Err(err) = + TransportMultiplexer::handle_rx(rx, channels2, unclaimed_channels2).await + { + error!("RX handler failed: {:?}", err); + } + }); + + Ok(TransportMultiplexer { + task_tx, + task_rx, + initiator, + tx_queue: txqueue_tx, + channels, + unclaimed_channels, + }) + } + + pub fn open_channel_with_id(&self, channel_id: u16) -> Option> { + match self.channels.entry(channel_id) { + Entry::Occupied(oe) => { + match self.unclaimed_channels.remove(&channel_id) { + Some((_k, chan_rx)) => Some(Channel { + id: channel_id, + inner: Arc::new(Mutex::new(ChannelInner { + chan_id: channel_id, + tx: self.tx_queue.clone(), + rx: chan_rx, + })), + marker: Default::default(), + }), + None => { + // Channel ID already in use. + None + } + } + } + Entry::Vacant(ve) => { + let (chan_tx, chan_rx) = tokio::sync::mpsc::channel(8); + ve.insert(chan_tx); + Some(Channel { + id: channel_id, + inner: Arc::new(Mutex::new(ChannelInner { + chan_id: channel_id, + tx: self.tx_queue.clone(), + rx: chan_rx, + })), + marker: Default::default(), + }) + } + } + } + + async fn handle_tx( + mut tx: W, + mut txqueue_rx: Receiver<(u16, Vec)>, + ) -> Result<(), CnrError> { + while let Some((chan_id, next_msg)) = txqueue_rx.recv().await { + tx.write_u16(chan_id).await?; + tx.write_u32(next_msg.len().try_into().unwrap()).await?; + tx.write_all(&next_msg).await?; + // TODO(performance) would be nice not to flush if something else is just behind. + tx.flush().await?; + } + Ok(()) + } + + async fn handle_rx( + mut rx: R, + channels: Arc>>>, + unclaimed_channels: Arc>>>, + ) -> Result<(), CnrError> { + loop { + // TODO(features): need to be able to support graceful EOF + let chan_id = rx.read_u16().await?; + let length = rx.read_u32().await? as usize; + // TODO(perf): use uninit? + let mut buf = vec![0u8; length]; + + rx.read_exact(&mut buf[..]).await?; + + match channels.entry(chan_id) { + Entry::Occupied(oe) => { + if let Err(err) = oe.get().send(buf).await { + // TODO this channel has died. What can we do about it? + todo!(); + } + } + Entry::Vacant(ve) => { + let (chan_tx, chan_rx) = tokio::sync::mpsc::channel(8); + unclaimed_channels.insert(chan_id, chan_rx); + chan_tx.try_send(buf).expect("empty channel succeeds"); + ve.insert(chan_tx); + } + } + } + } +} + +pub struct ChannelHandle { + marker: PhantomData<(LTR, RTL)>, +} + +struct ChannelInner { + chan_id: u16, + tx: Sender<(u16, Vec)>, + rx: Receiver>, +} + +pub struct ChannelLock<'a, R, W> { + // TODO we could use an OwnedMutexGuard if needed + guard: MutexGuard<'a, ChannelInner>, + marker: PhantomData<(R, W)>, +} + +impl<'a, R: DeserializeOwned, W: Serialize> ChannelLock<'a, R, W> { + pub async fn send(&mut self, message: &W) -> Result<(), CnrError> { + let bytes = serde_bare::to_vec(message)?; + self.guard + .tx + .send((self.guard.chan_id, bytes)) + .await + .map_err(|_| CnrError::Closed)?; + Ok(()) + } + + pub async fn recv(&mut self) -> Result { + match self.guard.rx.recv().await { + Some(bytes) => Ok(serde_bare::from_slice(&bytes)?), + None => Err(CnrError::Closed), + } + } +} + +pub struct Channel { + pub id: u16, + inner: Arc>, + marker: PhantomData<(R, W)>, +} + +impl Channel { + pub async fn lock(&self) -> ChannelLock<'_, R, W> { + ChannelLock { + guard: self.inner.lock().await, + marker: Default::default(), + } + } + + pub async fn send(&mut self, message: &W) -> Result<(), CnrError> { + self.lock().await.send(message).await + } + + pub async fn recv(&mut self) -> Result { + self.lock().await.recv().await + } +} diff --git a/bare_cnr/src/transport.rs b/bare_cnr/src/transport.rs new file mode 100644 index 0000000..561e741 --- /dev/null +++ b/bare_cnr/src/transport.rs @@ -0,0 +1,43 @@ +use serde::de::DeserializeOwned; +use serde::Serialize; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +use crate::CnrError; + +pub struct BareTransport { + writer: W, + reader: R, +} + +impl BareTransport { + pub fn new(writer: W, reader: R) -> Self { + Self { writer, reader } + } + + pub async fn write_one_message(&mut self, message: &M) -> Result<(), CnrError> { + let bytes = serde_bare::to_vec(message)?; + self.writer + .write_u32(bytes.len().try_into().unwrap()) + .await?; + self.writer.write_all(&bytes).await?; + // TODO flush? + Ok(()) + } + + pub async fn read_one_message(&mut self) -> Result, CnrError> { + let length = self.reader.read_u32().await? as usize; + // TODO(perf): use uninit? + let mut buf = vec![0u8; length]; + + let first_read = self.reader.read(&mut buf[..]).await?; + if first_read == 0 { + return Ok(None); + } + + if first_read != length { + self.reader.read_exact(&mut buf[first_read..]).await?; + } + + return Ok(Some(serde_bare::from_slice(&buf)?)); + } +}