diff --git a/bare_cnr/src/multiplexer.rs b/bare_cnr/src/multiplexer.rs index a3fa8e9..408f809 100644 --- a/bare_cnr/src/multiplexer.rs +++ b/bare_cnr/src/multiplexer.rs @@ -2,11 +2,13 @@ use crate::CnrError; use crate::CnrError::Closed; use dashmap::mapref::entry::Entry; use dashmap::DashMap; -use log::error; +use log::{error, warn}; use serde::de::DeserializeOwned; -use serde::Serialize; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::borrow::{Borrow, BorrowMut}; +use std::cell::RefCell; use std::marker::PhantomData; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::mpsc::{Receiver, Sender}; use tokio::sync::{Mutex, MutexGuard}; @@ -28,7 +30,10 @@ pub struct TransportMultiplexer { /// Channel receivers for channels that have received messages but haven't yet been claimed unclaimed_channels: Arc>>>, - // TODO channels: +} + +thread_local! { + static CURRENT_MULTIPLEXER: RefCell> = Default::default(); } impl TransportMultiplexer { @@ -36,7 +41,7 @@ impl TransportMultiplexer { rx: R, tx: W, initiator: bool, - ) -> Result { + ) -> Result, CnrError> { let (txqueue_tx, txqueue_rx) = tokio::sync::mpsc::channel(8); let channels = Arc::new(Default::default()); @@ -57,19 +62,19 @@ impl TransportMultiplexer { } }); - Ok(TransportMultiplexer { + Ok(Arc::new(TransportMultiplexer { task_tx, task_rx, initiator, tx_queue: txqueue_tx, channels, unclaimed_channels, - }) + })) } - pub fn open_channel_with_id(&self, channel_id: u16) -> Option> { + pub fn open_channel_with_id(self: &Arc, channel_id: u16) -> Option> { match self.channels.entry(channel_id) { - Entry::Occupied(oe) => { + Entry::Occupied(_) => { match self.unclaimed_channels.remove(&channel_id) { Some((_k, chan_rx)) => Some(Channel { id: channel_id, @@ -77,6 +82,7 @@ impl TransportMultiplexer { chan_id: channel_id, tx: self.tx_queue.clone(), rx: chan_rx, + multiplexer: Arc::downgrade(&self), })), marker: Default::default(), }), @@ -95,6 +101,7 @@ impl TransportMultiplexer { chan_id: channel_id, tx: self.tx_queue.clone(), rx: chan_rx, + multiplexer: Arc::downgrade(&self), })), marker: Default::default(), }) @@ -134,7 +141,8 @@ impl TransportMultiplexer { Entry::Occupied(oe) => { if let Err(err) = oe.get().send(buf).await { // TODO this channel has died. What can we do about it? - todo!(); + warn!("Message received but channel {} dead", chan_id); + // TODO maybe we should clean it up at this point... } } Entry::Vacant(ve) => { @@ -148,14 +156,51 @@ impl TransportMultiplexer { } } -pub struct ChannelHandle { +pub struct ChannelHandle { marker: PhantomData<(LTR, RTL)>, + chan_id: u16, + multiplexer: Weak, +} + +impl ChannelHandle { + pub fn into_channel(self) -> Option> { + let multiplexer = self.multiplexer.upgrade()?; + multiplexer.open_channel_with_id(self.chan_id) + } +} + +//impl Serialize for ChannelHandle { +impl Serialize for ChannelHandle { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + self.chan_id.serialize(serializer) + } +} + +//impl<'de, LTR, RTL, const INITIATOR: bool> Deserialize<'de> for ChannelHandle { +impl<'de, LTR, RTL> Deserialize<'de> for ChannelHandle { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + // TODO We could check the initiator flags are the right way around by embedding a bool in + // the format. + let chan_id = u16::deserialize(deserializer)?; + Ok(ChannelHandle { + marker: Default::default(), + chan_id, + multiplexer: CURRENT_MULTIPLEXER.with(|refcell| refcell.borrow().clone()), + }) + } } struct ChannelInner { chan_id: u16, tx: Sender<(u16, Vec)>, rx: Receiver>, + multiplexer: Weak, } pub struct ChannelLock<'a, R, W> { @@ -176,10 +221,18 @@ impl<'a, R: DeserializeOwned, W: Serialize> ChannelLock<'a, R, W> { } pub async fn recv(&mut self) -> Result { - match self.guard.rx.recv().await { + let new_weak = self.guard.multiplexer.clone(); + CURRENT_MULTIPLEXER.with(move |refcell| { + *(refcell.borrow_mut()) = new_weak; + }); + let res = match self.guard.rx.recv().await { Some(bytes) => Ok(serde_bare::from_slice(&bytes)?), None => Err(CnrError::Closed), - } + }; + CURRENT_MULTIPLEXER.with(|refcell| { + *(refcell.borrow_mut()) = Weak::new(); + }); + res } } @@ -189,6 +242,17 @@ pub struct Channel { marker: PhantomData<(R, W)>, } +impl Channel { + /// Channel handles only get sent in one direction. + pub fn handle(&self) -> ChannelHandle { + ChannelHandle { + marker: Default::default(), + chan_id: self.id, + multiplexer: Default::default(), + } + } +} + impl Channel { pub async fn lock(&self) -> ChannelLock<'_, R, W> { ChannelLock {