diff --git a/src/command/mod.rs b/src/command/mod.rs index 4e5ef3d..142cb48 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -13,7 +13,7 @@ use twilight_model::{ }, }; -use crate::VCs; +use crate::{VCs, track_vcs::GuildVoiceChannelToTextChannel}; mod debug; mod join; @@ -31,6 +31,7 @@ pub struct State { pub discord_bot_owner_user_id: Id, pub discord_client: Arc, pub discord_user_id: Id, + pub discord_voice_channel_corresponding_text_channel: Arc, pub recording_data: Operator, pub songbird: Arc, pub user_data: Operator, diff --git a/src/lib.rs b/src/lib.rs index ff05ddd..9de2c20 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,7 +11,7 @@ pub use one_to_many::OneToManyUniqueBTreeMap; pub use one_to_many_with_data::OneToManyUniqueBTreeMapWithData; pub use one_to_one::OneToOneBTreeMap; pub use storage::Storage; -pub use track_vcs::{VCs, initialize_vcs, update_vcs}; +pub use track_vcs::{GuildVoiceChannelToTextChannel, VCs, initialize_vcs, update_vcs}; pub use vc_user::{UserInVCData, VoiceStatus}; capnp::generated_code!(pub mod user_capnp); diff --git a/src/main.rs b/src/main.rs index 440ac58..2ff8ae0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,16 @@ use clap::Parser; -use fomo_reducer::{CommandRouter, State, Storage, all_commands, initialize_vcs, update_vcs}; +use fomo_reducer::{ + CommandRouter, GuildVoiceChannelToTextChannel, State, Storage, all_commands, initialize_vcs, + update_vcs, +}; use secrecy::{ExposeSecret, SecretString}; -use snafu::Snafu; +use snafu::{OptionExt, ResultExt, Snafu}; use songbird::{ Config, Songbird, driver::{Channels, DecodeConfig, SampleRate}, shards::TwilightMap, }; -use std::{fmt::Debug, sync::Arc}; +use std::{fmt::Debug, str::FromStr, sync::Arc}; use strum::EnumString; use tokio::{select, signal::ctrl_c, task::JoinSet}; use tokio_util::{sync::CancellationToken, time::FutureExt as _}; @@ -19,7 +22,10 @@ use twilight_model::{ payload::{incoming::InteractionCreate, outgoing::UpdatePresence}, presence::{ActivityType, MinimalActivity, Status}, }, - id::{Id, marker::UserMarker}, + id::{ + Id, + marker::{ChannelMarker, GuildMarker, UserMarker}, + }, }; #[derive(Clone, Copy, Debug, strum::Display, EnumString)] @@ -63,6 +69,41 @@ impl From for SampleRate { } } +#[derive(Debug, Snafu)] +enum ParseGuildVCToTextChannelError { + NoScope, + + NoRelation, + + ParseGuildError { + source: as FromStr>::Err, + }, + + ParseVoiceChannelError { + source: as FromStr>::Err, + }, + + ParseTextChannelError { + source: as FromStr>::Err, + }, +} + +fn parse_guild_vc_to_text_channel( + source: &str, +) -> Result<(Id, Id, Id), ParseGuildVCToTextChannelError> +{ + let (guild, voice_channel_and_text_channel) = source.split_once(':').context(NoScopeSnafu)?; + let (voice_channel, text_channel) = voice_channel_and_text_channel + .split_once("->") + .context(NoRelationSnafu)?; + + let guild = guild.parse().context(ParseGuildSnafu)?; + let voice_channel = voice_channel.parse().context(ParseVoiceChannelSnafu)?; + let text_channel = text_channel.parse().context(ParseTextChannelSnafu)?; + + Ok((guild, voice_channel, text_channel)) +} + #[derive(Debug, Parser)] struct AppArgs { #[arg(long, env)] @@ -77,6 +118,10 @@ struct AppArgs { #[arg(long, env)] discord_status: Option>, + #[arg(long, env, value_parser = parse_guild_vc_to_text_channel)] + discord_voice_channel_corresponding_text_channel: + Vec<(Id, Id, Id)>, + #[arg(long, env, default_value_t = AudioChannels::Mono)] audio_channels: AudioChannels, @@ -141,6 +186,7 @@ async fn main() -> Result<(), MainError> { discord_bot_owner_user_id, discord_nickname, discord_status, + discord_voice_channel_corresponding_text_channel, audio_channels, audio_sample_rate, bot_data, @@ -244,11 +290,11 @@ async fn main() -> Result<(), MainError> { .await .expect("failed to deserialize set commands"); // TODO + let vcs = initialize_vcs(&discord_client).await; + let command_router = CommandRouter::from_iter(commands); let command_router = Arc::new(command_router); - let vcs = initialize_vcs(&discord_client).await; - let discord_client = Arc::new(discord_client); let songbird = Arc::new(songbird); let vcs = Arc::new(vcs); @@ -257,6 +303,22 @@ async fn main() -> Result<(), MainError> { let recording_data = recording_data.into_inner(); let user_data = user_data.into_inner(); + let discord_voice_channel_corresponding_text_channel = { + let mut map = GuildVoiceChannelToTextChannel::default(); + + for (guild_id, voice_channel_id, text_channel_id) in + discord_voice_channel_corresponding_text_channel + { + map.entry(guild_id) + .or_default() + .insert(voice_channel_id, text_channel_id); + } + + map + }; + let discord_voice_channel_corresponding_text_channel = + Arc::new(discord_voice_channel_corresponding_text_channel); + let state = State { audio_channels, audio_sample_rate, @@ -266,6 +328,7 @@ async fn main() -> Result<(), MainError> { discord_bot_owner_user_id, discord_client, discord_user_id, + discord_voice_channel_corresponding_text_channel, recording_data, songbird, user_data, @@ -344,7 +407,12 @@ async fn handle_events(command_router: Arc, state: State, mut sha Ok(event) => { handle_event(command_router.clone(), state.clone(), event).await; } - Err(reconnect_error) if matches!(reconnect_error.kind(), &twilight_gateway::error::ReceiveMessageErrorType::Reconnect) => { + Err(reconnect_error) + if matches!( + reconnect_error.kind(), + &twilight_gateway::error::ReceiveMessageErrorType::Reconnect + ) => + { tracing::error!(?reconnect_error); return; } diff --git a/src/track_vcs.rs b/src/track_vcs.rs index 6cabb08..d78dbe3 100644 --- a/src/track_vcs.rs +++ b/src/track_vcs.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + use dashmap::DashMap; use futures::{StreamExt, stream::FuturesUnordered}; use twilight_model::{ @@ -8,10 +10,11 @@ use twilight_model::{ }, }; -use crate::{OneToManyUniqueBTreeMapWithData, UserInVCData, VoiceStatus}; +use crate::{OneToManyUniqueBTreeMapWithData, OneToOneBTreeMap, UserInVCData, VoiceStatus}; + +pub type GuildVoiceChannelToTextChannel = BTreeMap, OneToOneBTreeMap, Id>>; type VCsInGuild = OneToManyUniqueBTreeMapWithData, Id, UserInVCData>; - pub type VCs = DashMap, VCsInGuild>; #[tracing::instrument(skip(discord_client), ret)]