feat: heatseeking

This commit is contained in:
2026-05-24 13:20:43 -04:00
parent e1aab0a8fb
commit b598adb498
12 changed files with 288 additions and 40 deletions

View File

@@ -1,31 +1,237 @@
use std::{collections::BTreeMap, num::NonZero};
use std::{collections::BTreeMap, num::NonZero, str::Utf8Error, sync::Arc};
use futures::{StreamExt as _, stream::FuturesUnordered};
use itertools::Itertools as _;
use snafu::{ResultExt as _, Snafu};
use tokio::sync::watch;
use twilight_model::id::{
Id,
marker::{ChannelMarker, GuildMarker},
marker::{ChannelMarker, GuildMarker, UserMarker},
};
use twilight_util::builder::embed::EmbedBuilder;
use crate::{OneToManyUniqueBTreeMap, State, call::join_and_record};
use crate::{
BotDataManager, OneToManyUniqueBTreeMap, State, UserInVCData, bot_data,
call::join_and_record,
track_vcs::VCsInGuild,
vc_user::{Camera, Headphone, Microphone, Stream},
};
type Heat = u64;
type Hot = NonZero<Heat>;
type HotOption = Option<Hot>;
type ChannelHeat = BTreeMap<Id<ChannelMarker>, Heat>;
type HeatMap = OneToManyUniqueBTreeMap<Heat, Id<ChannelMarker>>;
#[tracing::instrument]
pub async fn heat_seek(state: State) {
let mut vcs_watcher = state.vcs_sender.subscribe();
let mut vcs_in_guild_senders = BTreeMap::default();
loop {
for (&guild_id, vcs_in_guild) in &*vcs_watcher.borrow() {
let vcs_in_guild_sender = vcs_in_guild_senders.entry(guild_id).or_insert_with(|| {
let (vcs_in_guild_sender, vcs_in_guild_watcher) =
watch::channel(Default::default());
let (channel_heat_sender, channel_heat_watcher) =
watch::channel(Default::default());
let (heat_map_sender, heat_map_watcher) = watch::channel(Default::default());
let (hottest_vc_sender, hottest_vc_watcher) = watch::channel(Default::default());
tokio::spawn(
evaluate_heat()
.bot_data_manager(state.bot_data_manager.clone())
.bot_owner_user_id(state.discord_bot_owner_user_id)
.bot_user_id(state.discord_user_id)
.channel_heat_sender(channel_heat_sender)
.vcs_in_guild_watcher(vcs_in_guild_watcher)
.call(),
);
tokio::spawn(map_heat(channel_heat_watcher, heat_map_sender));
tokio::spawn(track_hottest_vc(
state.discord_bot_owner_user_id,
heat_map_watcher,
hottest_vc_sender,
));
tokio::spawn(follow_hottest_vc(
state.clone(),
guild_id,
hottest_vc_watcher,
));
vcs_in_guild_sender
});
vcs_in_guild_sender.send_replace(Arc::new(vcs_in_guild.clone()));
}
if let Err(_closed) = vcs_watcher.changed().await {
break;
}
}
}
#[derive(Debug, Snafu)]
enum GetHeatError {
/// couldn't retrieve bot data
WithBotDataError { source: bot_data::WithError },
/// couldn't get the heat script from the bot data
GetHeatScriptError { source: capnp::Error },
/// the heat script is not a valid UTF-8 string
HeatScriptInvalidUtf8 { source: Utf8Error },
/// the heat script is not valid Rhai code
HeatScriptInvalidRhai { source: rhai::ParseError },
/// failed while evaluating the heat script
HeatScriptEvaluationError { source: Box<rhai::EvalAltResult> },
}
#[bon::builder]
#[tracing::instrument]
async fn get_heat(
users_in_vc: &BTreeMap<Id<UserMarker>, UserInVCData>,
bot_user_id: Id<UserMarker>,
bot_owner_user_id: Id<UserMarker>,
bot_data_manager: &BotDataManager,
) -> Result<Heat, GetHeatError> {
let heat_script = bot_data_manager
.with(|bot_data| {
bot_data.has_heat_script().then(|| {
bot_data
.get_heat_script()
.map(|heat_script| heat_script.to_string())
})
})
.await
.context(WithBotDataSnafu)?
.transpose()
.context(GetHeatScriptSnafu)?
.transpose()
.context(HeatScriptInvalidUtf8Snafu)?;
let engine = rhai::Engine::new();
let heat_function = heat_script
.map(|heat_script| engine.compile(heat_script))
.transpose()
.context(HeatScriptInvalidRhaiSnafu)?;
let heat = heat_function
.map(|heat_function| {
let mut scope = Default::default();
let args = (); // TODO
engine.call_fn(&mut scope, &heat_function, "heat", args)
})
.transpose()
.context(HeatScriptEvaluationSnafu)?;
let heat = heat.unwrap_or_else(|| {
tracing::warn!("using default heat scoring algorithm as no script was specified");
let mut users_in_vc = users_in_vc.clone();
let _bot = users_in_vc.remove(&bot_user_id);
let bot_owner = users_in_vc.remove(&bot_owner_user_id);
let mut heat = 0;
for (_user_id, user_in_vc_data) in users_in_vc {
if matches!(user_in_vc_data.microphone, Microphone::Unmuted) {
heat += 1000;
}
if matches!(user_in_vc_data.camera, Camera::Showing) {
heat += 100;
}
if matches!(user_in_vc_data.stream, Stream::Sharing) {
heat += 10;
}
if matches!(user_in_vc_data.headphone, Headphone::Undeafened) {
heat += 1;
}
}
if bot_owner.is_some() {
heat = heat.min(999);
}
heat
});
Ok(heat)
}
#[bon::builder]
#[tracing::instrument]
async fn evaluate_heat(
bot_data_manager: BotDataManager,
bot_owner_user_id: Id<UserMarker>,
bot_user_id: Id<UserMarker>,
mut vcs_in_guild_watcher: watch::Receiver<Arc<VCsInGuild>>,
channel_heat_sender: watch::Sender<ChannelHeat>,
) {
loop {
let vcs_in_guild = { vcs_in_guild_watcher.borrow().clone() };
let channel_heat_results: BTreeMap<_, _> = {
FuturesUnordered::from_iter((&*vcs_in_guild).into_iter().map(
|(&channel_id, users_in_vc)| {
let bot_data_manager = bot_data_manager.clone();
async move {
(
channel_id,
get_heat()
.bot_data_manager(&bot_data_manager)
.bot_owner_user_id(bot_owner_user_id)
.bot_user_id(bot_user_id)
.users_in_vc(users_in_vc)
.call()
.await,
)
}
},
))
}
.collect()
.await;
let (channel_heat, get_heat_errors): (ChannelHeat, Vec<_>) = channel_heat_results
.into_iter()
.map(|(channel_id, heat_result)| heat_result.map(|heat| (channel_id, heat)))
.partition_result();
channel_heat_sender.send_replace(channel_heat);
for get_heat_error in get_heat_errors {
tracing::error!(?get_heat_error, "failed to evaluate heat of channel")
}
if let Err(_closed) = vcs_in_guild_watcher.changed().await {
break;
}
}
}
#[tracing::instrument]
async fn map_heat(
mut channel_heat_watcher: watch::Receiver<ChannelHeat>,
heat_map_sender: watch::Sender<HeatMap>,
) {
loop {
heat_map_sender.send_modify(|heat_map| {
heat_map_sender.send_if_modified(|heat_map| {
let mut changed = false;
for (&channel, &heat) in &*channel_heat_watcher.borrow() {
heat_map.insert(heat, channel);
let existing = heat_map.insert(heat, channel);
if existing.map_or(true, |(old_heat, old_channel)| {
old_heat != heat || channel != old_channel
}) {
changed = true;
}
}
changed
});
if let Err(_closed) = channel_heat_watcher.changed().await {
@@ -36,6 +242,7 @@ async fn map_heat(
#[tracing::instrument]
async fn track_hottest_vc(
bot_owner_id: Id<UserMarker>,
mut heat_map_watcher: watch::Receiver<HeatMap>,
hottest_vc_sender: watch::Sender<Option<Id<ChannelMarker>>>,
) {
@@ -47,7 +254,6 @@ async fn track_hottest_vc(
.and_then(|(&heat, hottest_vcs)| {
let hot_option = Hot::new(heat);
// TODO: tiebreak by whichever one this bot is already in
hot_option.map(|_| *hottest_vcs.first().unwrap())
})
};
@@ -65,7 +271,7 @@ async fn track_hottest_vc(
}
#[tracing::instrument]
async fn follow_heat(
async fn follow_hottest_vc(
state: State,
guild_id: Id<GuildMarker>,
mut hottest_vc_watcher: watch::Receiver<Option<Id<ChannelMarker>>>,
@@ -146,8 +352,3 @@ async fn follow_heat(
}
}
}
#[tracing::instrument]
pub async fn heat_seek(state: State) {
todo!();
}