diff --git a/driver/kasa/Cargo.toml b/driver/kasa/Cargo.toml index 51a1c80..68041e3 100644 --- a/driver/kasa/Cargo.toml +++ b/driver/kasa/Cargo.toml @@ -7,9 +7,11 @@ license = { workspace = true } [dependencies] backon = { workspace = true } deranged = { workspace = true } +derive_more = { workspace = true, features = ["from"] } mac_address = { version = "1.1.8", features = ["serde"] } +palette = { workspace = true } protocol = { path = "../../protocol" } -serde = { version = "1.0.219", features = ["derive"] } +serde = { workspace = true, features = ["derive"] } serde_json = "1.0.140" serde_repr = "0.1.20" serde_with = "3.12.0" diff --git a/driver/kasa/src/connection.rs b/driver/kasa/src/connection.rs index 09b5497..d84841d 100644 --- a/driver/kasa/src/connection.rs +++ b/driver/kasa/src/connection.rs @@ -1,10 +1,14 @@ -use crate::messages::{GetSysInfo, GetSysInfoResponse, LB130USSys, SysInfo}; +use crate::messages::{ + GetSysInfo, GetSysInfoResponse, LB130USSys, LightState, Off, On, SetLightLastOn, SetLightOff, + SetLightState, SetLightStateArgs, SetLightStateResponse, SetLightTo, SysInfo, +}; use backon::{FibonacciBuilder, Retryable}; -use protocol::light::{Kelvin, KelvinLight, Light, Rgb, RgbLight}; + +use serde::{Deserialize, Serialize}; use snafu::{ResultExt, Snafu}; -use std::{convert::Infallible, io, net::SocketAddr, num::NonZero, time::Duration}; +use std::{io, net::SocketAddr, num::NonZero, time::Duration}; use tokio::{ - io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}, + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter}, net::TcpStream, sync::{mpsc, oneshot}, time::timeout, @@ -51,11 +55,23 @@ pub enum CommunicationError { WrongDevice, } +fn should_try_reconnecting(communication_error: &CommunicationError) -> bool { + matches!( + communication_error, + CommunicationError::WriteError { .. } | CommunicationError::ReadError { .. } + ) +} + #[derive(Debug)] enum LB130USMessage { GetSysInfo(oneshot::Sender>), + SetLightState( + SetLightStateArgs, + oneshot::Sender>, + ), } +#[tracing::instrument(skip(messages))] async fn lb130us_actor( addr: SocketAddr, disconnect_after_idle: Duration, @@ -116,86 +132,105 @@ async fn lb130us_actor( tracing::info!("yay connected and got a message"); - // TODO: do something match message { LB130USMessage::GetSysInfo(callback) => { - tracing::info!("going to try to get sys info for you..."); + let res = handle_get_sysinfo(writer, reader).await; - // TODO: extract to its own function - let outgoing = GetSysInfo; - let outgoing = match serde_json::to_vec(&outgoing) { - Ok(outgoing) => outgoing, - Err(err) => { - // TODO (continued) instead of doing stuff like this - let _ = - callback.send(Err(CommunicationError::SerializeError { source: err })); - continue; - } - }; - - tracing::info!(?outgoing); - - let encrypted_outgoing = into_encrypted(outgoing); - - tracing::info!(?encrypted_outgoing); - - if let Err(err) = writer.write_all(&encrypted_outgoing).await { - connection_cell.take(); - let _ = callback.send(Err(CommunicationError::WriteError { source: err })); - continue; - } - - if let Err(err) = writer.flush().await { - connection_cell.take(); - let _ = callback.send(Err(CommunicationError::WriteError { source: err })); - continue; - } - tracing::info!("sent it, now about to try to get a response"); - - let incoming_length = match reader.read_u32().await { - Ok(incoming_length) => incoming_length, - Err(err) => { + if let Err(communication_error) = &res { + if should_try_reconnecting(communication_error) { connection_cell.take(); - let _ = callback.send(Err(CommunicationError::ReadError { source: err })); - continue; } - }; - tracing::info!(?incoming_length); - - let mut incoming_message = Vec::new(); - incoming_message.resize(incoming_length as usize, 0); - if let Err(err) = reader.read_exact(&mut incoming_message).await { - connection_cell.take(); - let _ = callback.send(Err(CommunicationError::ReadError { source: err })); - continue; } - XorEncryption::<171>::decrypt_in_place(&mut incoming_message); - tracing::info!(?incoming_message); + let _ = callback.send(res); + } + LB130USMessage::SetLightState(args, callback) => { + let res = handle_set_light_state(writer, reader, args).await; - let response: GetSysInfoResponse = match serde_json::from_slice(&incoming_message) { - Ok(response) => response, - Err(err) => { - let _ = callback - .send(Err(CommunicationError::DeserializeError { source: err })); - continue; + if let Err(communication_error) = &res { + if should_try_reconnecting(communication_error) { + connection_cell.take(); } - }; - tracing::info!(?response); + } - let SysInfo::LB130US(lb130us) = response.system.get_sysinfo else { - let _ = callback.send(Err(CommunicationError::WrongDevice)); - continue; - }; - tracing::info!(?lb130us); - - let _ = callback.send(Ok(lb130us)); - tracing::info!("cool, gave a response! onto the next message!"); + let _ = callback.send(res); } } } } +#[tracing::instrument(skip(writer, reader, request))] +async fn send_request< + AW: AsyncWrite + Unpin, + AR: AsyncRead + Unpin, + Request: Serialize, + Response: for<'de> Deserialize<'de>, +>( + writer: &mut AW, + reader: &mut AR, + request: &Request, +) -> Result { + let outgoing = serde_json::to_vec(request).context(SerializeSnafu)?; + tracing::info!(?outgoing); + + let encrypted_outgoing = into_encrypted(outgoing); + tracing::info!(?encrypted_outgoing); + + writer + .write_all(&encrypted_outgoing) + .await + .context(WriteSnafu)?; + writer.flush().await.context(WriteSnafu)?; + tracing::info!("sent it, now about to try to get a response"); + + let incoming_length = reader.read_u32().await.context(ReadSnafu)?; + tracing::info!(?incoming_length); + + let mut incoming_message = Vec::new(); + incoming_message.resize(incoming_length as usize, 0); + reader + .read_exact(&mut incoming_message) + .await + .context(ReadSnafu)?; + + XorEncryption::<171>::decrypt_in_place(&mut incoming_message); + tracing::info!(?incoming_message); + + let response_as_json: serde_json::Value = + serde_json::from_slice(&incoming_message).context(DeserializeSnafu)?; + tracing::info!(?response_as_json); + + let response = Response::deserialize(response_as_json).context(DeserializeSnafu)?; + + Ok(response) +} + +#[tracing::instrument(skip(writer, reader))] +async fn handle_get_sysinfo( + writer: &mut AW, + reader: &mut AR, +) -> Result { + let request = GetSysInfo; + let response: GetSysInfoResponse = send_request(writer, reader, &request).await?; + + let SysInfo::LB130US(lb130us) = response.system.get_sysinfo else { + return Err(CommunicationError::WrongDevice); + }; + tracing::info!(?lb130us); + + Ok(lb130us) +} + +#[tracing::instrument(skip(writer, reader))] +async fn handle_set_light_state( + writer: &mut AW, + reader: &mut AR, + args: SetLightStateArgs, +) -> Result { + let request = SetLightState(args); + send_request(writer, reader, &request).await +} + #[derive(Debug, Clone)] pub struct LB130USHandle { sender: mpsc::Sender, @@ -225,52 +260,19 @@ impl LB130USHandle { .map_err(|_| HandleError::Dead)? .context(CommunicationSnafu) } -} -impl Light for LB130USHandle { - type IsOnError = Infallible; // TODO - - async fn is_on(&self) -> Result { - todo!() - } - - type IsOffError = Infallible; // TODO - - async fn is_off(&self) -> Result { - todo!() - } - - type TurnOnError = Infallible; // TODO - - async fn turn_on(&mut self) -> Result<(), Self::TurnOnError> { - todo!() - } - - type TurnOffError = Infallible; // TODO - - async fn turn_off(&mut self) -> Result<(), Self::TurnOffError> { - todo!() - } - - type ToggleError = Infallible; // TODO - - async fn toggle(&mut self) -> Result<(), Self::ToggleError> { - todo!() - } -} - -impl KelvinLight for LB130USHandle { - type TurnToKelvinError = Infallible; // TODO - - async fn turn_to_kelvin(&mut self, temperature: Kelvin) -> Result<(), Self::TurnToKelvinError> { - todo!() - } -} - -impl RgbLight for LB130USHandle { - type TurnToRgbError = Infallible; // TODO - - async fn turn_to_rgb(&mut self, color: Rgb) -> Result<(), Self::TurnToRgbError> { - todo!() + pub async fn set_light_state( + &self, + args: SetLightStateArgs, + ) -> Result { + let (sender, receiver) = oneshot::channel(); + self.sender + .send(LB130USMessage::SetLightState(args, sender)) + .await + .map_err(|_| HandleError::Dead)?; + receiver + .await + .map_err(|_| HandleError::Dead)? + .context(CommunicationSnafu) } } diff --git a/driver/kasa/src/impl_protocol.rs b/driver/kasa/src/impl_protocol.rs new file mode 100644 index 0000000..6bb9105 --- /dev/null +++ b/driver/kasa/src/impl_protocol.rs @@ -0,0 +1,97 @@ +use std::convert::Infallible; + +use palette::{encoding::Srgb, Hsv, IntoColor}; +use protocol::light::{GetState, Kelvin, SetState, TurnToColor, TurnToTemperature}; +use snafu::{ResultExt, Snafu}; + +use crate::{ + connection::{HandleError, LB130USHandle}, + messages::{ + Angle, Hsb, LightState, Off, On, Percentage, SetLightHsv, SetLightLastOn, SetLightOff, + SetLightStateArgs, SetLightTo, + }, +}; + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum GetStateError { + HandleError { source: HandleError }, +} + +impl GetState for LB130USHandle { + type Error = GetStateError; + + async fn get_state(&self) -> Result { + let sys = self + .get_sysinfo() + .await + .context(get_state_error::HandleSnafu)?; + let light_state = sys.sys_info.light_state; + let state = match light_state { + LightState::On { .. } => protocol::light::State::On, + LightState::Off { .. } => protocol::light::State::Off, + }; + + Ok(state) + } +} + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum SetStateError { + HandleError { source: HandleError }, +} + +impl SetState for LB130USHandle { + type Error = SetStateError; + + async fn set_state(&mut self, state: protocol::light::State) -> Result<(), Self::Error> { + let to = match state { + protocol::light::State::Off => SetLightTo::Off(SetLightOff { on_off: Off }), + protocol::light::State::On => SetLightTo::LastOn(SetLightLastOn { on_off: On }), + }; + + let args = SetLightStateArgs { + to, + transition: None, + }; + + self.set_light_state(args) + .await + .context(set_state_error::HandleSnafu)?; + + Ok(()) + } +} + +impl TurnToTemperature for LB130USHandle { + type Error = Infallible; // TODO + + async fn turn_to_temperature(&mut self, temperature: Kelvin) -> Result<(), Self::Error> { + todo!() + } +} + +#[derive(Debug, Snafu)] +#[snafu(module)] +pub enum TurnToColorError { + HandleError { source: HandleError }, +} + +impl TurnToColor for LB130USHandle { + type Error = TurnToColorError; + + async fn turn_to_color(&mut self, color: protocol::light::Oklch) -> Result<(), Self::Error> { + let hsv: Hsv = color.into_color(); + let hsb = hsv.into_color(); + + self.set_light_state(SetLightStateArgs { + to: SetLightTo::Hsv(SetLightHsv { on_off: On, hsb }), + transition: None, + }) + .await + .context(turn_to_color_error::HandleSnafu)?; + + Ok(()) + } +} diff --git a/driver/kasa/src/lib.rs b/driver/kasa/src/lib.rs index d323001..90ac624 100644 --- a/driver/kasa/src/lib.rs +++ b/driver/kasa/src/lib.rs @@ -1,2 +1,3 @@ pub mod connection; +mod impl_protocol; pub mod messages; diff --git a/driver/kasa/src/messages.rs b/driver/kasa/src/messages.rs index aeaaacd..0cb8f02 100644 --- a/driver/kasa/src/messages.rs +++ b/driver/kasa/src/messages.rs @@ -1,7 +1,8 @@ -use std::{collections::BTreeMap, fmt::Display, str::FromStr}; +use std::{collections::BTreeMap, fmt::Display, str::FromStr, time::Duration}; use deranged::{RangedU16, RangedU8}; use mac_address::{MacAddress, MacParseError}; +use palette::{FromColor, Hsv}; use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize}; use serde_repr::Deserialize_repr; use serde_with::{DeserializeFromStr, SerializeDisplay}; @@ -36,38 +37,38 @@ pub struct GetSysInfoResponseSystem { #[derive(Debug, Deserialize)] pub struct CommonSysInfo { - active_mode: ActiveMode, - alias: String, - ctrl_protocols: CtrlProtocols, - description: String, - dev_state: DevState, + pub active_mode: ActiveMode, + pub alias: String, + pub ctrl_protocols: CtrlProtocols, + pub description: String, + pub dev_state: DevState, #[serde(rename = "deviceId")] - device_id: DeviceId, - disco_ver: String, - err_code: i32, // No idea - heapsize: u64, // No idea + pub device_id: DeviceId, + pub disco_ver: String, + pub err_code: i32, // No idea + pub heapsize: u64, // No idea #[serde(rename = "hwId")] - hw_id: HardwareId, - hw_ver: String, - is_color: IsColor, - is_dimmable: IsDimmable, - is_factory: bool, - is_variable_color_temp: IsVariableColorTemp, - light_state: LightState, - mic_mac: MacAddressWithoutSeparators, - mic_type: MicType, + pub hw_id: HardwareId, + pub hw_ver: String, + pub is_color: IsColor, + pub is_dimmable: IsDimmable, + pub is_factory: bool, + pub is_variable_color_temp: IsVariableColorTemp, + pub light_state: LightState, + pub mic_mac: MacAddressWithoutSeparators, + pub mic_type: MicType, // model: Model, #[serde(rename = "oemId")] - oem_id: OemId, - preferred_state: Vec, - rssi: i32, - sw_ver: String, + pub oem_id: OemId, + pub preferred_state: Vec, + pub rssi: i32, + pub sw_ver: String, } #[derive(Debug, Deserialize)] pub struct LB130USSys { #[serde(flatten)] - sys_info: CommonSysInfo, + pub sys_info: CommonSysInfo, } #[derive(Debug, Deserialize)] @@ -78,9 +79,9 @@ pub enum SysInfo { } #[derive(Debug, Deserialize)] -struct PreferredStateChoice { +pub struct PreferredStateChoice { #[serde(flatten)] - color: Color, + pub color: Color, } #[derive(Debug, SerializeDisplay, DeserializeFromStr)] @@ -162,9 +163,9 @@ enum IsVariableColorTemp { VariableColorTemp = 1, } -type Percentage = RangedU8<0, 100>; -type Angle = RangedU16<0, 360>; -type Kelvin = RangedU16<2500, 9000>; +pub type Percentage = RangedU8<0, 100>; +pub type Angle = RangedU16<0, 360>; +pub type Kelvin = RangedU16<2500, 9000>; #[derive(Debug, Clone)] struct MaybeKelvin(Option); @@ -198,13 +199,34 @@ struct RawColor { saturation: Percentage, } -#[derive(Debug, Clone)] -struct Hsb { +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Hsb { hue: Angle, saturation: Percentage, brightness: Percentage, } +impl FromColor> for Hsb { + fn from_color(hsv: Hsv) -> Self { + let (hue, saturation, value) = hsv.into_components(); + + let hue = hue.into_positive_degrees(); + let hue = Angle::new_saturating(hue as u16); + + let saturation = saturation * (Percentage::MAX.get() as f64); + let saturation = Percentage::new_saturating(saturation as u8); + + let brightness = value * (Percentage::MAX.get() as f64); + let brightness = Percentage::new_saturating(brightness as u8); + + Hsb { + hue, + saturation, + brightness, + } + } +} + #[derive(Debug, Clone)] struct KelvinWithBrightness { kelvin: Kelvin, @@ -245,12 +267,86 @@ impl<'de> Deserialize<'de> for Color { } } +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)] +pub struct Off; + +impl<'de> Deserialize<'de> for Off { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = u8::deserialize(deserializer)?; + + if value == 0 { + Ok(Off) + } else { + Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Unsigned(value.into()), + &"0", + )) + } + } +} + +impl Serialize for Off { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_u8(0) + } +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)] +pub struct On; + +impl<'de> Deserialize<'de> for On { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let value = u8::deserialize(deserializer)?; + + if value == 1 { + Ok(On) + } else { + Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Unsigned(value.into()), + &"1", + )) + } + } +} + +impl Serialize for On { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_u8(1) + } +} + #[derive(Debug, Clone, Deserialize)] -struct LightState { +#[serde(untagged)] +pub enum LightState { + On { + on_off: On, + #[serde(flatten)] + color: Color, + mode: LightStateMode, + }, + Off { + on_off: Off, + dft_on_state: DftOnState, + }, +} + +#[derive(Debug, Clone, Deserialize)] +struct DftOnState { #[serde(flatten)] color: Color, mode: LightStateMode, - on_off: OnOrOff, } #[derive(Debug, Clone, Deserialize)] @@ -259,14 +355,6 @@ enum LightStateMode { Normal, } -#[derive(Debug, Clone, Deserialize_repr)] -#[repr(u8)] -#[non_exhaustive] -enum OnOrOff { - Off = 0, - On = 1, -} - #[derive(Debug, Clone, Deserialize)] enum MicType { #[serde(rename = "IOT.SMARTBULB")] @@ -275,3 +363,59 @@ enum MicType { #[derive(Debug, Clone, Deserialize)] struct OemId(pub String); + +#[derive(Debug, Clone, Serialize)] +pub struct SetLightStateArgs { + #[serde(flatten)] + pub to: SetLightTo, + pub transition: Option, +} + +#[derive(Debug, Clone, Serialize)] +pub struct SetLightOff { + pub on_off: Off, +} + +#[derive(Debug, Clone, Serialize)] +pub struct SetLightLastOn { + pub on_off: On, +} + +#[derive(Debug, Clone, Serialize)] +pub struct SetLightHsv { + pub on_off: On, + #[serde(flatten)] + pub hsb: Hsb, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(untagged)] +pub enum SetLightTo { + Off(SetLightOff), + LastOn(SetLightLastOn), + Hsv(SetLightHsv), + // TODO: kelvin +} + +#[derive(Debug, Clone, derive_more::From)] +pub struct SetLightState(pub SetLightStateArgs); + +impl Serialize for SetLightState { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let target = "smartlife.iot.smartbulb.lightingservice"; + let cmd = "transition_light_state"; + let arg = &self.0; + + let mut top_level_map = serializer.serialize_map(Some(1))?; + top_level_map.serialize_entry(target, &BTreeMap::from([(cmd, arg)]))?; + top_level_map.end() + } +} + +#[derive(Debug, Clone, Deserialize)] +pub struct SetLightStateResponse { + // TODO +}