feat: early stages of a TP-Link Kasa driver for our smart lights

This commit is contained in:
2025-04-21 16:42:14 -04:00
parent 38e89f31f4
commit f884bc7675
4 changed files with 576 additions and 0 deletions

17
driver/kasa/Cargo.toml Normal file
View File

@@ -0,0 +1,17 @@
[package]
name = "driver-kasa"
version = "0.1.0"
edition = "2021"
[dependencies]
backoff = { workspace = true, features = ["tokio"] }
deranged = { workspace = true }
mac_address = { version = "1.1.8", features = ["serde"] }
protocol = { path = "../../protocol" }
serde = { version = "1.0.219", features = ["derive"] }
serde_json = "1.0.140"
serde_repr = "0.1.20"
serde_with = "3.12.0"
snafu = { workspace = true }
tokio = { workspace = true, features = ["io-util", "net", "sync", "time"] }
tracing = { workspace = true }

View File

@@ -0,0 +1,280 @@
use std::{convert::Infallible, io, net::SocketAddr, num::NonZero, time::Duration};
use backoff::{ExponentialBackoff, ExponentialBackoffBuilder};
use protocol::light::{Kelvin, KelvinLight, Light, Rgb, RgbLight};
use snafu::{ResultExt, Snafu};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt, BufReader, BufWriter},
net::{TcpListener, TcpSocket, TcpStream},
sync::{mpsc, oneshot, OnceCell},
time::timeout,
};
use crate::messages::{GetSysInfo, GetSysInfoResponse, LB130USSys, SysInfo};
struct XorEncryption<const INITIAL_KEY: u8>;
impl<const INITIAL_KEY: u8> XorEncryption<INITIAL_KEY> {
fn encrypt_in_place(bytes: &mut [u8]) {
let mut key = INITIAL_KEY;
for unencrypted_byte in bytes {
let encrypted_byte = key ^ *unencrypted_byte;
key = encrypted_byte;
*unencrypted_byte = encrypted_byte;
}
}
fn decrypt_in_place(bytes: &mut [u8]) {
let mut key = INITIAL_KEY;
for encrypted_byte in bytes {
let unencrypted_byte = key ^ *encrypted_byte;
key = *encrypted_byte;
*encrypted_byte = unencrypted_byte;
}
}
}
fn into_encrypted(mut msg: Vec<u8>) -> Vec<u8> {
let length = msg.len() as u32;
let big_endian = length.to_be_bytes();
XorEncryption::<171>::encrypt_in_place(&mut msg);
let all_together = big_endian.into_iter().chain(msg);
all_together.collect()
}
#[derive(Debug, Snafu)]
pub enum CommunicationError {
SerializeError { source: serde_json::Error },
WriteError { source: std::io::Error },
ReadError { source: std::io::Error },
DeserializeError { source: serde_json::Error },
WrongDevice,
}
#[derive(Debug)]
enum LB130USMessage {
GetSysInfo(oneshot::Sender<Result<LB130USSys, CommunicationError>>),
}
async fn lb130us_actor(
addr: SocketAddr,
disconnect_after_idle: Duration,
mut messages: mpsc::Receiver<LB130USMessage>,
) {
let mut connection_cell = None;
loop {
let (connection, message) = match &mut connection_cell {
Some(connection) => match timeout(disconnect_after_idle, messages.recv()).await {
Ok(Some(message)) => (connection, message),
Ok(None) => return,
Err(timed_out) => {
tracing::warn!(
?addr,
?timed_out,
"disconnecting from the LB130(US) because the idle timeout has been reached",
);
connection_cell.take();
continue;
}
},
None => {
let Some(message) = messages.recv().await else {
return;
};
tracing::info!(
"connecting for a first time / reconnecting after having gone idle..."
);
match backoff::future::retry_notify(
ExponentialBackoff::default(),
|| async {
let stream = TcpStream::connect(addr).await?;
let (reader, writer) = stream.into_split();
let buf_reader = BufReader::new(reader);
let buf_writer = BufWriter::new(writer);
Ok((buf_reader, buf_writer))
},
|err, duration| {
tracing::error!(?err, ?duration);
},
)
.await
{
Ok(connection) => (connection_cell.insert(connection), message),
Err(err) => {
tracing::error!(?addr, ?err, "error connecting to an LB130(US)");
continue;
}
}
}
};
let (reader, writer) = connection;
tracing::info!("yay connected and got a message");
// TODO: do something
match message {
LB130USMessage::GetSysInfo(callback) => {
tracing::info!("going to try to get sys info for you...");
// TODO: extract to its own function
let outgoing = GetSysInfo;
let outgoing = match serde_json::to_vec(&outgoing) {
Ok(outgoing) => outgoing,
Err(err) => {
// TODO (continued) instead of doing stuff like this
let _ =
callback.send(Err(CommunicationError::SerializeError { source: err }));
continue;
}
};
tracing::info!(?outgoing);
let encrypted_outgoing = into_encrypted(outgoing);
tracing::info!(?encrypted_outgoing);
if let Err(err) = writer.write_all(&encrypted_outgoing).await {
connection_cell.take();
let _ = callback.send(Err(CommunicationError::WriteError { source: err }));
continue;
}
if let Err(err) = writer.flush().await {
connection_cell.take();
let _ = callback.send(Err(CommunicationError::WriteError { source: err }));
continue;
}
tracing::info!("sent it, now about to try to get a response");
let incoming_length = match reader.read_u32().await {
Ok(incoming_length) => incoming_length,
Err(err) => {
connection_cell.take();
let _ = callback.send(Err(CommunicationError::ReadError { source: err }));
continue;
}
};
tracing::info!(?incoming_length);
let mut incoming_message = Vec::new();
incoming_message.resize(incoming_length as usize, 0);
if let Err(err) = reader.read_exact(&mut incoming_message).await {
connection_cell.take();
let _ = callback.send(Err(CommunicationError::ReadError { source: err }));
continue;
}
XorEncryption::<171>::decrypt_in_place(&mut incoming_message);
tracing::info!(?incoming_message);
let response: GetSysInfoResponse = match serde_json::from_slice(&incoming_message) {
Ok(response) => response,
Err(err) => {
let _ = callback
.send(Err(CommunicationError::DeserializeError { source: err }));
continue;
}
};
tracing::info!(?response);
let SysInfo::LB130US(lb130us) = response.system.get_sysinfo else {
let _ = callback.send(Err(CommunicationError::WrongDevice));
continue;
};
tracing::info!(?lb130us);
let _ = callback.send(Ok(lb130us));
tracing::info!("cool, gave a response! onto the next message!");
}
}
}
}
#[derive(Debug, Clone)]
pub struct LB130USHandle {
sender: mpsc::Sender<LB130USMessage>,
}
#[derive(Debug, Snafu)]
pub enum HandleError {
CommunicationError { source: CommunicationError },
Dead,
}
impl LB130USHandle {
pub fn new(addr: SocketAddr, disconnect_after_idle: Duration, buffer: NonZero<usize>) -> Self {
let (sender, receiver) = mpsc::channel(buffer.get());
tokio::spawn(lb130us_actor(addr, disconnect_after_idle, receiver));
Self { sender }
}
pub async fn get_sysinfo(&self) -> Result<LB130USSys, HandleError> {
let (sender, receiver) = oneshot::channel();
self.sender
.send(LB130USMessage::GetSysInfo(sender))
.await
.map_err(|_| HandleError::Dead)?;
receiver
.await
.map_err(|_| HandleError::Dead)?
.context(CommunicationSnafu)
}
}
impl Light for LB130USHandle {
type IsOnError = Infallible; // TODO
async fn is_on(&self) -> Result<bool, Self::IsOnError> {
todo!()
}
type IsOffError = Infallible; // TODO
async fn is_off(&self) -> Result<bool, Self::IsOffError> {
todo!()
}
type TurnOnError = Infallible; // TODO
async fn turn_on(&mut self) -> Result<(), Self::TurnOnError> {
todo!()
}
type TurnOffError = Infallible; // TODO
async fn turn_off(&mut self) -> Result<(), Self::TurnOffError> {
todo!()
}
type ToggleError = Infallible; // TODO
async fn toggle(&mut self) -> Result<(), Self::ToggleError> {
todo!()
}
}
impl KelvinLight for LB130USHandle {
type TurnToKelvinError = Infallible; // TODO
async fn turn_to_kelvin(&mut self, temperature: Kelvin) -> Result<(), Self::TurnToKelvinError> {
todo!()
}
}
impl RgbLight for LB130USHandle {
type TurnToRgbError = Infallible; // TODO
async fn turn_to_rgb(&mut self, color: Rgb) -> Result<(), Self::TurnToRgbError> {
todo!()
}
}

