Compare commits

...

3 Commits

9 changed files with 218 additions and 64 deletions

View File

@@ -3,11 +3,13 @@ use chrono_tz::Tz;
use ijson::{IArray, INumber, IObject, IString, IValue}; use ijson::{IArray, INumber, IObject, IString, IValue};
#[cfg(feature = "pyo3")] #[cfg(feature = "pyo3")]
use pyo3::{ use pyo3::{
exceptions::{PyTypeError, PyValueError}, exceptions::{PyException, PyTypeError, PyValueError},
prelude::*, prelude::*,
types::{PyList, PyNone}, types::{PyList, PyNone},
}; };
use snafu::Snafu; use snafu::{ResultExt, Snafu};
use crate::finite_f64::NotFinite;
use super::{finite_f64::FiniteF64, map::Map, map_key::MapKey}; use super::{finite_f64::FiniteF64, map::Map, map_key::MapKey};
@@ -73,22 +75,64 @@ impl From<Arbitrary> for IValue {
} }
#[cfg(feature = "pyo3")] #[cfg(feature = "pyo3")]
impl<'py> FromPyObject<'py> for Arbitrary { #[derive(Debug, Snafu)]
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> { pub enum ExtractArbitraryError {
/// error getting the qualified type name when trying to report
/// that an instance of this type cannot be extracted as an [`Arbitrary`]
GetTypeNameError { source: PyErr },
/// error extracting the (successfully retrieved) fully qualified type name as a [`String`]
ExtractTypeNameError { source: PyErr },
/// the float trying to be extracted is not finite, which isn't supported
FloatNotFinite { source: NotFinite },
/// can't extract an arbitrary from a {type_name}
UnsupportedType { type_name: String },
}
#[cfg(feature = "pyo3")]
impl From<ExtractArbitraryError> for PyErr {
fn from(error: ExtractArbitraryError) -> Self {
match &error {
ExtractArbitraryError::GetTypeNameError { .. } => {
PyException::new_err(error.to_string())
}
ExtractArbitraryError::ExtractTypeNameError { .. } => {
PyException::new_err(error.to_string())
}
ExtractArbitraryError::FloatNotFinite { .. } => {
PyValueError::new_err(error.to_string())
}
ExtractArbitraryError::UnsupportedType { .. } => {
PyTypeError::new_err(error.to_string())
}
}
}
}
#[cfg(feature = "pyo3")]
impl<'a, 'py> FromPyObject<'a, 'py> for Arbitrary {
type Error = ExtractArbitraryError;
fn extract(ob: Borrowed<'a, 'py, PyAny>) -> Result<Self, Self::Error> {
if let Ok(map_key) = ob.extract::<MapKey>() { if let Ok(map_key) = ob.extract::<MapKey>() {
Ok(map_key.into()) Ok(map_key.into())
} else if let Ok(map) = ob.extract() { } else if let Ok(map) = ob.extract() {
Ok(Self::Map(map)) Ok(Self::Map(map))
} else if let Ok(f) = ob.extract::<f64>() { } else if let Ok(f) = ob.extract::<f64>() {
let f = FiniteF64::try_from(f).map_err(|err| PyValueError::new_err(err.to_string()))?; let f = FiniteF64::try_from(f).context(FloatNotFiniteSnafu)?;
Ok(Self::Float(f)) Ok(Self::Float(f))
} else if let Ok(vec) = ob.extract() { } else if let Ok(vec) = ob.extract() {
Ok(Self::Array(vec)) Ok(Self::Array(vec))
} else { } else {
let type_name = ob.get_type().fully_qualified_name()?; let type_name = ob
Err(PyTypeError::new_err(format!( .get_type()
"can't extract an arbitrary from a {type_name}" .fully_qualified_name()
))) .context(GetTypeNameSnafu)?
.extract()
.context(ExtractTypeNameSnafu)?;
Err(ExtractArbitraryError::UnsupportedType { type_name })
} }
} }
} }

View File

