diff --git a/Cargo.lock b/Cargo.lock index cacf42a..ea6c8f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -338,7 +338,6 @@ dependencies = [ "spirv-std-macros", "spirv_backend", "spirv_backend_runner", - "threadpool", "wgpu", "y4m", ] diff --git a/crates/coremem/Cargo.toml b/crates/coremem/Cargo.toml index 3060dda..3191f53 100644 --- a/crates/coremem/Cargo.toml +++ b/crates/coremem/Cargo.toml @@ -32,7 +32,6 @@ num = "0.4" # MIT or Apache 2.0 rand = "0.8" # MIT or Apache 2.0 rayon = "1.5" # MIT or Apache 2.0 serde = "1.0" # MIT or Apache 2.0 -threadpool = "1.8" # MIT or Apache 2.0 y4m = "0.7" # MIT wgpu = "0.12" diff --git a/crates/coremem/src/driver.rs b/crates/coremem/src/driver.rs index 6f324ee..e41d61d 100644 --- a/crates/coremem/src/driver.rs +++ b/crates/coremem/src/driver.rs @@ -27,16 +27,13 @@ use serde::{Deserialize, Serialize}; use std::cell::Cell; use std::path::PathBuf; use std::sync::{Arc, Mutex}; -use std::sync::mpsc::{sync_channel, SyncSender, Receiver}; use std::time::Instant; -use threadpool::ThreadPool; pub struct Driver> { state: S, renderer: Arc>, // TODO: use Rayon's thread pool? - render_pool: ThreadPool, - render_channel: (SyncSender<()>, Receiver<()>), + render_pool: JobPool, measurements: Vec>>, stimuli: StimAccess, /// simulation end time @@ -52,8 +49,7 @@ impl Driver { Self { state, renderer: Arc::new(MultiRenderer::new()), - render_pool: ThreadPool::new(3), - render_channel: sync_channel(0), + render_pool: JobPool::new(1), measurements: vec![ Arc::new(meas::Time), Arc::new(meas::Meta), @@ -101,7 +97,6 @@ impl Driver { state: self.state, renderer: self.renderer, render_pool: self.render_pool, - render_channel: self.render_channel, measurements: self.measurements, stimuli: StimAccess::new(self.diag.clone(), self.stimuli.into_inner().append(s)), sim_end_time: self.sim_end_time, @@ -114,7 +109,6 @@ impl Driver { state: self.state, renderer: self.renderer, render_pool: self.render_pool, - render_channel: self.render_channel, measurements: self.measurements, stimuli: StimAccess::new(self.diag.clone(), stimuli), sim_end_time: self.sim_end_time, @@ -229,25 +223,32 @@ where Stim: DriverStimulus + Send + 'static, { fn render(&mut self) { - self.diag.instrument_render_prep(|| { - let diag_handle = self.diag.clone(); - let their_state = self.state.clone(); - let their_measurements = self.measurements.clone(); - let renderer = self.renderer.clone(); - let sender = self.render_channel.0.clone(); - self.render_pool.execute(move || { - // unblock the main thread (this limits the number of renders in-flight at any time - sender.send(()).unwrap(); - trace!("render begin"); - diag_handle.instrument_render_cpu_side(|| { - let meas: Vec<&dyn AbstractMeasurement> = their_measurements.iter().map(|m| &**m).collect(); - renderer.render(&their_state, &*meas, Default::default()); + let their_state = self.diag.instrument_render_prep(|| { + if self.render_pool.num_workers() != 3 { + let diag = self.diag.clone(); + // TODO: these measurements will come to differ from the ones in the Driver, + // if the user calls `add_measurement`! + let measurements = self.measurements.clone(); + let renderer = self.renderer.clone(); + self.render_pool.spawn_workers(3, move |state| { + // unblock the main thread (this limits the number of renders in-flight at any time + trace!("render begin"); + diag.instrument_render_cpu_side(|| { + let meas: Vec<&dyn AbstractMeasurement> = measurements.iter().map(|m| &**m).collect(); + renderer.render(&state, &*meas, Default::default()); + }); + trace!("render end"); }); - trace!("render end"); - }); + } + self.state.clone() }); + // TODO: this instrumentation is not 100% accurate. + // - 'prep' and 'blocked' have effectively been folded together. + // - either delete 'prep', or change this block to use a `try_send` (prep) followed by a + // `send` (blocking) self.diag.instrument_render_blocked(|| { - self.render_channel.1.recv().unwrap(); + self.render_pool.tend(); + self.render_pool.send(their_state); }); } /// Return the number of steps actually stepped @@ -327,7 +328,7 @@ where // render the final frame -- unless we already *have* self.render(); } - self.render_pool.join(); + self.render_pool.join_workers(); self.sim_end_time = None; } } diff --git a/crates/coremem/src/worker.rs b/crates/coremem/src/worker.rs index 4ef696d..1e6ece6 100644 --- a/crates/coremem/src/worker.rs +++ b/crates/coremem/src/worker.rs @@ -50,6 +50,23 @@ impl JobPool { pub fn num_workers(&self) -> u32 { self.handles.len().try_into().unwrap() } + pub fn recv(&self) -> R { + self.response_chan.recv().unwrap() + } + /// `try_recv`. named `tend` because this is often used when we want to ensure no workers are + /// blocked due to lack of space in the output queue. + pub fn tend(&self) -> Option { + self.response_chan.try_recv().ok() + } + pub fn join_workers(&mut self) { + // hang up the sender, to signal workers to exit. + let cap = self.command_chan.capacity().unwrap_or(0); + (self.command_chan, self.worker_command_chan) = channel::bounded(cap); + (self.worker_response_chan, self.response_chan) = channel::bounded(cap); + for h in self.handles.drain(..) { + h.join().unwrap(); + } + } } impl JobPool { @@ -75,12 +92,7 @@ impl JobPool { impl Drop for JobPool { fn drop(&mut self) { - // hang up the sender, to signal workers to exit. - (self.command_chan, _) = channel::bounded(0); - (_, self.response_chan) = channel::bounded(0); - for h in self.handles.drain(..) { - h.join().unwrap(); - } + self.join_workers(); } } @@ -90,12 +102,6 @@ impl JobPool { } } -impl JobPool { - pub fn recv(&self) -> R { - self.response_chan.recv().unwrap() - } -} - #[cfg(test)] mod test { use super::*; @@ -182,4 +188,29 @@ mod test { pool.send(1); pool.send(2); } + + #[test] + fn join_workers() { + let mut pool: JobPool = JobPool::new(1); + + pool.spawn_workers(2, |x| x*2); + pool.send(5); + pool.join_workers(); + + pool.spawn_workers(2, |x| x*2); + pool.send(4); + // the earlier response to '5' should be lost in the channel + assert_eq!(pool.recv(), 8); + + // one message in the response queue; one in the send queue, 2 in the worker threads + pool.send(3); pool.send(2); pool.send(1); pool.send(0); + // should still be able to join even though everyone's blocked. + pool.join_workers(); + + pool.spawn_workers(1, |x| x*2); + pool.send(7); + // the old '0' command should be lost in the channel + assert_eq!(pool.recv(), 14); + + } }