2
driver/kasa/src/lib.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod connection;
pub mod messages;

277
driver/kasa/src/messages.rs Normal file
View File

@@ -0,0 +1,277 @@
use std::{collections::BTreeMap, fmt::Display, str::FromStr};
use deranged::{RangedU16, RangedU8};
use mac_address::{MacAddress, MacParseError};
use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize};
use serde_repr::Deserialize_repr;
use serde_with::{DeserializeFromStr, SerializeDisplay};
#[derive(Debug)]
pub struct GetSysInfo;
impl Serialize for GetSysInfo {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let target = "system";
let cmd = "get_sysinfo";
let arg: Option<()> = None;
let mut top_level_map = serializer.serialize_map(Some(1))?;
top_level_map.serialize_entry(target, &BTreeMap::from([(cmd, arg)]))?;
top_level_map.end()
}
}
#[derive(Debug, Deserialize)]
pub struct GetSysInfoResponse {
pub system: GetSysInfoResponseSystem,
}
#[derive(Debug, Deserialize)]
pub struct GetSysInfoResponseSystem {
pub get_sysinfo: SysInfo,
}
#[derive(Debug, Deserialize)]
pub struct CommonSysInfo {
active_mode: ActiveMode,
alias: String,
ctrl_protocols: CtrlProtocols,
description: String,
dev_state: DevState,
#[serde(rename = "deviceId")]
device_id: DeviceId,
disco_ver: String,
err_code: i32, // No idea
heapsize: u64, // No idea
#[serde(rename = "hwId")]
hw_id: HardwareId,
hw_ver: String,
is_color: IsColor,
is_dimmable: IsDimmable,
is_factory: bool,
is_variable_color_temp: IsVariableColorTemp,
light_state: LightState,
mic_mac: MacAddressWithoutSeparators,
mic_type: MicType,
// model: Model,
#[serde(rename = "oemId")]
oem_id: OemId,
preferred_state: Vec<PreferredStateChoice>,
rssi: i32,
sw_ver: String,
}
#[derive(Debug, Deserialize)]
pub struct LB130USSys {
#[serde(flatten)]
sys_info: CommonSysInfo,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "model")]
pub enum SysInfo {
#[serde(rename = "LB130(US)")]
LB130US(LB130USSys),
}
#[derive(Debug, Deserialize)]
struct PreferredStateChoice {
#[serde(flatten)]
color: Color,
}
#[derive(Debug, SerializeDisplay, DeserializeFromStr)]
struct MacAddressWithoutSeparators(MacAddress);
impl FromStr for MacAddressWithoutSeparators {
type Err = MacParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let [a, b, c, d, e, f, g, h, i, j, k, l] = s
.as_bytes()
.try_into()
.map_err(|_| MacParseError::InvalidLength)?;
let bytes = [(a, b), (c, d), (e, f), (g, h), (i, j), (k, l)];
let mut digits = [0; 6];
for (i, (one, two)) in bytes.into_iter().enumerate() {
let slice = [one, two];
let as_string = std::str::from_utf8(&slice).map_err(|_| MacParseError::InvalidDigit)?;
let number =
u8::from_str_radix(as_string, 16).map_err(|_| MacParseError::InvalidDigit)?;
digits[i] = number;
}
Ok(MacAddressWithoutSeparators(MacAddress::new(digits)))
}
}
impl Display for MacAddressWithoutSeparators {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
Display::fmt(&self.0, f)
}
}
#[derive(Debug, Deserialize)]
enum ActiveMode {
#[serde(rename = "none")]
None,
}
#[derive(Debug, Deserialize)]
struct CtrlProtocols {
name: String,
version: String,
}
#[derive(Debug, Deserialize)]
struct DeviceId(pub String);
#[derive(Debug, Deserialize)]
enum DevState {
#[serde(rename = "normal")]
Normal,
}
#[derive(Debug, Deserialize)]
struct HardwareId(pub String);
#[derive(Debug, Deserialize_repr)]
#[repr(u8)]
enum IsColor {
NoColor = 0,
Color = 1,
}
#[derive(Debug, Deserialize_repr)]
#[repr(u8)]
enum IsDimmable {
NotDimmable = 0,
Dimmable = 1,
}
#[derive(Debug, Deserialize_repr)]
#[repr(u8)]
enum IsVariableColorTemp {
NoVariableColorTemp = 0,
VariableColorTemp = 1,
}
type Percentage = RangedU8<0, 100>;
type Angle = RangedU16<0, 360>;
type Kelvin = RangedU16<2500, 9000>;
#[derive(Debug, Clone)]
struct MaybeKelvin(Option<Kelvin>);
impl<'de> Deserialize<'de> for MaybeKelvin {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
match u16::deserialize(deserializer)? {
0 => Ok(MaybeKelvin(None)),
value => {
let kelvin = Kelvin::try_from(value).map_err(|e| {
serde::de::Error::custom(format!(
"{value} is not in the range {}..{}",
Kelvin::MIN,
Kelvin::MAX
))
})?;
Ok(MaybeKelvin(Some(kelvin)))
}
}
}
}
#[derive(Debug, Clone, Deserialize)]
struct RawColor {
brightness: Percentage,
color_temp: MaybeKelvin,
hue: Angle,
saturation: Percentage,
}
#[derive(Debug, Clone)]
struct Hsb {
hue: Angle,
saturation: Percentage,
brightness: Percentage,
}
#[derive(Debug, Clone)]
struct KelvinWithBrightness {
kelvin: Kelvin,
brightness: Percentage,
}
#[derive(Debug, Clone)]
enum Color {
HSB(Hsb),
KelvinWithBrightness(KelvinWithBrightness),
}
impl<'de> Deserialize<'de> for Color {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let raw_color = RawColor::deserialize(deserializer)?;
let RawColor {
brightness,
color_temp,
hue,
saturation,
} = raw_color;
match color_temp.0 {
Some(kelvin) => Ok(Color::KelvinWithBrightness(KelvinWithBrightness {
kelvin,
brightness,
})),
None => Ok(Color::HSB(Hsb {
hue,
saturation,
brightness,
})),
}
}
}
#[derive(Debug, Clone, Deserialize)]
struct LightState {
#[serde(flatten)]
color: Color,
mode: LightStateMode,
on_off: OnOrOff,
}
#[derive(Debug, Clone, Deserialize)]
enum LightStateMode {
#[serde(rename = "normal")]
Normal,
}
#[derive(Debug, Clone, Deserialize_repr)]
#[repr(u8)]
#[non_exhaustive]
enum OnOrOff {
Off = 0,
On = 1,
}
#[derive(Debug, Clone, Deserialize)]
enum MicType {
#[serde(rename = "IOT.SMARTBULB")]
IotSmartbulb,
}
#[derive(Debug, Clone, Deserialize)]
struct OemId(pub String);