@@ -10,6 +10,7 @@ use pyo3::{
prelude::*, prelude::*,
types::{PyNone, PyTuple}, types::{PyNone, PyTuple},
}; };
use snafu::{ResultExt, Snafu};
use super::arbitrary::{Arbitrary, MapKeyFromArbitraryError}; use super::arbitrary::{Arbitrary, MapKeyFromArbitraryError};
@@ -43,9 +44,40 @@ impl Display for MapKey {
} }
#[cfg(feature = "pyo3")] #[cfg(feature = "pyo3")]
impl<'py> FromPyObject<'py> for MapKey { #[derive(Debug, Snafu)]
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> { pub enum ExtractMapKeyError {
if let Ok(_none) = ob.downcast::<PyNone>() { /// error getting the qualified type name when trying to report
/// that an instance of this type cannot be extracted as an [`Arbitrary`]
GetTypeNameError { source: PyErr },
/// error extracting the (successfully retrieved) fully qualified type name as a [`String`]
ExtractTypeNameError { source: PyErr },
/// can't extract a map key from a {type_name}
UnsupportedType { type_name: String },
}
#[cfg(feature = "pyo3")]
impl From<ExtractMapKeyError> for PyErr {
fn from(error: ExtractMapKeyError) -> Self {
use pyo3::exceptions::PyException;
match &error {
ExtractMapKeyError::GetTypeNameError { .. } => PyException::new_err(error.to_string()),
ExtractMapKeyError::ExtractTypeNameError { .. } => {
PyException::new_err(error.to_string())
}
ExtractMapKeyError::UnsupportedType { .. } => PyTypeError::new_err(error.to_string()),
}
}
}
#[cfg(feature = "pyo3")]
impl<'a, 'py> FromPyObject<'a, 'py> for MapKey {
type Error = ExtractMapKeyError;
fn extract(ob: Borrowed<'a, 'py, PyAny>) -> Result<Self, Self::Error> {
if let Ok(_none) = ob.cast::<PyNone>() {
Ok(Self::Null) Ok(Self::Null)
} else if let Ok(b) = ob.extract() { } else if let Ok(b) = ob.extract() {
Ok(Self::Bool(b)) Ok(Self::Bool(b))
@@ -56,10 +88,13 @@ impl<'py> FromPyObject<'py> for MapKey {
} else if let Ok(tuple) = ob.extract() { } else if let Ok(tuple) = ob.extract() {
Ok(Self::Tuple(tuple)) Ok(Self::Tuple(tuple))
} else { } else {
let type_name = ob.get_type().fully_qualified_name()?; let type_name = ob
Err(PyTypeError::new_err(format!( .get_type()
"can't extract a map key from a {type_name}" .fully_qualified_name()
))) .context(GetTypeNameSnafu)?
.extract()
.context(ExtractTypeNameSnafu)?;
Err(ExtractMapKeyError::UnsupportedType { type_name })
} }
} }
} }

View File

@@ -85,7 +85,7 @@ pub struct PreferredStateChoice {
} }
#[derive(Debug, SerializeDisplay, DeserializeFromStr)] #[derive(Debug, SerializeDisplay, DeserializeFromStr)]
struct MacAddressWithoutSeparators(MacAddress); pub struct MacAddressWithoutSeparators(MacAddress);
impl FromStr for MacAddressWithoutSeparators { impl FromStr for MacAddressWithoutSeparators {
type Err = MacParseError; type Err = MacParseError;
@@ -119,46 +119,46 @@ impl Display for MacAddressWithoutSeparators {
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
enum ActiveMode { pub enum ActiveMode {
#[serde(rename = "none")] #[serde(rename = "none")]
None, None,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct CtrlProtocols { pub struct CtrlProtocols {
name: String, name: String,
version: String, version: String,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct DeviceId(pub String); pub struct DeviceId(pub String);
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
enum DevState { pub enum DevState {
#[serde(rename = "normal")] #[serde(rename = "normal")]
Normal, Normal,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct HardwareId(pub String); pub struct HardwareId(pub String);
#[derive(Debug, Deserialize_repr)] #[derive(Debug, Deserialize_repr)]
#[repr(u8)] #[repr(u8)]
enum IsColor { pub enum IsColor {
NoColor = 0, NoColor = 0,
Color = 1, Color = 1,
} }
#[derive(Debug, Deserialize_repr)] #[derive(Debug, Deserialize_repr)]
#[repr(u8)] #[repr(u8)]
enum IsDimmable { pub enum IsDimmable {
NotDimmable = 0, NotDimmable = 0,
Dimmable = 1, Dimmable = 1,
} }
#[derive(Debug, Deserialize_repr)] #[derive(Debug, Deserialize_repr)]
#[repr(u8)] #[repr(u8)]
enum IsVariableColorTemp { pub enum IsVariableColorTemp {
NoVariableColorTemp = 0, NoVariableColorTemp = 0,
VariableColorTemp = 1, VariableColorTemp = 1,
} }
@@ -228,13 +228,13 @@ impl<S> FromColor<Hsv<S, f64>> for Hsb {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
struct KelvinWithBrightness { pub struct KelvinWithBrightness {
kelvin: Kelvin, kelvin: Kelvin,
brightness: Percentage, brightness: Percentage,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
enum Color { pub enum Color {
HSB(Hsb), HSB(Hsb),
KelvinWithBrightness(KelvinWithBrightness), KelvinWithBrightness(KelvinWithBrightness),
} }
@@ -343,26 +343,26 @@ pub enum LightState {
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
struct DftOnState { pub struct DftOnState {
#[serde(flatten)] #[serde(flatten)]
color: Color, color: Color,
mode: LightStateMode, mode: LightStateMode,
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
enum LightStateMode { pub enum LightStateMode {
#[serde(rename = "normal")] #[serde(rename = "normal")]
Normal, Normal,
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
enum MicType { pub enum MicType {
#[serde(rename = "IOT.SMARTBULB")] #[serde(rename = "IOT.SMARTBULB")]
IotSmartbulb, IotSmartbulb,
} }
#[derive(Debug, Clone, Deserialize)] #[derive(Debug, Clone, Deserialize)]
struct OemId(pub String); pub struct OemId(pub String);
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize)]
pub struct SetLightStateArgs { pub struct SetLightStateArgs {

View File

@@ -7,4 +7,5 @@ license = { workspace = true }
[dependencies] [dependencies]
deranged = { workspace = true } deranged = { workspace = true }
ext-trait = { workspace = true } ext-trait = { workspace = true }
snafu = { workspace = true }
tokio = { workspace = true, features = ["sync"] } tokio = { workspace = true, features = ["sync"] }

View File

@@ -1,12 +1,9 @@
use std::{future::Future, num::NonZero};
use deranged::RangedUsize;
use tokio::{
sync::{broadcast, mpsc},
task::JoinHandle,
};
use super::ProducerExited; use super::ProducerExited;
use deranged::RangedUsize;
use snafu::Snafu;
use std::{future::Future, num::NonZero};
use tokio::sync::{broadcast, mpsc};
pub use tokio::task::JoinError;
#[derive(Debug)] #[derive(Debug)]
pub struct Publisher<T> { pub struct Publisher<T> {
@@ -47,7 +44,7 @@ impl<T> Emitter<T> {
pub fn new<R, Fut>( pub fn new<R, Fut>(
producer: impl FnOnce(PublisherStream<T>) -> Fut, producer: impl FnOnce(PublisherStream<T>) -> Fut,
capacity: Capacity, capacity: Capacity,
) -> (Self, JoinHandle<R>) ) -> (Self, impl Future<Output = Result<R, JoinError>>)
where where
Fut: Future<Output = R> + Send + 'static, Fut: Future<Output = R> + Send + 'static,
T: Clone, T: Clone,
@@ -93,8 +90,12 @@ pub struct Subscription<T> {
receiver: broadcast::Receiver<T>, receiver: broadcast::Receiver<T>,
} }
#[derive(Debug, Clone, Snafu)]
pub enum NextError { pub enum NextError {
ProducerExited(ProducerExited), /// the producer backing this emitter exited
ProducerExited { source: ProducerExited },
/// the broadcast channel underlying this emitter lagged and {skipped_events} events were skipped
Lagged { skipped_events: NonZero<u64> }, Lagged { skipped_events: NonZero<u64> },
} }
@@ -104,7 +105,9 @@ impl<T> Subscription<T> {
T: Clone, T: Clone,
{ {
self.receiver.recv().await.map_err(|err| match err { self.receiver.recv().await.map_err(|err| match err {
broadcast::error::RecvError::Closed => NextError::ProducerExited(ProducerExited), broadcast::error::RecvError::Closed => NextError::ProducerExited {
source: ProducerExited,
},
broadcast::error::RecvError::Lagged(skipped_events) => NextError::Lagged { broadcast::error::RecvError::Lagged(skipped_events) => NextError::Lagged {
skipped_events: skipped_events skipped_events: skipped_events
.try_into() .try_into()

View File

@@ -1,11 +1,17 @@
use ext_trait::extension; use std::future::Future;
use tokio::{select, task::JoinHandle};
use super::emitter::{Capacity, Emitter, NextError}; use ext_trait::extension;
use tokio::select;
use super::emitter::{Capacity, Emitter, JoinError, NextError};
#[extension(pub trait EmitterExt)] #[extension(pub trait EmitterExt)]
impl<T> Emitter<T> { impl<T> Emitter<T> {
fn map<M, F>(self, mut func: F, capacity: Capacity) -> (Emitter<M>, JoinHandle<()>) fn map<M, F>(
self,
mut func: F,
capacity: Capacity,
) -> (Emitter<M>, impl Future<Output = Result<(), JoinError>>)
where where
T: Send + 'static + Clone, T: Send + 'static + Clone,
M: Send + 'static + Clone, M: Send + 'static + Clone,
@@ -28,7 +34,7 @@ impl<T> Emitter<T> {
match event_res { match event_res {
Ok(event) => publisher.publish(func(event)), Ok(event) => publisher.publish(func(event)),
Err(NextError::Lagged { .. }) => {}, Err(NextError::Lagged { .. }) => {},
Err(NextError::ProducerExited(_)) => return, Err(NextError::ProducerExited { .. }) => return,
} }
} }
} }
@@ -39,7 +45,11 @@ impl<T> Emitter<T> {
) )
} }
fn filter<F>(self, mut func: F, capacity: Capacity) -> (Emitter<T>, JoinHandle<()>) fn filter<F>(
self,
mut func: F,
capacity: Capacity,
) -> (Emitter<T>, impl Future<Output = Result<(), JoinError>>)
where where
T: Send + 'static + Clone, T: Send + 'static + Clone,
F: Send + 'static + FnMut(&T) -> bool, F: Send + 'static + FnMut(&T) -> bool,
@@ -63,7 +73,7 @@ impl<T> Emitter<T> {
publisher.publish(event) publisher.publish(event)
}, },
Err(NextError::Lagged { .. }) => {}, Err(NextError::Lagged { .. }) => {},
Err(NextError::ProducerExited(_)) => return, Err(NextError::ProducerExited { .. }) => return,
} }
} }
} }
@@ -74,7 +84,11 @@ impl<T> Emitter<T> {
) )
} }
fn filter_mut<F>(self, mut func: F, capacity: Capacity) -> (Emitter<T>, JoinHandle<()>) fn filter_mut<F>(
self,
mut func: F,
capacity: Capacity,
) -> (Emitter<T>, impl Future<Output = Result<(), JoinError>>)
where where
T: Send + 'static + Clone, T: Send + 'static + Clone,
F: Send + 'static + FnMut(&mut T) -> bool, F: Send + 'static + FnMut(&mut T) -> bool,
@@ -98,7 +112,7 @@ impl<T> Emitter<T> {
publisher.publish(event) publisher.publish(event)
}, },
Err(NextError::Lagged { .. }) => {}, Err(NextError::Lagged { .. }) => {},
Err(NextError::ProducerExited(_)) => return, Err(NextError::ProducerExited { .. }) => return,
} }
} }
} }
@@ -109,7 +123,11 @@ impl<T> Emitter<T> {
) )
} }
fn filter_map<M, F>(self, mut func: F, capacity: Capacity) -> (Emitter<M>, JoinHandle<()>) fn filter_map<M, F>(
self,
mut func: F,
capacity: Capacity,
) -> (Emitter<M>, impl Future<Output = Result<(), JoinError>>)
where where
T: Send + 'static + Clone, T: Send + 'static + Clone,
M: Send + 'static + Clone, M: Send + 'static + Clone,
@@ -134,7 +152,7 @@ impl<T> Emitter<T> {
publisher.publish(mapped) publisher.publish(mapped)
}, },
Err(NextError::Lagged { .. }) => {}, Err(NextError::Lagged { .. }) => {},
Err(NextError::ProducerExited(_)) => return, Err(NextError::ProducerExited { .. }) => return,
} }
} }
} }

