feat: heatseeking

This commit is contained in:
2026-05-24 13:20:43 -04:00
parent e1aab0a8fb
commit b598adb498
12 changed files with 288 additions and 40 deletions

22
Cargo.lock generated
View File

@@ -1732,7 +1732,7 @@ dependencies = [
"futures-core",
"futures-sink",
"nanorand",
"spin",
"spin 0.9.8",
]
[[package]]
@@ -1777,6 +1777,7 @@ dependencies = [
"extension-traits",
"futures",
"hound",
"itertools",
"moka",
"opendal",
"opus2",
@@ -1801,7 +1802,6 @@ dependencies = [
"twilight-model",
"twilight-util",
"typed-builder 0.23.2",
"yoke",
]
[[package]]
@@ -2967,7 +2967,7 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
dependencies = [
"spin",
"spin 0.9.8",
]
[[package]]
@@ -3611,6 +3611,15 @@ dependencies = [
"getrandom 0.2.17",
]
[[package]]
name = "no-std-compat"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c"
dependencies = [
"spin 0.5.2",
]
[[package]]
name = "no-std-net"
version = "0.6.0"
@@ -5453,6 +5462,7 @@ checksum = "1f9ef5dabe4c0b43d8f1187dc6beb67b53fe607fff7e30c5eb7f71b814b8c2c1"
dependencies = [
"ahash",
"bitflags 2.11.1",
"no-std-compat",
"num-traits",
"once_cell",
"rhai_codegen",
@@ -6266,6 +6276,12 @@ dependencies = [
"uuid",
]
[[package]]
name = "spin"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]]
name = "spin"
version = "0.9.8"

View File

