Driver: replace the other ThreadPool with a JobPool

as a bonus we can remove the threadpool dep :-)
This commit is contained in:
2022-08-23 00:03:58 -07:00
parent 7586bc8ff2
commit 3c30ac33aa
4 changed files with 69 additions and 39 deletions

1
Cargo.lock generated
View File

@@ -338,7 +338,6 @@ dependencies = [
"spirv-std-macros", "spirv-std-macros",
"spirv_backend", "spirv_backend",
"spirv_backend_runner", "spirv_backend_runner",
"threadpool",
"wgpu", "wgpu",
"y4m", "y4m",
] ]

View File

@@ -32,7 +32,6 @@ num = "0.4" # MIT or Apache 2.0
rand = "0.8" # MIT or Apache 2.0 rand = "0.8" # MIT or Apache 2.0
rayon = "1.5" # MIT or Apache 2.0 rayon = "1.5" # MIT or Apache 2.0
serde = "1.0" # MIT or Apache 2.0 serde = "1.0" # MIT or Apache 2.0
threadpool = "1.8" # MIT or Apache 2.0
y4m = "0.7" # MIT y4m = "0.7" # MIT
wgpu = "0.12" wgpu = "0.12"

View File

@@ -27,16 +27,13 @@ use serde::{Deserialize, Serialize};
use std::cell::Cell; use std::cell::Cell;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::sync::mpsc::{sync_channel, SyncSender, Receiver};
use std::time::Instant; use std::time::Instant;
use threadpool::ThreadPool;
pub struct Driver<R, S, Stim=DriverStimulusDynVec<R>> { pub struct Driver<R, S, Stim=DriverStimulusDynVec<R>> {
state: S, state: S,
renderer: Arc<MultiRenderer<S>>, renderer: Arc<MultiRenderer<S>>,
// TODO: use Rayon's thread pool? // TODO: use Rayon's thread pool?
render_pool: ThreadPool, render_pool: JobPool<S, ()>,
render_channel: (SyncSender<()>, Receiver<()>),
measurements: Vec<Arc<dyn AbstractMeasurement<S>>>, measurements: Vec<Arc<dyn AbstractMeasurement<S>>>,
stimuli: StimAccess<R, Stim>, stimuli: StimAccess<R, Stim>,
/// simulation end time /// simulation end time
@@ -52,8 +49,7 @@ impl<S: AbstractSim, Stim> Driver<S::Real, S, Stim> {
Self { Self {
state, state,
renderer: Arc::new(MultiRenderer::new()), renderer: Arc::new(MultiRenderer::new()),
render_pool: ThreadPool::new(3), render_pool: JobPool::new(1),
render_channel: sync_channel(0),
measurements: vec![ measurements: vec![
Arc::new(meas::Time), Arc::new(meas::Time),
Arc::new(meas::Meta), Arc::new(meas::Meta),
@@ -101,7 +97,6 @@ impl<S: AbstractSim, Stim> Driver<S::Real, S, Stim> {
state: self.state, state: self.state,
renderer: self.renderer, renderer: self.renderer,
render_pool: self.render_pool, render_pool: self.render_pool,
render_channel: self.render_channel,
measurements: self.measurements, measurements: self.measurements,
stimuli: StimAccess::new(self.diag.clone(), self.stimuli.into_inner().append(s)), stimuli: StimAccess::new(self.diag.clone(), self.stimuli.into_inner().append(s)),
sim_end_time: self.sim_end_time, sim_end_time: self.sim_end_time,
@@ -114,7 +109,6 @@ impl<S: AbstractSim, Stim> Driver<S::Real, S, Stim> {
state: self.state, state: self.state,
renderer: self.renderer, renderer: self.renderer,
render_pool: self.render_pool, render_pool: self.render_pool,
render_channel: self.render_channel,
measurements: self.measurements, measurements: self.measurements,
stimuli: StimAccess::new(self.diag.clone(), stimuli), stimuli: StimAccess::new(self.diag.clone(), stimuli),
sim_end_time: self.sim_end_time, sim_end_time: self.sim_end_time,
@@ -229,25 +223,32 @@ where
Stim: DriverStimulus<S::Real> + Send + 'static, Stim: DriverStimulus<S::Real> + Send + 'static,
{ {
fn render(&mut self) { fn render(&mut self) {
self.diag.instrument_render_prep(|| { let their_state = self.diag.instrument_render_prep(|| {
let diag_handle = self.diag.clone(); if self.render_pool.num_workers() != 3 {
let their_state = self.state.clone(); let diag = self.diag.clone();
let their_measurements = self.measurements.clone(); // TODO: these measurements will come to differ from the ones in the Driver,
let renderer = self.renderer.clone(); // if the user calls `add_measurement`!
let sender = self.render_channel.0.clone(); let measurements = self.measurements.clone();
self.render_pool.execute(move || { let renderer = self.renderer.clone();
// unblock the main thread (this limits the number of renders in-flight at any time self.render_pool.spawn_workers(3, move |state| {
sender.send(()).unwrap(); // unblock the main thread (this limits the number of renders in-flight at any time
trace!("render begin"); trace!("render begin");
diag_handle.instrument_render_cpu_side(|| { diag.instrument_render_cpu_side(|| {
let meas: Vec<&dyn AbstractMeasurement<S>> = their_measurements.iter().map(|m| &**m).collect(); let meas: Vec<&dyn AbstractMeasurement<S>> = measurements.iter().map(|m| &**m).collect();
renderer.render(&their_state, &*meas, Default::default()); renderer.render(&state, &*meas, Default::default());
});
trace!("render end");
}); });
trace!("render end"); }
}); self.state.clone()
}); });
// TODO: this instrumentation is not 100% accurate.
// - 'prep' and 'blocked' have effectively been folded together.
// - either delete 'prep', or change this block to use a `try_send` (prep) followed by a
// `send` (blocking)
self.diag.instrument_render_blocked(|| { self.diag.instrument_render_blocked(|| {
self.render_channel.1.recv().unwrap(); self.render_pool.tend();
self.render_pool.send(their_state);
}); });
} }
/// Return the number of steps actually stepped /// Return the number of steps actually stepped
@@ -327,7 +328,7 @@ where
// render the final frame -- unless we already *have* // render the final frame -- unless we already *have*
self.render(); self.render();
} }
self.render_pool.join(); self.render_pool.join_workers();
self.sim_end_time = None; self.sim_end_time = None;
} }
} }

View File

@@ -50,6 +50,23 @@ impl<C, R> JobPool<C, R> {
pub fn num_workers(&self) -> u32 { pub fn num_workers(&self) -> u32 {
self.handles.len().try_into().unwrap() self.handles.len().try_into().unwrap()
} }
pub fn recv(&self) -> R {
self.response_chan.recv().unwrap()
}
/// `try_recv`. named `tend` because this is often used when we want to ensure no workers are
/// blocked due to lack of space in the output queue.
pub fn tend(&self) -> Option<R> {
self.response_chan.try_recv().ok()
}
pub fn join_workers(&mut self) {
// hang up the sender, to signal workers to exit.
let cap = self.command_chan.capacity().unwrap_or(0);
(self.command_chan, self.worker_command_chan) = channel::bounded(cap);
(self.worker_response_chan, self.response_chan) = channel::bounded(cap);
for h in self.handles.drain(..) {
h.join().unwrap();
}
}
} }
impl<C: Send + 'static, R: Send + 'static> JobPool<C, R> { impl<C: Send + 'static, R: Send + 'static> JobPool<C, R> {
@@ -75,12 +92,7 @@ impl<C: Send + 'static, R: Send + 'static> JobPool<C, R> {
impl<C, R> Drop for JobPool<C, R> { impl<C, R> Drop for JobPool<C, R> {
fn drop(&mut self) { fn drop(&mut self) {
// hang up the sender, to signal workers to exit. self.join_workers();
(self.command_chan, _) = channel::bounded(0);
(_, self.response_chan) = channel::bounded(0);
for h in self.handles.drain(..) {
h.join().unwrap();
}
} }
} }
@@ -90,12 +102,6 @@ impl<C: Send + 'static, R> JobPool<C, R> {
} }
} }
impl<C, R> JobPool<C, R> {
pub fn recv(&self) -> R {
self.response_chan.recv().unwrap()
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use super::*;
@@ -182,4 +188,29 @@ mod test {
pool.send(1); pool.send(1);
pool.send(2); pool.send(2);
} }
#[test]
fn join_workers() {
let mut pool: JobPool<u32, u32> = JobPool::new(1);
pool.spawn_workers(2, |x| x*2);
pool.send(5);
pool.join_workers();
pool.spawn_workers(2, |x| x*2);
pool.send(4);
// the earlier response to '5' should be lost in the channel
assert_eq!(pool.recv(), 8);
// one message in the response queue; one in the send queue, 2 in the worker threads
pool.send(3); pool.send(2); pool.send(1); pool.send(0);
// should still be able to join even though everyone's blocked.
pool.join_workers();
pool.spawn_workers(1, |x| x*2);
pool.send(7);
// the old '0' command should be lost in the channel
assert_eq!(pool.recv(), 14);
}
} }