Files
fomo-reducer/src/heat_seek.rs

402 lines
14 KiB
Rust

use std::{collections::BTreeMap, num::NonZero, str::Utf8Error, sync::Arc};
use futures::{StreamExt as _, stream::FuturesUnordered};
use itertools::Itertools as _;
use snafu::{ResultExt as _, Snafu};
use tokio::sync::watch;
use tokio_util::{sync::CancellationToken, time::FutureExt};
use twilight_model::id::{
Id,
marker::{ChannelMarker, GuildMarker, UserMarker},
};
use twilight_util::builder::embed::EmbedBuilder;
use crate::{
BotDataManager, OneToManyUniqueBTreeMap, State, UserInVCData, bot_data,
call::join_and_record,
track_vcs::VCsInGuild,
vc_user::{Camera, Headphone, Microphone, Stream},
};
type Heat = u64;
type Hot = NonZero<Heat>;
type ChannelHeat = BTreeMap<Id<ChannelMarker>, Heat>;
type HeatMap = OneToManyUniqueBTreeMap<Heat, Id<ChannelMarker>>;
#[tracing::instrument(skip(state))]
pub async fn heat_seek(state: State) {
let mut vcs_watcher = state.vcs_sender.subscribe();
let mut vcs_in_guild_senders = BTreeMap::default();
loop {
for (&guild_id, vcs_in_guild) in &*vcs_watcher.borrow() {
let vcs_in_guild_sender = vcs_in_guild_senders.entry(guild_id).or_insert_with(|| {
let (vcs_in_guild_sender, vcs_in_guild_watcher) =
watch::channel(Default::default());
let (channel_heat_sender, channel_heat_watcher) =
watch::channel(Default::default());
let (heat_map_sender, heat_map_watcher) = watch::channel(Default::default());
let (hottest_vc_sender, hottest_vc_watcher) = watch::channel(Default::default());
tokio::spawn(
evaluate_heat()
.bot_data_manager(state.bot_data_manager.clone())
.bot_owner_user_id(state.discord_bot_owner_user_id)
.bot_user_id(state.discord_user_id)
.cancellation_token(state.cancellation_token.clone())
.channel_heat_sender(channel_heat_sender)
.vcs_in_guild_watcher(vcs_in_guild_watcher)
.call(),
);
tokio::spawn(
map_heat()
.cancellation_token(state.cancellation_token.clone())
.channel_heat_watcher(channel_heat_watcher)
.heat_map_sender(heat_map_sender)
.call(),
);
tokio::spawn(
track_hottest_vc()
.cancellation_token(state.cancellation_token.clone())
.heat_map_watcher(heat_map_watcher)
.hottest_vc_sender(hottest_vc_sender)
.call(),
);
tokio::spawn(follow_hottest_vc(
state.clone(),
guild_id,
hottest_vc_watcher,
));
vcs_in_guild_sender
});
vcs_in_guild_sender.send_replace(Arc::new(vcs_in_guild.clone()));
}
if matches!(
vcs_watcher
.changed()
.with_cancellation_token(&state.cancellation_token)
.await,
None | Some(Err(_))
) {
break;
}
}
}
#[derive(Debug, Snafu)]
enum GetHeatError {
/// couldn't retrieve bot data
WithBotDataError { source: bot_data::WithError },
/// couldn't get the heat script from the bot data
GetHeatScriptError { source: capnp::Error },
/// the heat script is not a valid UTF-8 string
HeatScriptInvalidUtf8 { source: Utf8Error },
/// the heat script is not valid Rhai code
HeatScriptInvalidRhai { source: rhai::ParseError },
/// failed while evaluating the heat script
HeatScriptEvaluationError { source: Box<rhai::EvalAltResult> },
}
#[bon::builder]
#[tracing::instrument]
async fn get_heat(
users_in_vc: &BTreeMap<Id<UserMarker>, UserInVCData>,
bot_user_id: Id<UserMarker>,
bot_owner_user_id: Id<UserMarker>,
bot_data_manager: &BotDataManager,
) -> Result<Heat, GetHeatError> {
let heat_script = bot_data_manager
.with(|bot_data| {
bot_data.has_heat_script().then(|| {
bot_data
.get_heat_script()
.map(|heat_script| heat_script.to_string())
})
})
.await
.context(WithBotDataSnafu)?
.transpose()
.context(GetHeatScriptSnafu)?
.transpose()
.context(HeatScriptInvalidUtf8Snafu)?;
let engine = rhai::Engine::new();
let heat_function = heat_script
.map(|heat_script| engine.compile(heat_script))
.transpose()
.context(HeatScriptInvalidRhaiSnafu)?;
let heat = heat_function
.map(|heat_function| {
let mut scope = Default::default();
let args = (); // TODO
engine.call_fn(&mut scope, &heat_function, "heat", args)
})
.transpose()
.context(HeatScriptEvaluationSnafu)?;
let heat = heat.unwrap_or_else(|| {
tracing::warn!("using default heat scoring algorithm as no script was specified");
let mut users_in_vc = users_in_vc.clone();
let _bot = users_in_vc.remove(&bot_user_id);
let bot_owner = users_in_vc.remove(&bot_owner_user_id);
let mut heat = 0;
for (_user_id, user_in_vc_data) in users_in_vc {
if matches!(user_in_vc_data.microphone, Microphone::Unmuted) {
heat += 1000;
}
if matches!(user_in_vc_data.camera, Camera::Showing) {
heat += 100;
}
if matches!(user_in_vc_data.stream, Stream::Sharing) {
heat += 10;
}
if matches!(user_in_vc_data.headphone, Headphone::Undeafened) {
heat += 1;
}
}
let bot_owner_might_be_listening =
bot_owner.is_some_and(|user_data| matches!(user_data.headphone, Headphone::Undeafened));
if bot_owner_might_be_listening {
heat = heat.min(999);
}
heat
});
Ok(heat)
}
#[bon::builder]
#[tracing::instrument(skip(vcs_in_guild_watcher, channel_heat_sender))]
async fn evaluate_heat(
bot_data_manager: BotDataManager,
bot_owner_user_id: Id<UserMarker>,
bot_user_id: Id<UserMarker>,
cancellation_token: CancellationToken,
mut vcs_in_guild_watcher: watch::Receiver<Arc<VCsInGuild>>,
channel_heat_sender: watch::Sender<ChannelHeat>,
) {
loop {
let vcs_in_guild = { vcs_in_guild_watcher.borrow().clone() };
let channel_heat_results: BTreeMap<_, _> = {
FuturesUnordered::from_iter((&*vcs_in_guild).into_iter().map(
|(&channel_id, users_in_vc)| {
let bot_data_manager = bot_data_manager.clone();
async move {
(
channel_id,
get_heat()
.bot_data_manager(&bot_data_manager)
.bot_owner_user_id(bot_owner_user_id)
.bot_user_id(bot_user_id)
.users_in_vc(users_in_vc)
.call()
.await,
)
}
},
))
}
.collect()
.await;
let (channel_heat, get_heat_errors): (ChannelHeat, Vec<_>) = channel_heat_results
.into_iter()
.map(|(channel_id, heat_result)| heat_result.map(|heat| (channel_id, heat)))
.partition_result();
channel_heat_sender.send_replace(channel_heat);
for get_heat_error in get_heat_errors {
tracing::error!(?get_heat_error, "failed to evaluate heat of channel")
}
if matches!(
vcs_in_guild_watcher
.changed()
.with_cancellation_token(&cancellation_token)
.await,
None | Some(Err(_))
) {
break;
}
}
}
#[bon::builder]
#[tracing::instrument(skip(channel_heat_watcher, heat_map_sender))]
async fn map_heat(
cancellation_token: CancellationToken,
mut channel_heat_watcher: watch::Receiver<ChannelHeat>,
heat_map_sender: watch::Sender<HeatMap>,
) {
loop {
heat_map_sender.send_if_modified(|heat_map| {
let mut changed = false;
for (&channel, &heat) in &*channel_heat_watcher.borrow() {
let existing = heat_map.insert(heat, channel);
if existing.is_none_or(|(old_heat, old_channel)| {
old_heat != heat || channel != old_channel
}) {
changed = true;
}
}
changed
});
if matches!(
channel_heat_watcher
.changed()
.with_cancellation_token(&cancellation_token)
.await,
None | Some(Err(_))
) {
break;
}
}
}
#[bon::builder]
#[tracing::instrument(skip(heat_map_watcher, hottest_vc_sender))]
async fn track_hottest_vc(
cancellation_token: CancellationToken,
mut heat_map_watcher: watch::Receiver<HeatMap>,
hottest_vc_sender: watch::Sender<Option<Id<ChannelMarker>>>,
) {
loop {
let new_hottest_vc_option = {
heat_map_watcher
.borrow()
.last_left_and_rights()
.and_then(|(&heat, hottest_vcs)| {
let hot_option = Hot::new(heat);
hot_option.map(|_| *hottest_vcs.first().unwrap())
})
};
hottest_vc_sender.send_if_modified(|old_hottest_vc_option| {
let modified = (*old_hottest_vc_option) != new_hottest_vc_option;
*old_hottest_vc_option = new_hottest_vc_option;
modified
});
if matches!(
heat_map_watcher
.changed()
.with_cancellation_token(&cancellation_token)
.await,
None | Some(Err(_))
) {
break;
}
}
}
#[tracing::instrument(skip(state, hottest_vc_watcher))]
async fn follow_hottest_vc(
state: State,
guild_id: Id<GuildMarker>,
mut hottest_vc_watcher: watch::Receiver<Option<Id<ChannelMarker>>>,
) {
loop {
let hottest_vc_option = { *hottest_vc_watcher.borrow() };
match hottest_vc_option {
Some(hottest_vc) => {
match join_and_record()
.audio_channels(state.audio_channels)
.audio_sample_rate(state.audio_sample_rate)
.guild_id(guild_id)
.recording_data_manager(state.recording_data_manager.clone())
.songbird(&state.songbird)
.user_data_manager(state.user_data_manager.clone())
.voice_channel_id(hottest_vc)
.call()
.await
{
Ok(()) => {
let text_channel = state
.discord_voice_channel_corresponding_text_channel
.get(&guild_id)
.and_then(|guild_mappings| {
guild_mappings.get_right_for(&hottest_vc).copied()
})
.unwrap_or(hottest_vc);
let vc_mention = format!("<#{hottest_vc}>");
let info_mention = format!(
"</{}:{}>",
state.discord_info_command_name, state.discord_info_command_id
);
let opt_in_mention = format!(
"</{}:{}>",
state.discord_opt_in_command_name, state.discord_opt_in_command_id
);
let opt_out_mention = format!(
"</{}:{}>",
state.discord_opt_out_command_name, state.discord_opt_out_command_id
);
if let Err(posting_recording_disclosure_error) = state
.discord_client
.create_message(text_channel)
.embeds(&[
EmbedBuilder::new()
.title("Joined VC to record")
.description(format!("This bot joined {vc_mention} and intends to record. You can opt out with {opt_out_mention} or explicitly opt in with {opt_in_mention} (I'd appreciate this one). Please use {info_mention} for more information about this bot."))
.validate()
.unwrap()
.build()
])
.await {
tracing::error!(?text_channel, ?posting_recording_disclosure_error, "couldn't post a recording disclosure");
}
}
Err(joining_to_record_error) => {
tracing::error!(
?hottest_vc,
?joining_to_record_error,
"couldn't join to record"
);
}
}
}
None => {
if let Err(leaving_error) = state.songbird.leave(guild_id).await {
tracing::error!(?leaving_error, "couldn't leave vc");
}
}
}
if matches!(
hottest_vc_watcher
.changed()
.with_cancellation_token(&state.cancellation_token)
.await,
None | Some(Err(_))
) {
break;
}
}
}