@@ -14,6 +14,7 @@ dashmap = "6.1.0"
extension-traits = "2.0.2"
futures = "0.3.32"
hound = "3.5.1"
itertools = "0.14.0"
moka = { version = "0.12.15", features = ["future"] }
opendal = { git = "https://github.com/apache/opendal", rev = "ecf840b04afd2be109830b9978ba89759adfee79", features = [
"services-azfile",
@@ -54,7 +55,7 @@ opendal = { git = "https://github.com/apache/opendal", rev = "ecf840b04afd2be109
] }
opus2 = "0.4.0"
patricia_tree = "0.10.1"
rhai = "1.23.6"
rhai = { version = "1.23.6", features = ["sync"] }
rustls = "0.23"
secrecy = { version = "0.10.3", features = ["serde"] }
shadow-rs = { version = "2.0.0", default-features = false }
@@ -78,6 +79,7 @@ tokio-websockets-0-11 = { package = "tokio-websockets", version = "0.11", featur
"rustls-webpki-roots",
] }
tracing = "0.1.41"
tracing-appender = "0.2.5"
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
twilight-gateway = { version = "0.17", default-features = false, features = [
"rustls-webpki-roots",
@@ -91,8 +93,6 @@ twilight-http = { version = "0.17", default-features = false, features = [
twilight-model = "0.17"
twilight-util = { version = "0.17", features = ["builder"] }
typed-builder = "0.23.2"
yoke = "0.8.2"
tracing-appender = "0.2.5"
[build-dependencies]
capnpc = "0.25.3"

View File

@@ -1,5 +1,5 @@
use crate::{
OneToManyUniqueBTreeMap, UserDataManager, VCs, command::State, option_ext::OptionExt as _,
OneToManyUniqueBTreeMap, UserDataManager, option_ext::OptionExt as _,
user_capnp::user::Consent, user_data::RECORD_IF_CONSENT_UNSPECIFIED,
};
use async_trait::async_trait;

View File

@@ -82,7 +82,7 @@ fn get_guild_and_vc_error_to_embed(error: GetGuildAndVoiceChannelIdError) -> Emb
#[tracing::instrument(skip(state))]
pub async fn handle(state: State, interaction: Interaction) {
let guild_and_voice_channel_id_res =
{ get_guild_and_voice_channel_id(&interaction, &state.vcs_watcher.borrow()) };
{ get_guild_and_voice_channel_id(&interaction, &state.vcs_sender.borrow()) };
let (guild_id, voice_channel_id) = match guild_and_voice_channel_id_res {
Ok((guild_id, voice_channel_id)) => (guild_id, voice_channel_id),
Err(error) => {

View File

@@ -84,7 +84,7 @@ pub async fn handle(state: State, interaction: Interaction) {
get_user_and_guild_and_voice_channel_id(
state.discord_user_id,
&interaction,
&state.vcs_watcher.borrow(),
&state.vcs_sender.borrow(),
)
};
let (user_id, guild_id, voice_channel_id) = match user_and_guild_and_voice_channel_id_res {

View File

@@ -16,7 +16,7 @@ use twilight_model::{
},
};
use crate::{BotDataManager, GuildVoiceChannelToTextChannel, UserDataManager, VCsWatcher};
use crate::{BotDataManager, GuildVoiceChannelToTextChannel, UserDataManager, VCsSender};
pub mod info;
pub mod join;
@@ -45,7 +45,7 @@ pub struct State {
pub recording_data: Operator,
pub songbird: Arc<Songbird>,
pub user_data_manager: UserDataManager,
pub vcs_watcher: VCsWatcher,
pub vcs_sender: VCsSender,
}
type Return = ();

View File

@@ -1,31 +1,237 @@
use std::{collections::BTreeMap, num::NonZero};
use std::{collections::BTreeMap, num::NonZero, str::Utf8Error, sync::Arc};
use futures::{StreamExt as _, stream::FuturesUnordered};
use itertools::Itertools as _;
use snafu::{ResultExt as _, Snafu};
use tokio::sync::watch;
use twilight_model::id::{
Id,
marker::{ChannelMarker, GuildMarker},
marker::{ChannelMarker, GuildMarker, UserMarker},
};
use twilight_util::builder::embed::EmbedBuilder;
use crate::{OneToManyUniqueBTreeMap, State, call::join_and_record};
use crate::{
BotDataManager, OneToManyUniqueBTreeMap, State, UserInVCData, bot_data,
call::join_and_record,
track_vcs::VCsInGuild,
vc_user::{Camera, Headphone, Microphone, Stream},
};
type Heat = u64;
type Hot = NonZero<Heat>;
type HotOption = Option<Hot>;
type ChannelHeat = BTreeMap<Id<ChannelMarker>, Heat>;
type HeatMap = OneToManyUniqueBTreeMap<Heat, Id<ChannelMarker>>;
#[tracing::instrument]
pub async fn heat_seek(state: State) {
let mut vcs_watcher = state.vcs_sender.subscribe();
let mut vcs_in_guild_senders = BTreeMap::default();
loop {
for (&guild_id, vcs_in_guild) in &*vcs_watcher.borrow() {
let vcs_in_guild_sender = vcs_in_guild_senders.entry(guild_id).or_insert_with(|| {
let (vcs_in_guild_sender, vcs_in_guild_watcher) =
watch::channel(Default::default());
let (channel_heat_sender, channel_heat_watcher) =
watch::channel(Default::default());
let (heat_map_sender, heat_map_watcher) = watch::channel(Default::default());
let (hottest_vc_sender, hottest_vc_watcher) = watch::channel(Default::default());
tokio::spawn(
evaluate_heat()
.bot_data_manager(state.bot_data_manager.clone())
.bot_owner_user_id(state.discord_bot_owner_user_id)
.bot_user_id(state.discord_user_id)
.channel_heat_sender(channel_heat_sender)
.vcs_in_guild_watcher(vcs_in_guild_watcher)
.call(),
);
tokio::spawn(map_heat(channel_heat_watcher, heat_map_sender));
tokio::spawn(track_hottest_vc(
state.discord_bot_owner_user_id,
heat_map_watcher,
hottest_vc_sender,
));
tokio::spawn(follow_hottest_vc(
state.clone(),
guild_id,
hottest_vc_watcher,
));
vcs_in_guild_sender
});
vcs_in_guild_sender.send_replace(Arc::new(vcs_in_guild.clone()));
}
if let Err(_closed) = vcs_watcher.changed().await {
break;
}
}
}
#[derive(Debug, Snafu)]
enum GetHeatError {
/// couldn't retrieve bot data
WithBotDataError { source: bot_data::WithError },
/// couldn't get the heat script from the bot data
GetHeatScriptError { source: capnp::Error },
/// the heat script is not a valid UTF-8 string
HeatScriptInvalidUtf8 { source: Utf8Error },
/// the heat script is not valid Rhai code
HeatScriptInvalidRhai { source: rhai::ParseError },
/// failed while evaluating the heat script
HeatScriptEvaluationError { source: Box<rhai::EvalAltResult> },
}
#[bon::builder]
#[tracing::instrument]
async fn get_heat(
users_in_vc: &BTreeMap<Id<UserMarker>, UserInVCData>,
bot_user_id: Id<UserMarker>,
bot_owner_user_id: Id<UserMarker>,
bot_data_manager: &BotDataManager,
) -> Result<Heat, GetHeatError> {
let heat_script = bot_data_manager
.with(|bot_data| {
bot_data.has_heat_script().then(|| {
bot_data
.get_heat_script()
.map(|heat_script| heat_script.to_string())
})
})
.await
.context(WithBotDataSnafu)?
.transpose()
.context(GetHeatScriptSnafu)?
.transpose()
.context(HeatScriptInvalidUtf8Snafu)?;
let engine = rhai::Engine::new();
let heat_function = heat_script
.map(|heat_script| engine.compile(heat_script))
.transpose()
.context(HeatScriptInvalidRhaiSnafu)?;
let heat = heat_function
.map(|heat_function| {
let mut scope = Default::default();
let args = (); // TODO
engine.call_fn(&mut scope, &heat_function, "heat", args)
})
.transpose()
.context(HeatScriptEvaluationSnafu)?;
let heat = heat.unwrap_or_else(|| {
tracing::warn!("using default heat scoring algorithm as no script was specified");
let mut users_in_vc = users_in_vc.clone();
let _bot = users_in_vc.remove(&bot_user_id);
let bot_owner = users_in_vc.remove(&bot_owner_user_id);
let mut heat = 0;
for (_user_id, user_in_vc_data) in users_in_vc {
if matches!(user_in_vc_data.microphone, Microphone::Unmuted) {
heat += 1000;
}
if matches!(user_in_vc_data.camera, Camera::Showing) {
heat += 100;
}
if matches!(user_in_vc_data.stream, Stream::Sharing) {
heat += 10;
}
if matches!(user_in_vc_data.headphone, Headphone::Undeafened) {
heat += 1;
}
}
if bot_owner.is_some() {
heat = heat.min(999);
}
heat
});
Ok(heat)
}
#[bon::builder]
#[tracing::instrument]
async fn evaluate_heat(
bot_data_manager: BotDataManager,
bot_owner_user_id: Id<UserMarker>,
bot_user_id: Id<UserMarker>,
mut vcs_in_guild_watcher: watch::Receiver<Arc<VCsInGuild>>,
channel_heat_sender: watch::Sender<ChannelHeat>,
) {
loop {
let vcs_in_guild = { vcs_in_guild_watcher.borrow().clone() };
let channel_heat_results: BTreeMap<_, _> = {
FuturesUnordered::from_iter((&*vcs_in_guild).into_iter().map(
|(&channel_id, users_in_vc)| {
let bot_data_manager = bot_data_manager.clone();
async move {
(
channel_id,
get_heat()
.bot_data_manager(&bot_data_manager)
.bot_owner_user_id(bot_owner_user_id)
.bot_user_id(bot_user_id)
.users_in_vc(users_in_vc)
.call()
.await,
)
}
},
))
}
.collect()
.await;
let (channel_heat, get_heat_errors): (ChannelHeat, Vec<_>) = channel_heat_results
.into_iter()
.map(|(channel_id, heat_result)| heat_result.map(|heat| (channel_id, heat)))
.partition_result();
channel_heat_sender.send_replace(channel_heat);
for get_heat_error in get_heat_errors {
tracing::error!(?get_heat_error, "failed to evaluate heat of channel")
}
if let Err(_closed) = vcs_in_guild_watcher.changed().await {
break;
}
}
}
#[tracing::instrument]
async fn map_heat(
mut channel_heat_watcher: watch::Receiver<ChannelHeat>,
heat_map_sender: watch::Sender<HeatMap>,
) {
loop {
heat_map_sender.send_modify(|heat_map| {
heat_map_sender.send_if_modified(|heat_map| {
let mut changed = false;
for (&channel, &heat) in &*channel_heat_watcher.borrow() {
heat_map.insert(heat, channel);
let existing = heat_map.insert(heat, channel);
if existing.map_or(true, |(old_heat, old_channel)| {
old_heat != heat || channel != old_channel
}) {
changed = true;
}
}
changed
});
if let Err(_closed) = channel_heat_watcher.changed().await {
@@ -36,6 +242,7 @@ async fn map_heat(
#[tracing::instrument]
async fn track_hottest_vc(
bot_owner_id: Id<UserMarker>,
mut heat_map_watcher: watch::Receiver<HeatMap>,
hottest_vc_sender: watch::Sender<Option<Id<ChannelMarker>>>,
) {
@@ -47,7 +254,6 @@ async fn track_hottest_vc(
.and_then(|(&heat, hottest_vcs)| {
let hot_option = Hot::new(heat);
// TODO: tiebreak by whichever one this bot is already in
hot_option.map(|_| *hottest_vcs.first().unwrap())
})
};
@@ -65,7 +271,7 @@ async fn track_hottest_vc(
}
#[tracing::instrument]
async fn follow_heat(
async fn follow_hottest_vc(
state: State,
guild_id: Id<GuildMarker>,
mut hottest_vc_watcher: watch::Receiver<Option<Id<ChannelMarker>>>,
@@ -146,8 +352,3 @@ async fn follow_heat(
}
}
}
#[tracing::instrument]
pub async fn heat_seek(state: State) {
todo!();
}

View File

@@ -17,11 +17,12 @@ shadow_rs::shadow!(build_info);
pub use bot_data::BotDataManager;
pub use command::{Router as CommandRouter, State, all as all_commands};
pub use heat_seek::heat_seek;
pub use one_to_many::OneToManyUniqueBTreeMap;
pub use one_to_many_with_data::OneToManyUniqueBTreeMapWithData;
pub use one_to_one::OneToOneBTreeMap;
pub use operator_ext::OperatorExt;
pub use storage::Storage;
pub use track_vcs::{GuildVoiceChannelToTextChannel, VCs, VCsWatcher, initialize_vcs, update_vcs};
pub use track_vcs::{GuildVoiceChannelToTextChannel, VCs, VCsSender, initialize_vcs, update_vcs};
pub use user_data::UserDataManager;
pub use vc_user::{UserInVCData, VoiceStatus};

View File

@@ -1,7 +1,6 @@
use clap::Parser;
use fomo_reducer::{
BotDataManager, CommandRouter, GuildVoiceChannelToTextChannel, State, Storage, UserDataManager,
VCsWatcher, all_commands, command, initialize_vcs, update_vcs,
BotDataManager, CommandRouter, GuildVoiceChannelToTextChannel, State, Storage, UserDataManager, VCsSender, all_commands, command, heat_seek, initialize_vcs, update_vcs
};
use secrecy::{ExposeSecret, SecretString};
use snafu::{OptionExt, ResultExt, Snafu};
@@ -336,7 +335,7 @@ async fn main() -> Result<(), MainError> {
let discord_client = Arc::new(discord_client);
let songbird = Arc::new(songbird);
let vcs_watcher = VCsWatcher::new(vcs);
let vcs_sender = VCsSender::new(vcs);
let bot_data = bot_data.into_inner();
let recording_data = recording_data.into_inner();
@@ -380,9 +379,11 @@ async fn main() -> Result<(), MainError> {
recording_data,
songbird,
user_data_manager,
vcs_watcher,
vcs_sender,
};
let heat_seeking = tokio::spawn(heat_seek(state.clone()));
if let Some(discord_status) = discord_status {
shards.iter().for_each(|shard| {
shard.command(
@@ -409,7 +410,6 @@ async fn main() -> Result<(), MainError> {
.map(|shard| handle_events(command_router.clone(), state.clone(), shard));
let run_shards = JoinSet::from_iter(run_shards);
let run_shards = run_shards.join_all();
tokio::pin!(run_shards);
tokio::spawn({
let cancellation_token = cancellation_token.clone();
@@ -432,13 +432,19 @@ async fn main() -> Result<(), MainError> {
}
});
let finished_naturally = async move {
heat_seeking.await.unwrap();
run_shards.await;
};
tokio::pin!(finished_naturally);
select! {
_ = &mut run_shards => {
_ = &mut finished_naturally => {
Ok(())
}
() = cancellation_token.cancelled() => {
tracing::warn!("waiting for tasks to gracefully shut down");
run_shards.await;
finished_naturally.await;
Err(MainError::Cancelled)
}
@@ -496,7 +502,7 @@ async fn handle_event(command_router: Arc<CommandRouter>, state: State, event: E
match event {
Event::VoiceStateUpdate(voice_state_update) => {
state
.vcs_watcher
.vcs_sender
.send_modify(|vcs| update_vcs(&voice_state_update, vcs));
}
Event::InteractionCreate(interaction_create) => {

View File

@@ -94,6 +94,30 @@ where
}
}
impl<Left, Right, RightData> IntoIterator
for OneToManyUniqueBTreeMapWithData<Left, Right, RightData>
{
type Item = (Left, BTreeMap<Right, RightData>);
type IntoIter = <BTreeMap<Left, BTreeMap<Right, RightData>> as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
self.left_to_rights.into_iter()
}
}
impl<'a, Left, Right, RightData> IntoIterator
for &'a OneToManyUniqueBTreeMapWithData<Left, Right, RightData>
{
type Item = (&'a Left, &'a BTreeMap<Right, RightData>);
type IntoIter = <&'a BTreeMap<Left, BTreeMap<Right, RightData>> as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
self.left_to_rights.iter()
}
}
impl<Left, Right, RightData> FromIterator<(Left, Right, RightData)>
for OneToManyUniqueBTreeMapWithData<Left, Right, RightData>
where

View File

@@ -18,7 +18,7 @@ pub type GuildVoiceChannelToTextChannel =
pub type VCsInGuild =
OneToManyUniqueBTreeMapWithData<Id<ChannelMarker>, Id<UserMarker>, UserInVCData>;
pub type VCs = BTreeMap<Id<GuildMarker>, VCsInGuild>;
pub type VCsWatcher = watch::Sender<VCs>;
pub type VCsSender = watch::Sender<VCs>;
#[tracing::instrument(skip(discord_client), ret)]
async fn initialize_user_in_vc(

View File

@@ -1,20 +1,20 @@
use typed_builder::TypedBuilder;
#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
pub enum Microphone {
Unmuted,
ServerMuted,
Muted,
}
#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
pub enum Headphone {
Undeafened,
ServerDeafened,
Deafened,
}
#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
pub enum Camera {
Showing,
Off,
@@ -30,7 +30,7 @@ impl From<bool> for Camera {
}
}
#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
pub enum Stream {
Sharing,
None,
@@ -46,7 +46,7 @@ impl From<bool> for Stream {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct UserInVCData {
pub microphone: Microphone,
pub headphone: Headphone,
@@ -54,7 +54,7 @@ pub struct UserInVCData {
pub stream: Stream,
}
#[derive(Debug, TypedBuilder)]
#[derive(Debug, Clone, TypedBuilder)]
pub struct VoiceStatus {
server_deafened: bool,
self_deafened: bool,