Start a half decent Bare CnR crate
This commit is contained in:
parent
d0ed984dca
commit
183f365032
|
@ -0,0 +1,22 @@
|
|||
[package]
|
||||
name = "bare_cnr"
|
||||
version = "0.1.0-alpha.1"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
thiserror = "1.0.31"
|
||||
serde = "1.0.137"
|
||||
serde_bare = "0.5.0"
|
||||
dashmap = "5.3.4"
|
||||
log = "0.4.17"
|
||||
|
||||
[dependencies.tokio]
|
||||
# TODO optional = true
|
||||
version = "1.18.2"
|
||||
features = ["full"] # TODO restrict this
|
||||
|
||||
|
||||
[features]
|
||||
# TODO default = ["tokio"]
|
|
@ -0,0 +1,19 @@
|
|||
use thiserror::Error;
|
||||
|
||||
mod multiplexer;
|
||||
mod transport;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum CnrError {
|
||||
#[error("Version mismatch.")]
|
||||
VersionMismatch,
|
||||
|
||||
#[error("Ser/deserialisation error.")]
|
||||
Serde(#[from] serde_bare::error::Error),
|
||||
|
||||
#[error("Input/output error.")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("Channel closed.")]
|
||||
Closed,
|
||||
}
|
|
@ -0,0 +1,207 @@
|
|||
use crate::CnrError;
|
||||
use crate::CnrError::Closed;
|
||||
use dashmap::mapref::entry::Entry;
|
||||
use dashmap::DashMap;
|
||||
use log::error;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::sync::mpsc::{Receiver, Sender};
|
||||
use tokio::sync::{Mutex, MutexGuard};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
pub struct TransportMultiplexer {
|
||||
task_tx: JoinHandle<()>,
|
||||
task_rx: JoinHandle<()>,
|
||||
|
||||
/// Whether this side is the initiator.
|
||||
/// The initiator can open even-numbered channels.
|
||||
initiator: bool,
|
||||
|
||||
/// Queue to send new bytes out on the wire
|
||||
tx_queue: Sender<(u16, Vec<u8>)>,
|
||||
|
||||
/// Senders for sending messages from the wire to channels
|
||||
channels: Arc<DashMap<u16, Sender<Vec<u8>>>>,
|
||||
|
||||
/// Channel receivers for channels that have received messages but haven't yet been claimed
|
||||
unclaimed_channels: Arc<DashMap<u16, Receiver<Vec<u8>>>>,
|
||||
// TODO channels:
|
||||
}
|
||||
|
||||
impl TransportMultiplexer {
|
||||
pub fn new<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin>(
|
||||
rx: R,
|
||||
tx: W,
|
||||
initiator: bool,
|
||||
) -> Result<TransportMultiplexer, CnrError> {
|
||||
let (txqueue_tx, txqueue_rx) = tokio::sync::mpsc::channel(8);
|
||||
|
||||
let channels = Arc::new(Default::default());
|
||||
let unclaimed_channels = Arc::new(Default::default());
|
||||
|
||||
let task_tx = tokio::spawn(async move {
|
||||
if let Err(err) = TransportMultiplexer::handle_tx(tx, txqueue_rx).await {
|
||||
error!("TX handler failed: {:?}", err);
|
||||
}
|
||||
});
|
||||
let channels2 = Arc::clone(&channels);
|
||||
let unclaimed_channels2 = Arc::clone(&unclaimed_channels);
|
||||
let task_rx = tokio::spawn(async move {
|
||||
if let Err(err) =
|
||||
TransportMultiplexer::handle_rx(rx, channels2, unclaimed_channels2).await
|
||||
{
|
||||
error!("RX handler failed: {:?}", err);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(TransportMultiplexer {
|
||||
task_tx,
|
||||
task_rx,
|
||||
initiator,
|
||||
tx_queue: txqueue_tx,
|
||||
channels,
|
||||
unclaimed_channels,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn open_channel_with_id<R, W>(&self, channel_id: u16) -> Option<Channel<R, W>> {
|
||||
match self.channels.entry(channel_id) {
|
||||
Entry::Occupied(oe) => {
|
||||
match self.unclaimed_channels.remove(&channel_id) {
|
||||
Some((_k, chan_rx)) => Some(Channel {
|
||||
id: channel_id,
|
||||
inner: Arc::new(Mutex::new(ChannelInner {
|
||||
chan_id: channel_id,
|
||||
tx: self.tx_queue.clone(),
|
||||
rx: chan_rx,
|
||||
})),
|
||||
marker: Default::default(),
|
||||
}),
|
||||
None => {
|
||||
// Channel ID already in use.
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
Entry::Vacant(ve) => {
|
||||
let (chan_tx, chan_rx) = tokio::sync::mpsc::channel(8);
|
||||
ve.insert(chan_tx);
|
||||
Some(Channel {
|
||||
id: channel_id,
|
||||
inner: Arc::new(Mutex::new(ChannelInner {
|
||||
chan_id: channel_id,
|
||||
tx: self.tx_queue.clone(),
|
||||
rx: chan_rx,
|
||||
})),
|
||||
marker: Default::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_tx<W: AsyncWrite + Unpin>(
|
||||
mut tx: W,
|
||||
mut txqueue_rx: Receiver<(u16, Vec<u8>)>,
|
||||
) -> Result<(), CnrError> {
|
||||
while let Some((chan_id, next_msg)) = txqueue_rx.recv().await {
|
||||
tx.write_u16(chan_id).await?;
|
||||
tx.write_u32(next_msg.len().try_into().unwrap()).await?;
|
||||
tx.write_all(&next_msg).await?;
|
||||
// TODO(performance) would be nice not to flush if something else is just behind.
|
||||
tx.flush().await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_rx<R: AsyncRead + Unpin>(
|
||||
mut rx: R,
|
||||
channels: Arc<DashMap<u16, Sender<Vec<u8>>>>,
|
||||
unclaimed_channels: Arc<DashMap<u16, Receiver<Vec<u8>>>>,
|
||||
) -> Result<(), CnrError> {
|
||||
loop {
|
||||
// TODO(features): need to be able to support graceful EOF
|
||||
let chan_id = rx.read_u16().await?;
|
||||
let length = rx.read_u32().await? as usize;
|
||||
// TODO(perf): use uninit?
|
||||
let mut buf = vec![0u8; length];
|
||||
|
||||
rx.read_exact(&mut buf[..]).await?;
|
||||
|
||||
match channels.entry(chan_id) {
|
||||
Entry::Occupied(oe) => {
|
||||
if let Err(err) = oe.get().send(buf).await {
|
||||
// TODO this channel has died. What can we do about it?
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
Entry::Vacant(ve) => {
|
||||
let (chan_tx, chan_rx) = tokio::sync::mpsc::channel(8);
|
||||
unclaimed_channels.insert(chan_id, chan_rx);
|
||||
chan_tx.try_send(buf).expect("empty channel succeeds");
|
||||
ve.insert(chan_tx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ChannelHandle<LTR, RTL, const initiator: bool> {
|
||||
marker: PhantomData<(LTR, RTL)>,
|
||||
}
|
||||
|
||||
struct ChannelInner {
|
||||
chan_id: u16,
|
||||
tx: Sender<(u16, Vec<u8>)>,
|
||||
rx: Receiver<Vec<u8>>,
|
||||
}
|
||||
|
||||
pub struct ChannelLock<'a, R, W> {
|
||||
// TODO we could use an OwnedMutexGuard if needed
|
||||
guard: MutexGuard<'a, ChannelInner>,
|
||||
marker: PhantomData<(R, W)>,
|
||||
}
|
||||
|
||||
impl<'a, R: DeserializeOwned, W: Serialize> ChannelLock<'a, R, W> {
|
||||
pub async fn send(&mut self, message: &W) -> Result<(), CnrError> {
|
||||
let bytes = serde_bare::to_vec(message)?;
|
||||
self.guard
|
||||
.tx
|
||||
.send((self.guard.chan_id, bytes))
|
||||
.await
|
||||
.map_err(|_| CnrError::Closed)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn recv(&mut self) -> Result<R, CnrError> {
|
||||
match self.guard.rx.recv().await {
|
||||
Some(bytes) => Ok(serde_bare::from_slice(&bytes)?),
|
||||
None => Err(CnrError::Closed),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Channel<R, W> {
|
||||
pub id: u16,
|
||||
inner: Arc<Mutex<ChannelInner>>,
|
||||
marker: PhantomData<(R, W)>,
|
||||
}
|
||||
|
||||
impl<R: Send + DeserializeOwned, W: Serialize> Channel<R, W> {
|
||||
pub async fn lock(&self) -> ChannelLock<'_, R, W> {
|
||||
ChannelLock {
|
||||
guard: self.inner.lock().await,
|
||||
marker: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send(&mut self, message: &W) -> Result<(), CnrError> {
|
||||
self.lock().await.send(message).await
|
||||
}
|
||||
|
||||
pub async fn recv(&mut self) -> Result<R, CnrError> {
|
||||
self.lock().await.recv().await
|
||||
}
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
|
||||
use crate::CnrError;
|
||||
|
||||
pub struct BareTransport<R, W> {
|
||||
writer: W,
|
||||
reader: R,
|
||||
}
|
||||
|
||||
impl<W: AsyncWrite + Unpin, R: AsyncRead + Unpin> BareTransport<R, W> {
|
||||
pub fn new(writer: W, reader: R) -> Self {
|
||||
Self { writer, reader }
|
||||
}
|
||||
|
||||
pub async fn write_one_message<M: Serialize>(&mut self, message: &M) -> Result<(), CnrError> {
|
||||
let bytes = serde_bare::to_vec(message)?;
|
||||
self.writer
|
||||
.write_u32(bytes.len().try_into().unwrap())
|
||||
.await?;
|
||||
self.writer.write_all(&bytes).await?;
|
||||
// TODO flush?
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn read_one_message<M: DeserializeOwned>(&mut self) -> Result<Option<M>, CnrError> {
|
||||
let length = self.reader.read_u32().await? as usize;
|
||||
// TODO(perf): use uninit?
|
||||
let mut buf = vec![0u8; length];
|
||||
|
||||
let first_read = self.reader.read(&mut buf[..]).await?;
|
||||
if first_read == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if first_read != length {
|
||||
self.reader.read_exact(&mut buf[first_read..]).await?;
|
||||
}
|
||||
|
||||
return Ok(Some(serde_bare::from_slice(&buf)?));
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue