Use a tokio::watch in the supervisor to avoid races

This commit is contained in:
Félix Saparelli 2022-01-31 00:00:51 +13:00
parent 3e942c4d19
commit 995d38078e
No known key found for this signature in database
GPG Key ID: B948C4BAE44FC474
1 changed files with 20 additions and 48 deletions

View File

@ -1,15 +1,10 @@
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use command_group::AsyncCommandGroup;
use tokio::{
process::Command,
select, spawn,
sync::{
mpsc::{self, Sender},
oneshot,
watch,
},
};
use tracing::{debug, error, trace};
@ -37,12 +32,7 @@ enum Intervention {
pub struct Supervisor {
id: u32,
intervene: Sender<Intervention>,
// why this and not a watch::channel? two reasons:
// 1. I tried the watch and ran into some race conditions???
// 2. This way it's typed-enforced that I send only once
waiter: Option<oneshot::Receiver<()>>,
ongoing: Arc<AtomicBool>,
ongoing: watch::Receiver<bool>,
}
impl Supervisor {
@ -72,11 +62,9 @@ impl Supervisor {
(Process::Ungrouped(proc), id)
};
let ongoing = Arc::new(AtomicBool::new(true));
let (notify, waiter) = oneshot::channel();
let (notify, waiter) = watch::channel(true);
let (int_s, int_r) = mpsc::channel(8);
let going = ongoing.clone();
spawn(async move {
let mut process = process;
let mut int = int_r;
@ -92,9 +80,8 @@ impl Supervisor {
error!(%err, "while waiting on process");
errors.send(err).await.ok();
trace!("marking process as done");
going.store(false, Ordering::SeqCst);
notify.send(false).unwrap_or_else(|e| trace!(%e, "error sending process complete"));
trace!("closing supervisor task early");
notify.send(()).ok();
return;
}
}
@ -167,15 +154,15 @@ impl Supervisor {
}
trace!("marking process as done");
going.store(false, Ordering::SeqCst);
notify
.send(false)
.unwrap_or_else(|e| trace!(%e, "error sending process complete"));
trace!("closing supervisor task");
notify.send(()).ok();
});
Ok(Self {
id,
waiter: Some(waiter),
ongoing,
ongoing: waiter,
intervene: int_s,
})
}
@ -223,7 +210,7 @@ impl Supervisor {
/// This is almost always equivalent to whether the _process_ is still running, but may not be
/// 100% in sync.
pub fn is_running(&self) -> bool {
let ongoing = self.ongoing.load(Ordering::SeqCst);
let ongoing = *self.ongoing.borrow();
trace!(?ongoing, "supervisor state");
ongoing
}
@ -232,36 +219,21 @@ impl Supervisor {
///
/// This is almost always equivalent to waiting for the _process_ to complete, but may not be
/// 100% in sync.
pub async fn wait(&mut self) -> Result<(), RuntimeError> {
if !self.ongoing.load(Ordering::SeqCst) {
pub async fn wait(&self) -> Result<(), RuntimeError> {
if !*self.ongoing.borrow() {
trace!("supervisor already completed (pre)");
return Ok(());
}
if let Some(waiter) = self.waiter.take() {
debug!("waiting on supervisor completion");
waiter
.await
.map_err(|err| RuntimeError::InternalSupervisor(err.to_string()))?;
debug!("supervisor completed");
if self.ongoing.swap(false, Ordering::SeqCst) {
#[cfg(debug_assertions)]
panic!("oneshot completed but ongoing was true, this should never happen");
#[cfg(not(debug_assertions))]
tracing::warn!("oneshot completed but ongoing was true, this should never happen");
}
} else if self.ongoing.load(Ordering::SeqCst) {
#[cfg(debug_assertions)]
panic!("waiter is None but ongoing was true, this should never happen");
#[cfg(not(debug_assertions))]
{
self.ongoing.store(false, Ordering::SeqCst);
tracing::warn!("waiter is None but ongoing was true, this should never happen");
}
} else {
trace!("supervisor already completed (post)");
}
debug!("waiting on supervisor completion");
let mut ongoing = self.ongoing.clone();
// never completes if ongoing is marked false in between the previous check and now!
// TODO: select with something that sleeps a bit and rechecks the ongoing
ongoing
.changed()
.await
.map_err(|err| RuntimeError::InternalSupervisor(err.to_string()))?;
debug!("supervisor completed");
Ok(())
}