Initial commit

This commit is contained in:
Olivier 'reivilibre' 2022-06-13 22:09:48 +01:00
commit 4688f8bd43
3 changed files with 381 additions and 0 deletions

26
Cargo.toml Normal file
View File

@ -0,0 +1,26 @@
[package]
name = "bare_cnr"
version = "0.1.0-alpha.1"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
thiserror = "1.0.31"
serde = "1.0.137"
serde_bare = "0.5.0"
dashmap = "5.3.4"
log = "0.4.17"
rand = "0.8.5"
[dev-dependencies]
tokio-test = "0.4.2"
[dependencies.tokio]
# TODO optional = true
version = "1.18.2"
features = ["full"] # TODO restrict this
[features]
# TODO default = ["tokio"]

23
src/lib.rs Normal file
View File

@ -0,0 +1,23 @@
use thiserror::Error;
mod multiplexer;
#[derive(Error, Debug)]
pub enum CnrError {
#[error("Version mismatch.")]
VersionMismatch,
#[error("Ser/deserialisation error.")]
Serde(#[from] serde_bare::error::Error),
#[error("Input/output error.")]
Io(#[from] std::io::Error),
#[error("Channel closed.")]
Closed,
}
pub use multiplexer::Channel;
pub use multiplexer::ChannelHandle;
pub use multiplexer::ChannelLock;
pub use multiplexer::TransportMultiplexer;

332
src/multiplexer.rs Normal file
View File

@ -0,0 +1,332 @@
use crate::CnrError;
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use log::{error, warn};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::cell::RefCell;
use std::marker::PhantomData;
use std::sync::{Arc, Weak};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::{Mutex, MutexGuard};
use tokio::task::JoinHandle;
pub struct TransportMultiplexer {
task_tx: JoinHandle<()>,
task_rx: JoinHandle<()>,
/// Whether this side is the initiator.
/// The initiator can open even-numbered channels.
initiator: bool,
/// Queue to send new bytes out on the wire
tx_queue: Sender<(u16, Vec<u8>)>,
/// Senders for sending messages from the wire to channels
channels: Arc<DashMap<u16, Sender<Vec<u8>>>>,
/// Channel receivers for channels that have received messages but haven't yet been claimed
unclaimed_channels: Arc<DashMap<u16, Receiver<Vec<u8>>>>,
}
thread_local! {
static CURRENT_MULTIPLEXER: RefCell<Weak<TransportMultiplexer>> = Default::default();
}
impl TransportMultiplexer {
pub fn new<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin>(
rx: R,
tx: W,
initiator: bool,
) -> Result<Arc<TransportMultiplexer>, CnrError> {
let (txqueue_tx, txqueue_rx) = tokio::sync::mpsc::channel(8);
let channels = Arc::new(Default::default());
let unclaimed_channels = Arc::new(Default::default());
let task_tx = tokio::spawn(async move {
if let Err(err) = TransportMultiplexer::handle_tx(tx, txqueue_rx).await {
error!("TX handler failed: {:?}", err);
}
});
let channels2 = Arc::clone(&channels);
let unclaimed_channels2 = Arc::clone(&unclaimed_channels);
let task_rx = tokio::spawn(async move {
if let Err(err) =
TransportMultiplexer::handle_rx(rx, channels2, unclaimed_channels2).await
{
error!("RX handler failed: {:?}", err);
}
});
Ok(Arc::new(TransportMultiplexer {
task_tx,
task_rx,
initiator,
tx_queue: txqueue_tx,
channels,
unclaimed_channels,
}))
}
pub fn open_channel_with_id<R, W>(self: &Arc<Self>, channel_id: u16) -> Option<Channel<R, W>> {
match self.channels.entry(channel_id) {
Entry::Occupied(_) => {
match self.unclaimed_channels.remove(&channel_id) {
Some((_k, chan_rx)) => Some(Channel {
id: channel_id,
inner: Arc::new(Mutex::new(ChannelInner {
chan_id: channel_id,
tx: self.tx_queue.clone(),
rx: chan_rx,
multiplexer: Arc::downgrade(&self),
})),
marker: Default::default(),
}),
None => {
// Channel ID already in use.
None
}
}
}
Entry::Vacant(ve) => {
let (chan_tx, chan_rx) = tokio::sync::mpsc::channel(8);
ve.insert(chan_tx);
Some(Channel {
id: channel_id,
inner: Arc::new(Mutex::new(ChannelInner {
chan_id: channel_id,
tx: self.tx_queue.clone(),
rx: chan_rx,
multiplexer: Arc::downgrade(&self),
})),
marker: Default::default(),
})
}
}
}
pub fn open_unused_channel<R, W>(self: &Arc<Self>) -> Option<Channel<R, W>> {
loop {
let id: u16 = rand::random();
let id = if self.initiator { id & !1 } else { id | 1 };
if let Some(chan) = self.open_channel_with_id(id) {
return Some(chan);
}
}
}
async fn handle_tx<W: AsyncWrite + Unpin>(
mut tx: W,
mut txqueue_rx: Receiver<(u16, Vec<u8>)>,
) -> Result<(), CnrError> {
while let Some((chan_id, next_msg)) = txqueue_rx.recv().await {
tx.write_u16(chan_id).await?;
tx.write_u32(next_msg.len().try_into().unwrap()).await?;
tx.write_all(&next_msg).await?;
// TODO(performance) would be nice not to flush if something else is just behind.
tx.flush().await?;
}
Ok(())
}
async fn handle_rx<R: AsyncRead + Unpin>(
mut rx: R,
channels: Arc<DashMap<u16, Sender<Vec<u8>>>>,
unclaimed_channels: Arc<DashMap<u16, Receiver<Vec<u8>>>>,
) -> Result<(), CnrError> {
loop {
// TODO(features): need to be able to support graceful EOF
let chan_id = rx.read_u16().await?;
let length = rx.read_u32().await? as usize;
// TODO(perf): use uninit?
let mut buf = vec![0u8; length];
rx.read_exact(&mut buf[..]).await?;
match channels.entry(chan_id) {
Entry::Occupied(oe) => {
if let Err(err) = oe.get().send(buf).await {
// TODO this channel has died. What can we do about it?
warn!("Message received but channel {} dead", chan_id);
// TODO maybe we should clean it up at this point...
}
}
Entry::Vacant(ve) => {
let (chan_tx, chan_rx) = tokio::sync::mpsc::channel(8);
unclaimed_channels.insert(chan_id, chan_rx);
chan_tx.try_send(buf).expect("empty channel succeeds");
ve.insert(chan_tx);
}
}
}
}
}
pub struct ChannelHandle<LTR, RTL, const INITIATOR: bool> {
marker: PhantomData<(LTR, RTL)>,
chan_id: u16,
multiplexer: Weak<TransportMultiplexer>,
}
impl<LTR, RTL> ChannelHandle<LTR, RTL, false> {
pub fn into_channel(self) -> Option<Channel<RTL, LTR>> {
let multiplexer = self.multiplexer.upgrade()?;
multiplexer.open_channel_with_id(self.chan_id)
}
}
//impl<LTR, RTL, const INITIATOR: bool> Serialize for ChannelHandle<LTR, RTL, INITIATOR> {
impl<LTR, RTL> Serialize for ChannelHandle<LTR, RTL, true> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.chan_id.serialize(serializer)
}
}
//impl<'de, LTR, RTL, const INITIATOR: bool> Deserialize<'de> for ChannelHandle<LTR, RTL, INITIATOR> {
impl<'de, LTR, RTL> Deserialize<'de> for ChannelHandle<LTR, RTL, false> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
// TODO We could check the initiator flags are the right way around by embedding a bool in
// the format.
let chan_id = u16::deserialize(deserializer)?;
Ok(ChannelHandle {
marker: Default::default(),
chan_id,
multiplexer: CURRENT_MULTIPLEXER.with(|refcell| refcell.borrow().clone()),
})
}
}
struct ChannelInner {
chan_id: u16,
tx: Sender<(u16, Vec<u8>)>,
rx: Receiver<Vec<u8>>,
multiplexer: Weak<TransportMultiplexer>,
}
impl Drop for ChannelInner {
fn drop(&mut self) {
if let Some(multiplexer) = self.multiplexer.upgrade() {
multiplexer.channels.remove(&self.chan_id);
}
}
}
pub struct ChannelLock<'a, R, W> {
// TODO we could use an OwnedMutexGuard if needed
guard: MutexGuard<'a, ChannelInner>,
marker: PhantomData<(R, W)>,
}
impl<'a, R: DeserializeOwned, W: Serialize> ChannelLock<'a, R, W> {
pub async fn send(&mut self, message: &W) -> Result<(), CnrError> {
let bytes = serde_bare::to_vec(message)?;
self.guard
.tx
.send((self.guard.chan_id, bytes))
.await
.map_err(|_| CnrError::Closed)?;
Ok(())
}
pub async fn recv(&mut self) -> Result<R, CnrError> {
let new_weak = self.guard.multiplexer.clone();
CURRENT_MULTIPLEXER.with(move |refcell| {
*(refcell.borrow_mut()) = new_weak;
});
let res = match self.guard.rx.recv().await {
Some(bytes) => Ok(serde_bare::from_slice(&bytes)?),
None => Err(CnrError::Closed),
};
CURRENT_MULTIPLEXER.with(|refcell| {
*(refcell.borrow_mut()) = Weak::new();
});
res
}
}
pub struct Channel<R, W> {
pub id: u16,
inner: Arc<Mutex<ChannelInner>>,
marker: PhantomData<(R, W)>,
}
impl<R, W> Channel<R, W> {
/// Channel handles only get sent in one direction.
pub fn handle(&self) -> ChannelHandle<R, W, true> {
ChannelHandle {
marker: Default::default(),
chan_id: self.id,
multiplexer: Default::default(),
}
}
}
impl<R: Send + DeserializeOwned, W: Serialize> Channel<R, W> {
pub async fn lock(&self) -> ChannelLock<'_, R, W> {
ChannelLock {
guard: self.inner.lock().await,
marker: Default::default(),
}
}
pub async fn send(&mut self, message: &W) -> Result<(), CnrError> {
self.lock().await.send(message).await
}
pub async fn recv(&mut self) -> Result<R, CnrError> {
self.lock().await.recv().await
}
}
#[cfg(test)]
mod test {
use crate::multiplexer::{Channel, ChannelHandle, TransportMultiplexer};
/// Tests that data reaches the other side!
#[tokio::test]
async fn test_data_reaches_the_other_side() {
let (commander, responder) = tokio::io::duplex(64);
let (commander_r, commander_w) = tokio::io::split(commander);
let (responder_r, responder_w) = tokio::io::split(responder);
let commander = TransportMultiplexer::new(commander_r, commander_w, true).unwrap();
let responder = TransportMultiplexer::new(responder_r, responder_w, false).unwrap();
let mut c_chan0: Channel<u64, u8> = commander.open_channel_with_id(0).unwrap();
c_chan0.send(&32).await.unwrap();
let mut r_chan0: Channel<u8, u64> = responder.open_channel_with_id(0).unwrap();
assert_eq!(r_chan0.recv().await.unwrap(), 32);
r_chan0.send(&2048).await.unwrap();
assert_eq!(c_chan0.recv().await.unwrap(), 2048);
}
/// Tests that you can hand over a channel just by including a ChannelHandle in the message.
#[tokio::test]
async fn test_channel_handover() {
let (commander, responder) = tokio::io::duplex(64);
let (commander_r, commander_w) = tokio::io::split(commander);
let (responder_r, responder_w) = tokio::io::split(responder);
let commander = TransportMultiplexer::new(commander_r, commander_w, true).unwrap();
let responder = TransportMultiplexer::new(responder_r, responder_w, false).unwrap();
let mut c_chan0: Channel<(), ChannelHandle<u8, u64, true>> =
commander.open_channel_with_id(0).unwrap();
let mut c_chan2: Channel<u8, u64> = commander.open_channel_with_id(2).unwrap();
c_chan0.send(&c_chan2.handle()).await.unwrap();
c_chan2.send(&42).await.unwrap();
let mut r_chan0: Channel<ChannelHandle<u8, u64, false>, ()> =
responder.open_channel_with_id(0).unwrap();
let mut r_chan2 = r_chan0.recv().await.unwrap().into_channel().unwrap();
assert_eq!(r_chan2.recv().await.unwrap(), 42);
}
}