402 lines
14 KiB
Rust
402 lines
14 KiB
Rust
use std::{collections::BTreeMap, num::NonZero, str::Utf8Error, sync::Arc};
|
|
|
|
use futures::{StreamExt as _, stream::FuturesUnordered};
|
|
use itertools::Itertools as _;
|
|
use snafu::{ResultExt as _, Snafu};
|
|
use tokio::sync::watch;
|
|
use tokio_util::{sync::CancellationToken, time::FutureExt};
|
|
use twilight_model::id::{
|
|
Id,
|
|
marker::{ChannelMarker, GuildMarker, UserMarker},
|
|
};
|
|
use twilight_util::builder::embed::EmbedBuilder;
|
|
|
|
use crate::{
|
|
BotDataManager, OneToManyUniqueBTreeMap, State, UserInVCData, bot_data,
|
|
call::join_and_record,
|
|
track_vcs::VCsInGuild,
|
|
vc_user::{Camera, Headphone, Microphone, Stream},
|
|
};
|
|
|
|
type Heat = u64;
|
|
type Hot = NonZero<Heat>;
|
|
|
|
type ChannelHeat = BTreeMap<Id<ChannelMarker>, Heat>;
|
|
type HeatMap = OneToManyUniqueBTreeMap<Heat, Id<ChannelMarker>>;
|
|
|
|
#[tracing::instrument(skip(state))]
|
|
pub async fn heat_seek(state: State) {
|
|
let mut vcs_watcher = state.vcs_sender.subscribe();
|
|
let mut vcs_in_guild_senders = BTreeMap::default();
|
|
|
|
loop {
|
|
for (&guild_id, vcs_in_guild) in &*vcs_watcher.borrow() {
|
|
let vcs_in_guild_sender = vcs_in_guild_senders.entry(guild_id).or_insert_with(|| {
|
|
let (vcs_in_guild_sender, vcs_in_guild_watcher) =
|
|
watch::channel(Default::default());
|
|
let (channel_heat_sender, channel_heat_watcher) =
|
|
watch::channel(Default::default());
|
|
let (heat_map_sender, heat_map_watcher) = watch::channel(Default::default());
|
|
let (hottest_vc_sender, hottest_vc_watcher) = watch::channel(Default::default());
|
|
|
|
tokio::spawn(
|
|
evaluate_heat()
|
|
.bot_data_manager(state.bot_data_manager.clone())
|
|
.bot_owner_user_id(state.discord_bot_owner_user_id)
|
|
.bot_user_id(state.discord_user_id)
|
|
.cancellation_token(state.cancellation_token.clone())
|
|
.channel_heat_sender(channel_heat_sender)
|
|
.vcs_in_guild_watcher(vcs_in_guild_watcher)
|
|
.call(),
|
|
);
|
|
tokio::spawn(
|
|
map_heat()
|
|
.cancellation_token(state.cancellation_token.clone())
|
|
.channel_heat_watcher(channel_heat_watcher)
|
|
.heat_map_sender(heat_map_sender)
|
|
.call(),
|
|
);
|
|
tokio::spawn(
|
|
track_hottest_vc()
|
|
.cancellation_token(state.cancellation_token.clone())
|
|
.heat_map_watcher(heat_map_watcher)
|
|
.hottest_vc_sender(hottest_vc_sender)
|
|
.call(),
|
|
);
|
|
tokio::spawn(follow_hottest_vc(
|
|
state.clone(),
|
|
guild_id,
|
|
hottest_vc_watcher,
|
|
));
|
|
|
|
vcs_in_guild_sender
|
|
});
|
|
vcs_in_guild_sender.send_replace(Arc::new(vcs_in_guild.clone()));
|
|
}
|
|
|
|
if matches!(
|
|
vcs_watcher
|
|
.changed()
|
|
.with_cancellation_token(&state.cancellation_token)
|
|
.await,
|
|
None | Some(Err(_))
|
|
) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Snafu)]
|
|
enum GetHeatError {
|
|
/// couldn't retrieve bot data
|
|
WithBotDataError { source: bot_data::WithError },
|
|
|
|
/// couldn't get the heat script from the bot data
|
|
GetHeatScriptError { source: capnp::Error },
|
|
|
|
/// the heat script is not a valid UTF-8 string
|
|
HeatScriptInvalidUtf8 { source: Utf8Error },
|
|
|
|
/// the heat script is not valid Rhai code
|
|
HeatScriptInvalidRhai { source: rhai::ParseError },
|
|
|
|
/// failed while evaluating the heat script
|
|
HeatScriptEvaluationError { source: Box<rhai::EvalAltResult> },
|
|
}
|
|
|
|
#[bon::builder]
|
|
#[tracing::instrument]
|
|
async fn get_heat(
|
|
users_in_vc: &BTreeMap<Id<UserMarker>, UserInVCData>,
|
|
bot_user_id: Id<UserMarker>,
|
|
bot_owner_user_id: Id<UserMarker>,
|
|
bot_data_manager: &BotDataManager,
|
|
) -> Result<Heat, GetHeatError> {
|
|
let heat_script = bot_data_manager
|
|
.with(|bot_data| {
|
|
bot_data.has_heat_script().then(|| {
|
|
bot_data
|
|
.get_heat_script()
|
|
.map(|heat_script| heat_script.to_string())
|
|
})
|
|
})
|
|
.await
|
|
.context(WithBotDataSnafu)?
|
|
.transpose()
|
|
.context(GetHeatScriptSnafu)?
|
|
.transpose()
|
|
.context(HeatScriptInvalidUtf8Snafu)?;
|
|
|
|
let engine = rhai::Engine::new();
|
|
let heat_function = heat_script
|
|
.map(|heat_script| engine.compile(heat_script))
|
|
.transpose()
|
|
.context(HeatScriptInvalidRhaiSnafu)?;
|
|
|
|
let heat = heat_function
|
|
.map(|heat_function| {
|
|
let mut scope = Default::default();
|
|
|
|
let args = (); // TODO
|
|
|
|
engine.call_fn(&mut scope, &heat_function, "heat", args)
|
|
})
|
|
.transpose()
|
|
.context(HeatScriptEvaluationSnafu)?;
|
|
|
|
let heat = heat.unwrap_or_else(|| {
|
|
tracing::warn!("using default heat scoring algorithm as no script was specified");
|
|
|
|
let mut users_in_vc = users_in_vc.clone();
|
|
|
|
let _bot = users_in_vc.remove(&bot_user_id);
|
|
let bot_owner = users_in_vc.remove(&bot_owner_user_id);
|
|
|
|
let mut heat = 0;
|
|
|
|
for (_user_id, user_in_vc_data) in users_in_vc {
|
|
if matches!(user_in_vc_data.microphone, Microphone::Unmuted) {
|
|
heat += 1000;
|
|
}
|
|
if matches!(user_in_vc_data.camera, Camera::Showing) {
|
|
heat += 100;
|
|
}
|
|
if matches!(user_in_vc_data.stream, Stream::Sharing) {
|
|
heat += 10;
|
|
}
|
|
if matches!(user_in_vc_data.headphone, Headphone::Undeafened) {
|
|
heat += 1;
|
|
}
|
|
}
|
|
|
|
let bot_owner_might_be_listening =
|
|
bot_owner.is_some_and(|user_data| matches!(user_data.headphone, Headphone::Undeafened));
|
|
|
|
if bot_owner_might_be_listening {
|
|
heat = heat.min(999);
|
|
}
|
|
|
|
heat
|
|
});
|
|
|
|
Ok(heat)
|
|
}
|
|
|
|
#[bon::builder]
|
|
#[tracing::instrument(skip(vcs_in_guild_watcher, channel_heat_sender))]
|
|
async fn evaluate_heat(
|
|
bot_data_manager: BotDataManager,
|
|
bot_owner_user_id: Id<UserMarker>,
|
|
bot_user_id: Id<UserMarker>,
|
|
cancellation_token: CancellationToken,
|
|
|
|
mut vcs_in_guild_watcher: watch::Receiver<Arc<VCsInGuild>>,
|
|
channel_heat_sender: watch::Sender<ChannelHeat>,
|
|
) {
|
|
loop {
|
|
let vcs_in_guild = { vcs_in_guild_watcher.borrow().clone() };
|
|
|
|
let channel_heat_results: BTreeMap<_, _> = {
|
|
FuturesUnordered::from_iter((&*vcs_in_guild).into_iter().map(
|
|
|(&channel_id, users_in_vc)| {
|
|
let bot_data_manager = bot_data_manager.clone();
|
|
async move {
|
|
(
|
|
channel_id,
|
|
get_heat()
|
|
.bot_data_manager(&bot_data_manager)
|
|
.bot_owner_user_id(bot_owner_user_id)
|
|
.bot_user_id(bot_user_id)
|
|
.users_in_vc(users_in_vc)
|
|
.call()
|
|
.await,
|
|
)
|
|
}
|
|
},
|
|
))
|
|
}
|
|
.collect()
|
|
.await;
|
|
|
|
let (channel_heat, get_heat_errors): (ChannelHeat, Vec<_>) = channel_heat_results
|
|
.into_iter()
|
|
.map(|(channel_id, heat_result)| heat_result.map(|heat| (channel_id, heat)))
|
|
.partition_result();
|
|
|
|
channel_heat_sender.send_replace(channel_heat);
|
|
|
|
for get_heat_error in get_heat_errors {
|
|
tracing::error!(?get_heat_error, "failed to evaluate heat of channel")
|
|
}
|
|
if matches!(
|
|
vcs_in_guild_watcher
|
|
.changed()
|
|
.with_cancellation_token(&cancellation_token)
|
|
.await,
|
|
None | Some(Err(_))
|
|
) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
#[bon::builder]
|
|
#[tracing::instrument(skip(channel_heat_watcher, heat_map_sender))]
|
|
async fn map_heat(
|
|
cancellation_token: CancellationToken,
|
|
mut channel_heat_watcher: watch::Receiver<ChannelHeat>,
|
|
heat_map_sender: watch::Sender<HeatMap>,
|
|
) {
|
|
loop {
|
|
heat_map_sender.send_if_modified(|heat_map| {
|
|
let mut changed = false;
|
|
for (&channel, &heat) in &*channel_heat_watcher.borrow() {
|
|
let existing = heat_map.insert(heat, channel);
|
|
if existing.is_none_or(|(old_heat, old_channel)| {
|
|
old_heat != heat || channel != old_channel
|
|
}) {
|
|
changed = true;
|
|
}
|
|
}
|
|
changed
|
|
});
|
|
|
|
if matches!(
|
|
channel_heat_watcher
|
|
.changed()
|
|
.with_cancellation_token(&cancellation_token)
|
|
.await,
|
|
None | Some(Err(_))
|
|
) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
#[bon::builder]
|
|
#[tracing::instrument(skip(heat_map_watcher, hottest_vc_sender))]
|
|
async fn track_hottest_vc(
|
|
cancellation_token: CancellationToken,
|
|
|
|
mut heat_map_watcher: watch::Receiver<HeatMap>,
|
|
hottest_vc_sender: watch::Sender<Option<Id<ChannelMarker>>>,
|
|
) {
|
|
loop {
|
|
let new_hottest_vc_option = {
|
|
heat_map_watcher
|
|
.borrow()
|
|
.last_left_and_rights()
|
|
.and_then(|(&heat, hottest_vcs)| {
|
|
let hot_option = Hot::new(heat);
|
|
|
|
hot_option.map(|_| *hottest_vcs.first().unwrap())
|
|
})
|
|
};
|
|
|
|
hottest_vc_sender.send_if_modified(|old_hottest_vc_option| {
|
|
let modified = (*old_hottest_vc_option) != new_hottest_vc_option;
|
|
*old_hottest_vc_option = new_hottest_vc_option;
|
|
modified
|
|
});
|
|
|
|
if matches!(
|
|
heat_map_watcher
|
|
.changed()
|
|
.with_cancellation_token(&cancellation_token)
|
|
.await,
|
|
None | Some(Err(_))
|
|
) {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tracing::instrument(skip(state, hottest_vc_watcher))]
|
|
async fn follow_hottest_vc(
|
|
state: State,
|
|
guild_id: Id<GuildMarker>,
|
|
mut hottest_vc_watcher: watch::Receiver<Option<Id<ChannelMarker>>>,
|
|
) {
|
|
loop {
|
|
let hottest_vc_option = { *hottest_vc_watcher.borrow() };
|
|
|
|
match hottest_vc_option {
|
|
Some(hottest_vc) => {
|
|
match join_and_record()
|
|
.audio_channels(state.audio_channels)
|
|
.audio_sample_rate(state.audio_sample_rate)
|
|
.guild_id(guild_id)
|
|
.recording_data_manager(state.recording_data_manager.clone())
|
|
.songbird(&state.songbird)
|
|
.user_data_manager(state.user_data_manager.clone())
|
|
.voice_channel_id(hottest_vc)
|
|
.call()
|
|
.await
|
|
{
|
|
Ok(()) => {
|
|
let text_channel = state
|
|
.discord_voice_channel_corresponding_text_channel
|
|
.get(&guild_id)
|
|
.and_then(|guild_mappings| {
|
|
guild_mappings.get_right_for(&hottest_vc).copied()
|
|
})
|
|
.unwrap_or(hottest_vc);
|
|
|
|
let vc_mention = format!("<#{hottest_vc}>");
|
|
|
|
let info_mention = format!(
|
|
"</{}:{}>",
|
|
state.discord_info_command_name, state.discord_info_command_id
|
|
);
|
|
let opt_in_mention = format!(
|
|
"</{}:{}>",
|
|
state.discord_opt_in_command_name, state.discord_opt_in_command_id
|
|
);
|
|
let opt_out_mention = format!(
|
|
"</{}:{}>",
|
|
state.discord_opt_out_command_name, state.discord_opt_out_command_id
|
|
);
|
|
|
|
if let Err(posting_recording_disclosure_error) = state
|
|
.discord_client
|
|
.create_message(text_channel)
|
|
.embeds(&[
|
|
EmbedBuilder::new()
|
|
.title("Joined VC to record")
|
|
.description(format!("This bot joined {vc_mention} and intends to record. You can opt out with {opt_out_mention} or explicitly opt in with {opt_in_mention} (I'd appreciate this one). Please use {info_mention} for more information about this bot."))
|
|
.validate()
|
|
.unwrap()
|
|
.build()
|
|
])
|
|
.await {
|
|
tracing::error!(?text_channel, ?posting_recording_disclosure_error, "couldn't post a recording disclosure");
|
|
}
|
|
}
|
|
Err(joining_to_record_error) => {
|
|
tracing::error!(
|
|
?hottest_vc,
|
|
?joining_to_record_error,
|
|
"couldn't join to record"
|
|
);
|
|
}
|
|
}
|
|
}
|
|
None => {
|
|
if let Err(leaving_error) = state.songbird.leave(guild_id).await {
|
|
tracing::error!(?leaving_error, "couldn't leave vc");
|
|
}
|
|
}
|
|
}
|
|
|
|
if matches!(
|
|
hottest_vc_watcher
|
|
.changed()
|
|
.with_cancellation_token(&state.cancellation_token)
|
|
.await,
|
|
None | Some(Err(_))
|
|
) {
|
|
break;
|
|
}
|
|
}
|
|
}
|