Files
fomo-reducer/src/main.rs

506 lines
15 KiB
Rust

use clap::Parser;
use fomo_reducer::{
BotDataManager, CommandRouter, GuildVoiceChannelToTextChannel, State, Storage, UserDataManager, VCsWatcher, all_commands, command, initialize_vcs, update_vcs
};
use secrecy::{ExposeSecret, SecretString};
use snafu::{OptionExt, ResultExt, Snafu};
use songbird::{
Config, Songbird,
driver::{Channels, DecodeConfig, SampleRate},
shards::TwilightMap,
};
use std::{collections::BTreeMap, fmt::Debug, str::FromStr, sync::Arc};
use strum::EnumString;
use tokio::{select, signal::ctrl_c, task::JoinSet};
use tokio_util::{sync::CancellationToken, time::FutureExt as _};
use tracing_subscriber::{EnvFilter, fmt::format::FmtSpan};
use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, StreamExt};
use twilight_model::{
application::interaction::InteractionData,
gateway::{
payload::{incoming::InteractionCreate, outgoing::UpdatePresence},
presence::{ActivityType, MinimalActivity, Status},
},
id::{
Id,
marker::{ChannelMarker, GuildMarker, UserMarker},
},
};
#[derive(Clone, Copy, Debug, strum::Display, EnumString)]
enum AudioChannels {
Mono,
Stereo,
}
impl From<AudioChannels> for Channels {
fn from(value: AudioChannels) -> Self {
match value {
AudioChannels::Mono => Channels::Mono,
AudioChannels::Stereo => Channels::Stereo,
}
}
}
#[derive(Clone, Copy, Debug, strum::Display, EnumString)]
enum AudioSampleRate {
#[strum(serialize = "8000Hz")]
Hz8000,
#[strum(serialize = "12000Hz")]
Hz12000,
#[strum(serialize = "16000Hz")]
Hz16000,
#[strum(serialize = "24000Hz")]
Hz24000,
#[strum(serialize = "48000Hz")]
Hz48000,
}
impl From<AudioSampleRate> for SampleRate {
fn from(value: AudioSampleRate) -> Self {
match value {
AudioSampleRate::Hz8000 => SampleRate::Hz8000,
AudioSampleRate::Hz12000 => SampleRate::Hz12000,
AudioSampleRate::Hz16000 => SampleRate::Hz16000,
AudioSampleRate::Hz24000 => SampleRate::Hz24000,
AudioSampleRate::Hz48000 => SampleRate::Hz48000,
}
}
}
#[derive(Debug, Snafu)]
enum ParseGuildVCToTextChannelError {
NoScope,
NoRelation,
ParseGuildError {
source: <Id<GuildMarker> as FromStr>::Err,
},
ParseVoiceChannelError {
source: <Id<ChannelMarker> as FromStr>::Err,
},
ParseTextChannelError {
source: <Id<ChannelMarker> as FromStr>::Err,
},
}
fn parse_guild_vc_to_text_channel(
source: &str,
) -> Result<(Id<GuildMarker>, Id<ChannelMarker>, Id<ChannelMarker>), ParseGuildVCToTextChannelError>
{
let (guild, voice_channel_and_text_channel) = source.split_once(':').context(NoScopeSnafu)?;
let (voice_channel, text_channel) = voice_channel_and_text_channel
.split_once("->")
.context(NoRelationSnafu)?;
let guild = guild.parse().context(ParseGuildSnafu)?;
let voice_channel = voice_channel.parse().context(ParseVoiceChannelSnafu)?;
let text_channel = text_channel.parse().context(ParseTextChannelSnafu)?;
Ok((guild, voice_channel, text_channel))
}
#[derive(Debug, Parser)]
struct AppArgs {
#[arg(long, env)]
discord_token: SecretString,
#[arg(long, env)]
discord_bot_owner_user_id: Id<UserMarker>,
#[arg(long, env)]
discord_nickname: Option<Arc<str>>,
#[arg(long, env)]
discord_status: Option<Arc<str>>,
#[arg(long, env, value_parser = parse_guild_vc_to_text_channel)]
discord_voice_channel_corresponding_text_channel:
Vec<(Id<GuildMarker>, Id<ChannelMarker>, Id<ChannelMarker>)>,
#[arg(long, env, default_value_t = AudioChannels::Mono)]
audio_channels: AudioChannels,
#[arg(long, env, default_value_t = AudioSampleRate::Hz12000)]
audio_sample_rate: AudioSampleRate,
#[arg(long, env)]
bot_data: Storage,
#[arg(long, env)]
user_data: Storage,
#[arg(long, env)]
recording_data: Storage,
}
#[derive(Parser)]
struct LoggingArgs {
#[arg(
long = "logging-directives",
env = "RUST_LOG",
default_value = "warn,fomo_reducer=debug"
)]
env_filter: EnvFilter,
}
#[derive(Parser)]
struct Args {
#[clap(flatten)]
app_args: AppArgs,
#[clap(flatten)]
logging_args: LoggingArgs,
}
#[derive(Debug, Snafu)]
enum MainError {
/// the program was cancelled, perhaps by Ctrl-C / SIGINT
Cancelled,
}
#[snafu::report]
#[tokio::main]
async fn main() -> Result<(), MainError> {
let Args {
logging_args,
app_args,
} = Parser::parse();
let LoggingArgs { env_filter } = logging_args;
tracing_subscriber::fmt()
.pretty()
.with_env_filter(env_filter)
.with_span_events(FmtSpan::NEW | FmtSpan::CLOSE)
.init();
tracing::debug!(?app_args, "using");
let AppArgs {
discord_token,
discord_bot_owner_user_id,
discord_nickname,
discord_status,
discord_voice_channel_corresponding_text_channel,
audio_channels,
audio_sample_rate,
bot_data,
user_data,
recording_data,
} = app_args;
let cancellation_token = CancellationToken::new();
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
.unwrap();
let discord_client = twilight_http::Client::new(discord_token.expose_secret().to_owned());
let guilds = discord_client
.current_user_guilds()
.limit(200)
.await
.expect("TODO")
.model()
.await
.expect("TODO");
JoinSet::from_iter(guilds.into_iter().map(|guild| {
discord_client
.update_current_member(guild.id)
.nick(discord_nickname.as_deref())
.into_future()
}))
.join_all()
.await;
let discord_user = discord_client
.current_user()
.await
.expect("couldn't fetch current user") // TODO
.model()
.await
.expect("couldn't deserialize current user"); // TODO
let discord_user_id = discord_user.id;
let current_application = discord_client
.current_user_application()
.await
.expect("couldn't get current Discord application"); // TODO
let current_application = current_application
.model()
.await
.expect("couldn't get current Discord application"); // TODO
let discord_application_id = current_application.id;
let intents = Intents::GUILD_VOICE_STATES;
let config = twilight_gateway::Config::new(discord_token.expose_secret().to_owned(), intents);
let shards = twilight_gateway::create_recommended(&discord_client, config, |_id, builder| {
builder.build()
})
.await
.expect("TODO");
let shards = Vec::from_iter(shards);
let senders = TwilightMap::new(
shards
.iter()
.map(|shard| (shard.id().number(), shard.sender()))
.collect(),
);
let audio_channels = audio_channels.into();
let audio_sample_rate = audio_sample_rate.into();
let senders = Arc::new(senders);
let songbird = Songbird::twilight(senders, discord_user_id);
songbird.set_config(
Config::default().decode_mode(songbird::driver::DecodeMode::Decode(DecodeConfig::new(
audio_channels,
audio_sample_rate,
))),
);
let interaction_client = discord_client.interaction(discord_application_id);
let commands = all_commands();
let returned_commands = interaction_client
.set_global_commands(
Vec::from_iter(
commands
.iter()
.map(|(command, _handler)| (*command).clone()),
)
.as_slice(),
)
.await
.expect("failed to set interaction commands") // TODO
.models()
.await
.expect("failed to deserialize set commands"); // TODO
let mut discord_command_name_to_returned_command = BTreeMap::from_iter(
returned_commands
.into_iter()
.map(|command| (command.name.clone(), command)),
);
let discord_info_command = discord_command_name_to_returned_command
.remove(&command::info::COMMAND.name)
.expect("TODO");
let discord_opt_in_command = discord_command_name_to_returned_command
.remove(&command::opt_in::COMMAND.name)
.expect("TODO");
let discord_opt_out_command = discord_command_name_to_returned_command
.remove(&command::opt_out::COMMAND.name)
.expect("TODO");
let discord_info_command_id = discord_info_command.id.expect("TODO");
let discord_opt_in_command_id = discord_opt_in_command.id.expect("TODO");
let discord_opt_out_command_id = discord_opt_out_command.id.expect("TODO");
let discord_info_command_name = discord_info_command.name.into();
let discord_opt_in_command_name = discord_opt_in_command.name.into();
let discord_opt_out_command_name = discord_opt_out_command.name.into();
let vcs = initialize_vcs(&discord_client).await;
let command_router = CommandRouter::from_iter(commands);
let command_router = Arc::new(command_router);
let discord_client = Arc::new(discord_client);
let songbird = Arc::new(songbird);
let vcs_watcher = VCsWatcher::new(vcs);
let bot_data = bot_data.into_inner();
let recording_data = recording_data.into_inner();
let user_data = user_data.into_inner();
let bot_data_manager = BotDataManager::new(bot_data);
let user_data_manager = UserDataManager::new(user_data);
let discord_voice_channel_corresponding_text_channel = {
let mut map = GuildVoiceChannelToTextChannel::default();
for (guild_id, voice_channel_id, text_channel_id) in
discord_voice_channel_corresponding_text_channel
{
map.entry(guild_id)
.or_default()
.insert(voice_channel_id, text_channel_id);
}
map
};
let discord_voice_channel_corresponding_text_channel =
Arc::new(discord_voice_channel_corresponding_text_channel);
let state = State {
audio_channels,
audio_sample_rate,
bot_data_manager,
cancellation_token: cancellation_token.clone(),
discord_application_id,
discord_bot_owner_user_id,
discord_client,
discord_info_command_id,
discord_info_command_name,
discord_opt_in_command_id,
discord_opt_in_command_name,
discord_opt_out_command_id,
discord_opt_out_command_name,
discord_user_id,
discord_voice_channel_corresponding_text_channel,
recording_data,
songbird,
user_data_manager,
vcs_watcher,
};
if let Some(discord_status) = discord_status {
shards.iter().for_each(|shard| {
shard.command(
&UpdatePresence::new(
vec![
MinimalActivity {
kind: ActivityType::Listening,
name: (*discord_status).to_owned(),
url: None,
}
.into(),
],
false,
None,
Status::Idle,
)
.expect("TODO"),
)
});
}
let run_shards = JoinSet::from_iter(
shards
.into_iter()
.map(|shard| handle_events(command_router.clone(), state.clone(), shard)),
);
let run_shards = run_shards.join_all();
tokio::pin!(run_shards);
tokio::spawn({
let cancellation_token = cancellation_token.clone();
async move {
match ctrl_c().await {
Ok(()) => cancellation_token.cancel(),
Err(error) => tracing::error!(?error, "failed to listen for interrupt signal"),
}
}
});
select! {
_ = &mut run_shards => {
Ok(())
}
() = cancellation_token.cancelled() => {
tracing::warn!("waiting for tasks to gracefully shut down");
run_shards.await;
Err(MainError::Cancelled)
}
}
}
#[tracing::instrument(skip(command_router, shard, state))]
async fn handle_events(command_router: Arc<CommandRouter>, state: State, mut shard: Shard) {
let event_types = EventTypeFlags::GUILD_VOICE_STATES
| EventTypeFlags::INTERACTION_CREATE
| EventTypeFlags::VOICE_SERVER_UPDATE
| EventTypeFlags::VOICE_STATE_UPDATE;
while let Some(Some(event_res)) = shard
.next_event(event_types)
.with_cancellation_token(&state.cancellation_token)
.await
{
match event_res {
Ok(twilight_model::gateway::event::Event::GatewayClose(frame_option)) => {
tracing::warn!(?frame_option);
return;
}
Ok(event) => {
handle_event(command_router.clone(), state.clone(), event).await;
}
Err(reconnect_error)
if matches!(
reconnect_error.kind(),
&twilight_gateway::error::ReceiveMessageErrorType::Reconnect
) =>
{
tracing::error!(?reconnect_error);
return;
}
Err(error) => {
tracing::error!(?error);
}
}
}
}
#[tracing::instrument(skip(command_router, state))]
async fn handle_event(command_router: Arc<CommandRouter>, state: State, event: Event) {
state.songbird.process(&event).await;
match event {
Event::VoiceStateUpdate(voice_state_update) => {
state.vcs_watcher.send_modify(|vcs| update_vcs(&voice_state_update, vcs));
}
Event::InteractionCreate(interaction_create) => {
let InteractionCreate(interaction) = *interaction_create;
match &interaction.data {
None => {
tracing::warn!("missing expected interaction data");
}
Some(InteractionData::ApplicationCommand(command_data)) => {
let command_name = command_data.name.clone();
tokio::spawn(async move {
command_router
.handle(state, &command_name, interaction)
.await;
});
}
Some(InteractionData::MessageComponent(component_data)) => {
tracing::warn!(
?component_data,
"wasn't expected because this bot has no modal features"
);
}
Some(InteractionData::ModalSubmit(modal_data)) => {
tracing::warn!(
?modal_data,
"wasn't expected because this bot has no modal features"
);
}
Some(other_interaction_data) => {
tracing::warn!(?other_interaction_data, "wasn't expected");
}
}
}
other_event => {
tracing::warn!(?other_event, "wasn't expected");
}
}
}