Start a half decent Bare CnR crate
This commit is contained in:
		
							parent
							
								
									d0ed984dca
								
							
						
					
					
						commit
						183f365032
					
				
							
								
								
									
										22
									
								
								bare_cnr/Cargo.toml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								bare_cnr/Cargo.toml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,22 @@
 | 
			
		||||
[package]
 | 
			
		||||
name = "bare_cnr"
 | 
			
		||||
version = "0.1.0-alpha.1"
 | 
			
		||||
edition = "2021"
 | 
			
		||||
 | 
			
		||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 | 
			
		||||
 | 
			
		||||
[dependencies]
 | 
			
		||||
thiserror = "1.0.31"
 | 
			
		||||
serde = "1.0.137"
 | 
			
		||||
serde_bare = "0.5.0"
 | 
			
		||||
dashmap = "5.3.4"
 | 
			
		||||
log = "0.4.17"
 | 
			
		||||
 | 
			
		||||
[dependencies.tokio]
 | 
			
		||||
# TODO optional = true
 | 
			
		||||
version = "1.18.2"
 | 
			
		||||
features = ["full"] # TODO restrict this
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
[features]
 | 
			
		||||
# TODO default = ["tokio"]
 | 
			
		||||
							
								
								
									
										19
									
								
								bare_cnr/src/lib.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								bare_cnr/src/lib.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,19 @@
 | 
			
		||||
use thiserror::Error;
 | 
			
		||||
 | 
			
		||||
mod multiplexer;
 | 
			
		||||
mod transport;
 | 
			
		||||
 | 
			
		||||
#[derive(Error, Debug)]
 | 
			
		||||
pub enum CnrError {
 | 
			
		||||
    #[error("Version mismatch.")]
 | 
			
		||||
    VersionMismatch,
 | 
			
		||||
 | 
			
		||||
    #[error("Ser/deserialisation error.")]
 | 
			
