diff --git a/Cargo.lock b/Cargo.lock index 0ae9e2b..1417ad4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -154,9 +154,9 @@ checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" [[package]] name = "arc-swap" -version = "1.8.2" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9f3647c145568cec02c42054e07bdf9a5a698e15b466fb2341bfc393cd24aa5" +checksum = "a07d1f37ff60921c83bdfc7407723bdefe89b44b98a9b772f225c8f9d67141a6" dependencies = [ "rustversion", ] @@ -1455,6 +1455,7 @@ version = "0.1.0" dependencies = [ "blart", "clap", + "dashmap 6.1.0", "futures", "opendal", "rhai", diff --git a/Cargo.toml b/Cargo.toml index cf3a718..a486fd1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2024" [dependencies] blart = "0.4.0" clap = { version = "4.5.40", features = ["derive", "env"] } +dashmap = "6.1.0" futures = "0.3.32" opendal = { git = "https://github.com/apache/opendal", features = [ "services-azfile", diff --git a/src/command/mod.rs b/src/command/mod.rs index 61de949..a918518 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -1,6 +1,7 @@ use std::{ ffi::{CStr, CString}, fmt::Debug, + sync::Arc, }; use blart::TreeMap; @@ -9,12 +10,16 @@ use twilight_model::application::{ command::Command, interaction::application_command::CommandData, }; +use crate::VCs; + mod join; mod leave; mod opt_out; #[derive(Debug, Clone)] -pub struct State {} +pub struct State { + pub vcs: Arc, +} type Return = (); type BoxedHandler = Box BoxFuture<'static, Return>>; @@ -53,7 +58,7 @@ impl Router { self.map.insert(name.parse().unwrap(), boxed_handler); } - pub async fn handle(&self, args: State, command_data: CommandData) -> Return { + pub async fn handle(&self, state: State, command_data: CommandData) -> Return { let name = &command_data.name; let key = CStr::from_bytes_with_nul(name.as_bytes()).unwrap(); @@ -62,19 +67,27 @@ impl Router { .get(key) .expect("asked to handle an inexistent command"); - handler(args, command_data).await + handler(state, command_data).await } } -impl<'a> FromIterator<(&'a CommandData, BoxedHandler)> for Router { - fn from_iter>(iter: T) -> Self { - let mut router = Router::default(); +impl<'a> FromIterator<(&'a Command, BoxedHandler)> for Router { + #[inline] + fn from_iter>(iter: T) -> Self { + let mut this = Self::default(); - for (command, handler) in iter { + this.extend(iter); + + this + } +} + +impl<'a> Extend<(&'a Command, BoxedHandler)> for Router { + #[inline] + fn extend>(&mut self, iter: T) { + for (command, boxed_handler) in iter { let name = &command.name; - router.add_route(name, handler); + self.add_route_already_boxed(name, boxed_handler); } - - router } } diff --git a/src/lib.rs b/src/lib.rs index dd01016..df4dcda 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,6 @@ pub use one_to_many::OneToManyUniqueBTreeMap; pub use one_to_many_with_data::OneToManyUniqueBTreeMapWithData; pub use one_to_one::OneToOneBTreeMap; -pub use command::all as all_commands; +pub use command::{Router as CommandRouter, State, all as all_commands}; pub use track_vcs::{VCs, initialize_vcs, update_vcs}; pub use vc_user::{UserInVCData, VoiceStatus}; diff --git a/src/main.rs b/src/main.rs index ed80a80..ccb248d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,16 @@ use clap::Parser; -use fomo_reducer::{VCs, all_commands, initialize_vcs, update_vcs}; +use fomo_reducer::{CommandRouter, State, VCs, all_commands, initialize_vcs, update_vcs}; use opendal::{IntoOperatorUri, Operator, OperatorUri}; use secrecy::{ExposeSecret, SecretString}; use snafu::Snafu; -use std::{fmt::Debug, str::FromStr}; +use std::{fmt::Debug, str::FromStr, sync::Arc}; use tracing_subscriber::{EnvFilter, fmt::format::FmtSpan}; use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt}; -use twilight_model::id::{Id, marker::UserMarker}; +use twilight_model::{ + application::interaction::InteractionData, + gateway::payload::incoming::InteractionCreate, + id::{Id, marker::UserMarker}, +}; #[derive(Clone)] struct OpendalOperator { @@ -126,7 +130,7 @@ async fn main() -> Result<(), MainError> { let commands = all_commands(); - let returned_commands = interaction_client + let _returned_commands = interaction_client .set_global_commands( Vec::from_iter( commands @@ -141,13 +145,16 @@ async fn main() -> Result<(), MainError> { .await .expect("failed to deserialize set commands"); // TODO - let mut voice_status = initialize_vcs(&discord_client).await; + let command_router = CommandRouter::from_iter(commands); + + let vcs = initialize_vcs(&discord_client).await; + let vcs = Arc::new(vcs); while let Some(event_res) = next_event.await { match event_res { Ok(event) => { - tracing::debug!(?voice_status, "before handling"); - handle_event(event, &mut voice_status).await; - tracing::debug!(?voice_status, "after handling"); + tracing::debug!(?vcs, "before handling"); + handle_event(&command_router, vcs.clone(), event).await; + tracing::debug!(?vcs, "after handling"); } Err(error) => { tracing::error!(?error); @@ -160,14 +167,45 @@ async fn main() -> Result<(), MainError> { Ok(()) } -#[tracing::instrument(skip(vcs))] -async fn handle_event(event: Event, vcs: &mut VCs) { +#[tracing::instrument(skip(command_router, vcs))] +async fn handle_event(command_router: &CommandRouter, vcs: Arc, event: Event) { match event { Event::VoiceStateUpdate(voice_state_update) => { - update_vcs(&voice_state_update, vcs); + update_vcs(&voice_state_update, &vcs); } - other => { - tracing::warn!(?other, "wasn't expected"); + Event::InteractionCreate(interaction_create) => { + let InteractionCreate(interaction) = *interaction_create; + + match interaction.data { + None => { + tracing::warn!("missing expected interaction data"); + } + Some(InteractionData::ApplicationCommand(command_data)) => { + let state = State { vcs }; + command_router.handle(state, *command_data).await; + } + + Some(InteractionData::MessageComponent(component_data)) => { + tracing::warn!( + ?component_data, + "wasn't expected because this bot has no modal features" + ); + } + + Some(InteractionData::ModalSubmit(modal_data)) => { + tracing::warn!( + ?modal_data, + "wasn't expected because this bot has no modal features" + ); + } + + Some(other_interaction_data) => { + tracing::warn!(?other_interaction_data, "wasn't expected"); + } + } + } + other_event => { + tracing::warn!(?other_event, "wasn't expected"); } } } diff --git a/src/one_to_many_with_data.rs b/src/one_to_many_with_data.rs index a41b993..e539166 100644 --- a/src/one_to_many_with_data.rs +++ b/src/one_to_many_with_data.rs @@ -7,6 +7,7 @@ pub struct OneToManyUniqueBTreeMapWithData { } impl Default for OneToManyUniqueBTreeMapWithData { + #[inline] fn default() -> Self { Self { left_to_rights: Default::default(), @@ -69,3 +70,33 @@ where Some((left, right, right_data)) } } + +impl FromIterator<(Left, Right, RightData)> + for OneToManyUniqueBTreeMapWithData +where + Left: Ord + Clone, + Right: Ord + Clone, +{ + #[inline] + fn from_iter>(iter: T) -> Self { + let mut this = Self::default(); + + this.extend(iter); + + this + } +} + +impl Extend<(Left, Right, RightData)> + for OneToManyUniqueBTreeMapWithData +where + Left: Ord + Clone, + Right: Ord + Clone, +{ + #[inline] + fn extend>(&mut self, iter: T) { + for (left, right, right_data) in iter { + self.insert(left, right, right_data); + } + } +} diff --git a/src/track_vcs.rs b/src/track_vcs.rs index 3de8bbd..7aeddb1 100644 --- a/src/track_vcs.rs +++ b/src/track_vcs.rs @@ -1,9 +1,14 @@ -type VCsInServer = OneToManyUniqueBTreeMapWithData, Id, UserInVCData>; +type VCsInGuild = OneToManyUniqueBTreeMapWithData, Id, UserInVCData>; -pub type VCs = BTreeMap, VCsInServer>; +pub type VCs = DashMap, VCsInGuild>; use std::collections::BTreeMap; +use dashmap::DashMap; +use futures::{ + StreamExt, + stream::{self, FuturesUnordered}, +}; use twilight_model::{ gateway::payload::incoming::VoiceStateUpdate, id::{ @@ -15,51 +20,79 @@ use twilight_model::{ use crate::{OneToManyUniqueBTreeMapWithData, UserInVCData, VoiceStatus}; #[tracing::instrument(skip(discord_client), ret)] -pub async fn initialize_vcs(discord_client: &twilight_http::Client) -> VCs { - let mut vcs = VCs::default(); +async fn initialize_user_in_vc( + discord_client: &twilight_http::Client, + guild_id: Id, + user_id: Id, +) -> Option<(Id, UserInVCData)> { + if let Ok(voice_state_res) = discord_client.user_voice_state(guild_id, user_id).await + && let Ok(voice_state) = voice_state_res.model().await + { + tracing::info!(?user_id, ?voice_state); + let voice_status = VoiceStatus::builder() + .self_deafened(voice_state.self_deaf) + .self_muted(voice_state.self_mute) + .server_deafened(voice_state.deaf) + .server_muted(voice_state.mute) + .camming(voice_state.self_video) + .streaming(voice_state.self_stream) + .build(); + let user_in_vc_data = voice_status.into(); + + voice_state + .channel_id + .map(|channel_id| (channel_id, user_in_vc_data)) + } else { + None // TODO + } +} + +#[tracing::instrument(skip(discord_client), ret)] +async fn initialize_server_vcs( + discord_client: &twilight_http::Client, + id: Id, +) -> VCsInGuild { + if let Ok(guild_members_res) = discord_client.guild_members(id).limit(999).await + && let Ok(guild_members) = guild_members_res.model().await + { + FuturesUnordered::from_iter(guild_members.into_iter().map(|member| async move { + ( + member.user.id, + initialize_user_in_vc(discord_client, id, member.user.id).await, + ) + })) + .filter_map( + |(user_id, channel_id_and_user_in_vc_data_option)| async move { + channel_id_and_user_in_vc_data_option + .map(|(channel_id, user_in_vc_data)| (channel_id, user_id, user_in_vc_data)) + }, + ) + .collect() + .await + } else { + Default::default() + } +} + +#[tracing::instrument(skip(discord_client), ret)] +pub async fn initialize_vcs(discord_client: &twilight_http::Client) -> VCs { if let Ok(guilds_res) = discord_client.current_user_guilds().limit(200).await && let Ok(guilds) = guilds_res.model().await { - for guild in guilds { - if let Ok(guild_members_res) = discord_client.guild_members(guild.id).limit(999).await - && let Ok(guild_members) = guild_members_res.model().await - { - for member in guild_members { - if let Ok(voice_state_res) = discord_client - .user_voice_state(guild.id, member.user.id) - .await - && let Ok(voice_state) = voice_state_res.model().await - { - tracing::info!(?member.user.id, ?voice_state); + FuturesUnordered::from_iter(guilds.into_iter().map(|guild| async move { + let guild_vcs = initialize_server_vcs(discord_client, guild.id).await; - let voice_status = VoiceStatus::builder() - .self_deafened(voice_state.self_deaf) - .self_muted(voice_state.self_mute) - .server_deafened(voice_state.deaf) - .server_muted(voice_state.mute) - .camming(voice_state.self_video) - .streaming(voice_state.self_stream) - .build(); - let user_in_vc_data = voice_status.into(); - - if let Some(channel_id) = voice_state.channel_id { - vcs.entry(guild.id).or_default().insert( - channel_id, - member.user.id, - user_in_vc_data, - ); - } - } - } - } - } + (guild.id, guild_vcs) + })) + .collect() + .await + } else { + Default::default() } - - vcs } -pub fn update_vcs(voice_state_update: &VoiceStateUpdate, vcs: &mut VCs) { +pub fn update_vcs(voice_state_update: &VoiceStateUpdate, vcs: &VCs) { let user_id = voice_state_update.user_id; match voice_state_update.guild_id { Some(guild_id) => match voice_state_update.channel_id { @@ -87,7 +120,7 @@ pub fn update_vcs(voice_state_update: &VoiceStateUpdate, vcs: &mut VCs) { } None => { - if let Some(channel_vcers) = vcs.get_mut(&guild_id) { + if let Some(mut channel_vcers) = vcs.get_mut(&guild_id) { channel_vcers.remove_right(&user_id); }