Files
fomo-reducer/src/main.rs

592 lines
19 KiB
Rust

use clap::Parser;
use fomo_reducer::{
AudioChannels, AudioSampleRate, BotManager, CommandRouter, GuildVoiceChannelToTextChannel,
RecordingManager, RenderManager, State, Storage, UserManager, VCsSender, all_commands, command,
heat_seek, initialize_vcs, update_vcs,
};
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use secrecy::{ExposeSecret, SecretString};
use snafu::{OptionExt, ResultExt, Snafu};
use songbird::{Config, Songbird, driver::DecodeConfig, shards::TwilightMap};
use std::{
collections::BTreeMap,
fmt::{Debug, Display},
num::NonZero,
str::FromStr,
sync::Arc,
time::Duration,
};
use tokio::{select, signal::ctrl_c, task::JoinSet};
use tokio_util::{sync::CancellationToken, time::FutureExt as _};
use tracing::Level;
use tracing_subscriber::{
EnvFilter,
fmt::{format::FmtSpan, writer::MakeWriterExt},
};
use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, StreamExt};
use twilight_model::{
application::interaction::InteractionData,
gateway::{
payload::{incoming::InteractionCreate, outgoing::UpdatePresence},
presence::{ActivityType, MinimalActivity, Status},
},
id::{
Id,
marker::{ChannelMarker, GuildMarker, UserMarker},
},
};
#[derive(Debug, Snafu)]
enum ParseGuildVCToTextChannelError {
/// the guild ID needs to be included with : before the voice channel to text channel mapping
NoScope,
/// a voice channel ID needs to be specified then -> to the corresponding text channel ID
NoRelation,
/// could not parse the guild ID
ParseGuildError {
source: <Id<GuildMarker> as FromStr>::Err,
},
/// could not parse the voice channel ID
ParseVoiceChannelError {
source: <Id<ChannelMarker> as FromStr>::Err,
},
/// could not parse the text channel ID
ParseTextChannelError {
source: <Id<ChannelMarker> as FromStr>::Err,
},
}
fn parse_guild_vc_to_text_channel(
source: &str,
) -> Result<(Id<GuildMarker>, Id<ChannelMarker>, Id<ChannelMarker>), ParseGuildVCToTextChannelError>
{
let (guild, voice_channel_and_text_channel) = source.split_once(':').context(NoScopeSnafu)?;
let (voice_channel, text_channel) = voice_channel_and_text_channel
.split_once("->")
.context(NoRelationSnafu)?;
let guild = guild.parse().context(ParseGuildSnafu)?;
let voice_channel = voice_channel.parse().context(ParseVoiceChannelSnafu)?;
let text_channel = text_channel.parse().context(ParseTextChannelSnafu)?;
Ok((guild, voice_channel, text_channel))
}
#[derive(Clone)]
struct HumanDuration(Duration);
impl FromStr for HumanDuration {
type Err = humantime::DurationError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
humantime::parse_duration(s).map(Self)
}
}
impl Debug for HumanDuration {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl Display for HumanDuration {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", humantime::format_duration(self.0))
}
}
#[derive(Debug, Parser)]
struct AppArgs {
#[arg(long, env)]
discord_token: SecretString,
#[arg(long, env)]
discord_bot_owner_user_id: Id<UserMarker>,
#[arg(long, env)]
discord_nickname: Option<Arc<str>>,
#[arg(long, env)]
discord_status: Option<String>,
#[arg(long, env, value_parser = parse_guild_vc_to_text_channel)]
discord_voice_channel_corresponding_text_channel:
Vec<(Id<GuildMarker>, Id<ChannelMarker>, Id<ChannelMarker>)>,
#[arg(long, env, default_value_t = AudioChannels::Mono)]
audio_channels: AudioChannels,
#[arg(long, env, default_value_t = AudioSampleRate::Hz24000)]
audio_sample_rate: AudioSampleRate,
#[arg(long, env)]
bot_data: Storage,
#[arg(long, env)]
user_data: Storage,
#[arg(long, env)]
recording_data: Storage,
#[arg(long, env)]
render_data: Storage,
#[arg(long, env, default_value_t = HumanDuration(Duration::from_secs(5)))]
watchdog_frequency: HumanDuration,
#[arg(long, env, default_value_t = 8.try_into().unwrap())]
watchdog_channel_size: NonZero<usize>,
}
#[derive(Parser)]
struct LoggingArgs {
#[arg(
long = "logging-directives",
env = "RUST_LOG",
default_value = "warn,fomo_reducer=debug"
)]
env_filter: EnvFilter,
}
#[derive(Parser)]
struct Args {
#[clap(flatten)]
app_args: AppArgs,
#[clap(flatten)]
logging_args: LoggingArgs,
}
#[derive(Debug, Snafu)]
enum MainError {
/// the program was cancelled, perhaps by Ctrl-C / SIGINT
Cancelled,
}
#[snafu::report]
#[tokio::main]
async fn main() -> Result<(), MainError> {
let Args {
logging_args,
app_args,
} = Parser::parse();
let LoggingArgs { env_filter } = logging_args;
let (stdout, _stdout_guard) = tracing_appender::non_blocking(std::io::stdout());
let (stderr, _stderr_guard) = tracing_appender::non_blocking(std::io::stderr());
let writer = stderr.with_max_level(Level::WARN).or_else(stdout);
tracing_subscriber::fmt()
.pretty()
.with_env_filter(env_filter)
.with_span_events(FmtSpan::NEW | FmtSpan::CLOSE)
.with_writer(writer)
.init();
tracing::debug!(?app_args, "using");
let AppArgs {
discord_token,
discord_bot_owner_user_id,
discord_nickname,
discord_status,
discord_voice_channel_corresponding_text_channel,
audio_channels,
audio_sample_rate,
bot_data,
user_data,
recording_data,
render_data,
watchdog_frequency: HumanDuration(watchdog_frequency),
watchdog_channel_size,
} = app_args;
let cancellation_token = CancellationToken::new();
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
.unwrap();
let discord_client = twilight_http::Client::new(discord_token.expose_secret().to_owned());
let guilds = discord_client
.current_user_guilds()
.limit(200)
.await
.expect("TODO")
.model()
.await
.expect("TODO");
JoinSet::from_iter(guilds.into_iter().map(|guild| {
discord_client
.update_current_member(guild.id)
.nick(discord_nickname.as_deref())
.into_future()
}))
.join_all()
.await;
let discord_user = discord_client
.current_user()
.await
.expect("couldn't fetch current user") // TODO
.model()
.await
.expect("couldn't deserialize current user"); // TODO
let discord_user_id = discord_user.id;
let current_application = discord_client
.current_user_application()
.await
.expect("couldn't get current Discord application"); // TODO
let current_application = current_application
.model()
.await
.expect("couldn't get current Discord application"); // TODO
let discord_application_id = current_application.id;
let interaction_client = discord_client.interaction(discord_application_id);
let commands = all_commands();
let returned_commands = interaction_client
.set_global_commands(
Vec::from_iter(
commands
.iter()
.map(|(command, _handler)| (*command).clone()),
)
.as_slice(),
)
.await
.expect("failed to set interaction commands") // TODO
.models()
.await
.expect("failed to deserialize set commands"); // TODO
let mut discord_command_name_to_returned_command = BTreeMap::from_iter(
returned_commands
.into_iter()
.map(|command| (command.name.clone(), command)),
);
let discord_info_command = discord_command_name_to_returned_command
.remove(&command::info::COMMAND.name)
.expect("TODO");
let discord_opt_in_command = discord_command_name_to_returned_command
.remove(&command::opt_in::COMMAND.name)
.expect("TODO");
let discord_opt_out_command = discord_command_name_to_returned_command
.remove(&command::opt_out::COMMAND.name)
.expect("TODO");
let discord_info_command_id = discord_info_command.id.expect("TODO");
let discord_opt_in_command_id = discord_opt_in_command.id.expect("TODO");
let discord_opt_out_command_id = discord_opt_out_command.id.expect("TODO");
let discord_info_command_name: Arc<str> = discord_info_command.name.into();
let discord_opt_in_command_name: Arc<str> = discord_opt_in_command.name.into();
let discord_opt_out_command_name: Arc<str> = discord_opt_out_command.name.into();
let command_router = CommandRouter::from_iter(commands);
let command_router = Arc::new(command_router);
let discord_client = Arc::new(discord_client);
let vcs_sender = VCsSender::new(Default::default());
let bot_data = bot_data.into_inner();
let recording_data = recording_data.into_inner();
let render_data = render_data.into_inner();
let user_data = user_data.into_inner();
let bot_manager = BotManager::new(bot_data);
let recording_manager = RecordingManager::new(recording_data);
let render_manager = RenderManager::new(render_data);
let user_manager = UserManager::new(user_data);
let discord_voice_channel_corresponding_text_channel = {
let mut map = GuildVoiceChannelToTextChannel::default();
for (guild_id, voice_channel_id, text_channel_id) in
discord_voice_channel_corresponding_text_channel
{
map.entry(guild_id)
.or_default()
.insert(voice_channel_id, text_channel_id);
}
map
};
let discord_voice_channel_corresponding_text_channel =
Arc::new(discord_voice_channel_corresponding_text_channel);
tokio::spawn({
let cancellation_token = cancellation_token.clone();
async move {
match ctrl_c().await {
Ok(()) => cancellation_token.cancel(),
Err(error) => tracing::error!(?error, "failed to listen for interrupt signal"),
}
}
});
let (mut watchdog_tx, mut watchdog_rx) =
futures::channel::mpsc::channel(watchdog_channel_size.get());
std::thread::spawn({
let discord_voice_channel_corresponding_text_channel =
discord_voice_channel_corresponding_text_channel.clone();
let discord_client = discord_client.clone();
let vcs_watcher = vcs_sender.subscribe();
move || {
loop {
if watchdog_tx.try_send(()).is_err() {
tracing::error!("tokio runtime deadlocked");
vcs_watcher.borrow().par_iter().for_each(|(&guild_id, vcs_in_guild)| {
if let Some(&voice_channel_id) = vcs_in_guild.get_left_for(&discord_user_id) {
let text_channel_id =
discord_voice_channel_corresponding_text_channel
.get(&guild_id)
.and_then(|guild_mappings| {
guild_mappings.get_right_for(&voice_channel_id).copied()
})
.unwrap_or(voice_channel_id);
let _ = futures::executor::block_on(discord_client.create_message(text_channel_id).content("so sorry I died, I'm in purgatory now, I don't like it here.\nbut I will be back in 5-20 minutes (even if it says I'm still there, I'm not currently recording and will be disconnected soon before later reconnecting and announcing recording again)").into_future());
}
});
std::process::exit(1);
}
std::thread::sleep(watchdog_frequency);
}
}
});
tokio::spawn(async move {
loop {
if watchdog_rx.recv().await.is_err() {
tracing::error!("watchdog died (this should be impossible)");
std::process::exit(1);
}
}
});
loop {
tokio::spawn({
let vcs_sender = vcs_sender.clone();
let discord_client = discord_client.clone();
async move { initialize_vcs(&vcs_sender, &discord_client).await }
});
let intents = Intents::GUILD_VOICE_STATES;
let config =
twilight_gateway::Config::new(discord_token.expose_secret().to_owned(), intents);
let shards =
twilight_gateway::create_recommended(&discord_client, config, |_id, builder| {
builder.build()
})
.await
.expect("TODO");
let shards = Vec::from_iter(shards);
let senders = TwilightMap::new(
shards
.iter()
.map(|shard| (shard.id().number(), shard.sender()))
.collect(),
);
let senders = Arc::new(senders);
let songbird = Songbird::twilight(senders, discord_user_id);
songbird.set_config(
Config::default().decode_mode(songbird::driver::DecodeMode::Decode(DecodeConfig::new(
audio_channels.into(),
audio_sample_rate.into(),
))),
);
if let Some(discord_status) = &discord_status {
shards.iter().for_each(|shard| {
shard.command(
&UpdatePresence::new(
vec![
MinimalActivity {
kind: ActivityType::Listening,
name: discord_status.clone(),
url: None,
}
.into(),
],
false,
None,
Status::Idle,
)
.expect("TODO"),
)
});
}
let songbird = Arc::new(songbird);
let state = State {
audio_channels,
audio_sample_rate,
bot_manager: bot_manager.clone(),
cancellation_token: cancellation_token.clone(),
discord_application_id,
discord_bot_owner_user_id,
discord_client: discord_client.clone(),
discord_info_command_id,
discord_info_command_name: discord_info_command_name.clone(),
discord_opt_in_command_id,
discord_opt_in_command_name: discord_opt_in_command_name.clone(),
discord_opt_out_command_id,
discord_opt_out_command_name: discord_opt_out_command_name.clone(),
discord_user_id,
discord_voice_channel_corresponding_text_channel:
discord_voice_channel_corresponding_text_channel.clone(),
recording_manager: recording_manager.clone(),
render_manager: render_manager.clone(),
songbird,
user_manager: user_manager.clone(),
vcs_sender: vcs_sender.clone(),
};
let mut heat_seeking = tokio::spawn(heat_seek(state.clone()));
let run_shards = shards
.into_iter()
.map(|shard| handle_events(command_router.clone(), state.clone(), shard));
let mut run_shards = JoinSet::from_iter(run_shards);
select! {
_heat_seeking_exited = &mut heat_seeking => {
// this shouldn't happen, but let's try again
continue;
}
_first_shard_exited = run_shards.join_next() => {
heat_seeking.abort();
continue;
}
() = cancellation_token.cancelled() => {
heat_seeking.await.unwrap();
run_shards.join_all().await;
return Err(MainError::Cancelled);
}
}
}
}
#[tracing::instrument(skip(command_router, shard, state))]
async fn handle_events(command_router: Arc<CommandRouter>, state: State, mut shard: Shard) {
let event_types = EventTypeFlags::GUILD_VOICE_STATES
| EventTypeFlags::INTERACTION_CREATE
| EventTypeFlags::VOICE_SERVER_UPDATE
| EventTypeFlags::VOICE_STATE_UPDATE;
while let Some(Some(event_res)) = shard
.next_event(event_types)
.with_cancellation_token(&state.cancellation_token)
.await
{
match event_res {
Ok(twilight_model::gateway::event::Event::GatewayClose(frame_option)) => {
tracing::warn!(?frame_option);
break;
}
Ok(event) => {
handle_event(command_router.clone(), state.clone(), event).await;
}
Err(reconnect_error)
if matches!(
reconnect_error.kind(),
&twilight_gateway::error::ReceiveMessageErrorType::Reconnect
) =>
{
tracing::error!(?reconnect_error);
break;
}
Err(error) => {
tracing::error!(?error);
}
}
}
state.cancellation_token.cancel();
}
#[tracing::instrument(skip(command_router, state))]
async fn handle_event(command_router: Arc<CommandRouter>, state: State, event: Event) {
tokio::spawn({
let event = event.clone();
let songbird = state.songbird.clone();
async move {
songbird.process(&event).await;
}
})
.await
.unwrap();
match event {
Event::VoiceStateUpdate(voice_state_update) => {
state
.vcs_sender
.send_modify(|vcs| update_vcs(&voice_state_update, vcs));
}
Event::InteractionCreate(interaction_create) => {
let InteractionCreate(interaction) = *interaction_create;
match &interaction.data {
None => {
tracing::warn!("missing expected interaction data");
}
Some(InteractionData::ApplicationCommand(command_data)) => {
let command_name = command_data.name.clone();
tokio::spawn(async move {
command_router
.handle(state, &command_name, interaction)
.await;
});
}
Some(InteractionData::MessageComponent(component_data)) => {
tracing::warn!(
?component_data,
"wasn't expected because this bot has no modal features"
);
}
Some(InteractionData::ModalSubmit(modal_data)) => {
tracing::warn!(
?modal_data,
"wasn't expected because this bot has no modal features"
);
}
Some(other_interaction_data) => {
tracing::warn!(?other_interaction_data, "wasn't expected");
}
}
}
other_event => {
tracing::warn!(?other_event, "wasn't expected");
}
}
}