Introduce channel handles that can be passed in serde messages
This commit is contained in:
		
							parent
							
								
									183f365032
								
							
						
					
					
						commit
						b659a5ddac
					
				| @ -2,11 +2,13 @@ use crate::CnrError; | |||||||
| use crate::CnrError::Closed; | use crate::CnrError::Closed; | ||||||
| use dashmap::mapref::entry::Entry; | use dashmap::mapref::entry::Entry; | ||||||
| use dashmap::DashMap; | use dashmap::DashMap; | ||||||
| use log::error; | use log::{error, warn}; | ||||||
| use serde::de::DeserializeOwned; | use serde::de::DeserializeOwned; | ||||||
| use serde::Serialize; | use serde::{Deserialize, Deserializer, Serialize, Serializer}; | ||||||
|  | use std::borrow::{Borrow, BorrowMut}; | ||||||
|  | use std::cell::RefCell; | ||||||
| use std::marker::PhantomData; | use std::marker::PhantomData; | ||||||
| use std::sync::Arc; | use std::sync::{Arc, Weak}; | ||||||
| use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; | use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; | ||||||
| use tokio::sync::mpsc::{Receiver, Sender}; | use tokio::sync::mpsc::{Receiver, Sender}; | ||||||
| use tokio::sync::{Mutex, MutexGuard}; | use tokio::sync::{Mutex, MutexGuard}; | ||||||
| @ -28,7 +30,10 @@ pub struct TransportMultiplexer { | |||||||
| 
 | 
 | ||||||
|     /// Channel receivers for channels that have received messages but haven't yet been claimed
 |     /// Channel receivers for channels that have received messages but haven't yet been claimed
 | ||||||
|     unclaimed_channels: Arc<DashMap<u16, Receiver<Vec<u8>>>>, |     unclaimed_channels: Arc<DashMap<u16, Receiver<Vec<u8>>>>, | ||||||
|     // TODO channels:
 | } | ||||||
|  | 
 | ||||||
|  | thread_local! { | ||||||
|  |     static CURRENT_MULTIPLEXER: RefCell<Weak<TransportMultiplexer>> = Default::default(); | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl TransportMultiplexer { | impl TransportMultiplexer { | ||||||
| @ -36,7 +41,7 @@ impl TransportMultiplexer { | |||||||
|         rx: R, |         rx: R, | ||||||
|         tx: W, |         tx: W, | ||||||
|         initiator: bool, |         initiator: bool, | ||||||
|     ) -> Result<TransportMultiplexer, CnrError> { |     ) -> Result<Arc<TransportMultiplexer>, CnrError> { | ||||||
|         let (txqueue_tx, txqueue_rx) = tokio::sync::mpsc::channel(8); |         let (txqueue_tx, txqueue_rx) = tokio::sync::mpsc::channel(8); | ||||||
| 
 | 
 | ||||||
|         let channels = Arc::new(Default::default()); |         let channels = Arc::new(Default::default()); | ||||||
| @ -57,19 +62,19 @@ impl TransportMultiplexer { | |||||||
|             } |             } | ||||||
|         }); |         }); | ||||||
| 
 | 
 | ||||||
