diff --git a/driver/kasa/Cargo.toml b/driver/kasa/Cargo.toml new file mode 100644 index 0000000..9c6ad21 --- /dev/null +++ b/driver/kasa/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "driver-kasa" +version = "0.1.0" +edition = "2021" + +[dependencies] +backoff = { workspace = true, features = ["tokio"] } +deranged = { workspace = true } +mac_address = { version = "1.1.8", features = ["serde"] } +protocol = { path = "../../protocol" } +serde = { version = "1.0.219", features = ["derive"] } +serde_json = "1.0.140" +serde_repr = "0.1.20" +serde_with = "3.12.0" +snafu = { workspace = true } +tokio = { workspace = true, features = ["io-util", "net", "sync", "time"] } +tracing = { workspace = true } diff --git a/driver/kasa/src/connection.rs b/driver/kasa/src/connection.rs new file mode 100644 index 0000000..560b6f0 --- /dev/null +++ b/driver/kasa/src/connection.rs @@ -0,0 +1,280 @@ +use std::{convert::Infallible, io, net::SocketAddr, num::NonZero, time::Duration}; + +use backoff::{ExponentialBackoff, ExponentialBackoffBuilder}; +use protocol::light::{Kelvin, KelvinLight, Light, Rgb, RgbLight}; +use snafu::{ResultExt, Snafu}; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}, + net::{TcpListener, TcpSocket, TcpStream}, + sync::{mpsc, oneshot, OnceCell}, + time::timeout, +}; + +use crate::messages::{GetSysInfo, GetSysInfoResponse, LB130USSys, SysInfo}; + +struct XorEncryption; + +impl XorEncryption { + fn encrypt_in_place(bytes: &mut [u8]) { + let mut key = INITIAL_KEY; + for unencrypted_byte in bytes { + let encrypted_byte = key ^ *unencrypted_byte; + key = encrypted_byte; + *unencrypted_byte = encrypted_byte; + } + } + + fn decrypt_in_place(bytes: &mut [u8]) { + let mut key = INITIAL_KEY; + for encrypted_byte in bytes { + let unencrypted_byte = key ^ *encrypted_byte; + key = *encrypted_byte; + *encrypted_byte = unencrypted_byte; + } + } +} + +fn into_encrypted(mut msg: Vec) -> Vec { + let length = msg.len() as u32; + let big_endian = length.to_be_bytes(); + XorEncryption::<171>::encrypt_in_place(&mut msg); + + let all_together = big_endian.into_iter().chain(msg); + + all_together.collect() +} + +#[derive(Debug, Snafu)] +pub enum CommunicationError { + SerializeError { source: serde_json::Error }, + WriteError { source: std::io::Error }, + ReadError { source: std::io::Error }, + DeserializeError { source: serde_json::Error }, + WrongDevice, +} + +#[derive(Debug)] +enum LB130USMessage { + GetSysInfo(oneshot::Sender>), +} + +async fn lb130us_actor( + addr: SocketAddr, + disconnect_after_idle: Duration, + mut messages: mpsc::Receiver, +) { + let mut connection_cell = None; + + loop { + let (connection, message) = match &mut connection_cell { + Some(connection) => match timeout(disconnect_after_idle, messages.recv()).await { + Ok(Some(message)) => (connection, message), + Ok(None) => return, + Err(timed_out) => { + tracing::warn!( + ?addr, + ?timed_out, + "disconnecting from the LB130(US) because the idle timeout has been reached", + ); + + connection_cell.take(); + continue; + } + }, + None => { + let Some(message) = messages.recv().await else { + return; + }; + + tracing::info!( + "connecting for a first time / reconnecting after having gone idle..." + ); + + match backoff::future::retry_notify( + ExponentialBackoff::default(), + || async { + let stream = TcpStream::connect(addr).await?; + let (reader, writer) = stream.into_split(); + + let buf_reader = BufReader::new(reader); + let buf_writer = BufWriter::new(writer); + + Ok((buf_reader, buf_writer)) + }, + |err, duration| { + tracing::error!(?err, ?duration); + }, + ) + .await + { + Ok(connection) => (connection_cell.insert(connection), message), + Err(err) => { + tracing::error!(?addr, ?err, "error connecting to an LB130(US)"); + continue; + } + } + } + }; + + let (reader, writer) = connection; + + tracing::info!("yay connected and got a message"); + + // TODO: do something + match message { + LB130USMessage::GetSysInfo(callback) => { + tracing::info!("going to try to get sys info for you..."); + + // TODO: extract to its own function + let outgoing = GetSysInfo; + let outgoing = match serde_json::to_vec(&outgoing) { + Ok(outgoing) => outgoing, + Err(err) => { + // TODO (continued) instead of doing stuff like this + let _ = + callback.send(Err(CommunicationError::SerializeError { source: err })); + continue; + } + }; + + tracing::info!(?outgoing); + + let encrypted_outgoing = into_encrypted(outgoing); + + tracing::info!(?encrypted_outgoing); + + if let Err(err) = writer.write_all(&encrypted_outgoing).await { + connection_cell.take(); + let _ = callback.send(Err(CommunicationError::WriteError { source: err })); + continue; + } + + if let Err(err) = writer.flush().await { + connection_cell.take(); + let _ = callback.send(Err(CommunicationError::WriteError { source: err })); + continue; + } + tracing::info!("sent it, now about to try to get a response"); + + let incoming_length = match reader.read_u32().await { + Ok(incoming_length) => incoming_length, + Err(err) => { + connection_cell.take(); + let _ = callback.send(Err(CommunicationError::ReadError { source: err })); + continue; + } + }; + tracing::info!(?incoming_length); + + let mut incoming_message = Vec::new(); + incoming_message.resize(incoming_length as usize, 0); + if let Err(err) = reader.read_exact(&mut incoming_message).await { + connection_cell.take(); + let _ = callback.send(Err(CommunicationError::ReadError { source: err })); + continue; + } + + XorEncryption::<171>::decrypt_in_place(&mut incoming_message); + tracing::info!(?incoming_message); + + let response: GetSysInfoResponse = match serde_json::from_slice(&incoming_message) { + Ok(response) => response, + Err(err) => { + let _ = callback + .send(Err(CommunicationError::DeserializeError { source: err })); + continue; + } + }; + tracing::info!(?response); + + let SysInfo::LB130US(lb130us) = response.system.get_sysinfo else { + let _ = callback.send(Err(CommunicationError::WrongDevice)); + continue; + }; + tracing::info!(?lb130us); + + let _ = callback.send(Ok(lb130us)); + tracing::info!("cool, gave a response! onto the next message!"); + } + } + } +} + +#[derive(Debug, Clone)] +pub struct LB130USHandle { + sender: mpsc::Sender, +} + +#[derive(Debug, Snafu)] +pub enum HandleError { + CommunicationError { source: CommunicationError }, + Dead, +} + +impl LB130USHandle { + pub fn new(addr: SocketAddr, disconnect_after_idle: Duration, buffer: NonZero) -> Self { + let (sender, receiver) = mpsc::channel(buffer.get()); + tokio::spawn(lb130us_actor(addr, disconnect_after_idle, receiver)); + Self { sender } + } + + pub async fn get_sysinfo(&self) -> Result { + let (sender, receiver) = oneshot::channel(); + self.sender + .send(LB130USMessage::GetSysInfo(sender)) + .await + .map_err(|_| HandleError::Dead)?; + receiver + .await + .map_err(|_| HandleError::Dead)? + .context(CommunicationSnafu) + } +} + +impl Light for LB130USHandle { + type IsOnError = Infallible; // TODO + + async fn is_on(&self) -> Result { + todo!() + } + + type IsOffError = Infallible; // TODO + + async fn is_off(&self) -> Result { + todo!() + } + + type TurnOnError = Infallible; // TODO + + async fn turn_on(&mut self) -> Result<(), Self::TurnOnError> { + todo!() + } + + type TurnOffError = Infallible; // TODO + + async fn turn_off(&mut self) -> Result<(), Self::TurnOffError> { + todo!() + } + + type ToggleError = Infallible; // TODO + + async fn toggle(&mut self) -> Result<(), Self::ToggleError> { + todo!() + } +} + +impl KelvinLight for LB130USHandle { + type TurnToKelvinError = Infallible; // TODO + + async fn turn_to_kelvin(&mut self, temperature: Kelvin) -> Result<(), Self::TurnToKelvinError> { + todo!() + } +} + +impl RgbLight for LB130USHandle { + type TurnToRgbError = Infallible; // TODO + + async fn turn_to_rgb(&mut self, color: Rgb) -> Result<(), Self::TurnToRgbError> { + todo!() + } +} diff --git a/driver/kasa/src/lib.rs b/driver/kasa/src/lib.rs new file mode 100644 index 0000000..d323001 --- /dev/null +++ b/driver/kasa/src/lib.rs @@ -0,0 +1,2 @@ +pub mod connection; +pub mod messages; diff --git a/driver/kasa/src/messages.rs b/driver/kasa/src/messages.rs new file mode 100644 index 0000000..aeaaacd --- /dev/null +++ b/driver/kasa/src/messages.rs @@ -0,0 +1,277 @@ +use std::{collections::BTreeMap, fmt::Display, str::FromStr}; + +use deranged::{RangedU16, RangedU8}; +use mac_address::{MacAddress, MacParseError}; +use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize}; +use serde_repr::Deserialize_repr; +use serde_with::{DeserializeFromStr, SerializeDisplay}; + +#[derive(Debug)] +pub struct GetSysInfo; + +impl Serialize for GetSysInfo { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let target = "system"; + let cmd = "get_sysinfo"; + let arg: Option<()> = None; + + let mut top_level_map = serializer.serialize_map(Some(1))?; + top_level_map.serialize_entry(target, &BTreeMap::from([(cmd, arg)]))?; + top_level_map.end() + } +} + +#[derive(Debug, Deserialize)] +pub struct GetSysInfoResponse { + pub system: GetSysInfoResponseSystem, +} + +#[derive(Debug, Deserialize)] +pub struct GetSysInfoResponseSystem { + pub get_sysinfo: SysInfo, +} + +#[derive(Debug, Deserialize)] +pub struct CommonSysInfo { + active_mode: ActiveMode, + alias: String, + ctrl_protocols: CtrlProtocols, + description: String, + dev_state: DevState, + #[serde(rename = "deviceId")] + device_id: DeviceId, + disco_ver: String, + err_code: i32, // No idea + heapsize: u64, // No idea + #[serde(rename = "hwId")] + hw_id: HardwareId, + hw_ver: String, + is_color: IsColor, + is_dimmable: IsDimmable, + is_factory: bool, + is_variable_color_temp: IsVariableColorTemp, + light_state: LightState, + mic_mac: MacAddressWithoutSeparators, + mic_type: MicType, + // model: Model, + #[serde(rename = "oemId")] + oem_id: OemId, + preferred_state: Vec, + rssi: i32, + sw_ver: String, +} + +#[derive(Debug, Deserialize)] +pub struct LB130USSys { + #[serde(flatten)] + sys_info: CommonSysInfo, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "model")] +pub enum SysInfo { + #[serde(rename = "LB130(US)")] + LB130US(LB130USSys), +} + +#[derive(Debug, Deserialize)] +struct PreferredStateChoice { + #[serde(flatten)] + color: Color, +} + +#[derive(Debug, SerializeDisplay, DeserializeFromStr)] +struct MacAddressWithoutSeparators(MacAddress); + +impl FromStr for MacAddressWithoutSeparators { + type Err = MacParseError; + + fn from_str(s: &str) -> Result { + let [a, b, c, d, e, f, g, h, i, j, k, l] = s + .as_bytes() + .try_into() + .map_err(|_| MacParseError::InvalidLength)?; + + let bytes = [(a, b), (c, d), (e, f), (g, h), (i, j), (k, l)]; + + let mut digits = [0; 6]; + + for (i, (one, two)) in bytes.into_iter().enumerate() { + let slice = [one, two]; + let as_string = std::str::from_utf8(&slice).map_err(|_| MacParseError::InvalidDigit)?; + let number = + u8::from_str_radix(as_string, 16).map_err(|_| MacParseError::InvalidDigit)?; + digits[i] = number; + } + + Ok(MacAddressWithoutSeparators(MacAddress::new(digits))) + } +} + +impl Display for MacAddressWithoutSeparators { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + Display::fmt(&self.0, f) + } +} + +#[derive(Debug, Deserialize)] +enum ActiveMode { + #[serde(rename = "none")] + None, +} + +#[derive(Debug, Deserialize)] +struct CtrlProtocols { + name: String, + version: String, +} + +#[derive(Debug, Deserialize)] +struct DeviceId(pub String); + +#[derive(Debug, Deserialize)] +enum DevState { + #[serde(rename = "normal")] + Normal, +} + +#[derive(Debug, Deserialize)] +struct HardwareId(pub String); + +#[derive(Debug, Deserialize_repr)] +#[repr(u8)] +enum IsColor { + NoColor = 0, + Color = 1, +} + +#[derive(Debug, Deserialize_repr)] +#[repr(u8)] +enum IsDimmable { + NotDimmable = 0, + Dimmable = 1, +} + +#[derive(Debug, Deserialize_repr)] +#[repr(u8)] +enum IsVariableColorTemp { + NoVariableColorTemp = 0, + VariableColorTemp = 1, +} + +type Percentage = RangedU8<0, 100>; +type Angle = RangedU16<0, 360>; +type Kelvin = RangedU16<2500, 9000>; + +#[derive(Debug, Clone)] +struct MaybeKelvin(Option); + +impl<'de> Deserialize<'de> for MaybeKelvin { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + match u16::deserialize(deserializer)? { + 0 => Ok(MaybeKelvin(None)), + value => { + let kelvin = Kelvin::try_from(value).map_err(|e| { + serde::de::Error::custom(format!( + "{value} is not in the range {}..{}", + Kelvin::MIN, + Kelvin::MAX + )) + })?; + Ok(MaybeKelvin(Some(kelvin))) + } + } + } +} + +#[derive(Debug, Clone, Deserialize)] +struct RawColor { + brightness: Percentage, + color_temp: MaybeKelvin, + hue: Angle, + saturation: Percentage, +} + +#[derive(Debug, Clone)] +struct Hsb { + hue: Angle, + saturation: Percentage, + brightness: Percentage, +} + +#[derive(Debug, Clone)] +struct KelvinWithBrightness { + kelvin: Kelvin, + brightness: Percentage, +} + +#[derive(Debug, Clone)] +enum Color { + HSB(Hsb), + KelvinWithBrightness(KelvinWithBrightness), +} + +impl<'de> Deserialize<'de> for Color { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let raw_color = RawColor::deserialize(deserializer)?; + + let RawColor { + brightness, + color_temp, + hue, + saturation, + } = raw_color; + + match color_temp.0 { + Some(kelvin) => Ok(Color::KelvinWithBrightness(KelvinWithBrightness { + kelvin, + brightness, + })), + None => Ok(Color::HSB(Hsb { + hue, + saturation, + brightness, + })), + } + } +} + +#[derive(Debug, Clone, Deserialize)] +struct LightState { + #[serde(flatten)] + color: Color, + mode: LightStateMode, + on_off: OnOrOff, +} + +#[derive(Debug, Clone, Deserialize)] +enum LightStateMode { + #[serde(rename = "normal")] + Normal, +} + +#[derive(Debug, Clone, Deserialize_repr)] +#[repr(u8)] +#[non_exhaustive] +enum OnOrOff { + Off = 0, + On = 1, +} + +#[derive(Debug, Clone, Deserialize)] +enum MicType { + #[serde(rename = "IOT.SMARTBULB")] + IotSmartbulb, +} + +#[derive(Debug, Clone, Deserialize)] +struct OemId(pub String);