View File

@@ -1,10 +1,15 @@
pub mod emitter; use snafu::Snafu;
mod emitter_ext;
pub mod signal;
mod signal_ext;
pub mod emitter;
pub mod emitter_ext;
pub mod signal;
pub mod signal_ext;
pub use emitter::Emitter;
pub use emitter_ext::EmitterExt; pub use emitter_ext::EmitterExt;
pub use signal::Signal;
pub use signal_ext::SignalExt; pub use signal_ext::SignalExt;
#[derive(Debug, Clone, Copy)] /// the producer backing this [`Signal`] or [`Emitter`] exited
#[derive(Debug, Clone, Copy, Snafu)]
pub struct ProducerExited; pub struct ProducerExited;

View File

@@ -3,6 +3,8 @@ use std::future::Future;
use tokio::sync::{mpsc, watch}; use tokio::sync::{mpsc, watch};
pub use tokio::task::JoinError; pub use tokio::task::JoinError;
use crate::ProducerExited;
#[derive(Debug)] #[derive(Debug)]
pub struct Publisher<T> { pub struct Publisher<T> {
sender: watch::Sender<T>, sender: watch::Sender<T>,
@@ -87,9 +89,6 @@ pub struct Subscription<T> {
receiver: watch::Receiver<T>, receiver: watch::Receiver<T>,
} }
#[derive(Debug, Clone, Copy)]
pub struct ProducerExited;
impl<T> Subscription<T> { impl<T> Subscription<T> {
pub async fn changed(&mut self) -> Result<(), ProducerExited> { pub async fn changed(&mut self) -> Result<(), ProducerExited> {
self.receiver.changed().await.map_err(|_| ProducerExited) self.receiver.changed().await.map_err(|_| ProducerExited)

View File

@@ -1,6 +1,55 @@
use ext_trait::extension; use std::future::Future;
use super::signal::Signal; use ext_trait::extension;
use snafu::{ResultExt, Snafu};
use tokio::select;
use crate::ProducerExited;
use super::signal::{JoinError, Signal};
#[derive(Debug, Snafu)]
pub struct ProducerAlreadyExited {
source: ProducerExited,
}
#[extension(pub trait SignalExt)] #[extension(pub trait SignalExt)]
impl<T> Signal<T> {} impl<T> Signal<T> {
fn map<M, F>(
self,
mut func: F,
) -> Result<(Signal<M>, impl Future<Output = Result<(), JoinError>>), ProducerAlreadyExited>
where
T: 'static + Sync + Send + Clone,
M: 'static + Sync + Send + Clone,
F: 'static + Send + FnMut(T) -> M,
{
let initial = func(self.subscribe().context(ProducerAlreadyExitedSnafu)?.get());
Ok(Signal::new(initial, |mut publisher_stream| async move {
while let Some(publisher) = publisher_stream.wait().await {
let Ok(mut subscription) = self.subscribe() else {
return;
};
loop {
select! {
biased;
_ = publisher.all_unsubscribed() => {
break;
}
changed_res = subscription.changed() => {
match changed_res {
Ok(()) => {
let value = subscription.get();
publisher.publish(func(value))
},
Err(ProducerExited) => return,
}
}
}
}
}
}))
}
}