feat: graceful shutdown, try making join and leave work (but some bug fixes are still needed)

This commit is contained in:
2026-04-08 22:18:32 -04:00
parent 288a784870
commit d2511f7a55
6 changed files with 272 additions and 97 deletions

View File

@@ -5,6 +5,8 @@ use secrecy::{ExposeSecret, SecretString};
use snafu::Snafu;
use songbird::{Songbird, shards::TwilightMap};
use std::{fmt::Debug, str::FromStr, sync::Arc};
use tokio::{select, signal::ctrl_c, task::JoinSet};
use tokio_util::{sync::CancellationToken, time::FutureExt as _};
use tracing_subscriber::{EnvFilter, fmt::format::FmtSpan};
use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt};
use twilight_model::{
@@ -14,12 +16,12 @@ use twilight_model::{
};
#[derive(Clone)]
struct OpendalOperator {
struct Storage {
uri: OperatorUri,
operator: Operator,
}
impl FromStr for OpendalOperator {
impl FromStr for Storage {
type Err = opendal::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
@@ -30,19 +32,19 @@ impl FromStr for OpendalOperator {
}
}
impl OpendalOperator {
impl Storage {
fn into_inner(self) -> Operator {
self.operator
}
}
impl From<OpendalOperator> for Operator {
fn from(wrapper: OpendalOperator) -> Self {
impl From<Storage> for Operator {
fn from(wrapper: Storage) -> Self {
wrapper.into_inner()
}
}
impl Debug for OpendalOperator {
impl Debug for Storage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.uri, f)
}
@@ -54,10 +56,16 @@ struct AppArgs {
discord_token: SecretString,
#[arg(long, env)]
storage: OpendalOperator,
bot_owner: Id<UserMarker>,
#[arg(long, env)]
bot_owner: Id<UserMarker>,
bot_data: Storage,
#[arg(long, env)]
user_data: Storage,
#[arg(long, env)]
recording_data: Storage,
}
#[derive(Parser)]
@@ -80,7 +88,10 @@ struct Args {
}
#[derive(Debug, Snafu)]
enum MainError {}
enum MainError {
/// the program was cancelled, perhaps by Ctrl-C / SIGINT
Cancelled,
}
#[snafu::report]
#[tokio::main]
@@ -102,17 +113,21 @@ async fn main() -> Result<(), MainError> {
let AppArgs {
discord_token,
storage,
bot_owner,
bot_data,
user_data,
recording_data,
} = app_args;
let cancellation_token = CancellationToken::new();
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
.unwrap();
let discord_client = twilight_http::Client::new(discord_token.expose_secret().to_owned());
let user = discord_client
let discord_user = discord_client
.current_user()
.await
.expect("couldn't fetch current user") // TODO
@@ -120,7 +135,7 @@ async fn main() -> Result<(), MainError> {
.await
.expect("couldn't deserialize current user"); // TODO
let user_id = user.id;
let discord_user_id = discord_user.id;
let current_application = discord_client
.current_user_application()
@@ -134,21 +149,26 @@ async fn main() -> Result<(), MainError> {
let discord_application_id = current_application.id;
let shard_id = ShardId::new(0, 1);
let intents = Intents::GUILD_VOICE_STATES;
let mut shard = Shard::new(shard_id, discord_token.expose_secret().to_owned(), intents);
let config = twilight_gateway::Config::new(discord_token.expose_secret().to_owned(), intents);
let senders = TwilightMap::new(FromIterator::from_iter([(
shard.id().number(),
shard.sender(),
)]));
let shards = twilight_gateway::create_recommended(&discord_client, config, |_id, builder| {
builder.build()
})
.await
.expect("TODO");
let shards = Vec::from_iter(shards);
let senders = TwilightMap::new(
shards
.iter()
.map(|shard| (shard.id().number(), shard.sender()))
.collect(),
);
let senders = Arc::new(senders);
let songbird = Songbird::twilight(senders, user_id);
let event_types = EventTypeFlags::GUILD_VOICE_STATES | EventTypeFlags::INTERACTION_CREATE;
let mut next_event = shard.next_event(event_types);
let songbird = Songbird::twilight(senders, discord_user_id);
let interaction_client = discord_client.interaction(discord_application_id);
@@ -179,13 +199,54 @@ async fn main() -> Result<(), MainError> {
let vcs = Arc::new(vcs);
let state = State {
cancellation_token: cancellation_token.clone(),
discord_application_id,
discord_client,
discord_user_id,
songbird,
vcs,
};
while let Some(event_res) = next_event.await {
let run_shards = JoinSet::from_iter(
shards
.into_iter()
.map(|shard| handle_events(command_router.clone(), state.clone(), shard)),
);
let run_shards = run_shards.join_all();
tokio::pin!(run_shards);
tokio::spawn({
let cancellation_token = cancellation_token.clone();
async move {
match ctrl_c().await {
Ok(()) => cancellation_token.cancel(),
Err(error) => tracing::error!(?error, "failed to listen for interrupt signal"),
}
}
});
select! {
_ = &mut run_shards => {
Ok(())
}
() = cancellation_token.cancelled() => {
tracing::warn!("waiting for tasks to gracefully shut down");
run_shards.await;
Err(MainError::Cancelled)
}
}
}
#[tracing::instrument(skip(command_router, state))]
async fn handle_events(command_router: Arc<CommandRouter>, state: State, mut shard: Shard) {
let event_types = EventTypeFlags::GUILD_VOICE_STATES | EventTypeFlags::INTERACTION_CREATE;
while let Some(Some(event_res)) = shard
.next_event(event_types)
.with_cancellation_token(&state.cancellation_token)
.await
{
match event_res {
Ok(event) => {
handle_event(command_router.clone(), state.clone(), event).await;
@@ -194,16 +255,13 @@ async fn main() -> Result<(), MainError> {
tracing::error!(?error);
}
}
next_event = shard.next_event(event_types);
}
Ok(())
}
#[tracing::instrument(skip(command_router))]
#[tracing::instrument(skip(command_router, state))]
async fn handle_event(command_router: Arc<CommandRouter>, state: State, event: Event) {
state.songbird.process(&event).await;
match event {
Event::VoiceStateUpdate(voice_state_update) => {
update_vcs(&voice_state_update, &state.vcs);