Introduce channel handles that can be passed in serde messages

This commit is contained in:
Olivier 'reivilibre' 2022-05-28 20:40:04 +01:00
parent 183f365032
commit b659a5ddac
1 changed files with 77 additions and 13 deletions

View File

@ -2,11 +2,13 @@ use crate::CnrError;
use crate::CnrError::Closed;
use dashmap::mapref::entry::Entry;
use dashmap::DashMap;
use log::error;
use log::{error, warn};
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::borrow::{Borrow, BorrowMut};
use std::cell::RefCell;
use std::marker::PhantomData;
use std::sync::Arc;
use std::sync::{Arc, Weak};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::{Mutex, MutexGuard};
@ -28,7 +30,10 @@ pub struct TransportMultiplexer {
/// Channel receivers for channels that have received messages but haven't yet been claimed
unclaimed_channels: Arc<DashMap<u16, Receiver<Vec<u8>>>>,
// TODO channels:
}
thread_local! {
static CURRENT_MULTIPLEXER: RefCell<Weak<TransportMultiplexer>> = Default::default();
}
impl TransportMultiplexer {
@ -36,7 +41,7 @@ impl TransportMultiplexer {
rx: R,
tx: W,
initiator: bool,
) -> Result<TransportMultiplexer, CnrError> {
) -> Result<Arc<TransportMultiplexer>, CnrError> {
let (txqueue_tx, txqueue_rx) = tokio::sync::mpsc::channel(8);
let channels = Arc::new(Default::default());
@ -57,19 +62,19 @@ impl TransportMultiplexer {
}
});
Ok(TransportMultiplexer {
Ok(Arc::new(TransportMultiplexer {
task_tx,
task_rx,
initiator,
tx_queue: txqueue_tx,
channels,
unclaimed_channels,
})
}))
}
pub fn open_channel_with_id<R, W>(&self, channel_id: u16) -> Option<Channel<R, W>> {
pub fn open_channel_with_id<R, W>(self: &Arc<Self>, channel_id: u16) -> Option<Channel<R, W>> {
match self.channels.entry(channel_id) {
Entry::Occupied(oe) => {
Entry::Occupied(_) => {
match self.unclaimed_channels.remove(&channel_id) {
Some((_k, chan_rx)) => Some(Channel {
id: channel_id,
@ -77,6 +82,7 @@ impl TransportMultiplexer {
chan_id: channel_id,
tx: self.tx_queue.clone(),
rx: chan_rx,
multiplexer: Arc::downgrade(&self),
})),
marker: Default::default(),
}),
@ -95,6 +101,7 @@ impl TransportMultiplexer {
chan_id: channel_id,
tx: self.tx_queue.clone(),
rx: chan_rx,
multiplexer: Arc::downgrade(&self),
})),
marker: Default::default(),
})
@ -134,7 +141,8 @@ impl TransportMultiplexer {
Entry::Occupied(oe) => {
if let Err(err) = oe.get().send(buf).await {
// TODO this channel has died. What can we do about it?
todo!();
warn!("Message received but channel {} dead", chan_id);
// TODO maybe we should clean it up at this point...
}
}
Entry::Vacant(ve) => {
@ -148,14 +156,51 @@ impl TransportMultiplexer {
}
}
pub struct ChannelHandle<LTR, RTL, const initiator: bool> {
pub struct ChannelHandle<LTR, RTL, const INITIATOR: bool> {
marker: PhantomData<(LTR, RTL)>,
chan_id: u16,
multiplexer: Weak<TransportMultiplexer>,
}
impl<LTR, RTL> ChannelHandle<LTR, RTL, false> {
pub fn into_channel(self) -> Option<Channel<RTL, LTR>> {
let multiplexer = self.multiplexer.upgrade()?;
multiplexer.open_channel_with_id(self.chan_id)
}
}
//impl<LTR, RTL, const INITIATOR: bool> Serialize for ChannelHandle<LTR, RTL, INITIATOR> {
impl<LTR, RTL> Serialize for ChannelHandle<LTR, RTL, true> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.chan_id.serialize(serializer)
}
}
//impl<'de, LTR, RTL, const INITIATOR: bool> Deserialize<'de> for ChannelHandle<LTR, RTL, INITIATOR> {
impl<'de, LTR, RTL> Deserialize<'de> for ChannelHandle<LTR, RTL, false> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
// TODO We could check the initiator flags are the right way around by embedding a bool in
// the format.
let chan_id = u16::deserialize(deserializer)?;
Ok(ChannelHandle {
marker: Default::default(),
chan_id,
multiplexer: CURRENT_MULTIPLEXER.with(|refcell| refcell.borrow().clone()),
})
}
}
struct ChannelInner {
chan_id: u16,
tx: Sender<(u16, Vec<u8>)>,
rx: Receiver<Vec<u8>>,
multiplexer: Weak<TransportMultiplexer>,
}
pub struct ChannelLock<'a, R, W> {
@ -176,10 +221,18 @@ impl<'a, R: DeserializeOwned, W: Serialize> ChannelLock<'a, R, W> {
}
pub async fn recv(&mut self) -> Result<R, CnrError> {
match self.guard.rx.recv().await {
let new_weak = self.guard.multiplexer.clone();
CURRENT_MULTIPLEXER.with(move |refcell| {
*(refcell.borrow_mut()) = new_weak;
});
let res = match self.guard.rx.recv().await {
Some(bytes) => Ok(serde_bare::from_slice(&bytes)?),
None => Err(CnrError::Closed),
}
};
CURRENT_MULTIPLEXER.with(|refcell| {
*(refcell.borrow_mut()) = Weak::new();
});
res
}
}
@ -189,6 +242,17 @@ pub struct Channel<R, W> {
marker: PhantomData<(R, W)>,
}
impl<R, W> Channel<R, W> {
/// Channel handles only get sent in one direction.
pub fn handle(&self) -> ChannelHandle<R, W, true> {
ChannelHandle {
marker: Default::default(),
chan_id: self.id,
multiplexer: Default::default(),
}
}
}
impl<R: Send + DeserializeOwned, W: Serialize> Channel<R, W> {
pub async fn lock(&self) -> ChannelLock<'_, R, W> {
ChannelLock {