From 7b2ebc5fe9b60453c2e4d4623fd7428e63ec1bce Mon Sep 17 00:00:00 2001 From: Jacob Date: Wed, 19 Mar 2025 20:51:07 -0400 Subject: [PATCH] feat: initial store implementation --- Cargo.toml | 7 ++- src/lib.rs | 1 + src/store/mod.rs | 116 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 123 insertions(+), 1 deletion(-) create mode 100644 src/store/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 5482764..73da326 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,7 +30,12 @@ serde_json = "1.0.140" shadow-rs = { version = "1.0.1", default-features = false } snafu = "0.8.5" strum = { version = "0.27.1", features = ["derive"] } -tokio = { version = "1.32.0", features = ["rt", "rt-multi-thread", "time"] } +tokio = { version = "1.32.0", features = [ + "rt", + "rt-multi-thread", + "sync", + "time", +] } tracing = "0.1.37" tracing-appender = "0.2.3" tracing-subscriber = "0.3.17" diff --git a/src/lib.rs b/src/lib.rs index a30816d..39f3769 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,6 +17,7 @@ use tracing_to_home_assistant::TracingToHomeAssistant; mod arbitrary; mod home_assistant; mod python_utils; +mod store; mod tracing_to_home_assistant; shadow!(build_info); diff --git a/src/store/mod.rs b/src/store/mod.rs new file mode 100644 index 0000000..b1836bb --- /dev/null +++ b/src/store/mod.rs @@ -0,0 +1,116 @@ +use std::future::Future; + +use tokio::{ + sync::{mpsc, watch}, + task::{JoinError, JoinHandle}, +}; + +#[derive(Debug)] +pub struct PublisherStream { + receiver: mpsc::Receiver>, +} + +impl PublisherStream { + pub async fn wait(&mut self) -> Option> { + self.receiver.recv().await + } +} + +#[derive(Debug)] +pub struct Publisher { + sender: watch::Sender, +} + +impl Publisher { + pub async fn all_unsubscribed(&self) { + self.sender.closed().await + } + + pub fn publish(&self, value: T) { + self.sender.send_replace(value); + } +} + +#[derive(Debug)] +pub struct Store { + sender: watch::Sender, + publisher_sender: mpsc::Sender>, + producer_join_handle: JoinHandle<()>, +} + +impl Store { + pub fn new + Send + 'static>( + initial: T, + producer: impl FnOnce(PublisherStream) -> Fut, + ) -> Self { + let (sender, _) = watch::channel(initial); + let (publisher_sender, publisher_receiver) = mpsc::channel(1); + + let subscribers_stream = PublisherStream { + receiver: publisher_receiver, + }; + + let producer_join_handle = tokio::spawn(producer(subscribers_stream)); + + Self { + publisher_sender, + sender, + producer_join_handle, + } + } + + pub fn subscribe(&self) -> Result, ProducerExited> { + let receiver = self.sender.subscribe(); + + if self.sender.receiver_count() == 1 { + if let Err(e) = self.publisher_sender.try_send(Publisher { + sender: self.sender.clone(), + }) { + match e { + mpsc::error::TrySendError::Full(_) => unreachable!(), + mpsc::error::TrySendError::Closed(_) => return Err(ProducerExited), + } + } + } + + Ok(Subscription { receiver }) + } + + /// Signify that no one can ever subscribe again, + /// and wait for the producer task to complete. + pub fn run(self) -> impl Future> { + self.producer_join_handle + } +} + +pub struct Subscription { + receiver: watch::Receiver, +} + +#[derive(Debug, Clone, Copy)] +pub struct ProducerExited; + +impl Subscription { + pub async fn changed(&mut self) -> Result<(), ProducerExited> { + self.receiver.changed().await.map_err(|_| ProducerExited) + } + + pub fn get(&mut self) -> T::Owned + where + T: ToOwned, + { + self.receiver.borrow_and_update().to_owned() + } + + pub async fn for_each>(mut self, mut func: impl FnMut(T::Owned) -> Fut) + where + T: ToOwned, + { + loop { + func(self.get()).await; + if self.changed().await.is_err() { + return; + } + } + } +}