Introduce channel handles that can be passed in serde messages
This commit is contained in:
parent
183f365032
commit
b659a5ddac
|
@ -2,11 +2,13 @@ use crate::CnrError;
|
||||||
use crate::CnrError::Closed;
|
use crate::CnrError::Closed;
|
||||||
use dashmap::mapref::entry::Entry;
|
use dashmap::mapref::entry::Entry;
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use log::error;
|
use log::{error, warn};
|
||||||
use serde::de::DeserializeOwned;
|
use serde::de::DeserializeOwned;
|
||||||
use serde::Serialize;
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||||
|
use std::borrow::{Borrow, BorrowMut};
|
||||||
|
use std::cell::RefCell;
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, Weak};
|
||||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
use tokio::sync::mpsc::{Receiver, Sender};
|
use tokio::sync::mpsc::{Receiver, Sender};
|
||||||
use tokio::sync::{Mutex, MutexGuard};
|
use tokio::sync::{Mutex, MutexGuard};
|
||||||
|
@ -28,7 +30,10 @@ pub struct TransportMultiplexer {
|
||||||
|
|
||||||
/// Channel receivers for channels that have received messages but haven't yet been claimed
|
/// Channel receivers for channels that have received messages but haven't yet been claimed
|
||||||
unclaimed_channels: Arc<DashMap<u16, Receiver<Vec<u8>>>>,
|
unclaimed_channels: Arc<DashMap<u16, Receiver<Vec<u8>>>>,
|
||||||
// TODO channels:
|
}
|
||||||
|
|
||||||
|
thread_local! {
|
||||||
|
static CURRENT_MULTIPLEXER: RefCell<Weak<TransportMultiplexer>> = Default::default();
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TransportMultiplexer {
|
impl TransportMultiplexer {
|
||||||
|
@ -36,7 +41,7 @@ impl TransportMultiplexer {
|
||||||
rx: R,
|
rx: R,
|
||||||
tx: W,
|
tx: W,
|
||||||
initiator: bool,
|
initiator: bool,
|
||||||
) -> Result<TransportMultiplexer, CnrError> {
|
) -> Result<Arc<TransportMultiplexer>, CnrError> {
|
||||||
let (txqueue_tx, txqueue_rx) = tokio::sync::mpsc::channel(8);
|
let (txqueue_tx, txqueue_rx) = tokio::sync::mpsc::channel(8);
|
||||||
|
|
||||||
let channels = Arc::new(Default::default());
|
let channels = Arc::new(Default::default());
|
||||||
|
@ -57,19 +62,19 @@ impl TransportMultiplexer {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
Ok(TransportMultiplexer {
|
Ok(Arc::new(TransportMultiplexer {
|
||||||
task_tx,
|
task_tx,
|
||||||
task_rx,
|
task_rx,
|
||||||
initiator,
|
initiator,
|
||||||
tx_queue: txqueue_tx,
|
tx_queue: txqueue_tx,
|
||||||
channels,
|
channels,
|
||||||
unclaimed_channels,
|
unclaimed_channels,
|
||||||
})
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn open_channel_with_id<R, W>(&self, channel_id: u16) -> Option<Channel<R, W>> {
|
pub fn open_channel_with_id<R, W>(self: &Arc<Self>, channel_id: u16) -> Option<Channel<R, W>> {
|
||||||
match self.channels.entry(channel_id) {
|
match self.channels.entry(channel_id) {
|
||||||
Entry::Occupied(oe) => {
|
Entry::Occupied(_) => {
|
||||||
match self.unclaimed_channels.remove(&channel_id) {
|
match self.unclaimed_channels.remove(&channel_id) {
|
||||||
Some((_k, chan_rx)) => Some(Channel {
|
Some((_k, chan_rx)) => Some(Channel {
|
||||||
id: channel_id,
|
id: channel_id,
|
||||||
|
@ -77,6 +82,7 @@ impl TransportMultiplexer {
|
||||||
chan_id: channel_id,
|
chan_id: channel_id,
|
||||||
tx: self.tx_queue.clone(),
|
tx: self.tx_queue.clone(),
|
||||||
rx: chan_rx,
|
rx: chan_rx,
|
||||||
|
multiplexer: Arc::downgrade(&self),
|
||||||
})),
|
})),
|
||||||
marker: Default::default(),
|
marker: Default::default(),
|
||||||
}),
|
}),
|
||||||
|
@ -95,6 +101,7 @@ impl TransportMultiplexer {
|
||||||
chan_id: channel_id,
|
chan_id: channel_id,
|
||||||
tx: self.tx_queue.clone(),
|
tx: self.tx_queue.clone(),
|
||||||
rx: chan_rx,
|
rx: chan_rx,
|
||||||
|
multiplexer: Arc::downgrade(&self),
|
||||||
})),
|
})),
|
||||||
marker: Default::default(),
|
marker: Default::default(),
|
||||||
})
|
})
|
||||||
|
@ -134,7 +141,8 @@ impl TransportMultiplexer {
|
||||||
Entry::Occupied(oe) => {
|
Entry::Occupied(oe) => {
|
||||||
if let Err(err) = oe.get().send(buf).await {
|
if let Err(err) = oe.get().send(buf).await {
|
||||||
// TODO this channel has died. What can we do about it?
|
// TODO this channel has died. What can we do about it?
|
||||||
todo!();
|
warn!("Message received but channel {} dead", chan_id);
|
||||||
|
// TODO maybe we should clean it up at this point...
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Entry::Vacant(ve) => {
|
Entry::Vacant(ve) => {
|
||||||
|
@ -148,14 +156,51 @@ impl TransportMultiplexer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ChannelHandle<LTR, RTL, const initiator: bool> {
|
pub struct ChannelHandle<LTR, RTL, const INITIATOR: bool> {
|
||||||
marker: PhantomData<(LTR, RTL)>,
|
marker: PhantomData<(LTR, RTL)>,
|
||||||
|
chan_id: u16,
|
||||||
|
multiplexer: Weak<TransportMultiplexer>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<LTR, RTL> ChannelHandle<LTR, RTL, false> {
|
||||||
|
pub fn into_channel(self) -> Option<Channel<RTL, LTR>> {
|
||||||
|
let multiplexer = self.multiplexer.upgrade()?;
|
||||||
|
multiplexer.open_channel_with_id(self.chan_id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//impl<LTR, RTL, const INITIATOR: bool> Serialize for ChannelHandle<LTR, RTL, INITIATOR> {
|
||||||
|
impl<LTR, RTL> Serialize for ChannelHandle<LTR, RTL, true> {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
|
self.chan_id.serialize(serializer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//impl<'de, LTR, RTL, const INITIATOR: bool> Deserialize<'de> for ChannelHandle<LTR, RTL, INITIATOR> {
|
||||||
|
impl<'de, LTR, RTL> Deserialize<'de> for ChannelHandle<LTR, RTL, false> {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
// TODO We could check the initiator flags are the right way around by embedding a bool in
|
||||||
|
// the format.
|
||||||
|
let chan_id = u16::deserialize(deserializer)?;
|
||||||
|
Ok(ChannelHandle {
|
||||||
|
marker: Default::default(),
|
||||||
|
chan_id,
|
||||||
|
multiplexer: CURRENT_MULTIPLEXER.with(|refcell| refcell.borrow().clone()),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ChannelInner {
|
struct ChannelInner {
|
||||||
chan_id: u16,
|
chan_id: u16,
|
||||||
tx: Sender<(u16, Vec<u8>)>,
|
tx: Sender<(u16, Vec<u8>)>,
|
||||||
rx: Receiver<Vec<u8>>,
|
rx: Receiver<Vec<u8>>,
|
||||||
|
multiplexer: Weak<TransportMultiplexer>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ChannelLock<'a, R, W> {
|
pub struct ChannelLock<'a, R, W> {
|
||||||
|
@ -176,10 +221,18 @@ impl<'a, R: DeserializeOwned, W: Serialize> ChannelLock<'a, R, W> {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn recv(&mut self) -> Result<R, CnrError> {
|
pub async fn recv(&mut self) -> Result<R, CnrError> {
|
||||||
match self.guard.rx.recv().await {
|
let new_weak = self.guard.multiplexer.clone();
|
||||||
|
CURRENT_MULTIPLEXER.with(move |refcell| {
|
||||||
|
*(refcell.borrow_mut()) = new_weak;
|
||||||
|
});
|
||||||
|
let res = match self.guard.rx.recv().await {
|
||||||
Some(bytes) => Ok(serde_bare::from_slice(&bytes)?),
|
Some(bytes) => Ok(serde_bare::from_slice(&bytes)?),
|
||||||
None => Err(CnrError::Closed),
|
None => Err(CnrError::Closed),
|
||||||
}
|
};
|
||||||
|
CURRENT_MULTIPLEXER.with(|refcell| {
|
||||||
|
*(refcell.borrow_mut()) = Weak::new();
|
||||||
|
});
|
||||||
|
res
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -189,6 +242,17 @@ pub struct Channel<R, W> {
|
||||||
marker: PhantomData<(R, W)>,
|
marker: PhantomData<(R, W)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<R, W> Channel<R, W> {
|
||||||
|
/// Channel handles only get sent in one direction.
|
||||||
|
pub fn handle(&self) -> ChannelHandle<R, W, true> {
|
||||||
|
ChannelHandle {
|
||||||
|
marker: Default::default(),
|
||||||
|
chan_id: self.id,
|
||||||
|
multiplexer: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<R: Send + DeserializeOwned, W: Serialize> Channel<R, W> {
|
impl<R: Send + DeserializeOwned, W: Serialize> Channel<R, W> {
|
||||||
pub async fn lock(&self) -> ChannelLock<'_, R, W> {
|
pub async fn lock(&self) -> ChannelLock<'_, R, W> {
|
||||||
ChannelLock {
|
ChannelLock {
|
||||||
|
|
Loading…
Reference in New Issue