From 38e89f31f449524f9d89be7356f574842731e622 Mon Sep 17 00:00:00 2001 From: Jacob Date: Mon, 21 Apr 2025 16:41:34 -0400 Subject: [PATCH] chore: convert into a workspace --- Cargo.toml | 5 + entrypoint/Cargo.toml | 62 ++++++++ build.rs => entrypoint/build.rs | 0 .../src}/home_assistant/domain.rs | 0 .../src}/home_assistant/entity_id.rs | 2 +- .../home_assistant/event/context/context.rs | 39 +++++ .../src/home_assistant/event/context/id.rs | 38 +++++ .../src}/home_assistant/event/context/mod.rs | 0 .../src}/home_assistant/event/event.rs | 0 .../src}/home_assistant/event/event_origin.rs | 0 .../src}/home_assistant/event/mod.rs | 0 .../src}/home_assistant/event/specific/mod.rs | 0 .../event/specific/state_changed.rs | 58 ++++++++ .../src}/home_assistant/home_assistant.rs | 7 +- .../src/home_assistant/light/attributes.rs | 8 + entrypoint/src/home_assistant/light/mod.rs | 54 +++++++ .../src/home_assistant/light/protocol.rs | 103 +++++++++++++ .../src/home_assistant/light/service/mod.rs | 2 + .../home_assistant/light/service/turn_off.rs | 33 +++++ .../home_assistant/light/service/turn_on.rs | 32 ++++ entrypoint/src/home_assistant/light/state.rs | 22 +++ .../src}/home_assistant/logger.rs | 16 +- {src => entrypoint/src}/home_assistant/mod.rs | 5 + entrypoint/src/home_assistant/object_id.rs | 21 +++ entrypoint/src/home_assistant/service/mod.rs | 11 ++ .../home_assistant/service/service_domain.rs | 21 +++ .../src/home_assistant/service/service_id.rs | 21 +++ .../src/home_assistant/service_registry.rs | 56 +++++++ .../src/home_assistant/slug.rs | 25 ++-- entrypoint/src/home_assistant/state.rs | 71 +++++++++ .../src}/home_assistant/state_machine.rs | 13 +- entrypoint/src/home_assistant/state_object.rs | 139 ++++++++++++++++++ entrypoint/src/lib.rs | 86 +++++++++++ {src => entrypoint/src}/python_utils.rs | 2 +- .../src}/tracing_to_home_assistant.rs | 0 src/home_assistant/event/context/context.rs | 14 -- src/home_assistant/event/context/id.rs | 20 --- .../event/specific/state_changed.rs | 37 ----- src/home_assistant/state.rs | 17 --- src/lib.rs | 61 -------- src/store/mod.rs | 116 --------------- 41 files changed, 926 insertions(+), 291 deletions(-) create mode 100644 entrypoint/Cargo.toml rename build.rs => entrypoint/build.rs (100%) rename {src => entrypoint/src}/home_assistant/domain.rs (100%) rename {src => entrypoint/src}/home_assistant/entity_id.rs (97%) create mode 100644 entrypoint/src/home_assistant/event/context/context.rs create mode 100644 entrypoint/src/home_assistant/event/context/id.rs rename {src => entrypoint/src}/home_assistant/event/context/mod.rs (100%) rename {src => entrypoint/src}/home_assistant/event/event.rs (100%) rename {src => entrypoint/src}/home_assistant/event/event_origin.rs (100%) rename {src => entrypoint/src}/home_assistant/event/mod.rs (100%) rename {src => entrypoint/src}/home_assistant/event/specific/mod.rs (100%) create mode 100644 entrypoint/src/home_assistant/event/specific/state_changed.rs rename {src => entrypoint/src}/home_assistant/home_assistant.rs (86%) create mode 100644 entrypoint/src/home_assistant/light/attributes.rs create mode 100644 entrypoint/src/home_assistant/light/mod.rs create mode 100644 entrypoint/src/home_assistant/light/protocol.rs create mode 100644 entrypoint/src/home_assistant/light/service/mod.rs create mode 100644 entrypoint/src/home_assistant/light/service/turn_off.rs create mode 100644 entrypoint/src/home_assistant/light/service/turn_on.rs create mode 100644 entrypoint/src/home_assistant/light/state.rs rename {src => entrypoint/src}/home_assistant/logger.rs (92%) rename {src => entrypoint/src}/home_assistant/mod.rs (60%) create mode 100644 entrypoint/src/home_assistant/object_id.rs create mode 100644 entrypoint/src/home_assistant/service/mod.rs create mode 100644 entrypoint/src/home_assistant/service/service_domain.rs create mode 100644 entrypoint/src/home_assistant/service/service_id.rs create mode 100644 entrypoint/src/home_assistant/service_registry.rs rename src/home_assistant/object_id.rs => entrypoint/src/home_assistant/slug.rs (66%) create mode 100644 entrypoint/src/home_assistant/state.rs rename {src => entrypoint/src}/home_assistant/state_machine.rs (69%) create mode 100644 entrypoint/src/home_assistant/state_object.rs create mode 100644 entrypoint/src/lib.rs rename {src => entrypoint/src}/python_utils.rs (90%) rename {src => entrypoint/src}/tracing_to_home_assistant.rs (100%) delete mode 100644 src/home_assistant/event/context/context.rs delete mode 100644 src/home_assistant/event/context/id.rs delete mode 100644 src/home_assistant/event/specific/state_changed.rs delete mode 100644 src/home_assistant/state.rs delete mode 100644 src/lib.rs delete mode 100644 src/store/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 4d8382e..df173c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,11 @@ [workspace] members = [ "arbitrary-value", + "driver/kasa", + "emitter-and-signal", + "entrypoint", + "protocol", +] resolver = "2" [workspace.dependencies] diff --git a/entrypoint/Cargo.toml b/entrypoint/Cargo.toml new file mode 100644 index 0000000..feb99cd --- /dev/null +++ b/entrypoint/Cargo.toml @@ -0,0 +1,62 @@ +[package] +name = "smart-home-in-rust-with-home-assistant" +version = "0.2.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "smart_home_in_rust_with_home_assistant" +crate-type = ["cdylib"] + +[dependencies] +arbitrary-value = { path = "../arbitrary-value", features = ["pyo3"] } +arc-swap = "1.7.1" +async-gate = "0.4.0" +axum = { version = "0.8.1", default-features = false, features = [ + "http1", + "tokio", +] } +chrono = { workspace = true } +chrono-tz = { workspace = true } +deranged = { workspace = true, features = ["serde"] } +derive_more = { workspace = true, features = [ + "display", + "from", + "from_str", + "into", + "try_from", + "try_into", +] } +driver-kasa = { path = "../driver/kasa" } +emitter-and-signal = { path = "../emitter-and-signal" } +im = { version = "15.1.0", features = ["rayon"] } +once_cell = "1.21.3" +protocol = { path = "../protocol" } +pyo3 = { workspace = true, features = [ + "auto-initialize", + "chrono", + "extension-module", +] } +pyo3-async-runtimes = { workspace = true, features = [ + "attributes", + "tokio-runtime", +] } +shadow-rs = { version = "1.0.1", default-features = false } +smol_str = "0.3.2" +snafu = { workspace = true } +strum = { version = "0.27.1", features = ["derive"] } +tokio = { workspace = true, features = [ + "macros", + "rt", + "rt-multi-thread", + "sync", + "time", +] } +tracing = { workspace = true } +tracing-appender = "0.2.3" +tracing-subscriber = "0.3.17" +ulid = "1.2.0" +uom = "0.36.0" + +[build-dependencies] +shadow-rs = "1.0.1" diff --git a/build.rs b/entrypoint/build.rs similarity index 100% rename from build.rs rename to entrypoint/build.rs diff --git a/src/home_assistant/domain.rs b/entrypoint/src/home_assistant/domain.rs similarity index 100% rename from src/home_assistant/domain.rs rename to entrypoint/src/home_assistant/domain.rs diff --git a/src/home_assistant/entity_id.rs b/entrypoint/src/home_assistant/entity_id.rs similarity index 97% rename from src/home_assistant/entity_id.rs rename to entrypoint/src/home_assistant/entity_id.rs index ef3356a..282329f 100644 --- a/src/home_assistant/entity_id.rs +++ b/entrypoint/src/home_assistant/entity_id.rs @@ -59,7 +59,7 @@ impl<'py> FromPyObject<'py> for EntityId { } } -impl<'py> IntoPyObject<'py> for &EntityId { +impl<'py> IntoPyObject<'py> for EntityId { type Target = PyString; type Output = Bound<'py, Self::Target>; type Error = Infallible; diff --git a/entrypoint/src/home_assistant/event/context/context.rs b/entrypoint/src/home_assistant/event/context/context.rs new file mode 100644 index 0000000..e073a8e --- /dev/null +++ b/entrypoint/src/home_assistant/event/context/context.rs @@ -0,0 +1,39 @@ +use super::id::Id; +use once_cell::sync::OnceCell; +use pyo3::{prelude::*, types::PyType}; + +/// The context that triggered something. +#[derive(Debug, FromPyObject)] +pub struct Context { + pub id: Id, + pub user_id: Option, + pub parent_id: Option, + /// In order to prevent cycles, the user must decide to pass [`Py`] for the `Event` type here + /// or for the `Context` type in [`Event`] + pub origin_event: Event, +} + +impl<'py, Event: IntoPyObject<'py>> IntoPyObject<'py> for Context { + type Target = PyAny; + + type Output = Bound<'py, Self::Target>; + + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + static HOMEASSISTANT_CORE: OnceCell> = OnceCell::new(); + + let homeassistant_core = HOMEASSISTANT_CORE + .get_or_try_init(|| Result::<_, PyErr>::Ok(py.import("homeassistant.core")?.unbind()))? + .bind(py); + + let context_class = homeassistant_core.getattr("Context")?; + let context_class = context_class.downcast_into::()?; + + let context_instance = context_class.call1((self.user_id, self.parent_id, self.id))?; + + context_instance.setattr("origin_event", self.origin_event)?; + + Ok(context_instance) + } +} diff --git a/entrypoint/src/home_assistant/event/context/id.rs b/entrypoint/src/home_assistant/event/context/id.rs new file mode 100644 index 0000000..69e715c --- /dev/null +++ b/entrypoint/src/home_assistant/event/context/id.rs @@ -0,0 +1,38 @@ +use std::convert::Infallible; + +use pyo3::{prelude::*, types::PyString}; +use smol_str::SmolStr; +use ulid::Ulid; + +#[derive(Debug, Clone)] +pub enum Id { + Ulid(Ulid), + Other(SmolStr), +} + +impl<'py> FromPyObject<'py> for Id { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let s = ob.extract::()?; + + if let Ok(ulid) = s.parse() { + Ok(Id::Ulid(ulid)) + } else { + Ok(Id::Other(s.into())) + } + } +} + +impl<'py> IntoPyObject<'py> for Id { + type Target = PyString; + + type Output = Bound<'py, Self::Target>; + + type Error = Infallible; + + fn into_pyobject(self, py: Python<'py>) -> Result { + match self { + Id::Ulid(ulid) => ulid.to_string().into_pyobject(py), + Id::Other(id) => id.as_str().into_pyobject(py), + } + } +} diff --git a/src/home_assistant/event/context/mod.rs b/entrypoint/src/home_assistant/event/context/mod.rs similarity index 100% rename from src/home_assistant/event/context/mod.rs rename to entrypoint/src/home_assistant/event/context/mod.rs diff --git a/src/home_assistant/event/event.rs b/entrypoint/src/home_assistant/event/event.rs similarity index 100% rename from src/home_assistant/event/event.rs rename to entrypoint/src/home_assistant/event/event.rs diff --git a/src/home_assistant/event/event_origin.rs b/entrypoint/src/home_assistant/event/event_origin.rs similarity index 100% rename from src/home_assistant/event/event_origin.rs rename to entrypoint/src/home_assistant/event/event_origin.rs diff --git a/src/home_assistant/event/mod.rs b/entrypoint/src/home_assistant/event/mod.rs similarity index 100% rename from src/home_assistant/event/mod.rs rename to entrypoint/src/home_assistant/event/mod.rs diff --git a/src/home_assistant/event/specific/mod.rs b/entrypoint/src/home_assistant/event/specific/mod.rs similarity index 100% rename from src/home_assistant/event/specific/mod.rs rename to entrypoint/src/home_assistant/event/specific/mod.rs diff --git a/entrypoint/src/home_assistant/event/specific/state_changed.rs b/entrypoint/src/home_assistant/event/specific/state_changed.rs new file mode 100644 index 0000000..5643688 --- /dev/null +++ b/entrypoint/src/home_assistant/event/specific/state_changed.rs @@ -0,0 +1,58 @@ +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; + +use crate::home_assistant::{entity_id::EntityId, state_object::StateObject}; + +#[derive(Debug, Clone)] +pub struct Type; + +impl<'py> FromPyObject<'py> for Type { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let s = ob.extract::<&str>()?; + + if s == "state_changed" { + Ok(Type) + } else { + Err(PyValueError::new_err(format!( + "expected a string of value 'state_changed', but got {s}" + ))) + } + } +} + +#[derive(Debug, FromPyObject)] +#[pyo3(from_item_all)] +pub struct Data< + OldState, + OldAttributes, + OldStateContextEvent, + NewState, + NewAttributes, + NewStateContextEvent, +> { + pub entity_id: EntityId, + pub old_state: Option>, + pub new_state: Option>, +} + +/// A state changed event is fired when on state write the state is changed. +pub type Event< + OldState, + OldAttributes, + OldStateContextEvent, + NewState, + NewAttributes, + NewStateContextEvent, + Context, +> = super::super::event::Event< + Type, + Data< + OldState, + OldAttributes, + OldStateContextEvent, + NewState, + NewAttributes, + NewStateContextEvent, + >, + Context, +>; diff --git a/src/home_assistant/home_assistant.rs b/entrypoint/src/home_assistant/home_assistant.rs similarity index 86% rename from src/home_assistant/home_assistant.rs rename to entrypoint/src/home_assistant/home_assistant.rs index 6104eea..3625a6b 100644 --- a/src/home_assistant/home_assistant.rs +++ b/entrypoint/src/home_assistant/home_assistant.rs @@ -4,7 +4,7 @@ use pyo3::prelude::*; use crate::python_utils::{detach, validate_type_by_name}; -use super::state_machine::StateMachine; +use super::{service_registry::ServiceRegistry, state_machine::StateMachine}; #[derive(Debug)] pub struct HomeAssistant(Py); @@ -52,4 +52,9 @@ impl HomeAssistant { let states = self.0.getattr(py, "states")?; states.extract(py) } + + pub fn services(&self, py: Python<'_>) -> Result { + let services = self.0.getattr(py, "services")?; + services.extract(py) + } } diff --git a/entrypoint/src/home_assistant/light/attributes.rs b/entrypoint/src/home_assistant/light/attributes.rs new file mode 100644 index 0000000..0b4854e --- /dev/null +++ b/entrypoint/src/home_assistant/light/attributes.rs @@ -0,0 +1,8 @@ +use pyo3::prelude::*; + +#[derive(Debug, FromPyObject)] +#[pyo3(from_item_all)] +pub struct LightAttributes { + min_color_temp_kelvin: Option, // TODO: only here to allow compilation! + max_color_temp_kelvin: Option, // TODO: only here to allow compilation! +} diff --git a/entrypoint/src/home_assistant/light/mod.rs b/entrypoint/src/home_assistant/light/mod.rs new file mode 100644 index 0000000..675e4c1 --- /dev/null +++ b/entrypoint/src/home_assistant/light/mod.rs @@ -0,0 +1,54 @@ +use attributes::LightAttributes; +use pyo3::prelude::*; +use snafu::{ResultExt, Snafu}; +use state::LightState; + +use crate::home_assistant::state::HomeAssistantState; + +use super::{ + domain::Domain, entity_id::EntityId, home_assistant::HomeAssistant, object_id::ObjectId, + state_object::StateObject, +}; + +mod attributes; +mod protocol; +mod service; +mod state; + +#[derive(Debug)] +pub struct HomeAssistantLight { + pub home_assistant: HomeAssistant, + pub object_id: ObjectId, +} + +impl HomeAssistantLight { + fn entity_id(&self) -> EntityId { + EntityId(Domain::Light, self.object_id.clone()) + } +} + +#[derive(Debug, Snafu)] +pub enum GetStateObjectError { + PythonError { source: PyErr }, + EntityMissing, +} + +impl HomeAssistantLight { + fn get_state_object( + &self, + ) -> Result< + StateObject, LightAttributes, Py>, + GetStateObjectError, + > { + Python::with_gil(|py| { + let states = self.home_assistant.states(py).context(PythonSnafu)?; + let entity_id = self.entity_id(); + let state_object = states + .get(py, entity_id) + .context(PythonSnafu)? + .ok_or(GetStateObjectError::EntityMissing)?; + + Ok(state_object) + }) + } +} diff --git a/entrypoint/src/home_assistant/light/protocol.rs b/entrypoint/src/home_assistant/light/protocol.rs new file mode 100644 index 0000000..9673c2c --- /dev/null +++ b/entrypoint/src/home_assistant/light/protocol.rs @@ -0,0 +1,103 @@ +use super::service::{turn_off::TurnOff, turn_on::TurnOn}; +use super::{state::LightState, GetStateObjectError, HomeAssistantLight}; +use crate::home_assistant::{ + event::context::context::Context, + state::{ErrorState, HomeAssistantState, UnexpectedState}, +}; +use arbitrary_value::arbitrary::Arbitrary; +use protocol::light::Light; +use pyo3::prelude::*; +use snafu::{ResultExt, Snafu}; + +#[derive(Debug, Snafu)] +pub enum IsStateError { + GetStateObjectError { source: GetStateObjectError }, + Error { state: ErrorState }, + UnexpectedError { state: UnexpectedState }, +} + +impl Light for HomeAssistantLight { + type IsOnError = IsStateError; + + async fn is_on(&self) -> Result { + let state_object = self.get_state_object().context(GetStateObjectSnafu)?; + let state = state_object.state; + + match state { + HomeAssistantState::Ok(light_state) => Ok(matches!(light_state, LightState::On)), + HomeAssistantState::Err(state) => Err(IsStateError::Error { state }), + HomeAssistantState::UnexpectedErr(state) => { + Err(IsStateError::UnexpectedError { state }) + } + } + } + + type IsOffError = IsStateError; + + async fn is_off(&self) -> Result { + let state_object = self.get_state_object().context(GetStateObjectSnafu)?; + let state = state_object.state; + + match state { + HomeAssistantState::Ok(light_state) => Ok(matches!(light_state, LightState::Off)), + HomeAssistantState::Err(state) => Err(IsStateError::Error { state }), + HomeAssistantState::UnexpectedErr(state) => { + Err(IsStateError::UnexpectedError { state }) + } + } + } + + type TurnOnError = PyErr; + + async fn turn_on(&mut self) -> Result<(), Self::TurnOnError> { + let context: Option> = None; + let target: Option<()> = None; + + let services = Python::with_gil(|py| self.home_assistant.services(py))?; + // TODO + let service_response: Arbitrary = services + .call_service( + TurnOn { + entity_id: self.entity_id(), + }, + context, + target, + false, + ) + .await?; + + // TODO + tracing::info!(?service_response); + + Ok(()) + } + + type TurnOffError = PyErr; + + async fn turn_off(&mut self) -> Result<(), Self::TurnOffError> { + let context: Option> = None; + let target: Option<()> = None; + + let services = Python::with_gil(|py| self.home_assistant.services(py))?; + // TODO + let service_response: Arbitrary // TODO: a type that validates as None + = services + .call_service( + TurnOff { + entity_id: self.entity_id(), + }, + context, + target, + false, + ) + .await?; + + Ok(()) + } + + type ToggleError = PyErr; + + async fn toggle(&mut self) -> Result<(), Self::ToggleError> { + todo!() + } +} diff --git a/entrypoint/src/home_assistant/light/service/mod.rs b/entrypoint/src/home_assistant/light/service/mod.rs new file mode 100644 index 0000000..44dde93 --- /dev/null +++ b/entrypoint/src/home_assistant/light/service/mod.rs @@ -0,0 +1,2 @@ +pub mod turn_off; +pub mod turn_on; diff --git a/entrypoint/src/home_assistant/light/service/turn_off.rs b/entrypoint/src/home_assistant/light/service/turn_off.rs new file mode 100644 index 0000000..b6c4fa2 --- /dev/null +++ b/entrypoint/src/home_assistant/light/service/turn_off.rs @@ -0,0 +1,33 @@ +use std::str::FromStr; + +use pyo3::IntoPyObject; + +use crate::home_assistant::{ + entity_id::EntityId, + service::{service_domain::ServiceDomain, service_id::ServiceId, IntoServiceCall}, +}; + +#[derive(Debug, Clone)] +pub struct TurnOff { + pub entity_id: EntityId, +} + +#[derive(Debug, Clone, IntoPyObject)] +pub struct TurnOffServiceData { + entity_id: EntityId, +} + +impl IntoServiceCall for TurnOff { + type ServiceData = TurnOffServiceData; + + fn into_service_call(self) -> (ServiceDomain, ServiceId, Self::ServiceData) { + let service_domain = ServiceDomain::from_str("light").expect("statically written and known to be a valid slug; hoping to get compiler checks instead in the future"); + let service_id = ServiceId::from_str("turn_off").expect("statically written and known to be a valid slug; hoping to get compiler checks instead in the future"); + + let Self { entity_id } = self; + + let service_data = TurnOffServiceData { entity_id }; + + (service_domain, service_id, service_data) + } +} diff --git a/entrypoint/src/home_assistant/light/service/turn_on.rs b/entrypoint/src/home_assistant/light/service/turn_on.rs new file mode 100644 index 0000000..9f14810 --- /dev/null +++ b/entrypoint/src/home_assistant/light/service/turn_on.rs @@ -0,0 +1,32 @@ +use std::{convert::Infallible, str::FromStr}; + +use pyo3::IntoPyObject; + +use crate::home_assistant::{ + entity_id::EntityId, + service::{service_domain::ServiceDomain, service_id::ServiceId, IntoServiceCall}, +}; + +#[derive(Debug, Clone)] +pub struct TurnOn { + pub entity_id: EntityId, +} + +#[derive(Debug, Clone, IntoPyObject)] +pub struct TurnOnServiceData { + entity_id: EntityId, +} + +impl IntoServiceCall for TurnOn { + type ServiceData = TurnOnServiceData; + + fn into_service_call(self) -> (ServiceDomain, ServiceId, Self::ServiceData) { + let service_domain = ServiceDomain::from_str("light").expect("statically written and known to be a valid slug; hoping to get compiler checks instead in the future"); + let service_id = ServiceId::from_str("turn_on").expect("statically written and known to be a valid slug; hoping to get compiler checks instead in the future"); + + let Self { entity_id } = self; + let service_data = TurnOnServiceData { entity_id }; + + (service_domain, service_id, service_data) + } +} diff --git a/entrypoint/src/home_assistant/light/state.rs b/entrypoint/src/home_assistant/light/state.rs new file mode 100644 index 0000000..bb418c4 --- /dev/null +++ b/entrypoint/src/home_assistant/light/state.rs @@ -0,0 +1,22 @@ +use std::str::FromStr; + +use pyo3::{exceptions::PyValueError, prelude::*}; +use strum::EnumString; + +#[derive(Debug, Clone, EnumString, strum::Display)] +#[strum(serialize_all = "snake_case")] +pub enum LightState { + On, + Off, +} + +impl<'py> FromPyObject<'py> for LightState { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let s = ob.extract::()?; + + let state = + LightState::from_str(&s).map_err(|err| PyValueError::new_err(err.to_string()))?; + + Ok(state) + } +} diff --git a/src/home_assistant/logger.rs b/entrypoint/src/home_assistant/logger.rs similarity index 92% rename from src/home_assistant/logger.rs rename to entrypoint/src/home_assistant/logger.rs index 7f0f190..474d27f 100644 --- a/src/home_assistant/logger.rs +++ b/entrypoint/src/home_assistant/logger.rs @@ -1,10 +1,8 @@ +use crate::python_utils::{detach, validate_type_by_name}; +use arbitrary_value::{arbitrary::Arbitrary, map::Map}; +use once_cell::sync::OnceCell; use pyo3::{prelude::*, types::PyTuple}; -use crate::{ - arbitrary::{arbitrary::Arbitrary, map::Map}, - python_utils::{detach, validate_type_by_name}, -}; - #[derive(Debug)] pub struct HassLogger(Py); @@ -55,8 +53,12 @@ pub struct LogData { impl HassLogger { pub fn new(py: Python<'_>, name: &str) -> PyResult { - let logging = py.import("logging")?; - let logger = logging.call_method1("getLogger", (name,))?; + static LOGGING_MODULE: OnceCell> = OnceCell::new(); + + let logging_module = LOGGING_MODULE + .get_or_try_init(|| Result::<_, PyErr>::Ok(py.import("logging")?.unbind()))? + .bind(py); + let logger = logging_module.call_method1("getLogger", (name,))?; Ok(logger.extract()?) } diff --git a/src/home_assistant/mod.rs b/entrypoint/src/home_assistant/mod.rs similarity index 60% rename from src/home_assistant/mod.rs rename to entrypoint/src/home_assistant/mod.rs index 4282014..80cdbf3 100644 --- a/src/home_assistant/mod.rs +++ b/entrypoint/src/home_assistant/mod.rs @@ -2,7 +2,12 @@ pub mod domain; pub mod entity_id; pub mod event; pub mod home_assistant; +pub mod light; pub mod logger; pub mod object_id; +pub mod service; +pub mod service_registry; +pub mod slug; pub mod state; pub mod state_machine; +pub mod state_object; diff --git a/entrypoint/src/home_assistant/object_id.rs b/entrypoint/src/home_assistant/object_id.rs new file mode 100644 index 0000000..b19a236 --- /dev/null +++ b/entrypoint/src/home_assistant/object_id.rs @@ -0,0 +1,21 @@ +use std::convert::Infallible; + +use pyo3::{prelude::*, types::PyString}; + +use super::slug::Slug; + +pub use super::slug::SlugParsingError as ObjectIdParsingError; + +#[derive(Debug, Clone, derive_more::Display, derive_more::FromStr)] +pub struct ObjectId(pub Slug); + +impl<'py> IntoPyObject<'py> for ObjectId { + type Target = PyString; + type Output = Bound<'py, Self::Target>; + type Error = Infallible; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let s = self.to_string(); + s.into_pyobject(py) + } +} diff --git a/entrypoint/src/home_assistant/service/mod.rs b/entrypoint/src/home_assistant/service/mod.rs new file mode 100644 index 0000000..f378030 --- /dev/null +++ b/entrypoint/src/home_assistant/service/mod.rs @@ -0,0 +1,11 @@ +use service_domain::ServiceDomain; +use service_id::ServiceId; + +pub mod service_domain; +pub mod service_id; + +pub trait IntoServiceCall { + type ServiceData; + + fn into_service_call(self) -> (ServiceDomain, ServiceId, Self::ServiceData); +} diff --git a/entrypoint/src/home_assistant/service/service_domain.rs b/entrypoint/src/home_assistant/service/service_domain.rs new file mode 100644 index 0000000..2a4bc7b --- /dev/null +++ b/entrypoint/src/home_assistant/service/service_domain.rs @@ -0,0 +1,21 @@ +use std::convert::Infallible; + +use pyo3::{prelude::*, types::PyString}; + +use super::super::slug::Slug; + +pub use super::super::slug::SlugParsingError as ServiceDomainParsingError; + +#[derive(Debug, Clone, derive_more::Display, derive_more::FromStr)] +pub struct ServiceDomain(pub Slug); + +impl<'py> IntoPyObject<'py> for ServiceDomain { + type Target = PyString; + type Output = Bound<'py, Self::Target>; + type Error = Infallible; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let s = self.to_string(); + s.into_pyobject(py) + } +} diff --git a/entrypoint/src/home_assistant/service/service_id.rs b/entrypoint/src/home_assistant/service/service_id.rs new file mode 100644 index 0000000..2807424 --- /dev/null +++ b/entrypoint/src/home_assistant/service/service_id.rs @@ -0,0 +1,21 @@ +use std::convert::Infallible; + +use pyo3::{prelude::*, types::PyString}; + +use super::super::slug::Slug; + +pub use super::super::slug::SlugParsingError as ServiceIdParsingError; + +#[derive(Debug, Clone, derive_more::Display, derive_more::FromStr)] +pub struct ServiceId(pub Slug); + +impl<'py> IntoPyObject<'py> for ServiceId { + type Target = PyString; + type Output = Bound<'py, Self::Target>; + type Error = Infallible; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let s = self.to_string(); + s.into_pyobject(py) + } +} diff --git a/entrypoint/src/home_assistant/service_registry.rs b/entrypoint/src/home_assistant/service_registry.rs new file mode 100644 index 0000000..f916ba3 --- /dev/null +++ b/entrypoint/src/home_assistant/service_registry.rs @@ -0,0 +1,56 @@ +use pyo3::prelude::*; + +use crate::python_utils::{detach, validate_type_by_name}; + +use super::{event::context::context::Context, service::IntoServiceCall}; + +#[derive(Debug)] +pub struct ServiceRegistry(Py); + +impl<'py> FromPyObject<'py> for ServiceRegistry { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + // region: Validation + validate_type_by_name(ob, "ServiceRegistry")?; + // endregion: Validation + + Ok(Self(detach(ob))) + } +} + +impl ServiceRegistry { + pub async fn call_service< + ServiceData: for<'py> IntoPyObject<'py>, + Target: for<'py> IntoPyObject<'py>, + Event: for<'py> IntoPyObject<'py>, + ServiceResponse: for<'py> FromPyObject<'py>, + >( + &self, + service_call: impl IntoServiceCall, + context: Option>, + target: Option, + return_response: bool, + ) -> PyResult { + let (domain, service, service_data) = service_call.into_service_call(); + + let blocking = true; + + let args = ( + domain, + service, + service_data, + blocking, + context, + target, + return_response, + ); + + let future = Python::with_gil::<_, PyResult<_>>(|py| { + let service_registry = self.0.bind(py); + let awaitable = service_registry.call_method("async_call", args, None)?; + pyo3_async_runtimes::tokio::into_future(awaitable) + })?; + + let service_response = future.await?; + Python::with_gil(|py| service_response.extract(py)) + } +} diff --git a/src/home_assistant/object_id.rs b/entrypoint/src/home_assistant/slug.rs similarity index 66% rename from src/home_assistant/object_id.rs rename to entrypoint/src/home_assistant/slug.rs index 454369a..729daf3 100644 --- a/src/home_assistant/object_id.rs +++ b/entrypoint/src/home_assistant/slug.rs @@ -1,19 +1,26 @@ -use std::{str::FromStr, sync::Arc}; +use std::str::FromStr; use pyo3::{exceptions::PyValueError, PyErr}; +use smol_str::SmolStr; use snafu::Snafu; #[derive(Debug, Clone, derive_more::Display)] -pub struct ObjectId(Arc); +pub struct Slug(SmolStr); #[derive(Debug, Clone, Snafu)] #[snafu(display("expected a lowercase ASCII alphabetical character (i.e. a through z) or a digit (i.e. 0 through 9) or an underscore (i.e. _) but encountered {encountered}"))] -pub struct ObjectIdParsingError { +pub struct SlugParsingError { encountered: char, } -impl FromStr for ObjectId { - type Err = ObjectIdParsingError; +impl From for PyErr { + fn from(error: SlugParsingError) -> Self { + PyValueError::new_err(error.to_string()) + } +} + +impl FromStr for Slug { + type Err = SlugParsingError; fn from_str(s: &str) -> Result { for c in s.chars() { @@ -21,16 +28,10 @@ impl FromStr for ObjectId { 'a'..='z' => {} '0'..='9' => {} '_' => {} - _ => return Err(ObjectIdParsingError { encountered: c }), + _ => return Err(SlugParsingError { encountered: c }), } } Ok(Self(s.into())) } } - -impl From for PyErr { - fn from(error: ObjectIdParsingError) -> Self { - PyValueError::new_err(error.to_string()) - } -} diff --git a/entrypoint/src/home_assistant/state.rs b/entrypoint/src/home_assistant/state.rs new file mode 100644 index 0000000..a6b48b7 --- /dev/null +++ b/entrypoint/src/home_assistant/state.rs @@ -0,0 +1,71 @@ +use std::{convert::Infallible, str::FromStr}; + +use pyo3::{exceptions::PyValueError, prelude::*}; +use smol_str::SmolStr; +use strum::EnumString; + +/// A state in Home Assistant that is known to represent an error of some kind: +/// * `unavailable` (the device is likely offline or unreachable from the Home Assistant instance) +/// * `unknown` (I don't know how to explain this one) +#[derive(Debug, Clone, EnumString, strum::Display)] +#[strum(serialize_all = "snake_case")] +pub enum ErrorState { + Unavailable, + Unknown, +} + +impl<'py> FromPyObject<'py> for ErrorState { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let s = ob.extract::()?; + + let state = + ErrorState::from_str(&s).map_err(|err| PyValueError::new_err(err.to_string()))?; + + Ok(state) + } +} + +#[derive(Debug, Clone, derive_more::Display, derive_more::FromStr)] +pub struct UnexpectedState(pub SmolStr); + +impl<'py> FromPyObject<'py> for UnexpectedState { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let s = ob.extract::()?; + let s = SmolStr::new(s); + + Ok(UnexpectedState(s)) + } +} + +#[derive(Debug, Clone, derive_more::Display)] +pub enum HomeAssistantState { + Ok(State), + Err(ErrorState), + UnexpectedErr(UnexpectedState), +} + +impl FromStr for HomeAssistantState { + type Err = Infallible; + + fn from_str(s: &str) -> Result::Err> { + if let Ok(ok) = State::from_str(s) { + return Ok(HomeAssistantState::Ok(ok)); + } + + if let Ok(error) = ErrorState::from_str(s) { + return Ok(HomeAssistantState::Err(error)); + } + + Ok(HomeAssistantState::UnexpectedErr(UnexpectedState(s.into()))) + } +} + +impl<'py, State: FromStr + FromPyObject<'py>> FromPyObject<'py> for HomeAssistantState { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + let s = ob.extract::()?; + + let Ok(state) = s.parse(); + + Ok(state) + } +} diff --git a/src/home_assistant/state_machine.rs b/entrypoint/src/home_assistant/state_machine.rs similarity index 69% rename from src/home_assistant/state_machine.rs rename to entrypoint/src/home_assistant/state_machine.rs index 3f2e80b..d2f5a18 100644 --- a/src/home_assistant/state_machine.rs +++ b/entrypoint/src/home_assistant/state_machine.rs @@ -5,7 +5,7 @@ use crate::{ python_utils::{detach, validate_type_by_name}, }; -use super::state::State; +use super::state_object::StateObject; #[derive(Debug)] pub struct StateMachine(Py); @@ -21,11 +21,16 @@ impl<'py> FromPyObject<'py> for StateMachine { } impl StateMachine { - pub fn get FromPyObject<'py>, ContextEvent: for<'py> FromPyObject<'py>>( + pub fn get< + 'py, + State: FromPyObject<'py>, + Attributes: FromPyObject<'py>, + ContextEvent: FromPyObject<'py>, + >( &self, - py: Python<'_>, + py: Python<'py>, entity_id: EntityId, - ) -> PyResult>> { + ) -> PyResult>> { let args = (entity_id.to_string(),); let state = self.0.call_method1(py, "get", args)?; state.extract(py) diff --git a/entrypoint/src/home_assistant/state_object.rs b/entrypoint/src/home_assistant/state_object.rs new file mode 100644 index 0000000..764fe20 --- /dev/null +++ b/entrypoint/src/home_assistant/state_object.rs @@ -0,0 +1,139 @@ +use super::{ + event::{context::context::Context, specific::state_changed}, + home_assistant::HomeAssistant, +}; +use crate::home_assistant::entity_id::EntityId; +use chrono::{DateTime, Utc}; +use emitter_and_signal::signal::Signal; +use once_cell::sync::OnceCell; +use pyo3::{ + prelude::*, + types::{PyCFunction, PyDict, PyTuple}, +}; +use std::{future::Future, sync::Arc}; +use tokio::{select, sync::mpsc}; + +#[derive(Debug, FromPyObject)] +pub struct StateObject { + pub entity_id: EntityId, + pub state: State, + pub attributes: Attributes, + pub last_changed: Option>, + pub last_reported: Option>, + pub last_updated: Option>, + pub context: Context, +} + +impl< + State: Send + Sync + 'static + for<'py> FromPyObject<'py>, + Attributes: Send + Sync + 'static + for<'py> FromPyObject<'py>, + ContextEvent: Send + Sync + 'static + for<'py> FromPyObject<'py>, + > StateObject +{ + pub fn store( + py: Python<'_>, + home_assistant: &HomeAssistant, + entity_id: EntityId, + ) -> PyResult<( + Signal>>, + impl Future>, + )> { + let state_machine = home_assistant.states(py)?; + let current = state_machine.get(py, entity_id.clone())?; + + let py_home_assistant = home_assistant.into_pyobject(py)?.unbind(); + + let (store, task) = Signal::new(current.map(Arc::new), |mut publisher_stream| async move { + while let Some(publisher) = publisher_stream.wait().await { + let (new_state_sender, mut new_state_receiver) = mpsc::channel(8); + + let untrack = Python::with_gil::<_, PyResult<_>>(|py| { + static EVENT_MODULE: OnceCell> = OnceCell::new(); + + let event_module = EVENT_MODULE + .get_or_try_init(|| { + Result::<_, PyErr>::Ok( + py.import("homeassistant.helpers.event")?.unbind(), + ) + })? + .bind(py); + + let untrack = { + let callback = + move |args: &Bound<'_, PyTuple>, + _kwargs: Option<&Bound<'_, PyDict>>| { + tracing::debug!("calling the closure"); + + if let Ok((event,)) = args.extract::<( + state_changed::Event< + State, + Attributes, + ContextEvent, + State, + Attributes, + ContextEvent, + Py, + >, + )>() { + let new_state = event.data.new_state; + + tracing::debug!("sending a new state"); // TODO: remove + new_state_sender.try_send(new_state).unwrap(); + } + }; + let callback = PyCFunction::new_closure(py, None, None, callback)?; + let args = ( + py_home_assistant.clone_ref(py), + vec![entity_id.clone()], + callback, + ); + event_module.call_method1("async_track_state_change_event", args)? + }; + tracing::debug!(?untrack, "as any"); + + let is_callable = untrack.is_callable(); + tracing::debug!(?is_callable); + + // let untrack = untrack.downcast_into::()?; + // tracing::debug!(?untrack, "as downcast"); + + let untrack = untrack.unbind(); + tracing::debug!(?untrack, "as unbound"); + + Ok(untrack) + }); + + if let Ok(untrack) = untrack { + tracing::debug!("untrack is ok, going to wait for the next relevant event..."); + loop { + select! { + biased; + _ = publisher.all_unsubscribed() => { + tracing::debug!("calling untrack"); + let res = Python::with_gil(|py| untrack.call0(py)); + tracing::debug!(?res); + break; + } + new_state = new_state_receiver.recv() => { + match new_state { + Some(new_state) => { + tracing::debug!("publishing new state"); + publisher.publish(new_state.map(Arc::new)) + }, + None => { + tracing::debug!("channel dropped"); + break + }, + } + } + } + } + } else { + tracing::debug!("untrack is err"); + } + } + }); + + Ok((store, task)) + } +} diff --git a/entrypoint/src/lib.rs b/entrypoint/src/lib.rs new file mode 100644 index 0000000..f1fa72b --- /dev/null +++ b/entrypoint/src/lib.rs @@ -0,0 +1,86 @@ +use std::{str::FromStr, time::Duration}; + +use driver_kasa::connection::LB130USHandle; +use home_assistant::{ + home_assistant::HomeAssistant, light::HomeAssistantLight, object_id::ObjectId, +}; +use protocol::light::Light; +use pyo3::prelude::*; +use shadow_rs::shadow; +use tokio::time::interval; +use tracing::{level_filters::LevelFilter, Level}; +use tracing_subscriber::{ + fmt::{self, format::FmtSpan}, + layer::SubscriberExt, + registry, + util::SubscriberInitExt, + Layer, +}; +use tracing_to_home_assistant::TracingToHomeAssistant; + +mod home_assistant; +mod python_utils; +mod tracing_to_home_assistant; + +shadow!(build_info); + +async fn real_main(home_assistant: HomeAssistant) -> ! { + registry() + .with( + fmt::layer() + .pretty() + .with_span_events(FmtSpan::ACTIVE) + .with_filter(LevelFilter::from_level(Level::TRACE)), + ) + .with(TracingToHomeAssistant) + .init(); + + let built_at = build_info::BUILD_TIME; + tracing::info!(built_at); + + // let lamp = HomeAssistantLight { + // home_assistant, + // object_id: ObjectId::from_str("jacob_s_lamp_top").unwrap(), + // }; + + let ip = [10, 0, 3, 71]; + let port = 9999; + + let some_light = LB130USHandle::new( + (ip, port).into(), + Duration::from_secs(10), + (64).try_into().unwrap(), + ); + + let mut interval = interval(Duration::from_secs(20)); + interval.tick().await; + loop { + interval.tick().await; + + tracing::info!("about to call get_sysinfo"); + let sysinfo_res = some_light.get_sysinfo().await; + tracing::info!(?sysinfo_res, "got sys info"); + + // let is_on = lamp.is_on().await; + // tracing::info!(?is_on); + // let is_off = lamp.is_off().await; + // tracing::info!(?is_off); + + // let something = lamp.turn_on().await; + // tracing::info!(?something); + } +} + +#[pyfunction] +fn main<'py>(py: Python<'py>, home_assistant: HomeAssistant) -> PyResult> { + pyo3_async_runtimes::tokio::future_into_py::<_, ()>(py, async { + real_main(home_assistant).await; + }) +} + +/// A Python module implemented in Rust. +#[pymodule] +fn smart_home_in_rust_with_home_assistant(module: &Bound<'_, PyModule>) -> PyResult<()> { + module.add_function(wrap_pyfunction!(main, module)?)?; + Ok(()) +} diff --git a/src/python_utils.rs b/entrypoint/src/python_utils.rs similarity index 90% rename from src/python_utils.rs rename to entrypoint/src/python_utils.rs index 49b0e1d..d958e47 100644 --- a/src/python_utils.rs +++ b/entrypoint/src/python_utils.rs @@ -1,6 +1,6 @@ use pyo3::{exceptions::PyTypeError, prelude::*}; -/// Create a GIL-independent reference (similar to [`Arc`](std::sync::Arc)) +/// Create a GIL-independent reference pub fn detach(bound: &Bound) -> Py { let py = bound.py(); bound.as_unbound().clone_ref(py) diff --git a/src/tracing_to_home_assistant.rs b/entrypoint/src/tracing_to_home_assistant.rs similarity index 100% rename from src/tracing_to_home_assistant.rs rename to entrypoint/src/tracing_to_home_assistant.rs diff --git a/src/home_assistant/event/context/context.rs b/src/home_assistant/event/context/context.rs deleted file mode 100644 index 2018adc..0000000 --- a/src/home_assistant/event/context/context.rs +++ /dev/null @@ -1,14 +0,0 @@ -use pyo3::prelude::*; - -use super::id::Id; - -/// The context that triggered something. -#[derive(Debug, FromPyObject)] -pub struct Context { - pub id: Id, - pub user_id: Option, - pub parent_id: Option, - /// In order to prevent cycles, the user must decide to pass [`Py`] for the `Event` type here - /// or for the `Context` type in [`Event`] - pub origin_event: Event, -} diff --git a/src/home_assistant/event/context/id.rs b/src/home_assistant/event/context/id.rs deleted file mode 100644 index d3517cf..0000000 --- a/src/home_assistant/event/context/id.rs +++ /dev/null @@ -1,20 +0,0 @@ -use pyo3::prelude::*; -use ulid::Ulid; - -#[derive(Debug, Clone)] -pub enum Id { - Ulid(Ulid), - Other(String), -} - -impl<'py> FromPyObject<'py> for Id { - fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { - let s = ob.extract::()?; - - if let Ok(ulid) = s.parse() { - Ok(Id::Ulid(ulid)) - } else { - Ok(Id::Other(s)) - } - } -} diff --git a/src/home_assistant/event/specific/state_changed.rs b/src/home_assistant/event/specific/state_changed.rs deleted file mode 100644 index 42f776d..0000000 --- a/src/home_assistant/event/specific/state_changed.rs +++ /dev/null @@ -1,37 +0,0 @@ -use pyo3::exceptions::PyValueError; -use pyo3::prelude::*; - -use crate::home_assistant::{entity_id::EntityId, state::State}; - -#[derive(Debug, Clone)] -pub struct Type; - -impl<'py> FromPyObject<'py> for Type { - fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { - let s = ob.extract::<&str>()?; - - if s == "state_changed" { - Ok(Type) - } else { - Err(PyValueError::new_err(format!( - "expected a string of value 'state_changed', but got {s}" - ))) - } - } -} - -#[derive(Debug, FromPyObject)] -#[pyo3(from_item_all)] -pub struct Data { - pub entity_id: EntityId, - pub old_state: Option>, - pub new_state: Option>, -} - -/// A state changed event is fired when on state write the state is changed. -pub type Event = - super::super::event::Event< - Type, - Data, - Context, - >; diff --git a/src/home_assistant/state.rs b/src/home_assistant/state.rs deleted file mode 100644 index 24b855d..0000000 --- a/src/home_assistant/state.rs +++ /dev/null @@ -1,17 +0,0 @@ -use chrono::{DateTime, Utc}; -use pyo3::prelude::*; - -use crate::home_assistant::entity_id::EntityId; - -use super::event::context::context::Context; - -#[derive(Debug, FromPyObject)] -pub struct State { - pub entity_id: EntityId, - pub state: String, - pub attributes: Attributes, - pub last_changed: Option>, - pub last_reported: Option>, - pub last_updated: Option>, - pub context: Context, -} diff --git a/src/lib.rs b/src/lib.rs deleted file mode 100644 index c57b6b8..0000000 --- a/src/lib.rs +++ /dev/null @@ -1,61 +0,0 @@ -use std::time::Duration; - -use home_assistant::home_assistant::HomeAssistant; -use pyo3::prelude::*; -use shadow_rs::shadow; -use tokio::time::interval; -use tracing::{level_filters::LevelFilter, Level}; -use tracing_subscriber::{ - fmt::{self, format::FmtSpan}, - layer::SubscriberExt, - registry, - util::SubscriberInitExt, - Layer, -}; -use tracing_to_home_assistant::TracingToHomeAssistant; - -mod arbitrary; -mod home_assistant; -mod python_utils; -mod store; -mod tracing_to_home_assistant; - -shadow!(build_info); - -async fn real_main(home_assistant: HomeAssistant) -> ! { - registry() - .with( - fmt::layer() - .pretty() - .with_span_events(FmtSpan::ACTIVE) - .with_filter(LevelFilter::from_level(Level::TRACE)), - ) - .with(TracingToHomeAssistant) - .init(); - - let built_at = build_info::BUILD_TIME; - tracing::info!(built_at); - - let duration = Duration::from_millis(5900); - let mut interval = interval(duration); - - loop { - let instant = interval.tick().await; - - tracing::debug!(?instant, "it is now"); - } -} - -#[pyfunction] -fn main<'py>(py: Python<'py>, home_assistant: HomeAssistant) -> PyResult> { - pyo3_async_runtimes::tokio::future_into_py::<_, ()>(py, async { - real_main(home_assistant).await; - }) -} - -/// A Python module implemented in Rust. -#[pymodule] -fn smart_home_in_rust_with_home_assistant(m: &Bound<'_, PyModule>) -> PyResult<()> { - m.add_function(wrap_pyfunction!(main, m)?)?; - Ok(()) -} diff --git a/src/store/mod.rs b/src/store/mod.rs deleted file mode 100644 index b1836bb..0000000 --- a/src/store/mod.rs +++ /dev/null @@ -1,116 +0,0 @@ -use std::future::Future; - -use tokio::{ - sync::{mpsc, watch}, - task::{JoinError, JoinHandle}, -}; - -#[derive(Debug)] -pub struct PublisherStream { - receiver: mpsc::Receiver>, -} - -impl PublisherStream { - pub async fn wait(&mut self) -> Option> { - self.receiver.recv().await - } -} - -#[derive(Debug)] -pub struct Publisher { - sender: watch::Sender, -} - -impl Publisher { - pub async fn all_unsubscribed(&self) { - self.sender.closed().await - } - - pub fn publish(&self, value: T) { - self.sender.send_replace(value); - } -} - -#[derive(Debug)] -pub struct Store { - sender: watch::Sender, - publisher_sender: mpsc::Sender>, - producer_join_handle: JoinHandle<()>, -} - -impl Store { - pub fn new + Send + 'static>( - initial: T, - producer: impl FnOnce(PublisherStream) -> Fut, - ) -> Self { - let (sender, _) = watch::channel(initial); - let (publisher_sender, publisher_receiver) = mpsc::channel(1); - - let subscribers_stream = PublisherStream { - receiver: publisher_receiver, - }; - - let producer_join_handle = tokio::spawn(producer(subscribers_stream)); - - Self { - publisher_sender, - sender, - producer_join_handle, - } - } - - pub fn subscribe(&self) -> Result, ProducerExited> { - let receiver = self.sender.subscribe(); - - if self.sender.receiver_count() == 1 { - if let Err(e) = self.publisher_sender.try_send(Publisher { - sender: self.sender.clone(), - }) { - match e { - mpsc::error::TrySendError::Full(_) => unreachable!(), - mpsc::error::TrySendError::Closed(_) => return Err(ProducerExited), - } - } - } - - Ok(Subscription { receiver }) - } - - /// Signify that no one can ever subscribe again, - /// and wait for the producer task to complete. - pub fn run(self) -> impl Future> { - self.producer_join_handle - } -} - -pub struct Subscription { - receiver: watch::Receiver, -} - -#[derive(Debug, Clone, Copy)] -pub struct ProducerExited; - -impl Subscription { - pub async fn changed(&mut self) -> Result<(), ProducerExited> { - self.receiver.changed().await.map_err(|_| ProducerExited) - } - - pub fn get(&mut self) -> T::Owned - where - T: ToOwned, - { - self.receiver.borrow_and_update().to_owned() - } - - pub async fn for_each>(mut self, mut func: impl FnMut(T::Owned) -> Fut) - where - T: ToOwned, - { - loop { - func(self.get()).await; - if self.changed().await.is_err() { - return; - } - } - } -}