From d2511f7a55d71c0e20ee9d1b8328b5ade760e860 Mon Sep 17 00:00:00 2001 From: Jacob Date: Wed, 8 Apr 2026 22:18:32 -0400 Subject: [PATCH] feat: graceful shutdown, try making join and leave work (but some bug fixes are still needed) --- Cargo.lock | 34 +++++++++++-- Cargo.toml | 11 ++-- src/command/join.rs | 85 +++++++++++++------------------ src/command/leave.rs | 118 +++++++++++++++++++++++++++++++++++++++---- src/command/mod.rs | 7 ++- src/main.rs | 114 +++++++++++++++++++++++++++++++---------- 6 files changed, 272 insertions(+), 97 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 85142f8..7c3c571 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -578,6 +578,24 @@ dependencies = [ "serde_core", ] +[[package]] +name = "capnp" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d1c82ec25a9501d60e22eef4be1b2c271769b5a96e224d0875baef28529cf30" +dependencies = [ + "embedded-io", +] + +[[package]] +name = "capnpc" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fca02be865c8c5a78bfc24b9819006ab6b59bef238467203928e26459557af93" +dependencies = [ + "capnp", +] + [[package]] name = "cargo-platform" version = "0.1.9" @@ -1409,6 +1427,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "embedded-io" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9eb1aa714776b75c7e67e1da744b81a129b3ff919c8712b5e1b32252c1f07cc7" + [[package]] name = "encoding_rs" version = "0.8.35" @@ -1628,6 +1652,8 @@ checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" name = "fomo-reducer" version = "0.1.0" dependencies = [ + "capnp", + "capnpc", "clap", "dashmap 6.1.0", "futures", @@ -1639,6 +1665,7 @@ dependencies = [ "snafu", "songbird", "tokio", + "tokio-util", "tracing", "tracing-subscriber", "twilight-gateway", @@ -6473,7 +6500,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0136791f7c95b1f6dd99f9cc786b91bb81c3800b639b3478e561ddb7be95e5f1" dependencies = [ "fastrand", - "getrandom 0.3.4", + "getrandom 0.4.2", "once_cell", "rustix", "windows-sys 0.61.2", @@ -6714,9 +6741,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.17" +version = "0.7.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2efa149fe76073d6e8fd97ef4f4eca7b67f599660115591483572e406e165594" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" dependencies = [ "bytes", "futures-core", @@ -7042,6 +7069,7 @@ dependencies = [ "tokio-websockets 0.13.2", "tracing", "twilight-gateway-queue", + "twilight-http", "twilight-model", ] diff --git a/Cargo.toml b/Cargo.toml index b0cca07..b11e71b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2024" [dependencies] +capnp = "0.25.3" clap = { version = "4.5.40", features = ["derive", "env"] } dashmap = "6.1.0" futures = "0.3.32" @@ -57,12 +58,11 @@ songbird = { version = "0.6.0", default-features = false, features = [ "twilight", "tws", ] } -tokio = { version = "1.46.0", features = ["rt-multi-thread", "macros"] } +tokio = { version = "1.46.0", features = ["rt-multi-thread", "macros", "signal"] } +tokio-util = "0.7.18" tracing = "0.1.41" tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } -twilight-gateway = { version = "0.17", default-features = false, features = [ - "rustls-webpki-roots", -] } +twilight-gateway = { version = "0.17", default-features = false, features = ["rustls-webpki-roots", "twilight-http"] } twilight-http = { version = "0.17", default-features = false, features = [ "rustls-webpki-roots", "hickory", @@ -71,3 +71,6 @@ twilight-http = { version = "0.17", default-features = false, features = [ twilight-model = "0.17" twilight-util = { version = "0.17", features = ["builder"] } typed-builder = "0.23.2" + +[build-dependencies] +capnpc = "0.25.3" diff --git a/src/command/join.rs b/src/command/join.rs index bcdefce..f8e32e4 100644 --- a/src/command/join.rs +++ b/src/command/join.rs @@ -1,6 +1,6 @@ -use std::sync::LazyLock; - +use crate::{VCs, command::State}; use snafu::{OptionExt, Snafu}; +use std::sync::LazyLock; use twilight_model::{ application::{ command::{Command, CommandType}, @@ -17,8 +17,6 @@ use twilight_util::builder::{ InteractionResponseDataBuilder, command::CommandBuilder, embed::EmbedBuilder, }; -use crate::{VCs, command::State}; - const NAME: &str = "join"; const DESCRIPTION: &str = "The bot will join the same VC as you (with intention to record)"; @@ -30,7 +28,7 @@ pub static COMMAND: LazyLock = LazyLock::new(|| { }); #[derive(Debug, Snafu)] -enum GetGuildAndChannelIdError { +enum GetGuildAndVoiceChannelIdError { /// this command was not used inside a guild (Discord server) NotInGuild, @@ -44,29 +42,11 @@ enum GetGuildAndChannelIdError { UserNotInVC, } -impl From for Embed { - fn from(error: GetGuildAndChannelIdError) -> Embed { - match error { - GetGuildAndChannelIdError::NotInGuild => { - EmbedBuilder::new().title("Use this in a server").description("This bot can't find a VC to join if the command is used outside of a server (you might've used it in a DM?).").validate().unwrap().build() - } - GetGuildAndChannelIdError::NoUser => EmbedBuilder::new().title("Not invoked by a user").description("This command works by joining the same VC as the user, but this bot didn't receive any user data. So did no user invoke it?! (This error should be impossible!)").validate().unwrap().build(), - GetGuildAndChannelIdError::NoVCsInGuild => { - EmbedBuilder::new().title("No VCs in this server").description("This bot can't find a VC to join because there aren't any in this server right now.").validate().unwrap().build() - }, - GetGuildAndChannelIdError::UserNotInVC => { - - EmbedBuilder::new().title("You're not in a VC").description("This bot can't follow you into VC if you aren't in one in this server.").validate().unwrap().build() - }, - } - } -} - #[tracing::instrument] -fn get_guild_and_channel_id( +fn get_guild_and_voice_channel_id( interaction: &Interaction, vcs: &VCs, -) -> Result<(Id, Id), GetGuildAndChannelIdError> { +) -> Result<(Id, Id), GetGuildAndVoiceChannelIdError> { let guild_id = interaction.guild_id.context(NotInGuildSnafu)?; let user_id = interaction @@ -77,17 +57,34 @@ fn get_guild_and_channel_id( let guild_vcs = vcs.get(&guild_id).context(NoVCsInGuildSnafu)?; - let &channel_id = guild_vcs.get_left_for(&user_id).context(UserNotInVCSnafu)?; + let &voice_channel_id = guild_vcs.get_left_for(&user_id).context(UserNotInVCSnafu)?; - Ok((guild_id, channel_id)) + Ok((guild_id, voice_channel_id)) +} + +fn get_guild_and_vc_error_to_embed(error: GetGuildAndVoiceChannelIdError) -> Embed { + match error { + GetGuildAndVoiceChannelIdError::NotInGuild => { + EmbedBuilder::new().title("Use this in a server").description("This bot can't find a VC to join if the command is used outside of a server (you might've used it in a DM?).").validate().unwrap().build() + } + GetGuildAndVoiceChannelIdError::NoUser => { + EmbedBuilder::new().title("Not invoked by a user").description("This command works by joining the same VC as the user, but this bot didn't receive any user data. So did no user invoke it?! (This error should be impossible!)").validate().unwrap().build() + }, + GetGuildAndVoiceChannelIdError::NoVCsInGuild => { + EmbedBuilder::new().title("No VCs in this server").description("This bot can't find a VC to join because there aren't any in this server right now.").validate().unwrap().build() + }, + GetGuildAndVoiceChannelIdError::UserNotInVC => { + EmbedBuilder::new().title("You're not in a VC").description("This bot can't follow you into VC if you aren't in one in this server.").validate().unwrap().build() + }, + } } #[tracing::instrument(skip(state))] pub async fn handle(state: State, interaction: Interaction) { let vcs = state.vcs; - let (guild_id, channel_id) = match get_guild_and_channel_id(&interaction, &vcs) { - Ok((guild_id, channel_id)) => (guild_id, channel_id), + let (guild_id, voice_channel_id) = match get_guild_and_voice_channel_id(&interaction, &vcs) { + Ok((guild_id, voice_channel_id)) => (guild_id, voice_channel_id), Err(error) => { state .discord_client @@ -99,7 +96,7 @@ pub async fn handle(state: State, interaction: Interaction) { kind: InteractionResponseType::ChannelMessageWithSource, data: Some( InteractionResponseDataBuilder::new() - .embeds([error.into()]) + .embeds([get_guild_and_vc_error_to_embed(error)]) .flags(MessageFlags::EPHEMERAL) .build(), ), @@ -126,24 +123,15 @@ pub async fn handle(state: State, interaction: Interaction) { .await .expect("TODO"); - let call = loop { - tracing::error!("TODO: about to try joining"); - - match state.songbird.join(guild_id, channel_id).await { - Ok(call) => break call, - Err(error) => { - tracing::error!(?error, "I'm still here"); + let call = state + .songbird + .join(guild_id, voice_channel_id) + .await + .expect("TODO"); - if error.should_leave_server() { - state.songbird.leave(guild_id).await.expect("TODO"); - } else if error.should_reconnect_driver() { - todo!(); - } - } - } - }; + tracing::error!(?call, "successfully joined"); - let channel_mention = format!("<#{channel_id}>"); + let channel_mention = format!("<#{voice_channel_id}>"); state .discord_client @@ -155,9 +143,4 @@ pub async fn handle(state: State, interaction: Interaction) { ])) .await .expect("TODO"); - - tracing::error!(?call, "TODO"); - - let call_guard = call.lock().await; - tracing::error!(?call_guard, "TODO"); } diff --git a/src/command/leave.rs b/src/command/leave.rs index 65315af..5c2ecaf 100644 --- a/src/command/leave.rs +++ b/src/command/leave.rs @@ -1,12 +1,23 @@ -use std::sync::LazyLock; - -use twilight_model::application::{ - command::{Command, CommandType}, - interaction::Interaction, -}; -use twilight_util::builder::command::CommandBuilder; - +use crate::VCs; use crate::command::State; +use snafu::{OptionExt, Snafu}; +use std::sync::LazyLock; +use twilight_model::channel::message::{Embed, MessageFlags}; +use twilight_model::http::interaction::{InteractionResponse, InteractionResponseType}; +use twilight_model::id::marker::UserMarker; +use twilight_model::{ + application::{ + command::{Command, CommandType}, + interaction::Interaction, + }, + id::{ + Id, + marker::{ChannelMarker, GuildMarker}, + }, +}; +use twilight_util::builder::InteractionResponseDataBuilder; +use twilight_util::builder::command::CommandBuilder; +use twilight_util::builder::embed::EmbedBuilder; const NAME: &str = "leave"; const DESCRIPTION: &str = "The bot will leave the VC it's in (so it won't record anyone anymore)"; @@ -18,7 +29,96 @@ pub static COMMAND: LazyLock = LazyLock::new(|| { .build() }); +#[derive(Debug, Snafu)] +pub enum GetGuildAndVoiceChannelIdError { + /// this command was not used inside a guild (Discord server) + NotInGuild, + + /// there are no voice chats in this guild + NoVCsInGuild, + + /// the bot is not in a voice chat in this guild + BotNotInVC, +} + +#[tracing::instrument] +pub fn get_guild_and_voice_channel_id( + bot_user_id: Id, + interaction: &Interaction, + vcs: &VCs, +) -> Result<(Id, Id), GetGuildAndVoiceChannelIdError> { + let guild_id = interaction.guild_id.context(NotInGuildSnafu)?; + + let guild_vcs = vcs.get(&guild_id).context(NoVCsInGuildSnafu)?; + + let &voice_channel_id = guild_vcs + .get_left_for(&bot_user_id) + .context(BotNotInVCSnafu)?; + + Ok((guild_id, voice_channel_id)) +} + +fn get_guild_and_vc_error_to_embed(error: GetGuildAndVoiceChannelIdError) -> Embed { + match error { + GetGuildAndVoiceChannelIdError::NotInGuild => { + EmbedBuilder::new().title("Use this in a server").description("This bot can't tell which VC to leave if the command is used outside of a server (you might've used it in a DM?).").validate().unwrap().build() + } + GetGuildAndVoiceChannelIdError::NoVCsInGuild => { + EmbedBuilder::new().title("No VCs in this server").description("This bot can't leave VC because there aren't any in this server right now (therefore the bot must not be in any).").validate().unwrap().build() + }, + GetGuildAndVoiceChannelIdError::BotNotInVC => { + EmbedBuilder::new().title("Not in a VC").description("This bot can't leave VC if it isn't in one in this server.").validate().unwrap().build() + }, + } +} + #[tracing::instrument] pub async fn handle(state: State, interaction: Interaction) { - todo!(); + let (guild_id, voice_channel_id) = + match get_guild_and_voice_channel_id(state.discord_user_id, &interaction, &state.vcs) { + Ok((guild_id, voice_channel_id)) => (guild_id, voice_channel_id), + Err(error) => { + state + .discord_client + .interaction(state.discord_application_id) + .create_response( + interaction.id, + &interaction.token, + &InteractionResponse { + kind: InteractionResponseType::ChannelMessageWithSource, + data: Some( + InteractionResponseDataBuilder::new() + .embeds([get_guild_and_vc_error_to_embed(error)]) + .flags(MessageFlags::EPHEMERAL) + .build(), + ), + }, + ) + .await + .expect("TODO"); + + return; + } + }; + + state.songbird.leave(guild_id).await.expect("TODO"); + + tracing::error!("TODO: successfully left the call"); + + let channel_mention = format!("<#{voice_channel_id}>"); + + state + .discord_client + .interaction(state.discord_application_id) + .update_response(&interaction.token) + .embeds(Some(&[EmbedBuilder::new() + .title("Left VC") + .description(format!( + "This bot left {channel_mention} (and is thereby unable to record anymore)." + )) + .validate() + .unwrap() + .build()])) + .await + .expect("TODO"); } diff --git a/src/command/mod.rs b/src/command/mod.rs index cf821af..d28c58a 100644 --- a/src/command/mod.rs +++ b/src/command/mod.rs @@ -3,9 +3,10 @@ use std::{fmt::Debug, sync::Arc}; use futures::future::BoxFuture; use patricia_tree::StringPatriciaMap; use songbird::Songbird; +use tokio_util::sync::CancellationToken; use twilight_model::{ application::{command::Command, interaction::Interaction}, - id::{Id, marker::ApplicationMarker}, + id::{Id, marker::{ApplicationMarker, UserMarker}}, }; use crate::VCs; @@ -16,8 +17,10 @@ mod opt_out; #[derive(Debug, Clone)] pub struct State { - pub discord_client: Arc, + pub cancellation_token: CancellationToken, pub discord_application_id: Id, + pub discord_client: Arc, + pub discord_user_id: Id, pub songbird: Arc, pub vcs: Arc, } diff --git a/src/main.rs b/src/main.rs index 360dfff..f072a29 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,8 @@ use secrecy::{ExposeSecret, SecretString}; use snafu::Snafu; use songbird::{Songbird, shards::TwilightMap}; use std::{fmt::Debug, str::FromStr, sync::Arc}; +use tokio::{select, signal::ctrl_c, task::JoinSet}; +use tokio_util::{sync::CancellationToken, time::FutureExt as _}; use tracing_subscriber::{EnvFilter, fmt::format::FmtSpan}; use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt}; use twilight_model::{ @@ -14,12 +16,12 @@ use twilight_model::{ }; #[derive(Clone)] -struct OpendalOperator { +struct Storage { uri: OperatorUri, operator: Operator, } -impl FromStr for OpendalOperator { +impl FromStr for Storage { type Err = opendal::Error; fn from_str(s: &str) -> Result { @@ -30,19 +32,19 @@ impl FromStr for OpendalOperator { } } -impl OpendalOperator { +impl Storage { fn into_inner(self) -> Operator { self.operator } } -impl From for Operator { - fn from(wrapper: OpendalOperator) -> Self { +impl From for Operator { + fn from(wrapper: Storage) -> Self { wrapper.into_inner() } } -impl Debug for OpendalOperator { +impl Debug for Storage { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Debug::fmt(&self.uri, f) } @@ -54,10 +56,16 @@ struct AppArgs { discord_token: SecretString, #[arg(long, env)] - storage: OpendalOperator, + bot_owner: Id, #[arg(long, env)] - bot_owner: Id, + bot_data: Storage, + + #[arg(long, env)] + user_data: Storage, + + #[arg(long, env)] + recording_data: Storage, } #[derive(Parser)] @@ -80,7 +88,10 @@ struct Args { } #[derive(Debug, Snafu)] -enum MainError {} +enum MainError { + /// the program was cancelled, perhaps by Ctrl-C / SIGINT + Cancelled, +} #[snafu::report] #[tokio::main] @@ -102,17 +113,21 @@ async fn main() -> Result<(), MainError> { let AppArgs { discord_token, - storage, bot_owner, + bot_data, + user_data, + recording_data, } = app_args; + let cancellation_token = CancellationToken::new(); + rustls::crypto::aws_lc_rs::default_provider() .install_default() .unwrap(); let discord_client = twilight_http::Client::new(discord_token.expose_secret().to_owned()); - let user = discord_client + let discord_user = discord_client .current_user() .await .expect("couldn't fetch current user") // TODO @@ -120,7 +135,7 @@ async fn main() -> Result<(), MainError> { .await .expect("couldn't deserialize current user"); // TODO - let user_id = user.id; + let discord_user_id = discord_user.id; let current_application = discord_client .current_user_application() @@ -134,21 +149,26 @@ async fn main() -> Result<(), MainError> { let discord_application_id = current_application.id; - let shard_id = ShardId::new(0, 1); let intents = Intents::GUILD_VOICE_STATES; - let mut shard = Shard::new(shard_id, discord_token.expose_secret().to_owned(), intents); + let config = twilight_gateway::Config::new(discord_token.expose_secret().to_owned(), intents); - let senders = TwilightMap::new(FromIterator::from_iter([( - shard.id().number(), - shard.sender(), - )])); + let shards = twilight_gateway::create_recommended(&discord_client, config, |_id, builder| { + builder.build() + }) + .await + .expect("TODO"); + let shards = Vec::from_iter(shards); + + let senders = TwilightMap::new( + shards + .iter() + .map(|shard| (shard.id().number(), shard.sender())) + .collect(), + ); let senders = Arc::new(senders); - let songbird = Songbird::twilight(senders, user_id); - - let event_types = EventTypeFlags::GUILD_VOICE_STATES | EventTypeFlags::INTERACTION_CREATE; - let mut next_event = shard.next_event(event_types); + let songbird = Songbird::twilight(senders, discord_user_id); let interaction_client = discord_client.interaction(discord_application_id); @@ -179,13 +199,54 @@ async fn main() -> Result<(), MainError> { let vcs = Arc::new(vcs); let state = State { + cancellation_token: cancellation_token.clone(), discord_application_id, discord_client, + discord_user_id, songbird, vcs, }; - while let Some(event_res) = next_event.await { + let run_shards = JoinSet::from_iter( + shards + .into_iter() + .map(|shard| handle_events(command_router.clone(), state.clone(), shard)), + ); + let run_shards = run_shards.join_all(); + tokio::pin!(run_shards); + + tokio::spawn({ + let cancellation_token = cancellation_token.clone(); + async move { + match ctrl_c().await { + Ok(()) => cancellation_token.cancel(), + Err(error) => tracing::error!(?error, "failed to listen for interrupt signal"), + } + } + }); + + select! { + _ = &mut run_shards => { + Ok(()) + } + () = cancellation_token.cancelled() => { + tracing::warn!("waiting for tasks to gracefully shut down"); + run_shards.await; + + Err(MainError::Cancelled) + } + } +} + +#[tracing::instrument(skip(command_router, state))] +async fn handle_events(command_router: Arc, state: State, mut shard: Shard) { + let event_types = EventTypeFlags::GUILD_VOICE_STATES | EventTypeFlags::INTERACTION_CREATE; + + while let Some(Some(event_res)) = shard + .next_event(event_types) + .with_cancellation_token(&state.cancellation_token) + .await + { match event_res { Ok(event) => { handle_event(command_router.clone(), state.clone(), event).await; @@ -194,16 +255,13 @@ async fn main() -> Result<(), MainError> { tracing::error!(?error); } } - - next_event = shard.next_event(event_types); } - - Ok(()) } -#[tracing::instrument(skip(command_router))] +#[tracing::instrument(skip(command_router, state))] async fn handle_event(command_router: Arc, state: State, event: Event) { state.songbird.process(&event).await; + match event { Event::VoiceStateUpdate(voice_state_update) => { update_vcs(&voice_state_update, &state.vcs);