Driver: replace the other ThreadPool with a JobPool
as a bonus we can remove the threadpool dep :-)
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -338,7 +338,6 @@ dependencies = [
|
||||
"spirv-std-macros",
|
||||
"spirv_backend",
|
||||
"spirv_backend_runner",
|
||||
"threadpool",
|
||||
"wgpu",
|
||||
"y4m",
|
||||
]
|
||||
|
@@ -32,7 +32,6 @@ num = "0.4" # MIT or Apache 2.0
|
||||
rand = "0.8" # MIT or Apache 2.0
|
||||
rayon = "1.5" # MIT or Apache 2.0
|
||||
serde = "1.0" # MIT or Apache 2.0
|
||||
threadpool = "1.8" # MIT or Apache 2.0
|
||||
y4m = "0.7" # MIT
|
||||
|
||||
wgpu = "0.12"
|
||||
|
@@ -27,16 +27,13 @@ use serde::{Deserialize, Serialize};
|
||||
use std::cell::Cell;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::mpsc::{sync_channel, SyncSender, Receiver};
|
||||
use std::time::Instant;
|
||||
use threadpool::ThreadPool;
|
||||
|
||||
pub struct Driver<R, S, Stim=DriverStimulusDynVec<R>> {
|
||||
state: S,
|
||||
renderer: Arc<MultiRenderer<S>>,
|
||||
// TODO: use Rayon's thread pool?
|
||||
render_pool: ThreadPool,
|
||||
render_channel: (SyncSender<()>, Receiver<()>),
|
||||
render_pool: JobPool<S, ()>,
|
||||
measurements: Vec<Arc<dyn AbstractMeasurement<S>>>,
|
||||
stimuli: StimAccess<R, Stim>,
|
||||
/// simulation end time
|
||||
@@ -52,8 +49,7 @@ impl<S: AbstractSim, Stim> Driver<S::Real, S, Stim> {
|
||||
Self {
|
||||
state,
|
||||
renderer: Arc::new(MultiRenderer::new()),
|
||||
render_pool: ThreadPool::new(3),
|
||||
render_channel: sync_channel(0),
|
||||
render_pool: JobPool::new(1),
|
||||
measurements: vec![
|
||||
Arc::new(meas::Time),
|
||||
Arc::new(meas::Meta),
|
||||
@@ -101,7 +97,6 @@ impl<S: AbstractSim, Stim> Driver<S::Real, S, Stim> {
|
||||
state: self.state,
|
||||
renderer: self.renderer,
|
||||
render_pool: self.render_pool,
|
||||
render_channel: self.render_channel,
|
||||
measurements: self.measurements,
|
||||
stimuli: StimAccess::new(self.diag.clone(), self.stimuli.into_inner().append(s)),
|
||||
sim_end_time: self.sim_end_time,
|
||||
@@ -114,7 +109,6 @@ impl<S: AbstractSim, Stim> Driver<S::Real, S, Stim> {
|
||||
state: self.state,
|
||||
renderer: self.renderer,
|
||||
render_pool: self.render_pool,
|
||||
render_channel: self.render_channel,
|
||||
measurements: self.measurements,
|
||||
stimuli: StimAccess::new(self.diag.clone(), stimuli),
|
||||
sim_end_time: self.sim_end_time,
|
||||
@@ -229,25 +223,32 @@ where
|
||||
Stim: DriverStimulus<S::Real> + Send + 'static,
|
||||
{
|
||||
fn render(&mut self) {
|
||||
self.diag.instrument_render_prep(|| {
|
||||
let diag_handle = self.diag.clone();
|
||||
let their_state = self.state.clone();
|
||||
let their_measurements = self.measurements.clone();
|
||||
let renderer = self.renderer.clone();
|
||||
let sender = self.render_channel.0.clone();
|
||||
self.render_pool.execute(move || {
|
||||
// unblock the main thread (this limits the number of renders in-flight at any time
|
||||
sender.send(()).unwrap();
|
||||
trace!("render begin");
|
||||
diag_handle.instrument_render_cpu_side(|| {
|
||||
let meas: Vec<&dyn AbstractMeasurement<S>> = their_measurements.iter().map(|m| &**m).collect();
|
||||
renderer.render(&their_state, &*meas, Default::default());
|
||||
let their_state = self.diag.instrument_render_prep(|| {
|
||||
if self.render_pool.num_workers() != 3 {
|
||||
let diag = self.diag.clone();
|
||||
// TODO: these measurements will come to differ from the ones in the Driver,
|
||||
// if the user calls `add_measurement`!
|
||||
let measurements = self.measurements.clone();
|
||||
let renderer = self.renderer.clone();
|
||||
self.render_pool.spawn_workers(3, move |state| {
|
||||
// unblock the main thread (this limits the number of renders in-flight at any time
|
||||
trace!("render begin");
|
||||
diag.instrument_render_cpu_side(|| {
|
||||
let meas: Vec<&dyn AbstractMeasurement<S>> = measurements.iter().map(|m| &**m).collect();
|
||||
renderer.render(&state, &*meas, Default::default());
|
||||
});
|
||||
trace!("render end");
|
||||
});
|
||||
trace!("render end");
|
||||
});
|
||||
}
|
||||
self.state.clone()
|
||||
});
|
||||
// TODO: this instrumentation is not 100% accurate.
|
||||
// - 'prep' and 'blocked' have effectively been folded together.
|
||||
// - either delete 'prep', or change this block to use a `try_send` (prep) followed by a
|
||||
// `send` (blocking)
|
||||
self.diag.instrument_render_blocked(|| {
|
||||
self.render_channel.1.recv().unwrap();
|
||||
self.render_pool.tend();
|
||||
self.render_pool.send(their_state);
|
||||
});
|
||||
}
|
||||
/// Return the number of steps actually stepped
|
||||
@@ -327,7 +328,7 @@ where
|
||||
// render the final frame -- unless we already *have*
|
||||
self.render();
|
||||
}
|
||||
self.render_pool.join();
|
||||
self.render_pool.join_workers();
|
||||
self.sim_end_time = None;
|
||||
}
|
||||
}
|
||||
|
@@ -50,6 +50,23 @@ impl<C, R> JobPool<C, R> {
|
||||
pub fn num_workers(&self) -> u32 {
|
||||
self.handles.len().try_into().unwrap()
|
||||
}
|
||||
pub fn recv(&self) -> R {
|
||||
self.response_chan.recv().unwrap()
|
||||
}
|
||||
/// `try_recv`. named `tend` because this is often used when we want to ensure no workers are
|
||||
/// blocked due to lack of space in the output queue.
|
||||
pub fn tend(&self) -> Option<R> {
|
||||
self.response_chan.try_recv().ok()
|
||||
}
|
||||
pub fn join_workers(&mut self) {
|
||||
// hang up the sender, to signal workers to exit.
|
||||
let cap = self.command_chan.capacity().unwrap_or(0);
|
||||
(self.command_chan, self.worker_command_chan) = channel::bounded(cap);
|
||||
(self.worker_response_chan, self.response_chan) = channel::bounded(cap);
|
||||
for h in self.handles.drain(..) {
|
||||
h.join().unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<C: Send + 'static, R: Send + 'static> JobPool<C, R> {
|
||||
@@ -75,12 +92,7 @@ impl<C: Send + 'static, R: Send + 'static> JobPool<C, R> {
|
||||
|
||||
impl<C, R> Drop for JobPool<C, R> {
|
||||
fn drop(&mut self) {
|
||||
// hang up the sender, to signal workers to exit.
|
||||
(self.command_chan, _) = channel::bounded(0);
|
||||
(_, self.response_chan) = channel::bounded(0);
|
||||
for h in self.handles.drain(..) {
|
||||
h.join().unwrap();
|
||||
}
|
||||
self.join_workers();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,12 +102,6 @@ impl<C: Send + 'static, R> JobPool<C, R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<C, R> JobPool<C, R> {
|
||||
pub fn recv(&self) -> R {
|
||||
self.response_chan.recv().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
@@ -182,4 +188,29 @@ mod test {
|
||||
pool.send(1);
|
||||
pool.send(2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn join_workers() {
|
||||
let mut pool: JobPool<u32, u32> = JobPool::new(1);
|
||||
|
||||
pool.spawn_workers(2, |x| x*2);
|
||||
pool.send(5);
|
||||
pool.join_workers();
|
||||
|
||||
pool.spawn_workers(2, |x| x*2);
|
||||
pool.send(4);
|
||||
// the earlier response to '5' should be lost in the channel
|
||||
assert_eq!(pool.recv(), 8);
|
||||
|
||||
// one message in the response queue; one in the send queue, 2 in the worker threads
|
||||
pool.send(3); pool.send(2); pool.send(1); pool.send(0);
|
||||
// should still be able to join even though everyone's blocked.
|
||||
pool.join_workers();
|
||||
|
||||
pool.spawn_workers(1, |x| x*2);
|
||||
pool.send(7);
|
||||
// the old '0' command should be lost in the channel
|
||||
assert_eq!(pool.recv(), 14);
|
||||
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user