use clap::Parser; use fomo_reducer::{ BotDataManager, CommandRouter, GuildVoiceChannelToTextChannel, State, Storage, UserDataManager, VCsWatcher, all_commands, command, initialize_vcs, update_vcs }; use secrecy::{ExposeSecret, SecretString}; use snafu::{OptionExt, ResultExt, Snafu}; use songbird::{ Config, Songbird, driver::{Channels, DecodeConfig, SampleRate}, shards::TwilightMap, }; use std::{collections::BTreeMap, fmt::Debug, str::FromStr, sync::Arc}; use strum::EnumString; use tokio::{select, signal::ctrl_c, task::JoinSet}; use tokio_util::{sync::CancellationToken, time::FutureExt as _}; use tracing_subscriber::{EnvFilter, fmt::format::FmtSpan}; use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, StreamExt}; use twilight_model::{ application::interaction::InteractionData, gateway::{ payload::{incoming::InteractionCreate, outgoing::UpdatePresence}, presence::{ActivityType, MinimalActivity, Status}, }, id::{ Id, marker::{ChannelMarker, GuildMarker, UserMarker}, }, }; #[derive(Clone, Copy, Debug, strum::Display, EnumString)] enum AudioChannels { Mono, Stereo, } impl From for Channels { fn from(value: AudioChannels) -> Self { match value { AudioChannels::Mono => Channels::Mono, AudioChannels::Stereo => Channels::Stereo, } } } #[derive(Clone, Copy, Debug, strum::Display, EnumString)] enum AudioSampleRate { #[strum(serialize = "8000Hz")] Hz8000, #[strum(serialize = "12000Hz")] Hz12000, #[strum(serialize = "16000Hz")] Hz16000, #[strum(serialize = "24000Hz")] Hz24000, #[strum(serialize = "48000Hz")] Hz48000, } impl From for SampleRate { fn from(value: AudioSampleRate) -> Self { match value { AudioSampleRate::Hz8000 => SampleRate::Hz8000, AudioSampleRate::Hz12000 => SampleRate::Hz12000, AudioSampleRate::Hz16000 => SampleRate::Hz16000, AudioSampleRate::Hz24000 => SampleRate::Hz24000, AudioSampleRate::Hz48000 => SampleRate::Hz48000, } } } #[derive(Debug, Snafu)] enum ParseGuildVCToTextChannelError { NoScope, NoRelation, ParseGuildError { source: as FromStr>::Err, }, ParseVoiceChannelError { source: as FromStr>::Err, }, ParseTextChannelError { source: as FromStr>::Err, }, } fn parse_guild_vc_to_text_channel( source: &str, ) -> Result<(Id, Id, Id), ParseGuildVCToTextChannelError> { let (guild, voice_channel_and_text_channel) = source.split_once(':').context(NoScopeSnafu)?; let (voice_channel, text_channel) = voice_channel_and_text_channel .split_once("->") .context(NoRelationSnafu)?; let guild = guild.parse().context(ParseGuildSnafu)?; let voice_channel = voice_channel.parse().context(ParseVoiceChannelSnafu)?; let text_channel = text_channel.parse().context(ParseTextChannelSnafu)?; Ok((guild, voice_channel, text_channel)) } #[derive(Debug, Parser)] struct AppArgs { #[arg(long, env)] discord_token: SecretString, #[arg(long, env)] discord_bot_owner_user_id: Id, #[arg(long, env)] discord_nickname: Option>, #[arg(long, env)] discord_status: Option>, #[arg(long, env, value_parser = parse_guild_vc_to_text_channel)] discord_voice_channel_corresponding_text_channel: Vec<(Id, Id, Id)>, #[arg(long, env, default_value_t = AudioChannels::Mono)] audio_channels: AudioChannels, #[arg(long, env, default_value_t = AudioSampleRate::Hz12000)] audio_sample_rate: AudioSampleRate, #[arg(long, env)] bot_data: Storage, #[arg(long, env)] user_data: Storage, #[arg(long, env)] recording_data: Storage, } #[derive(Parser)] struct LoggingArgs { #[arg( long = "logging-directives", env = "RUST_LOG", default_value = "warn,fomo_reducer=debug" )] env_filter: EnvFilter, } #[derive(Parser)] struct Args { #[clap(flatten)] app_args: AppArgs, #[clap(flatten)] logging_args: LoggingArgs, } #[derive(Debug, Snafu)] enum MainError { /// the program was cancelled, perhaps by Ctrl-C / SIGINT Cancelled, } #[snafu::report] #[tokio::main] async fn main() -> Result<(), MainError> { let Args { logging_args, app_args, } = Parser::parse(); let LoggingArgs { env_filter } = logging_args; tracing_subscriber::fmt() .pretty() .with_env_filter(env_filter) .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) .init(); tracing::debug!(?app_args, "using"); let AppArgs { discord_token, discord_bot_owner_user_id, discord_nickname, discord_status, discord_voice_channel_corresponding_text_channel, audio_channels, audio_sample_rate, bot_data, user_data, recording_data, } = app_args; let cancellation_token = CancellationToken::new(); rustls::crypto::aws_lc_rs::default_provider() .install_default() .unwrap(); let discord_client = twilight_http::Client::new(discord_token.expose_secret().to_owned()); let guilds = discord_client .current_user_guilds() .limit(200) .await .expect("TODO") .model() .await .expect("TODO"); JoinSet::from_iter(guilds.into_iter().map(|guild| { discord_client .update_current_member(guild.id) .nick(discord_nickname.as_deref()) .into_future() })) .join_all() .await; let discord_user = discord_client .current_user() .await .expect("couldn't fetch current user") // TODO .model() .await .expect("couldn't deserialize current user"); // TODO let discord_user_id = discord_user.id; let current_application = discord_client .current_user_application() .await .expect("couldn't get current Discord application"); // TODO let current_application = current_application .model() .await .expect("couldn't get current Discord application"); // TODO let discord_application_id = current_application.id; let intents = Intents::GUILD_VOICE_STATES; let config = twilight_gateway::Config::new(discord_token.expose_secret().to_owned(), intents); let shards = twilight_gateway::create_recommended(&discord_client, config, |_id, builder| { builder.build() }) .await .expect("TODO"); let shards = Vec::from_iter(shards); let senders = TwilightMap::new( shards .iter() .map(|shard| (shard.id().number(), shard.sender())) .collect(), ); let audio_channels = audio_channels.into(); let audio_sample_rate = audio_sample_rate.into(); let senders = Arc::new(senders); let songbird = Songbird::twilight(senders, discord_user_id); songbird.set_config( Config::default().decode_mode(songbird::driver::DecodeMode::Decode(DecodeConfig::new( audio_channels, audio_sample_rate, ))), ); let interaction_client = discord_client.interaction(discord_application_id); let commands = all_commands(); let returned_commands = interaction_client .set_global_commands( Vec::from_iter( commands .iter() .map(|(command, _handler)| (*command).clone()), ) .as_slice(), ) .await .expect("failed to set interaction commands") // TODO .models() .await .expect("failed to deserialize set commands"); // TODO let mut discord_command_name_to_returned_command = BTreeMap::from_iter( returned_commands .into_iter() .map(|command| (command.name.clone(), command)), ); let discord_info_command = discord_command_name_to_returned_command .remove(&command::info::COMMAND.name) .expect("TODO"); let discord_opt_in_command = discord_command_name_to_returned_command .remove(&command::opt_in::COMMAND.name) .expect("TODO"); let discord_opt_out_command = discord_command_name_to_returned_command .remove(&command::opt_out::COMMAND.name) .expect("TODO"); let discord_info_command_id = discord_info_command.id.expect("TODO"); let discord_opt_in_command_id = discord_opt_in_command.id.expect("TODO"); let discord_opt_out_command_id = discord_opt_out_command.id.expect("TODO"); let discord_info_command_name = discord_info_command.name.into(); let discord_opt_in_command_name = discord_opt_in_command.name.into(); let discord_opt_out_command_name = discord_opt_out_command.name.into(); let vcs = initialize_vcs(&discord_client).await; let command_router = CommandRouter::from_iter(commands); let command_router = Arc::new(command_router); let discord_client = Arc::new(discord_client); let songbird = Arc::new(songbird); let vcs_watcher = VCsWatcher::new(vcs); let bot_data = bot_data.into_inner(); let recording_data = recording_data.into_inner(); let user_data = user_data.into_inner(); let bot_data_manager = BotDataManager::new(bot_data); let user_data_manager = UserDataManager::new(user_data); let discord_voice_channel_corresponding_text_channel = { let mut map = GuildVoiceChannelToTextChannel::default(); for (guild_id, voice_channel_id, text_channel_id) in discord_voice_channel_corresponding_text_channel { map.entry(guild_id) .or_default() .insert(voice_channel_id, text_channel_id); } map }; let discord_voice_channel_corresponding_text_channel = Arc::new(discord_voice_channel_corresponding_text_channel); let state = State { audio_channels, audio_sample_rate, bot_data_manager, cancellation_token: cancellation_token.clone(), discord_application_id, discord_bot_owner_user_id, discord_client, discord_info_command_id, discord_info_command_name, discord_opt_in_command_id, discord_opt_in_command_name, discord_opt_out_command_id, discord_opt_out_command_name, discord_user_id, discord_voice_channel_corresponding_text_channel, recording_data, songbird, user_data_manager, vcs_watcher, }; if let Some(discord_status) = discord_status { shards.iter().for_each(|shard| { shard.command( &UpdatePresence::new( vec![ MinimalActivity { kind: ActivityType::Listening, name: (*discord_status).to_owned(), url: None, } .into(), ], false, None, Status::Idle, ) .expect("TODO"), ) }); } let run_shards = JoinSet::from_iter( shards .into_iter() .map(|shard| handle_events(command_router.clone(), state.clone(), shard)), ); let run_shards = run_shards.join_all(); tokio::pin!(run_shards); tokio::spawn({ let cancellation_token = cancellation_token.clone(); async move { match ctrl_c().await { Ok(()) => cancellation_token.cancel(), Err(error) => tracing::error!(?error, "failed to listen for interrupt signal"), } } }); select! { _ = &mut run_shards => { Ok(()) } () = cancellation_token.cancelled() => { tracing::warn!("waiting for tasks to gracefully shut down"); run_shards.await; Err(MainError::Cancelled) } } } #[tracing::instrument(skip(command_router, shard, state))] async fn handle_events(command_router: Arc, state: State, mut shard: Shard) { let event_types = EventTypeFlags::GUILD_VOICE_STATES | EventTypeFlags::INTERACTION_CREATE | EventTypeFlags::VOICE_SERVER_UPDATE | EventTypeFlags::VOICE_STATE_UPDATE; while let Some(Some(event_res)) = shard .next_event(event_types) .with_cancellation_token(&state.cancellation_token) .await { match event_res { Ok(twilight_model::gateway::event::Event::GatewayClose(frame_option)) => { tracing::warn!(?frame_option); return; } Ok(event) => { handle_event(command_router.clone(), state.clone(), event).await; } Err(reconnect_error) if matches!( reconnect_error.kind(), &twilight_gateway::error::ReceiveMessageErrorType::Reconnect ) => { tracing::error!(?reconnect_error); return; } Err(error) => { tracing::error!(?error); } } } } #[tracing::instrument(skip(command_router, state))] async fn handle_event(command_router: Arc, state: State, event: Event) { state.songbird.process(&event).await; match event { Event::VoiceStateUpdate(voice_state_update) => { state.vcs_watcher.send_modify(|vcs| update_vcs(&voice_state_update, vcs)); } Event::InteractionCreate(interaction_create) => { let InteractionCreate(interaction) = *interaction_create; match &interaction.data { None => { tracing::warn!("missing expected interaction data"); } Some(InteractionData::ApplicationCommand(command_data)) => { let command_name = command_data.name.clone(); tokio::spawn(async move { command_router .handle(state, &command_name, interaction) .await; }); } Some(InteractionData::MessageComponent(component_data)) => { tracing::warn!( ?component_data, "wasn't expected because this bot has no modal features" ); } Some(InteractionData::ModalSubmit(modal_data)) => { tracing::warn!( ?modal_data, "wasn't expected because this bot has no modal features" ); } Some(other_interaction_data) => { tracing::warn!(?other_interaction_data, "wasn't expected"); } } } other_event => { tracing::warn!(?other_event, "wasn't expected"); } } }