use std::future::Future; use tokio::{ sync::{mpsc, watch}, task::{JoinError, JoinHandle}, }; #[derive(Debug)] pub struct PublisherStream { receiver: mpsc::Receiver>, } impl PublisherStream { pub async fn wait(&mut self) -> Option> { self.receiver.recv().await } } #[derive(Debug)] pub struct Publisher { sender: watch::Sender, } impl Publisher { pub async fn all_unsubscribed(&self) { self.sender.closed().await } pub fn publish(&self, value: T) { self.sender.send_replace(value); } } #[derive(Debug)] pub struct Store { sender: watch::Sender, publisher_sender: mpsc::Sender>, producer_join_handle: JoinHandle<()>, } impl Store { pub fn new + Send + 'static>( initial: T, producer: impl FnOnce(PublisherStream) -> Fut, ) -> Self { let (sender, _) = watch::channel(initial); let (publisher_sender, publisher_receiver) = mpsc::channel(1); let subscribers_stream = PublisherStream { receiver: publisher_receiver, }; let producer_join_handle = tokio::spawn(producer(subscribers_stream)); Self { publisher_sender, sender, producer_join_handle, } } pub fn subscribe(&self) -> Result, ProducerExited> { let receiver = self.sender.subscribe(); if self.sender.receiver_count() == 1 { if let Err(e) = self.publisher_sender.try_send(Publisher { sender: self.sender.clone(), }) { match e { mpsc::error::TrySendError::Full(_) => unreachable!(), mpsc::error::TrySendError::Closed(_) => return Err(ProducerExited), } } } Ok(Subscription { receiver }) } /// Signify that no one can ever subscribe again, /// and wait for the producer task to complete. pub fn run(self) -> impl Future> { self.producer_join_handle } } pub struct Subscription { receiver: watch::Receiver, } #[derive(Debug, Clone, Copy)] pub struct ProducerExited; impl Subscription { pub async fn changed(&mut self) -> Result<(), ProducerExited> { self.receiver.changed().await.map_err(|_| ProducerExited) } pub fn get(&mut self) -> T::Owned where T: ToOwned, { self.receiver.borrow_and_update().to_owned() } pub async fn for_each>(mut self, mut func: impl FnMut(T::Owned) -> Fut) where T: ToOwned, { loop { func(self.get()).await; if self.changed().await.is_err() { return; } } } }