Driver: replace the other ThreadPool with a JobPool
as a bonus we can remove the threadpool dep :-)
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -338,7 +338,6 @@ dependencies = [
|
|||||||
"spirv-std-macros",
|
"spirv-std-macros",
|
||||||
"spirv_backend",
|
"spirv_backend",
|
||||||
"spirv_backend_runner",
|
"spirv_backend_runner",
|
||||||
"threadpool",
|
|
||||||
"wgpu",
|
"wgpu",
|
||||||
"y4m",
|
"y4m",
|
||||||
]
|
]
|
||||||
|
@@ -32,7 +32,6 @@ num = "0.4" # MIT or Apache 2.0
|
|||||||
rand = "0.8" # MIT or Apache 2.0
|
rand = "0.8" # MIT or Apache 2.0
|
||||||
rayon = "1.5" # MIT or Apache 2.0
|
rayon = "1.5" # MIT or Apache 2.0
|
||||||
serde = "1.0" # MIT or Apache 2.0
|
serde = "1.0" # MIT or Apache 2.0
|
||||||
threadpool = "1.8" # MIT or Apache 2.0
|
|
||||||
y4m = "0.7" # MIT
|
y4m = "0.7" # MIT
|
||||||
|
|
||||||
wgpu = "0.12"
|
wgpu = "0.12"
|
||||||
|
@@ -27,16 +27,13 @@ use serde::{Deserialize, Serialize};
|
|||||||
use std::cell::Cell;
|
use std::cell::Cell;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
use std::sync::mpsc::{sync_channel, SyncSender, Receiver};
|
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use threadpool::ThreadPool;
|
|
||||||
|
|
||||||
pub struct Driver<R, S, Stim=DriverStimulusDynVec<R>> {
|
pub struct Driver<R, S, Stim=DriverStimulusDynVec<R>> {
|
||||||
state: S,
|
state: S,
|
||||||
renderer: Arc<MultiRenderer<S>>,
|
renderer: Arc<MultiRenderer<S>>,
|
||||||
// TODO: use Rayon's thread pool?
|
// TODO: use Rayon's thread pool?
|
||||||
render_pool: ThreadPool,
|
render_pool: JobPool<S, ()>,
|
||||||
render_channel: (SyncSender<()>, Receiver<()>),
|
|
||||||
measurements: Vec<Arc<dyn AbstractMeasurement<S>>>,
|
measurements: Vec<Arc<dyn AbstractMeasurement<S>>>,
|
||||||
stimuli: StimAccess<R, Stim>,
|
stimuli: StimAccess<R, Stim>,
|
||||||
/// simulation end time
|
/// simulation end time
|
||||||
@@ -52,8 +49,7 @@ impl<S: AbstractSim, Stim> Driver<S::Real, S, Stim> {
|
|||||||
Self {
|
Self {
|
||||||
state,
|
state,
|
||||||
renderer: Arc::new(MultiRenderer::new()),
|
renderer: Arc::new(MultiRenderer::new()),
|
||||||
render_pool: ThreadPool::new(3),
|
render_pool: JobPool::new(1),
|
||||||
render_channel: sync_channel(0),
|
|
||||||
measurements: vec![
|
measurements: vec![
|
||||||
Arc::new(meas::Time),
|
Arc::new(meas::Time),
|
||||||
Arc::new(meas::Meta),
|
Arc::new(meas::Meta),
|
||||||
@@ -101,7 +97,6 @@ impl<S: AbstractSim, Stim> Driver<S::Real, S, Stim> {
|
|||||||
state: self.state,
|
state: self.state,
|
||||||
renderer: self.renderer,
|
renderer: self.renderer,
|
||||||
render_pool: self.render_pool,
|
render_pool: self.render_pool,
|
||||||
render_channel: self.render_channel,
|
|
||||||
measurements: self.measurements,
|
measurements: self.measurements,
|
||||||
stimuli: StimAccess::new(self.diag.clone(), self.stimuli.into_inner().append(s)),
|
stimuli: StimAccess::new(self.diag.clone(), self.stimuli.into_inner().append(s)),
|
||||||
sim_end_time: self.sim_end_time,
|
sim_end_time: self.sim_end_time,
|
||||||
@@ -114,7 +109,6 @@ impl<S: AbstractSim, Stim> Driver<S::Real, S, Stim> {
|
|||||||
state: self.state,
|
state: self.state,
|
||||||
renderer: self.renderer,
|
renderer: self.renderer,
|
||||||
render_pool: self.render_pool,
|
render_pool: self.render_pool,
|
||||||
render_channel: self.render_channel,
|
|
||||||
measurements: self.measurements,
|
measurements: self.measurements,
|
||||||
stimuli: StimAccess::new(self.diag.clone(), stimuli),
|
stimuli: StimAccess::new(self.diag.clone(), stimuli),
|
||||||
sim_end_time: self.sim_end_time,
|
sim_end_time: self.sim_end_time,
|
||||||
@@ -229,25 +223,32 @@ where
|
|||||||
Stim: DriverStimulus<S::Real> + Send + 'static,
|
Stim: DriverStimulus<S::Real> + Send + 'static,
|
||||||
{
|
{
|
||||||
fn render(&mut self) {
|
fn render(&mut self) {
|
||||||
self.diag.instrument_render_prep(|| {
|
let their_state = self.diag.instrument_render_prep(|| {
|
||||||
let diag_handle = self.diag.clone();
|
if self.render_pool.num_workers() != 3 {
|
||||||
let their_state = self.state.clone();
|
let diag = self.diag.clone();
|
||||||
let their_measurements = self.measurements.clone();
|
// TODO: these measurements will come to differ from the ones in the Driver,
|
||||||
let renderer = self.renderer.clone();
|
// if the user calls `add_measurement`!
|
||||||
let sender = self.render_channel.0.clone();
|
let measurements = self.measurements.clone();
|
||||||
self.render_pool.execute(move || {
|
let renderer = self.renderer.clone();
|
||||||
// unblock the main thread (this limits the number of renders in-flight at any time
|
self.render_pool.spawn_workers(3, move |state| {
|
||||||
sender.send(()).unwrap();
|
// unblock the main thread (this limits the number of renders in-flight at any time
|
||||||
trace!("render begin");
|
trace!("render begin");
|
||||||
diag_handle.instrument_render_cpu_side(|| {
|
diag.instrument_render_cpu_side(|| {
|
||||||
let meas: Vec<&dyn AbstractMeasurement<S>> = their_measurements.iter().map(|m| &**m).collect();
|
let meas: Vec<&dyn AbstractMeasurement<S>> = measurements.iter().map(|m| &**m).collect();
|
||||||
renderer.render(&their_state, &*meas, Default::default());
|
renderer.render(&state, &*meas, Default::default());
|
||||||
|
});
|
||||||
|
trace!("render end");
|
||||||
});
|
});
|
||||||
trace!("render end");
|
}
|
||||||
});
|
self.state.clone()
|
||||||
});
|
});
|
||||||
|
// TODO: this instrumentation is not 100% accurate.
|
||||||
|
// - 'prep' and 'blocked' have effectively been folded together.
|
||||||
|
// - either delete 'prep', or change this block to use a `try_send` (prep) followed by a
|
||||||
|
// `send` (blocking)
|
||||||
self.diag.instrument_render_blocked(|| {
|
self.diag.instrument_render_blocked(|| {
|
||||||
self.render_channel.1.recv().unwrap();
|
self.render_pool.tend();
|
||||||
|
self.render_pool.send(their_state);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
/// Return the number of steps actually stepped
|
/// Return the number of steps actually stepped
|
||||||
@@ -327,7 +328,7 @@ where
|
|||||||
// render the final frame -- unless we already *have*
|
// render the final frame -- unless we already *have*
|
||||||
self.render();
|
self.render();
|
||||||
}
|
}
|
||||||
self.render_pool.join();
|
self.render_pool.join_workers();
|
||||||
self.sim_end_time = None;
|
self.sim_end_time = None;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -50,6 +50,23 @@ impl<C, R> JobPool<C, R> {
|
|||||||
pub fn num_workers(&self) -> u32 {
|
pub fn num_workers(&self) -> u32 {
|
||||||
self.handles.len().try_into().unwrap()
|
self.handles.len().try_into().unwrap()
|
||||||
}
|
}
|
||||||
|
pub fn recv(&self) -> R {
|
||||||
|
self.response_chan.recv().unwrap()
|
||||||
|
}
|
||||||
|
/// `try_recv`. named `tend` because this is often used when we want to ensure no workers are
|
||||||
|
/// blocked due to lack of space in the output queue.
|
||||||
|
pub fn tend(&self) -> Option<R> {
|
||||||
|
self.response_chan.try_recv().ok()
|
||||||
|
}
|
||||||
|
pub fn join_workers(&mut self) {
|
||||||
|
// hang up the sender, to signal workers to exit.
|
||||||
|
let cap = self.command_chan.capacity().unwrap_or(0);
|
||||||
|
(self.command_chan, self.worker_command_chan) = channel::bounded(cap);
|
||||||
|
(self.worker_response_chan, self.response_chan) = channel::bounded(cap);
|
||||||
|
for h in self.handles.drain(..) {
|
||||||
|
h.join().unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C: Send + 'static, R: Send + 'static> JobPool<C, R> {
|
impl<C: Send + 'static, R: Send + 'static> JobPool<C, R> {
|
||||||
@@ -75,12 +92,7 @@ impl<C: Send + 'static, R: Send + 'static> JobPool<C, R> {
|
|||||||
|
|
||||||
impl<C, R> Drop for JobPool<C, R> {
|
impl<C, R> Drop for JobPool<C, R> {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
// hang up the sender, to signal workers to exit.
|
self.join_workers();
|
||||||
(self.command_chan, _) = channel::bounded(0);
|
|
||||||
(_, self.response_chan) = channel::bounded(0);
|
|
||||||
for h in self.handles.drain(..) {
|
|
||||||
h.join().unwrap();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -90,12 +102,6 @@ impl<C: Send + 'static, R> JobPool<C, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<C, R> JobPool<C, R> {
|
|
||||||
pub fn recv(&self) -> R {
|
|
||||||
self.response_chan.recv().unwrap()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -182,4 +188,29 @@ mod test {
|
|||||||
pool.send(1);
|
pool.send(1);
|
||||||
pool.send(2);
|
pool.send(2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn join_workers() {
|
||||||
|
let mut pool: JobPool<u32, u32> = JobPool::new(1);
|
||||||
|
|
||||||
|
pool.spawn_workers(2, |x| x*2);
|
||||||
|
pool.send(5);
|
||||||
|
pool.join_workers();
|
||||||
|
|
||||||
|
pool.spawn_workers(2, |x| x*2);
|
||||||
|
pool.send(4);
|
||||||
|
// the earlier response to '5' should be lost in the channel
|
||||||
|
assert_eq!(pool.recv(), 8);
|
||||||
|
|
||||||
|
// one message in the response queue; one in the send queue, 2 in the worker threads
|
||||||
|
pool.send(3); pool.send(2); pool.send(1); pool.send(0);
|
||||||
|
// should still be able to join even though everyone's blocked.
|
||||||
|
pool.join_workers();
|
||||||
|
|
||||||
|
pool.spawn_workers(1, |x| x*2);
|
||||||
|
pool.send(7);
|
||||||
|
// the old '0' command should be lost in the channel
|
||||||
|
assert_eq!(pool.recv(), 14);
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user