Introduce channel handles that can be passed in serde messages
This commit is contained in:
parent
183f365032
commit
b659a5ddac
|
@ -2,11 +2,13 @@ use crate::CnrError;
|
|||
use crate::CnrError::Closed;
|
||||
use dashmap::mapref::entry::Entry;
|
||||
use dashmap::DashMap;
|
||||
use log::error;
|
||||
use log::{error, warn};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::borrow::{Borrow, BorrowMut};
|
||||
use std::cell::RefCell;
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, Weak};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::sync::mpsc::{Receiver, Sender};
|
||||
use tokio::sync::{Mutex, MutexGuard};
|
||||
|
@ -28,7 +30,10 @@ pub struct TransportMultiplexer {
|
|||
|
||||
/// Channel receivers for channels that have received messages but haven't yet been claimed
|
||||
unclaimed_channels: Arc<DashMap<u16, Receiver<Vec<u8>>>>,
|
||||
// TODO channels:
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static CURRENT_MULTIPLEXER: RefCell<Weak<TransportMultiplexer>> = Default::default();
|
||||
}
|
||||
|
||||
impl TransportMultiplexer {
|
||||
|
@ -36,7 +41,7 @@ impl TransportMultiplexer {
|
|||
rx: R,
|
||||
tx: W,
|
||||
initiator: bool,
|
||||
) -> Result<TransportMultiplexer, CnrError> {
|
||||
) -> Result<Arc<TransportMultiplexer>, CnrError> {
|
||||
let (txqueue_tx, txqueue_rx) = tokio::sync::mpsc::channel(8);
|
||||
|
||||
let channels = Arc::new(Default::default());
|
||||
|
@ -57,19 +62,19 @@ impl TransportMultiplexer {
|
|||
}
|
||||
});
|
||||
|
||||
Ok(TransportMultiplexer {
|
||||
Ok(Arc::new(TransportMultiplexer {
|
||||
task_tx,
|
||||
task_rx,
|
||||
initiator,
|
||||
tx_queue: txqueue_tx,
|
||||
channels,
|
||||
unclaimed_channels,
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn open_channel_with_id<R, W>(&self, channel_id: u16) -> Option<Channel<R, W>> {
|
||||
pub fn open_channel_with_id<R, W>(self: &Arc<Self>, channel_id: u16) -> Option<Channel<R, W>> {
|
||||
match self.channels.entry(channel_id) {
|
||||
Entry::Occupied(oe) => {
|
||||
Entry::Occupied(_) => {
|
||||
match self.unclaimed_channels.remove(&channel_id) {
|
||||
Some((_k, chan_rx)) => Some(Channel {
|
||||
id: channel_id,
|
||||
|
@ -77,6 +82,7 @@ impl TransportMultiplexer {
|
|||
chan_id: channel_id,
|
||||
tx: self.tx_queue.clone(),
|
||||
rx: chan_rx,
|
||||
multiplexer: Arc::downgrade(&self),
|
||||
})),
|
||||
marker: Default::default(),
|
||||
}),
|
||||
|
@ -95,6 +101,7 @@ impl TransportMultiplexer {
|
|||
chan_id: channel_id,
|
||||
tx: self.tx_queue.clone(),
|
||||
rx: chan_rx,
|
||||
multiplexer: Arc::downgrade(&self),
|
||||
})),
|
||||
marker: Default::default(),
|
||||
})
|
||||
|
@ -134,7 +141,8 @@ impl TransportMultiplexer {
|
|||
Entry::Occupied(oe) => {
|
||||
if let Err(err) = oe.get().send(buf).await {
|
||||
// TODO this channel has died. What can we do about it?
|
||||
todo!();
|
||||
warn!("Message received but channel {} dead", chan_id);
|
||||
// TODO maybe we should clean it up at this point...
|
||||
}
|
||||
}
|
||||
Entry::Vacant(ve) => {
|
||||
|
@ -148,14 +156,51 @@ impl TransportMultiplexer {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct ChannelHandle<LTR, RTL, const initiator: bool> {
|
||||
pub struct ChannelHandle<LTR, RTL, const INITIATOR: bool> {
|
||||
marker: PhantomData<(LTR, RTL)>,
|
||||
chan_id: u16,
|
||||
multiplexer: Weak<TransportMultiplexer>,
|
||||
}
|
||||
|
||||
impl<LTR, RTL> ChannelHandle<LTR, RTL, false> {
|
||||
pub fn into_channel(self) -> Option<Channel<RTL, LTR>> {
|
||||
let multiplexer = self.multiplexer.upgrade()?;
|
||||
multiplexer.open_channel_with_id(self.chan_id)
|
||||
}
|
||||
}
|
||||
|
||||
//impl<LTR, RTL, const INITIATOR: bool> Serialize for ChannelHandle<LTR, RTL, INITIATOR> {
|
||||
impl<LTR, RTL> Serialize for ChannelHandle<LTR, RTL, true> {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
self.chan_id.serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
//impl<'de, LTR, RTL, const INITIATOR: bool> Deserialize<'de> for ChannelHandle<LTR, RTL, INITIATOR> {
|
||||
impl<'de, LTR, RTL> Deserialize<'de> for ChannelHandle<LTR, RTL, false> {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
// TODO We could check the initiator flags are the right way around by embedding a bool in
|
||||
// the format.
|
||||
let chan_id = u16::deserialize(deserializer)?;
|
||||
Ok(ChannelHandle {
|
||||
marker: Default::default(),
|
||||
chan_id,
|
||||
multiplexer: CURRENT_MULTIPLEXER.with(|refcell| refcell.borrow().clone()),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct ChannelInner {
|
||||
chan_id: u16,
|
||||
tx: Sender<(u16, Vec<u8>)>,
|
||||
rx: Receiver<Vec<u8>>,
|
||||
multiplexer: Weak<TransportMultiplexer>,
|
||||
}
|
||||
|
||||
pub struct ChannelLock<'a, R, W> {
|
||||
|
@ -176,10 +221,18 @@ impl<'a, R: DeserializeOwned, W: Serialize> ChannelLock<'a, R, W> {
|
|||
}
|
||||
|
||||
pub async fn recv(&mut self) -> Result<R, CnrError> {
|
||||
match self.guard.rx.recv().await {
|
||||
let new_weak = self.guard.multiplexer.clone();
|
||||
CURRENT_MULTIPLEXER.with(move |refcell| {
|
||||
*(refcell.borrow_mut()) = new_weak;
|
||||
});
|
||||
let res = match self.guard.rx.recv().await {
|
||||
Some(bytes) => Ok(serde_bare::from_slice(&bytes)?),
|
||||
None => Err(CnrError::Closed),
|
||||
}
|
||||
};
|
||||
CURRENT_MULTIPLEXER.with(|refcell| {
|
||||
*(refcell.borrow_mut()) = Weak::new();
|
||||
});
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -189,6 +242,17 @@ pub struct Channel<R, W> {
|
|||
marker: PhantomData<(R, W)>,
|
||||
}
|
||||
|
||||
impl<R, W> Channel<R, W> {
|
||||
/// Channel handles only get sent in one direction.
|
||||
pub fn handle(&self) -> ChannelHandle<R, W, true> {
|
||||
ChannelHandle {
|
||||
marker: Default::default(),
|
||||
chan_id: self.id,
|
||||
multiplexer: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Send + DeserializeOwned, W: Serialize> Channel<R, W> {
|
||||
pub async fn lock(&self) -> ChannelLock<'_, R, W> {
|
||||
ChannelLock {
|
||||
|
|
Loading…
Reference in New Issue