		||||
    Serde(#[from] serde_bare::error::Error),
 | 
			
		||||
 | 
			
		||||
    #[error("Input/output error.")]
 | 
			
		||||
    Io(#[from] std::io::Error),
 | 
			
		||||
 | 
			
		||||
    #[error("Channel closed.")]
 | 
			
		||||
    Closed,
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										207
									
								
								bare_cnr/src/multiplexer.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										207
									
								
								bare_cnr/src/multiplexer.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,207 @@
 | 
			
		||||
use crate::CnrError;
 | 
			
		||||
use crate::CnrError::Closed;
 | 
			
		||||
use dashmap::mapref::entry::Entry;
 | 
			
		||||
use dashmap::DashMap;
 | 
			
		||||
use log::error;
 | 
			
		||||
use serde::de::DeserializeOwned;
 | 
			
		||||
use serde::Serialize;
 | 
			
		||||
use std::marker::PhantomData;
 | 
			
		||||
use std::sync::Arc;
 | 
			
		||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
 | 
			
		||||
use tokio::sync::mpsc::{Receiver, Sender};
 | 
			
		||||
use tokio::sync::{Mutex, MutexGuard};
 | 
			
		||||
use tokio::task::JoinHandle;
 | 
			
		||||
 | 
			
		||||
pub struct TransportMultiplexer {
 | 
			
		||||
    task_tx: JoinHandle<()>,
 | 
			
		||||
    task_rx: JoinHandle<()>,
 | 
			
		||||
 | 
			
		||||
    /// Whether this side is the initiator.
 | 
			
		||||
    /// The initiator can open even-numbered channels.
 | 
			
		||||
    initiator: bool,
 | 
			
		||||
 | 
			
		||||
    /// Queue to send new bytes out on the wire
 | 
			
		||||
    tx_queue: Sender<(u16, Vec<u8>)>,
 | 
			
		||||
 | 
			
		||||
    /// Senders for sending messages from the wire to channels
 | 
			
		||||
    channels: Arc<DashMap<u16, Sender<Vec<u8>>>>,
 | 
			
		||||
 | 
			
		||||
    /// Channel receivers for channels that have received messages but haven't yet been claimed
 | 
			
		||||
    unclaimed_channels: Arc<DashMap<u16, Receiver<Vec<u8>>>>,
 | 
			
		||||
    // TODO channels:
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl TransportMultiplexer {
 | 
			
		||||
    pub fn new<R: AsyncRead + Send + 'static + Unpin, W: AsyncWrite + Send + 'static + Unpin>(
 | 
			
		||||
        rx: R,
 | 
			
		||||
        tx: W,
 | 
			
		||||
        initiator: bool,
 | 
			
		||||
    ) -> Result<TransportMultiplexer, CnrError> {
 | 
			
		||||
        let (txqueue_tx, txqueue_rx) = tokio::sync::mpsc::channel(8);
 | 
			
		||||
 | 
			
		||||
        let channels = Arc::new(Default::default());
 | 
			
		||||
        let unclaimed_channels = Arc::new(Default::default());
 | 
			
		||||
 | 
			
		||||
        let task_tx = tokio::spawn(async move {
 | 
			
		||||
            if let Err(err) = TransportMultiplexer::handle_tx(tx, txqueue_rx).await {
 | 
			
		||||
                error!("TX handler failed: {:?}", err);
 | 
			
		||||
            }
 | 
			
		||||
        });
 | 
			
		||||
        let channels2 = Arc::clone(&channels);
 | 
			
		||||
        let unclaimed_channels2 = Arc::clone(&unclaimed_channels);
 | 
			
		||||
        let task_rx = tokio::spawn(async move {
 | 
			
		||||
            if let Err(err) =
 | 
			
		||||
                TransportMultiplexer::handle_rx(rx, channels2, unclaimed_channels2).await
 | 
			
		||||
            {
 | 
			
		||||
                error!("RX handler failed: {:?}", err);
 | 
			
		||||
            }
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        Ok(TransportMultiplexer {
 | 
			
		||||
            task_tx,
 | 
			
		||||
            task_rx,
 | 
			
		||||
            initiator,
 | 
			
		||||
            tx_queue: txqueue_tx,
 | 
			
		||||
            channels,
 | 
			
		||||
            unclaimed_channels,
 | 
			
		||||
        })
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub fn open_channel_with_id<R, W>(&self, channel_id: u16) -> Option<Channel<R, W>> {
 | 
			
		||||
        match self.channels.entry(channel_id) {
 | 
			
		||||
            Entry::Occupied(oe) => {
 | 
			
		||||
                match self.unclaimed_channels.remove(&channel_id) {
 | 
			
		||||
                    Some((_k, chan_rx)) => Some(Channel {
 | 
			
		||||
                        id: channel_id,
 | 
			
		||||
                        inner: Arc::new(Mutex::new(ChannelInner {
 | 
			
		||||
                            chan_id: channel_id,
 | 
			
		||||
                            tx: self.tx_queue.clone(),
 | 
			
		||||
                            rx: chan_rx,
 | 
			
		||||
                        })),
 | 
			
		||||
                        marker: Default::default(),
 | 
			
		||||
                    }),
 | 
			
		||||
                    None => {
 | 
			
		||||
                        // Channel ID already in use.
 | 
			
		||||
                        None
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            Entry::Vacant(ve) => {
 | 
			
		||||
                let (chan_tx, chan_rx) = tokio::sync::mpsc::channel(8);
 | 
			
		||||
                ve.insert(chan_tx);
 | 
			
		||||
                Some(Channel {
 | 
			
		||||
                    id: channel_id,
 | 
			
		||||
                    inner: Arc::new(Mutex::new(ChannelInner {
 | 
			
		||||
                        chan_id: channel_id,
 | 
			
		||||
                        tx: self.tx_queue.clone(),
 | 
			
		||||
                        rx: chan_rx,
 | 
			
		||||
                    })),
 | 
			
		||||
                    marker: Default::default(),
 | 
			
		||||
                })
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn handle_tx<W: AsyncWrite + Unpin>(
 | 
			
		||||
        mut tx: W,
 | 
			
		||||
        mut txqueue_rx: Receiver<(u16, Vec<u8>)>,
 | 
			
		||||
    ) -> Result<(), CnrError> {
 | 
			
		||||
        while let Some((chan_id, next_msg)) = txqueue_rx.recv().await {
 | 
			
		||||
            tx.write_u16(chan_id).await?;
 | 
			
		||||
            tx.write_u32(next_msg.len().try_into().unwrap()).await?;
 | 
			
		||||
            tx.write_all(&next_msg).await?;
 | 
			
		||||
            // TODO(performance) would be nice not to flush if something else is just behind.
 | 
			
		||||
            tx.flush().await?;
 | 
			
		||||
        }
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    async fn handle_rx<R: AsyncRead + Unpin>(
 | 
			
		||||
        mut rx: R,
 | 
			
		||||
        channels: Arc<DashMap<u16, Sender<Vec<u8>>>>,
 | 
			
		||||
        unclaimed_channels: Arc<DashMap<u16, Receiver<Vec<u8>>>>,
 | 
			
		||||
    ) -> Result<(), CnrError> {
 | 
			
		||||
        loop {
 | 
			
		||||
            // TODO(features): need to be able to support graceful EOF
 | 
			
		||||
            let chan_id = rx.read_u16().await?;
 | 
			
		||||
            let length = rx.read_u32().await? as usize;
 | 
			
		||||
            // TODO(perf): use uninit?
 | 
			
		||||
            let mut buf = vec![0u8; length];
 | 
			
		||||
 | 
			
		||||
            rx.read_exact(&mut buf[..]).await?;
 | 
			
		||||
 | 
			
		||||
            match channels.entry(chan_id) {
 | 
			
		||||
                Entry::Occupied(oe) => {
 | 
			
		||||
                    if let Err(err) = oe.get().send(buf).await {
 | 
			
		||||
                        // TODO this channel has died. What can we do about it?
 | 
			
		||||
                        todo!();
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                Entry::Vacant(ve) => {
 | 
			
		||||
                    let (chan_tx, chan_rx) = tokio::sync::mpsc::channel(8);
 | 
			
		||||
                    unclaimed_channels.insert(chan_id, chan_rx);
 | 
			
		||||
                    chan_tx.try_send(buf).expect("empty channel succeeds");
 | 
			
		||||
                    ve.insert(chan_tx);
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct ChannelHandle<LTR, RTL, const initiator: bool> {
 | 
			
		||||
    marker: PhantomData<(LTR, RTL)>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct ChannelInner {
 | 
			
		||||
    chan_id: u16,
 | 
			
		||||
    tx: Sender<(u16, Vec<u8>)>,
 | 
			
		||||
    rx: Receiver<Vec<u8>>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct ChannelLock<'a, R, W> {
 | 
			
		||||
    // TODO we could use an OwnedMutexGuard if needed
 | 
			
		||||
    guard: MutexGuard<'a, ChannelInner>,
 | 
			
		||||
    marker: PhantomData<(R, W)>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'a, R: DeserializeOwned, W: Serialize> ChannelLock<'a, R, W> {
 | 
			
		||||
    pub async fn send(&mut self, message: &W) -> Result<(), CnrError> {
 | 
			
		||||
        let bytes = serde_bare::to_vec(message)?;
 | 
			
		||||
        self.guard
 | 
			
		||||
            .tx
 | 
			
		||||
            .send((self.guard.chan_id, bytes))
 | 
			
		||||
            .await
 | 
			
		||||
            .map_err(|_| CnrError::Closed)?;
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn recv(&mut self) -> Result<R, CnrError> {
 | 
			
		||||
        match self.guard.rx.recv().await {
 | 
			
		||||
            Some(bytes) => Ok(serde_bare::from_slice(&bytes)?),
 | 
			
		||||
            None => Err(CnrError::Closed),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
pub struct Channel<R, W> {
 | 
			
		||||
    pub id: u16,
 | 
			
		||||
    inner: Arc<Mutex<ChannelInner>>,
 | 
			
		||||
    marker: PhantomData<(R, W)>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<R: Send + DeserializeOwned, W: Serialize> Channel<R, W> {
 | 
			
		||||
    pub async fn lock(&self) -> ChannelLock<'_, R, W> {
 | 
			
		||||
        ChannelLock {
 | 
			
		||||
            guard: self.inner.lock().await,
 | 
			
		||||
            marker: Default::default(),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn send(&mut self, message: &W) -> Result<(), CnrError> {
 | 
			
		||||
        self.lock().await.send(message).await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn recv(&mut self) -> Result<R, CnrError> {
 | 
			
		||||
        self.lock().await.recv().await
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										43
									
								
								bare_cnr/src/transport.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								bare_cnr/src/transport.rs
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,43 @@
 | 
			
		||||
use serde::de::DeserializeOwned;
 | 
			
		||||
use serde::Serialize;
 | 
			
		||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
 | 
			
		||||
 | 
			
		||||
use crate::CnrError;
 | 
			
		||||
 | 
			
		||||
pub struct BareTransport<R, W> {
 | 
			
		||||
    writer: W,
 | 
			
		||||
    reader: R,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<W: AsyncWrite + Unpin, R: AsyncRead + Unpin> BareTransport<R, W> {
 | 
			
		||||
    pub fn new(writer: W, reader: R) -> Self {
 | 
			
		||||
        Self { writer, reader }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn write_one_message<M: Serialize>(&mut self, message: &M) -> Result<(), CnrError> {
 | 
			
		||||
        let bytes = serde_bare::to_vec(message)?;
 | 
			
		||||
        self.writer
 | 
			
		||||
            .write_u32(bytes.len().try_into().unwrap())
 | 
			
		||||
            .await?;
 | 
			
		||||
        self.writer.write_all(&bytes).await?;
 | 
			
		||||
        // TODO flush?
 | 
			
		||||
        Ok(())
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    pub async fn read_one_message<M: DeserializeOwned>(&mut self) -> Result<Option<M>, CnrError> {
 | 
			
		||||
        let length = self.reader.read_u32().await? as usize;
 | 
			
		||||
        // TODO(perf): use uninit?
 | 
			
		||||
        let mut buf = vec![0u8; length];
 | 
			
		||||
 | 
			
		||||
        let first_read = self.reader.read(&mut buf[..]).await?;
 | 
			
		||||
        if first_read == 0 {
 | 
			
		||||
            return Ok(None);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if first_read != length {
 | 
			
		||||
            self.reader.read_exact(&mut buf[first_read..]).await?;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return Ok(Some(serde_bare::from_slice(&buf)?));
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user