diff --git a/src/command/join.rs b/src/command/join.rs index 8809ae0..e5b2c91 100644 --- a/src/command/join.rs +++ b/src/command/join.rs @@ -67,10 +67,10 @@ fn get_guild_and_voice_channel_id( .and_then(|member| member.user.as_ref().map(|user| user.id)) .context(NoUserSnafu)?; - let voice_channel_id = vcs - .with_guild(guild_id, |guild_vcs| { - guild_vcs.get_left_for(&user_id).copied() - }) + let &voice_channel_id = vcs + .get(&guild_id) + .context(UserNotInVCSnafu)? + .get_left_for(&user_id) .context(UserNotInVCSnafu)?; Ok((guild_id, voice_channel_id)) @@ -236,33 +236,33 @@ impl EventHandler for Handler { #[tracing::instrument(skip(state))] pub async fn handle(state: State, interaction: Interaction) { - let vcs = state.vcs; + let guild_and_voice_channel_id_res = get_guild_and_voice_channel_id(&interaction, &state.vcs_watcher.borrow()); + let (guild_id, voice_channel_id) = + match guild_and_voice_channel_id_res { + Ok((guild_id, voice_channel_id)) => (guild_id, voice_channel_id), + Err(error) => { + state + .discord_client + .interaction(state.discord_application_id) + .create_response( + interaction.id, + &interaction.token, + &InteractionResponse { + kind: InteractionResponseType::ChannelMessageWithSource, + data: Some( + InteractionResponseDataBuilder::new() + .embeds([get_guild_and_vc_error_to_embed(error)]) + .flags(MessageFlags::EPHEMERAL) + .build(), + ), + }, + ) + .await + .expect("TODO"); - let (guild_id, voice_channel_id) = match get_guild_and_voice_channel_id(&interaction, &vcs) { - Ok((guild_id, voice_channel_id)) => (guild_id, voice_channel_id), - Err(error) => { - state - .discord_client - .interaction(state.discord_application_id) - .create_response( - interaction.id, - &interaction.token, - &InteractionResponse { - kind: InteractionResponseType::ChannelMessageWithSource, - data: Some( - InteractionResponseDataBuilder::new() - .embeds([get_guild_and_vc_error_to_embed(error)]) - .flags(MessageFlags::EPHEMERAL) - .build(), - ), - }, - ) - .await - .expect("TODO"); - - return; - } - }; + return; + } + }; state .discord_client diff --git a/src/command/leave.rs b/src/command/leave.rs index 02b5d8d..7561df6 100644 --- a/src/command/leave.rs +++ b/src/command/leave.rs @@ -55,10 +55,10 @@ pub fn get_user_and_guild_and_voice_channel_id( let guild_id = interaction.guild_id.context(NotInGuildSnafu)?; - let voice_channel_id = vcs - .with_guild(guild_id, |guild_vcs| { - guild_vcs.get_left_for(&bot_user_id).copied() - }) + let &voice_channel_id = vcs + .get(&guild_id) + .context(BotNotInVCSnafu)? + .get_left_for(&bot_user_id) .context(BotNotInVCSnafu)?; Ok((user_id, guild_id, voice_channel_id)) @@ -80,11 +80,12 @@ fn get_guild_and_vc_error_to_embed(error: GetGuildAndVoiceChannelIdError) -> Emb #[tracing::instrument] pub async fn handle(state: State, interaction: Interaction) { - let (user_id, guild_id, voice_channel_id) = match get_user_and_guild_and_voice_channel_id( + let user_and_guild_and_voice_channel_id_res = get_user_and_guild_and_voice_channel_id( state.discord_user_id, &interaction, - &state.vcs, - ) { + &state.vcs_watcher.borrow(), + ); + let (user_id, guild_id, voice_channel_id) = match user_and_guild_and_voice_channel_id_res { Ok((user_id, guild_id, voice_channel_id)) => (user_id, guild_id, voice_channel_id), Err(error) => { state diff --git a/src/command/mod.rs b/src/command/mod.rs index af98444..4175443 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -16,7 +16,7 @@ use twilight_model::{ }, }; -use crate::{BotDataManager, GuildVoiceChannelToTextChannel, UserDataManager, VCs}; +use crate::{BotDataManager, GuildVoiceChannelToTextChannel, UserDataManager, VCs, track_vcs::VCsWatcher}; pub mod info; pub mod join; @@ -42,7 +42,7 @@ pub struct State { pub recording_data: Operator, pub songbird: Arc, pub user_data_manager: UserDataManager, - pub vcs: Arc, + pub vcs_watcher: VCsWatcher, } type Return = (); diff --git a/src/lib.rs b/src/lib.rs index 90168b9..ae754c5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,6 @@ pub use one_to_many_with_data::OneToManyUniqueBTreeMapWithData; pub use one_to_one::OneToOneBTreeMap; pub use operator_ext::OperatorExt; pub use storage::Storage; -pub use track_vcs::{GuildVoiceChannelToTextChannel, VCs, initialize_vcs, update_vcs}; +pub use track_vcs::{GuildVoiceChannelToTextChannel, VCs, VCsWatcher, initialize_vcs, update_vcs}; pub use user_data::UserDataManager; pub use vc_user::{UserInVCData, VoiceStatus}; diff --git a/src/main.rs b/src/main.rs index 464f2d6..f9d1ce1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,6 @@ use clap::Parser; use fomo_reducer::{ - BotDataManager, CommandRouter, GuildVoiceChannelToTextChannel, State, Storage, UserDataManager, - all_commands, command, initialize_vcs, update_vcs, + BotDataManager, CommandRouter, GuildVoiceChannelToTextChannel, State, Storage, UserDataManager, VCsWatcher, all_commands, command, initialize_vcs, update_vcs }; use secrecy::{ExposeSecret, SecretString}; use snafu::{OptionExt, ResultExt, Snafu}; @@ -316,7 +315,7 @@ async fn main() -> Result<(), MainError> { let discord_client = Arc::new(discord_client); let songbird = Arc::new(songbird); - let vcs = Arc::new(vcs); + let vcs_watcher = VCsWatcher::new(vcs); let bot_data = bot_data.into_inner(); let recording_data = recording_data.into_inner(); @@ -358,7 +357,7 @@ async fn main() -> Result<(), MainError> { recording_data, songbird, user_data_manager, - vcs, + vcs_watcher, }; if let Some(discord_status) = discord_status { @@ -455,7 +454,7 @@ async fn handle_event(command_router: Arc, state: State, event: E match event { Event::VoiceStateUpdate(voice_state_update) => { - update_vcs(&voice_state_update, &state.vcs); + state.vcs_watcher.send_modify(|vcs| update_vcs(&voice_state_update, vcs)); } Event::InteractionCreate(interaction_create) => { let InteractionCreate(interaction) = *interaction_create; diff --git a/src/track_vcs.rs b/src/track_vcs.rs index db34e3d..2741db2 100644 --- a/src/track_vcs.rs +++ b/src/track_vcs.rs @@ -1,6 +1,5 @@ use std::collections::BTreeMap; -use dashmap::DashMap; use futures::{StreamExt, stream::FuturesUnordered}; use tokio::sync::watch; use twilight_model::{ @@ -16,38 +15,9 @@ use crate::{OneToManyUniqueBTreeMapWithData, OneToOneBTreeMap, UserInVCData, Voi pub type GuildVoiceChannelToTextChannel = BTreeMap, OneToOneBTreeMap, Id>>; -type VCsInGuild = OneToManyUniqueBTreeMapWithData, Id, UserInVCData>; - -#[derive(Debug, Default)] -pub struct VCs(DashMap, watch::Sender>); - -impl Extend<(Id, VCsInGuild)> for VCs { - fn extend, VCsInGuild)>>(&mut self, iter: T) { - for (id, guild_vcs) in iter { - self.0.insert(id, watch::Sender::new(guild_vcs)); - } - } -} - -impl VCs { - pub fn with_guild(&self, id: Id, f: impl FnOnce(&VCsInGuild) -> R) -> R { - f(&*self.0.entry(id).or_default().borrow()) - } - - pub fn update_guild(&self, id: Id, f: impl FnOnce(&mut VCsInGuild) -> R) -> R { - let mut ret_opt = None; - self.0.entry(id).or_default().send_modify(|guild_vcs| { - let ret = f(guild_vcs); - _ = ret_opt.insert(ret); - }); - let ret = ret_opt.unwrap(); - ret - } - - pub fn subscribe_to_guild(&self, id: Id) -> watch::Receiver { - self.0.entry(id).or_default().subscribe() - } -} +pub type VCsInGuild = OneToManyUniqueBTreeMapWithData, Id, UserInVCData>; +pub type VCs = BTreeMap, VCsInGuild>; +pub type VCsWatcher = watch::Sender; #[tracing::instrument(skip(discord_client), ret)] async fn initialize_user_in_vc( @@ -123,7 +93,7 @@ pub async fn initialize_vcs(discord_client: &twilight_http::Client) -> VCs { } #[tracing::instrument(skip(vcs))] -pub fn update_vcs(voice_state_update: &VoiceStateUpdate, vcs: &VCs) { +pub fn update_vcs(voice_state_update: &VoiceStateUpdate, vcs: &mut VCs) { let user_id = voice_state_update.user_id; match voice_state_update.guild_id { Some(guild_id) => match voice_state_update.channel_id { @@ -138,9 +108,7 @@ pub fn update_vcs(voice_state_update: &VoiceStateUpdate, vcs: &VCs) { .build(); let user_in_vc_data = voice_status.into(); - vcs.update_guild(guild_id, |guild_vcs| { - guild_vcs.insert(channel_id, user_id, user_in_vc_data) - }); + vcs.entry(guild_id).or_default().insert(channel_id, user_id, user_in_vc_data); tracing::info!( ?guild_id, @@ -151,7 +119,7 @@ pub fn update_vcs(voice_state_update: &VoiceStateUpdate, vcs: &VCs) { } None => { - vcs.update_guild(guild_id, |guild_vcs| guild_vcs.remove_right(&user_id)); + vcs.entry(guild_id).or_default().remove_right(&user_id); tracing::info!(?guild_id, ?user_id, "disconnected"); }