diff --git a/Cargo.lock b/Cargo.lock index 79d4c6a..13407f3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1732,7 +1732,7 @@ dependencies = [ "futures-core", "futures-sink", "nanorand", - "spin", + "spin 0.9.8", ] [[package]] @@ -1777,6 +1777,7 @@ dependencies = [ "extension-traits", "futures", "hound", + "itertools", "moka", "opendal", "opus2", @@ -1801,7 +1802,6 @@ dependencies = [ "twilight-model", "twilight-util", "typed-builder 0.23.2", - "yoke", ] [[package]] @@ -2967,7 +2967,7 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" dependencies = [ - "spin", + "spin 0.9.8", ] [[package]] @@ -3611,6 +3611,15 @@ dependencies = [ "getrandom 0.2.17", ] +[[package]] +name = "no-std-compat" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" +dependencies = [ + "spin 0.5.2", +] + [[package]] name = "no-std-net" version = "0.6.0" @@ -5453,6 +5462,7 @@ checksum = "1f9ef5dabe4c0b43d8f1187dc6beb67b53fe607fff7e30c5eb7f71b814b8c2c1" dependencies = [ "ahash", "bitflags 2.11.1", + "no-std-compat", "num-traits", "once_cell", "rhai_codegen", @@ -6266,6 +6276,12 @@ dependencies = [ "uuid", ] +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + [[package]] name = "spin" version = "0.9.8" diff --git a/Cargo.toml b/Cargo.toml index 47dfcda..05418b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ dashmap = "6.1.0" extension-traits = "2.0.2" futures = "0.3.32" hound = "3.5.1" +itertools = "0.14.0" moka = { version = "0.12.15", features = ["future"] } opendal = { git = "https://github.com/apache/opendal", rev = "ecf840b04afd2be109830b9978ba89759adfee79", features = [ "services-azfile", @@ -54,7 +55,7 @@ opendal = { git = "https://github.com/apache/opendal", rev = "ecf840b04afd2be109 ] } opus2 = "0.4.0" patricia_tree = "0.10.1" -rhai = "1.23.6" +rhai = { version = "1.23.6", features = ["sync"] } rustls = "0.23" secrecy = { version = "0.10.3", features = ["serde"] } shadow-rs = { version = "2.0.0", default-features = false } @@ -78,6 +79,7 @@ tokio-websockets-0-11 = { package = "tokio-websockets", version = "0.11", featur "rustls-webpki-roots", ] } tracing = "0.1.41" +tracing-appender = "0.2.5" tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } twilight-gateway = { version = "0.17", default-features = false, features = [ "rustls-webpki-roots", @@ -91,8 +93,6 @@ twilight-http = { version = "0.17", default-features = false, features = [ twilight-model = "0.17" twilight-util = { version = "0.17", features = ["builder"] } typed-builder = "0.23.2" -yoke = "0.8.2" -tracing-appender = "0.2.5" [build-dependencies] capnpc = "0.25.3" diff --git a/src/call.rs b/src/call.rs index 228a602..35467aa 100644 --- a/src/call.rs +++ b/src/call.rs @@ -1,5 +1,5 @@ use crate::{ - OneToManyUniqueBTreeMap, UserDataManager, VCs, command::State, option_ext::OptionExt as _, + OneToManyUniqueBTreeMap, UserDataManager, option_ext::OptionExt as _, user_capnp::user::Consent, user_data::RECORD_IF_CONSENT_UNSPECIFIED, }; use async_trait::async_trait; diff --git a/src/command/join.rs b/src/command/join.rs index 86a2939..3d8a2e2 100644 --- a/src/command/join.rs +++ b/src/command/join.rs @@ -82,7 +82,7 @@ fn get_guild_and_vc_error_to_embed(error: GetGuildAndVoiceChannelIdError) -> Emb #[tracing::instrument(skip(state))] pub async fn handle(state: State, interaction: Interaction) { let guild_and_voice_channel_id_res = - { get_guild_and_voice_channel_id(&interaction, &state.vcs_watcher.borrow()) }; + { get_guild_and_voice_channel_id(&interaction, &state.vcs_sender.borrow()) }; let (guild_id, voice_channel_id) = match guild_and_voice_channel_id_res { Ok((guild_id, voice_channel_id)) => (guild_id, voice_channel_id), Err(error) => { diff --git a/src/command/leave.rs b/src/command/leave.rs index 8c2fe08..c6991b1 100644 --- a/src/command/leave.rs +++ b/src/command/leave.rs @@ -84,7 +84,7 @@ pub async fn handle(state: State, interaction: Interaction) { get_user_and_guild_and_voice_channel_id( state.discord_user_id, &interaction, - &state.vcs_watcher.borrow(), + &state.vcs_sender.borrow(), ) }; let (user_id, guild_id, voice_channel_id) = match user_and_guild_and_voice_channel_id_res { diff --git a/src/command/mod.rs b/src/command/mod.rs index 2408adf..590481e 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -16,7 +16,7 @@ use twilight_model::{ }, }; -use crate::{BotDataManager, GuildVoiceChannelToTextChannel, UserDataManager, VCsWatcher}; +use crate::{BotDataManager, GuildVoiceChannelToTextChannel, UserDataManager, VCsSender}; pub mod info; pub mod join; @@ -45,7 +45,7 @@ pub struct State { pub recording_data: Operator, pub songbird: Arc, pub user_data_manager: UserDataManager, - pub vcs_watcher: VCsWatcher, + pub vcs_sender: VCsSender, } type Return = (); diff --git a/src/heat_seek.rs b/src/heat_seek.rs index 88cb54c..5e95101 100644 --- a/src/heat_seek.rs +++ b/src/heat_seek.rs @@ -1,31 +1,237 @@ -use std::{collections::BTreeMap, num::NonZero}; +use std::{collections::BTreeMap, num::NonZero, str::Utf8Error, sync::Arc}; +use futures::{StreamExt as _, stream::FuturesUnordered}; +use itertools::Itertools as _; +use snafu::{ResultExt as _, Snafu}; use tokio::sync::watch; use twilight_model::id::{ Id, - marker::{ChannelMarker, GuildMarker}, + marker::{ChannelMarker, GuildMarker, UserMarker}, }; use twilight_util::builder::embed::EmbedBuilder; -use crate::{OneToManyUniqueBTreeMap, State, call::join_and_record}; +use crate::{ + BotDataManager, OneToManyUniqueBTreeMap, State, UserInVCData, bot_data, + call::join_and_record, + track_vcs::VCsInGuild, + vc_user::{Camera, Headphone, Microphone, Stream}, +}; type Heat = u64; type Hot = NonZero; -type HotOption = Option; type ChannelHeat = BTreeMap, Heat>; type HeatMap = OneToManyUniqueBTreeMap>; +#[tracing::instrument] +pub async fn heat_seek(state: State) { + let mut vcs_watcher = state.vcs_sender.subscribe(); + let mut vcs_in_guild_senders = BTreeMap::default(); + + loop { + for (&guild_id, vcs_in_guild) in &*vcs_watcher.borrow() { + let vcs_in_guild_sender = vcs_in_guild_senders.entry(guild_id).or_insert_with(|| { + let (vcs_in_guild_sender, vcs_in_guild_watcher) = + watch::channel(Default::default()); + let (channel_heat_sender, channel_heat_watcher) = + watch::channel(Default::default()); + let (heat_map_sender, heat_map_watcher) = watch::channel(Default::default()); + let (hottest_vc_sender, hottest_vc_watcher) = watch::channel(Default::default()); + + tokio::spawn( + evaluate_heat() + .bot_data_manager(state.bot_data_manager.clone()) + .bot_owner_user_id(state.discord_bot_owner_user_id) + .bot_user_id(state.discord_user_id) + .channel_heat_sender(channel_heat_sender) + .vcs_in_guild_watcher(vcs_in_guild_watcher) + .call(), + ); + tokio::spawn(map_heat(channel_heat_watcher, heat_map_sender)); + tokio::spawn(track_hottest_vc( + state.discord_bot_owner_user_id, + heat_map_watcher, + hottest_vc_sender, + )); + tokio::spawn(follow_hottest_vc( + state.clone(), + guild_id, + hottest_vc_watcher, + )); + + vcs_in_guild_sender + }); + vcs_in_guild_sender.send_replace(Arc::new(vcs_in_guild.clone())); + } + + if let Err(_closed) = vcs_watcher.changed().await { + break; + } + } +} + +#[derive(Debug, Snafu)] +enum GetHeatError { + /// couldn't retrieve bot data + WithBotDataError { source: bot_data::WithError }, + + /// couldn't get the heat script from the bot data + GetHeatScriptError { source: capnp::Error }, + + /// the heat script is not a valid UTF-8 string + HeatScriptInvalidUtf8 { source: Utf8Error }, + + /// the heat script is not valid Rhai code + HeatScriptInvalidRhai { source: rhai::ParseError }, + + /// failed while evaluating the heat script + HeatScriptEvaluationError { source: Box }, +} + +#[bon::builder] +#[tracing::instrument] +async fn get_heat( + users_in_vc: &BTreeMap, UserInVCData>, + bot_user_id: Id, + bot_owner_user_id: Id, + bot_data_manager: &BotDataManager, +) -> Result { + let heat_script = bot_data_manager + .with(|bot_data| { + bot_data.has_heat_script().then(|| { + bot_data + .get_heat_script() + .map(|heat_script| heat_script.to_string()) + }) + }) + .await + .context(WithBotDataSnafu)? + .transpose() + .context(GetHeatScriptSnafu)? + .transpose() + .context(HeatScriptInvalidUtf8Snafu)?; + + let engine = rhai::Engine::new(); + let heat_function = heat_script + .map(|heat_script| engine.compile(heat_script)) + .transpose() + .context(HeatScriptInvalidRhaiSnafu)?; + + let heat = heat_function + .map(|heat_function| { + let mut scope = Default::default(); + + let args = (); // TODO + + engine.call_fn(&mut scope, &heat_function, "heat", args) + }) + .transpose() + .context(HeatScriptEvaluationSnafu)?; + + let heat = heat.unwrap_or_else(|| { + tracing::warn!("using default heat scoring algorithm as no script was specified"); + + let mut users_in_vc = users_in_vc.clone(); + + let _bot = users_in_vc.remove(&bot_user_id); + let bot_owner = users_in_vc.remove(&bot_owner_user_id); + + let mut heat = 0; + + for (_user_id, user_in_vc_data) in users_in_vc { + if matches!(user_in_vc_data.microphone, Microphone::Unmuted) { + heat += 1000; + } + if matches!(user_in_vc_data.camera, Camera::Showing) { + heat += 100; + } + if matches!(user_in_vc_data.stream, Stream::Sharing) { + heat += 10; + } + if matches!(user_in_vc_data.headphone, Headphone::Undeafened) { + heat += 1; + } + } + + if bot_owner.is_some() { + heat = heat.min(999); + } + + heat + }); + + Ok(heat) +} + +#[bon::builder] +#[tracing::instrument] +async fn evaluate_heat( + bot_data_manager: BotDataManager, + bot_owner_user_id: Id, + bot_user_id: Id, + + mut vcs_in_guild_watcher: watch::Receiver>, + channel_heat_sender: watch::Sender, +) { + loop { + let vcs_in_guild = { vcs_in_guild_watcher.borrow().clone() }; + + let channel_heat_results: BTreeMap<_, _> = { + FuturesUnordered::from_iter((&*vcs_in_guild).into_iter().map( + |(&channel_id, users_in_vc)| { + let bot_data_manager = bot_data_manager.clone(); + async move { + ( + channel_id, + get_heat() + .bot_data_manager(&bot_data_manager) + .bot_owner_user_id(bot_owner_user_id) + .bot_user_id(bot_user_id) + .users_in_vc(users_in_vc) + .call() + .await, + ) + } + }, + )) + } + .collect() + .await; + + let (channel_heat, get_heat_errors): (ChannelHeat, Vec<_>) = channel_heat_results + .into_iter() + .map(|(channel_id, heat_result)| heat_result.map(|heat| (channel_id, heat))) + .partition_result(); + + channel_heat_sender.send_replace(channel_heat); + + for get_heat_error in get_heat_errors { + tracing::error!(?get_heat_error, "failed to evaluate heat of channel") + } + + if let Err(_closed) = vcs_in_guild_watcher.changed().await { + break; + } + } +} + #[tracing::instrument] async fn map_heat( mut channel_heat_watcher: watch::Receiver, heat_map_sender: watch::Sender, ) { loop { - heat_map_sender.send_modify(|heat_map| { + heat_map_sender.send_if_modified(|heat_map| { + let mut changed = false; for (&channel, &heat) in &*channel_heat_watcher.borrow() { - heat_map.insert(heat, channel); + let existing = heat_map.insert(heat, channel); + if existing.map_or(true, |(old_heat, old_channel)| { + old_heat != heat || channel != old_channel + }) { + changed = true; + } } + changed }); if let Err(_closed) = channel_heat_watcher.changed().await { @@ -36,6 +242,7 @@ async fn map_heat( #[tracing::instrument] async fn track_hottest_vc( + bot_owner_id: Id, mut heat_map_watcher: watch::Receiver, hottest_vc_sender: watch::Sender>>, ) { @@ -47,7 +254,6 @@ async fn track_hottest_vc( .and_then(|(&heat, hottest_vcs)| { let hot_option = Hot::new(heat); - // TODO: tiebreak by whichever one this bot is already in hot_option.map(|_| *hottest_vcs.first().unwrap()) }) }; @@ -65,7 +271,7 @@ async fn track_hottest_vc( } #[tracing::instrument] -async fn follow_heat( +async fn follow_hottest_vc( state: State, guild_id: Id, mut hottest_vc_watcher: watch::Receiver>>, @@ -146,8 +352,3 @@ async fn follow_heat( } } } - -#[tracing::instrument] -pub async fn heat_seek(state: State) { - todo!(); -} diff --git a/src/lib.rs b/src/lib.rs index ce8e1d3..5f37eff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,11 +17,12 @@ shadow_rs::shadow!(build_info); pub use bot_data::BotDataManager; pub use command::{Router as CommandRouter, State, all as all_commands}; +pub use heat_seek::heat_seek; pub use one_to_many::OneToManyUniqueBTreeMap; pub use one_to_many_with_data::OneToManyUniqueBTreeMapWithData; pub use one_to_one::OneToOneBTreeMap; pub use operator_ext::OperatorExt; pub use storage::Storage; -pub use track_vcs::{GuildVoiceChannelToTextChannel, VCs, VCsWatcher, initialize_vcs, update_vcs}; +pub use track_vcs::{GuildVoiceChannelToTextChannel, VCs, VCsSender, initialize_vcs, update_vcs}; pub use user_data::UserDataManager; pub use vc_user::{UserInVCData, VoiceStatus}; diff --git a/src/main.rs b/src/main.rs index 9b29726..a35c109 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,6 @@ use clap::Parser; use fomo_reducer::{ - BotDataManager, CommandRouter, GuildVoiceChannelToTextChannel, State, Storage, UserDataManager, - VCsWatcher, all_commands, command, initialize_vcs, update_vcs, + BotDataManager, CommandRouter, GuildVoiceChannelToTextChannel, State, Storage, UserDataManager, VCsSender, all_commands, command, heat_seek, initialize_vcs, update_vcs }; use secrecy::{ExposeSecret, SecretString}; use snafu::{OptionExt, ResultExt, Snafu}; @@ -336,7 +335,7 @@ async fn main() -> Result<(), MainError> { let discord_client = Arc::new(discord_client); let songbird = Arc::new(songbird); - let vcs_watcher = VCsWatcher::new(vcs); + let vcs_sender = VCsSender::new(vcs); let bot_data = bot_data.into_inner(); let recording_data = recording_data.into_inner(); @@ -380,9 +379,11 @@ async fn main() -> Result<(), MainError> { recording_data, songbird, user_data_manager, - vcs_watcher, + vcs_sender, }; + let heat_seeking = tokio::spawn(heat_seek(state.clone())); + if let Some(discord_status) = discord_status { shards.iter().for_each(|shard| { shard.command( @@ -409,7 +410,6 @@ async fn main() -> Result<(), MainError> { .map(|shard| handle_events(command_router.clone(), state.clone(), shard)); let run_shards = JoinSet::from_iter(run_shards); let run_shards = run_shards.join_all(); - tokio::pin!(run_shards); tokio::spawn({ let cancellation_token = cancellation_token.clone(); @@ -432,13 +432,19 @@ async fn main() -> Result<(), MainError> { } }); + let finished_naturally = async move { + heat_seeking.await.unwrap(); + run_shards.await; + }; + tokio::pin!(finished_naturally); + select! { - _ = &mut run_shards => { + _ = &mut finished_naturally => { Ok(()) } () = cancellation_token.cancelled() => { tracing::warn!("waiting for tasks to gracefully shut down"); - run_shards.await; + finished_naturally.await; Err(MainError::Cancelled) } @@ -496,7 +502,7 @@ async fn handle_event(command_router: Arc, state: State, event: E match event { Event::VoiceStateUpdate(voice_state_update) => { state - .vcs_watcher + .vcs_sender .send_modify(|vcs| update_vcs(&voice_state_update, vcs)); } Event::InteractionCreate(interaction_create) => { diff --git a/src/one_to_many_with_data.rs b/src/one_to_many_with_data.rs index 2b2b7a0..14346e4 100644 --- a/src/one_to_many_with_data.rs +++ b/src/one_to_many_with_data.rs @@ -94,6 +94,30 @@ where } } +impl IntoIterator + for OneToManyUniqueBTreeMapWithData +{ + type Item = (Left, BTreeMap); + + type IntoIter = > as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.left_to_rights.into_iter() + } +} + +impl<'a, Left, Right, RightData> IntoIterator + for &'a OneToManyUniqueBTreeMapWithData +{ + type Item = (&'a Left, &'a BTreeMap); + + type IntoIter = <&'a BTreeMap> as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.left_to_rights.iter() + } +} + impl FromIterator<(Left, Right, RightData)> for OneToManyUniqueBTreeMapWithData where diff --git a/src/track_vcs.rs b/src/track_vcs.rs index 8f142a3..6586a11 100644 --- a/src/track_vcs.rs +++ b/src/track_vcs.rs @@ -18,7 +18,7 @@ pub type GuildVoiceChannelToTextChannel = pub type VCsInGuild = OneToManyUniqueBTreeMapWithData, Id, UserInVCData>; pub type VCs = BTreeMap, VCsInGuild>; -pub type VCsWatcher = watch::Sender; +pub type VCsSender = watch::Sender; #[tracing::instrument(skip(discord_client), ret)] async fn initialize_user_in_vc( diff --git a/src/vc_user.rs b/src/vc_user.rs index 4dd8eac..ea09913 100644 --- a/src/vc_user.rs +++ b/src/vc_user.rs @@ -1,20 +1,20 @@ use typed_builder::TypedBuilder; -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum Microphone { Unmuted, ServerMuted, Muted, } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum Headphone { Undeafened, ServerDeafened, Deafened, } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum Camera { Showing, Off, @@ -30,7 +30,7 @@ impl From for Camera { } } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum Stream { Sharing, None, @@ -46,7 +46,7 @@ impl From for Stream { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct UserInVCData { pub microphone: Microphone, pub headphone: Headphone, @@ -54,7 +54,7 @@ pub struct UserInVCData { pub stream: Stream, } -#[derive(Debug, TypedBuilder)] +#[derive(Debug, Clone, TypedBuilder)] pub struct VoiceStatus { server_deafened: bool, self_deafened: bool,