commit 4688f8bd4343f9f6c1da0ff9424486e2566aec8b Author: Olivier 'reivilibre Date: Mon Jun 13 22:09:48 2022 +0100 Initial commit diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..b1d23a1 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "bare_cnr" +version = "0.1.0-alpha.1" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +thiserror = "1.0.31" +serde = "1.0.137" +serde_bare = "0.5.0" +dashmap = "5.3.4" +log = "0.4.17" +rand = "0.8.5" + +[dev-dependencies] +tokio-test = "0.4.2" + +[dependencies.tokio] +# TODO optional = true +version = "1.18.2" +features = ["full"] # TODO restrict this + + +[features] +# TODO default = ["tokio"] \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..d7daa50 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,23 @@ +use thiserror::Error; + +mod multiplexer; + +#[derive(Error, Debug)] +pub enum CnrError { + #[error("Version mismatch.")] + VersionMismatch, + + #[error("Ser/deserialisation error.")] + Serde(#[from] serde_bare::error::Error), + + #[error("Input/output error.")] + Io(#[from] std::io::Error), + + #[error("Channel closed.")] + Closed, +} + +pub use multiplexer::Channel; +pub use multiplexer::ChannelHandle; +pub use multiplexer::ChannelLock; +pub use multiplexer::TransportMultiplexer; diff --git a/src/multiplexer.rs b/src/multiplexer.rs new file mode 100644 index 0000000..ccac1dc --- /dev/null +++ b/src/multiplexer.rs @@ -0,0 +1,332 @@ +use crate::CnrError; +use dashmap::mapref::entry::Entry; +use dashmap::DashMap; +use log::{error, warn}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::cell::RefCell; +use std::marker::PhantomData; +use std::sync::{Arc, Weak}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::sync::mpsc::{Receiver, Sender}; +use tokio::sync::{Mutex, MutexGuard}; +use tokio::task::JoinHandle; + +pub struct TransportMultiplexer { + task_tx: JoinHandle<()>, + task_rx: JoinHandle<()>, + + /// Whether this side is the initiator. + /// The initiator can open even-numbered channels. + initiator: bool, + + /// Queue to send new bytes out on the wire + tx_queue: Sender<(u16, Vec)>, + + /// Senders for sending messages from the wire to channels + channels: Arc>>>, + + /// Channel receivers for channels that have received messages but haven't yet been claimed + unclaimed_channels: Arc>>>, +} + +thread_local! { + static CURRENT_MULTIPLEXER: RefCell> = Default::default(); +} + +impl TransportMultiplexer { + pub fn new( + rx: R, + tx: W, + initiator: bool, + ) -> Result, CnrError> { + let (txqueue_tx, txqueue_rx) = tokio::sync::mpsc::channel(8); + + let channels = Arc::new(Default::default()); + let unclaimed_channels = Arc::new(Default::default()); + + let task_tx = tokio::spawn(async move { + if let Err(err) = TransportMultiplexer::handle_tx(tx, txqueue_rx).await { + error!("TX handler failed: {:?}", err); + } + }); + let channels2 = Arc::clone(&channels); + let unclaimed_channels2 = Arc::clone(&unclaimed_channels); + let task_rx = tokio::spawn(async move { + if let Err(err) = + TransportMultiplexer::handle_rx(rx, channels2, unclaimed_channels2).await + { + error!("RX handler failed: {:?}", err); + } + }); + + Ok(Arc::new(TransportMultiplexer { + task_tx, + task_rx, + initiator, + tx_queue: txqueue_tx, + channels, + unclaimed_channels, + })) + } + + pub fn open_channel_with_id(self: &Arc, channel_id: u16) -> Option> { + match self.channels.entry(channel_id) { + Entry::Occupied(_) => { + match self.unclaimed_channels.remove(&channel_id) { + Some((_k, chan_rx)) => Some(Channel { + id: channel_id, + inner: Arc::new(Mutex::new(ChannelInner { + chan_id: channel_id, + tx: self.tx_queue.clone(), + rx: chan_rx, + multiplexer: Arc::downgrade(&self), + })), + marker: Default::default(), + }), + None => { + // Channel ID already in use. + None + } + } + } + Entry::Vacant(ve) => { + let (chan_tx, chan_rx) = tokio::sync::mpsc::channel(8); + ve.insert(chan_tx); + Some(Channel { + id: channel_id, + inner: Arc::new(Mutex::new(ChannelInner { + chan_id: channel_id, + tx: self.tx_queue.clone(), + rx: chan_rx, + multiplexer: Arc::downgrade(&self), + })), + marker: Default::default(), + }) + } + } + } + + pub fn open_unused_channel(self: &Arc) -> Option> { + loop { + let id: u16 = rand::random(); + let id = if self.initiator { id & !1 } else { id | 1 }; + if let Some(chan) = self.open_channel_with_id(id) { + return Some(chan); + } + } + } + + async fn handle_tx( + mut tx: W, + mut txqueue_rx: Receiver<(u16, Vec)>, + ) -> Result<(), CnrError> { + while let Some((chan_id, next_msg)) = txqueue_rx.recv().await { + tx.write_u16(chan_id).await?; + tx.write_u32(next_msg.len().try_into().unwrap()).await?; + tx.write_all(&next_msg).await?; + // TODO(performance) would be nice not to flush if something else is just behind. + tx.flush().await?; + } + Ok(()) + } + + async fn handle_rx( + mut rx: R, + channels: Arc>>>, + unclaimed_channels: Arc>>>, + ) -> Result<(), CnrError> { + loop { + // TODO(features): need to be able to support graceful EOF + let chan_id = rx.read_u16().await?; + let length = rx.read_u32().await? as usize; + // TODO(perf): use uninit? + let mut buf = vec![0u8; length]; + + rx.read_exact(&mut buf[..]).await?; + + match channels.entry(chan_id) { + Entry::Occupied(oe) => { + if let Err(err) = oe.get().send(buf).await { + // TODO this channel has died. What can we do about it? + warn!("Message received but channel {} dead", chan_id); + // TODO maybe we should clean it up at this point... + } + } + Entry::Vacant(ve) => { + let (chan_tx, chan_rx) = tokio::sync::mpsc::channel(8); + unclaimed_channels.insert(chan_id, chan_rx); + chan_tx.try_send(buf).expect("empty channel succeeds"); + ve.insert(chan_tx); + } + } + } + } +} + +pub struct ChannelHandle { + marker: PhantomData<(LTR, RTL)>, + chan_id: u16, + multiplexer: Weak, +} + +impl ChannelHandle { + pub fn into_channel(self) -> Option> { + let multiplexer = self.multiplexer.upgrade()?; + multiplexer.open_channel_with_id(self.chan_id) + } +} + +//impl Serialize for ChannelHandle { +impl Serialize for ChannelHandle { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + self.chan_id.serialize(serializer) + } +} + +//impl<'de, LTR, RTL, const INITIATOR: bool> Deserialize<'de> for ChannelHandle { +impl<'de, LTR, RTL> Deserialize<'de> for ChannelHandle { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + // TODO We could check the initiator flags are the right way around by embedding a bool in + // the format. + let chan_id = u16::deserialize(deserializer)?; + Ok(ChannelHandle { + marker: Default::default(), + chan_id, + multiplexer: CURRENT_MULTIPLEXER.with(|refcell| refcell.borrow().clone()), + }) + } +} + +struct ChannelInner { + chan_id: u16, + tx: Sender<(u16, Vec)>, + rx: Receiver>, + multiplexer: Weak, +} + +impl Drop for ChannelInner { + fn drop(&mut self) { + if let Some(multiplexer) = self.multiplexer.upgrade() { + multiplexer.channels.remove(&self.chan_id); + } + } +} + +pub struct ChannelLock<'a, R, W> { + // TODO we could use an OwnedMutexGuard if needed + guard: MutexGuard<'a, ChannelInner>, + marker: PhantomData<(R, W)>, +} + +impl<'a, R: DeserializeOwned, W: Serialize> ChannelLock<'a, R, W> { + pub async fn send(&mut self, message: &W) -> Result<(), CnrError> { + let bytes = serde_bare::to_vec(message)?; + self.guard + .tx + .send((self.guard.chan_id, bytes)) + .await + .map_err(|_| CnrError::Closed)?; + Ok(()) + } + + pub async fn recv(&mut self) -> Result { + let new_weak = self.guard.multiplexer.clone(); + CURRENT_MULTIPLEXER.with(move |refcell| { + *(refcell.borrow_mut()) = new_weak; + }); + let res = match self.guard.rx.recv().await { + Some(bytes) => Ok(serde_bare::from_slice(&bytes)?), + None => Err(CnrError::Closed), + }; + CURRENT_MULTIPLEXER.with(|refcell| { + *(refcell.borrow_mut()) = Weak::new(); + }); + res + } +} + +pub struct Channel { + pub id: u16, + inner: Arc>, + marker: PhantomData<(R, W)>, +} + +impl Channel { + /// Channel handles only get sent in one direction. + pub fn handle(&self) -> ChannelHandle { + ChannelHandle { + marker: Default::default(), + chan_id: self.id, + multiplexer: Default::default(), + } + } +} + +impl Channel { + pub async fn lock(&self) -> ChannelLock<'_, R, W> { + ChannelLock { + guard: self.inner.lock().await, + marker: Default::default(), + } + } + + pub async fn send(&mut self, message: &W) -> Result<(), CnrError> { + self.lock().await.send(message).await + } + + pub async fn recv(&mut self) -> Result { + self.lock().await.recv().await + } +} + +#[cfg(test)] +mod test { + use crate::multiplexer::{Channel, ChannelHandle, TransportMultiplexer}; + + /// Tests that data reaches the other side! + #[tokio::test] + async fn test_data_reaches_the_other_side() { + let (commander, responder) = tokio::io::duplex(64); + let (commander_r, commander_w) = tokio::io::split(commander); + let (responder_r, responder_w) = tokio::io::split(responder); + let commander = TransportMultiplexer::new(commander_r, commander_w, true).unwrap(); + let responder = TransportMultiplexer::new(responder_r, responder_w, false).unwrap(); + + let mut c_chan0: Channel = commander.open_channel_with_id(0).unwrap(); + c_chan0.send(&32).await.unwrap(); + + let mut r_chan0: Channel = responder.open_channel_with_id(0).unwrap(); + assert_eq!(r_chan0.recv().await.unwrap(), 32); + r_chan0.send(&2048).await.unwrap(); + + assert_eq!(c_chan0.recv().await.unwrap(), 2048); + } + + /// Tests that you can hand over a channel just by including a ChannelHandle in the message. + #[tokio::test] + async fn test_channel_handover() { + let (commander, responder) = tokio::io::duplex(64); + let (commander_r, commander_w) = tokio::io::split(commander); + let (responder_r, responder_w) = tokio::io::split(responder); + let commander = TransportMultiplexer::new(commander_r, commander_w, true).unwrap(); + let responder = TransportMultiplexer::new(responder_r, responder_w, false).unwrap(); + + let mut c_chan0: Channel<(), ChannelHandle> = + commander.open_channel_with_id(0).unwrap(); + let mut c_chan2: Channel = commander.open_channel_with_id(2).unwrap(); + c_chan0.send(&c_chan2.handle()).await.unwrap(); + c_chan2.send(&42).await.unwrap(); + + let mut r_chan0: Channel, ()> = + responder.open_channel_with_id(0).unwrap(); + let mut r_chan2 = r_chan0.recv().await.unwrap().into_channel().unwrap(); + assert_eq!(r_chan2.recv().await.unwrap(), 42); + } +}