|         Ok(TransportMultiplexer { |         Ok(Arc::new(TransportMultiplexer { | ||||||
|             task_tx, |             task_tx, | ||||||
|             task_rx, |             task_rx, | ||||||
|             initiator, |             initiator, | ||||||
|             tx_queue: txqueue_tx, |             tx_queue: txqueue_tx, | ||||||
|             channels, |             channels, | ||||||
|             unclaimed_channels, |             unclaimed_channels, | ||||||
|         }) |         })) | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub fn open_channel_with_id<R, W>(&self, channel_id: u16) -> Option<Channel<R, W>> { |     pub fn open_channel_with_id<R, W>(self: &Arc<Self>, channel_id: u16) -> Option<Channel<R, W>> { | ||||||
|         match self.channels.entry(channel_id) { |         match self.channels.entry(channel_id) { | ||||||
|             Entry::Occupied(oe) => { |             Entry::Occupied(_) => { | ||||||
|                 match self.unclaimed_channels.remove(&channel_id) { |                 match self.unclaimed_channels.remove(&channel_id) { | ||||||
|                     Some((_k, chan_rx)) => Some(Channel { |                     Some((_k, chan_rx)) => Some(Channel { | ||||||
|                         id: channel_id, |                         id: channel_id, | ||||||
| @ -77,6 +82,7 @@ impl TransportMultiplexer { | |||||||
|                             chan_id: channel_id, |                             chan_id: channel_id, | ||||||
|                             tx: self.tx_queue.clone(), |                             tx: self.tx_queue.clone(), | ||||||
|                             rx: chan_rx, |                             rx: chan_rx, | ||||||
|  |                             multiplexer: Arc::downgrade(&self), | ||||||
|                         })), |                         })), | ||||||
|                         marker: Default::default(), |                         marker: Default::default(), | ||||||
|                     }), |                     }), | ||||||
| @ -95,6 +101,7 @@ impl TransportMultiplexer { | |||||||
|                         chan_id: channel_id, |                         chan_id: channel_id, | ||||||
|                         tx: self.tx_queue.clone(), |                         tx: self.tx_queue.clone(), | ||||||
|                         rx: chan_rx, |                         rx: chan_rx, | ||||||
|  |                         multiplexer: Arc::downgrade(&self), | ||||||
|                     })), |                     })), | ||||||
|                     marker: Default::default(), |                     marker: Default::default(), | ||||||
|                 }) |                 }) | ||||||
| @ -134,7 +141,8 @@ impl TransportMultiplexer { | |||||||
|                 Entry::Occupied(oe) => { |                 Entry::Occupied(oe) => { | ||||||
|                     if let Err(err) = oe.get().send(buf).await { |                     if let Err(err) = oe.get().send(buf).await { | ||||||
|                         // TODO this channel has died. What can we do about it?
 |                         // TODO this channel has died. What can we do about it?
 | ||||||
|                         todo!(); |                         warn!("Message received but channel {} dead", chan_id); | ||||||
|  |                         // TODO maybe we should clean it up at this point...
 | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|                 Entry::Vacant(ve) => { |                 Entry::Vacant(ve) => { | ||||||
| @ -148,14 +156,51 @@ impl TransportMultiplexer { | |||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| pub struct ChannelHandle<LTR, RTL, const initiator: bool> { | pub struct ChannelHandle<LTR, RTL, const INITIATOR: bool> { | ||||||
|     marker: PhantomData<(LTR, RTL)>, |     marker: PhantomData<(LTR, RTL)>, | ||||||
|  |     chan_id: u16, | ||||||
|  |     multiplexer: Weak<TransportMultiplexer>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<LTR, RTL> ChannelHandle<LTR, RTL, false> { | ||||||
|  |     pub fn into_channel(self) -> Option<Channel<RTL, LTR>> { | ||||||
|  |         let multiplexer = self.multiplexer.upgrade()?; | ||||||
|  |         multiplexer.open_channel_with_id(self.chan_id) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //impl<LTR, RTL, const INITIATOR: bool> Serialize for ChannelHandle<LTR, RTL, INITIATOR> {
 | ||||||
|  | impl<LTR, RTL> Serialize for ChannelHandle<LTR, RTL, true> { | ||||||
|  |     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | ||||||
|  |     where | ||||||
|  |         S: Serializer, | ||||||
|  |     { | ||||||
|  |         self.chan_id.serialize(serializer) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | //impl<'de, LTR, RTL, const INITIATOR: bool> Deserialize<'de> for ChannelHandle<LTR, RTL, INITIATOR> {
 | ||||||
|  | impl<'de, LTR, RTL> Deserialize<'de> for ChannelHandle<LTR, RTL, false> { | ||||||
|  |     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | ||||||
|  |     where | ||||||
|  |         D: Deserializer<'de>, | ||||||
|  |     { | ||||||
|  |         // TODO We could check the initiator flags are the right way around by embedding a bool in
 | ||||||
|  |         //      the format.
 | ||||||
|  |         let chan_id = u16::deserialize(deserializer)?; | ||||||
|  |         Ok(ChannelHandle { | ||||||
|  |             marker: Default::default(), | ||||||
|  |             chan_id, | ||||||
|  |             multiplexer: CURRENT_MULTIPLEXER.with(|refcell| refcell.borrow().clone()), | ||||||
|  |         }) | ||||||
|  |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| struct ChannelInner { | struct ChannelInner { | ||||||
|     chan_id: u16, |     chan_id: u16, | ||||||
|     tx: Sender<(u16, Vec<u8>)>, |     tx: Sender<(u16, Vec<u8>)>, | ||||||
|     rx: Receiver<Vec<u8>>, |     rx: Receiver<Vec<u8>>, | ||||||
|  |     multiplexer: Weak<TransportMultiplexer>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| pub struct ChannelLock<'a, R, W> { | pub struct ChannelLock<'a, R, W> { | ||||||
| @ -176,10 +221,18 @@ impl<'a, R: DeserializeOwned, W: Serialize> ChannelLock<'a, R, W> { | |||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     pub async fn recv(&mut self) -> Result<R, CnrError> { |     pub async fn recv(&mut self) -> Result<R, CnrError> { | ||||||
|         match self.guard.rx.recv().await { |         let new_weak = self.guard.multiplexer.clone(); | ||||||
|  |         CURRENT_MULTIPLEXER.with(move |refcell| { | ||||||
|  |             *(refcell.borrow_mut()) = new_weak; | ||||||
|  |         }); | ||||||
|  |         let res = match self.guard.rx.recv().await { | ||||||
|             Some(bytes) => Ok(serde_bare::from_slice(&bytes)?), |             Some(bytes) => Ok(serde_bare::from_slice(&bytes)?), | ||||||
|             None => Err(CnrError::Closed), |             None => Err(CnrError::Closed), | ||||||
|         } |         }; | ||||||
|  |         CURRENT_MULTIPLEXER.with(|refcell| { | ||||||
|  |             *(refcell.borrow_mut()) = Weak::new(); | ||||||
|  |         }); | ||||||
|  |         res | ||||||
|     } |     } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -189,6 +242,17 @@ pub struct Channel<R, W> { | |||||||
|     marker: PhantomData<(R, W)>, |     marker: PhantomData<(R, W)>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | impl<R, W> Channel<R, W> { | ||||||
|  |     /// Channel handles only get sent in one direction.
 | ||||||
|  |     pub fn handle(&self) -> ChannelHandle<R, W, true> { | ||||||
|  |         ChannelHandle { | ||||||
|  |             marker: Default::default(), | ||||||
|  |             chan_id: self.id, | ||||||
|  |             multiplexer: Default::default(), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
| impl<R: Send + DeserializeOwned, W: Serialize> Channel<R, W> { | impl<R: Send + DeserializeOwned, W: Serialize> Channel<R, W> { | ||||||
|     pub async fn lock(&self) -> ChannelLock<'_, R, W> { |     pub async fn lock(&self) -> ChannelLock<'_, R, W> { | ||||||
|         ChannelLock { |         ChannelLock { | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user