Gateway: Fix repeat joins on same channel from stalling (#47)
Joining a channel returns a future which fires on receipt of two messages from discord (by locally storing a channel). However, joining this same channel again after a success returns only *one* such message, causing the command to hang until another join fires or the channel is left. This alters internal behaviour to correctly cancel an in-progress connection attempt, or return success with known data if such a connection is present. This introduces a breaking change on `Call::update_state` to include the target `ChannelId`. The reason for this is that although the `ChannelId` of a target channel was being stored, server admins may move or kick a bot from its voice channel. This changes the true channel, and may accidentally trigger a "double join" elsewhere. This fix was tested by using an example to have a bot join its channel twice, to do so in a channel it had been moved to, and to move from a channel it had been moved to.
This commit is contained in:
108
src/handler.rs
108
src/handler.rs
@@ -32,7 +32,7 @@ enum Return {
|
||||
/// [`Driver`]: struct@Driver
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Call {
|
||||
connection: Option<(ChannelId, ConnectionProgress, Return)>,
|
||||
connection: Option<(ConnectionProgress, Return)>,
|
||||
|
||||
#[cfg(feature = "driver-core")]
|
||||
/// The internal controller of the voice connection monitor thread.
|
||||
@@ -132,12 +132,12 @@ impl Call {
|
||||
#[instrument(skip(self))]
|
||||
fn do_connect(&mut self) {
|
||||
match &self.connection {
|
||||
Some((_, ConnectionProgress::Complete(c), Return::Info(tx))) => {
|
||||
Some((ConnectionProgress::Complete(c), Return::Info(tx))) => {
|
||||
// It's okay if the receiver hung up.
|
||||
let _ = tx.send(c.clone());
|
||||
},
|
||||
#[cfg(feature = "driver-core")]
|
||||
Some((_, ConnectionProgress::Complete(c), Return::Conn(tx))) => {
|
||||
Some((ConnectionProgress::Complete(c), Return::Conn(tx))) => {
|
||||
self.driver.raw_connect(c.clone(), tx.clone());
|
||||
},
|
||||
_ => {},
|
||||
@@ -171,6 +171,31 @@ impl Call {
|
||||
self.self_deaf
|
||||
}
|
||||
|
||||
async fn should_actually_join<F, G>(
|
||||
&mut self,
|
||||
completion_generator: F,
|
||||
tx: &Sender<G>,
|
||||
channel_id: ChannelId,
|
||||
) -> JoinResult<bool>
|
||||
where
|
||||
F: FnOnce(&Self) -> G,
|
||||
{
|
||||
Ok(if let Some(conn) = &self.connection {
|
||||
if conn.0.in_progress() {
|
||||
self.leave().await?;
|
||||
true
|
||||
} else if conn.0.channel_id() == channel_id {
|
||||
let _ = tx.send(completion_generator(&self));
|
||||
false
|
||||
} else {
|
||||
// not in progress, and/or a channel change.
|
||||
true
|
||||
}
|
||||
} else {
|
||||
true
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(feature = "driver-core")]
|
||||
/// Connect or switch to the given voice channel by its Id.
|
||||
///
|
||||
@@ -190,13 +215,20 @@ impl Call {
|
||||
) -> JoinResult<RecvFut<'static, ConnectionResult<()>>> {
|
||||
let (tx, rx) = flume::unbounded();
|
||||
|
||||
self.connection = Some((
|
||||
channel_id,
|
||||
ConnectionProgress::new(self.guild_id, self.user_id),
|
||||
Return::Conn(tx),
|
||||
));
|
||||
let do_conn = self
|
||||
.should_actually_join(|_| Ok(()), &tx, channel_id)
|
||||
.await?;
|
||||
|
||||
self.update().await.map(|_| rx.into_recv_async())
|
||||
if do_conn {
|
||||
self.connection = Some((
|
||||
ConnectionProgress::new(self.guild_id, self.user_id, channel_id),
|
||||
Return::Conn(tx),
|
||||
));
|
||||
|
||||
self.update().await.map(|_| rx.into_recv_async())
|
||||
} else {
|
||||
Ok(rx.into_recv_async())
|
||||
}
|
||||
}
|
||||
|
||||
/// Join the selected voice channel, *without* running/starting an RTP
|
||||
@@ -221,13 +253,24 @@ impl Call {
|
||||
) -> JoinResult<RecvFut<'static, ConnectionInfo>> {
|
||||
let (tx, rx) = flume::unbounded();
|
||||
|
||||
self.connection = Some((
|
||||
channel_id,
|
||||
ConnectionProgress::new(self.guild_id, self.user_id),
|
||||
Return::Info(tx),
|
||||
));
|
||||
let do_conn = self
|
||||
.should_actually_join(
|
||||
|call| call.connection.as_ref().unwrap().0.info().unwrap(),
|
||||
&tx,
|
||||
channel_id,
|
||||
)
|
||||
.await?;
|
||||
|
||||
self.update().await.map(|_| rx.into_recv_async())
|
||||
if do_conn {
|
||||
self.connection = Some((
|
||||
ConnectionProgress::new(self.guild_id, self.user_id, channel_id),
|
||||
Return::Info(tx),
|
||||
));
|
||||
|
||||
self.update().await.map(|_| rx.into_recv_async())
|
||||
} else {
|
||||
Ok(rx.into_recv_async())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the current voice connection details for this Call,
|
||||
@@ -235,7 +278,7 @@ impl Call {
|
||||
#[instrument(skip(self))]
|
||||
pub fn current_connection(&self) -> Option<&ConnectionInfo> {
|
||||
match &self.connection {
|
||||
Some((_, progress, _)) => progress.get_connection_info(),
|
||||
Some((progress, _)) => progress.get_connection_info(),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@@ -265,13 +308,17 @@ impl Call {
|
||||
/// [`standalone`]: Call::standalone
|
||||
#[instrument(skip(self))]
|
||||
pub async fn leave(&mut self) -> JoinResult<()> {
|
||||
self.leave_local();
|
||||
|
||||
// Only send an update if we were in a voice channel.
|
||||
self.update().await
|
||||
}
|
||||
|
||||
fn leave_local(&mut self) {
|
||||
self.connection = None;
|
||||
|
||||
#[cfg(feature = "driver-core")]
|
||||
self.driver.leave();
|
||||
|
||||
self.update().await
|
||||
}
|
||||
|
||||
/// Sets whether the current connection is to be muted.
|
||||
@@ -307,7 +354,7 @@ impl Call {
|
||||
/// [`standalone`]: Call::standalone
|
||||
#[instrument(skip(self, token))]
|
||||
pub fn update_server(&mut self, endpoint: String, token: String) {
|
||||
let try_conn = if let Some((_, ref mut progress, _)) = self.connection.as_mut() {
|
||||
let try_conn = if let Some((ref mut progress, _)) = self.connection.as_mut() {
|
||||
progress.apply_server_update(endpoint, token)
|
||||
} else {
|
||||
false
|
||||
@@ -325,15 +372,20 @@ impl Call {
|
||||
///
|
||||
/// [`standalone`]: Call::standalone
|
||||
#[instrument(skip(self))]
|
||||
pub fn update_state(&mut self, session_id: String) {
|
||||
let try_conn = if let Some((_, ref mut progress, _)) = self.connection.as_mut() {
|
||||
progress.apply_state_update(session_id)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
pub fn update_state(&mut self, session_id: String, channel_id: Option<ChannelId>) {
|
||||
if let Some(channel_id) = channel_id {
|
||||
let try_conn = if let Some((ref mut progress, _)) = self.connection.as_mut() {
|
||||
progress.apply_state_update(session_id, channel_id)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if try_conn {
|
||||
self.do_connect();
|
||||
if try_conn {
|
||||
self.do_connect();
|
||||
}
|
||||
} else {
|
||||
// Likely that we were disconnected by an admin.
|
||||
self.leave_local();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -348,7 +400,7 @@ impl Call {
|
||||
let map = json!({
|
||||
"op": 4,
|
||||
"d": {
|
||||
"channel_id": self.connection.as_ref().map(|c| c.0.0),
|
||||
"channel_id": self.connection.as_ref().map(|c| c.0.channel_id().0),
|
||||
"guild_id": self.guild_id.0,
|
||||
"self_deaf": self.self_deaf,
|
||||
"self_mute": self.self_mute,
|
||||
|
||||
Reference in New Issue
Block a user