SpirvSim: explicitly pass the backend in when initialized

This commit is contained in:
2022-07-27 13:51:10 -07:00
parent 7698e0e5ba
commit 932bb163c3
4 changed files with 83 additions and 93 deletions

View File

@@ -52,7 +52,11 @@ where
spirv::WgpuBackend<R, M>: spirv::SimBackend<R, M>
{
pub fn new_spirv<C: Coord>(size: C, feature_size: f32) -> Self {
Self::new_with_state(SpirvSim::new(size.to_index(feature_size), feature_size))
Self::new_with_state(SpirvSim::new(
size.to_index(feature_size),
feature_size,
spirv::WgpuBackend::default(),
))
}
}

View File

@@ -5,12 +5,11 @@ use coremem_types::vec::{Vec3, Vec3u};
use super::SimBackend;
#[derive(Default)]
pub struct CpuBackend;
impl<R: Real, M: Material<R>> SimBackend<R, M> for CpuBackend {
fn new(_volume: u64) -> Self {
CpuBackend
}
fn set_meta(&mut self, _: SimMeta<R>) {}
fn step_n(
&self,
meta: SimMeta<R>,

View File

@@ -13,26 +13,26 @@ use spirv_backend::HasEntryPoints;
use super::SimBackend;
#[derive(Default)]
pub struct WgpuBackend<R, M> {
handles: Option<WgpuHandles>,
real: PhantomData<R>,
mat: PhantomData<M>,
}
struct WgpuHandles {
step_bind_group_layout: wgpu::BindGroupLayout,
step_e_pipeline: wgpu::ComputePipeline,
step_h_pipeline: wgpu::ComputePipeline,
device: wgpu::Device,
queue: wgpu::Queue,
real: PhantomData<R>,
mat: PhantomData<M>,
}
impl<R, M: HasEntryPoints<R>> Default for WgpuBackend<R, M> {
fn default() -> Self {
Self::new_with_vol(0)
}
}
impl<R, M: HasEntryPoints<R>> WgpuBackend<R, M> {
fn new_with_vol(volume: u64) -> Self {
info!("WgpuBackend::new_with_vol({})", volume);
impl<R: Copy, M: Send + Sync + HasEntryPoints<R>> SimBackend<R, M> for WgpuBackend<R, M> {
fn set_meta(&mut self, meta: SimMeta<R>) {
info!("WgpuBackend::set_meta({:?})", meta.dim);
use std::mem::size_of;
let volume = meta.dim.product_sum_usize() as u64;
let max_elem_size = size_of::<M>().max(size_of::<Vec3<R>>());
let max_array_size = volume * max_elem_size as u64;
let max_buf_size = max_array_size + 0x1000; // allow some overhead
@@ -43,22 +43,13 @@ impl<R, M: HasEntryPoints<R>> WgpuBackend<R, M> {
let (step_bind_group_layout, step_h_pipeline, step_e_pipeline) = make_pipelines(
&device, &shader_module, M::step_h(), M::step_e()
);
Self {
self.handles = Some(WgpuHandles {
step_bind_group_layout,
step_h_pipeline,
step_e_pipeline,
device,
queue,
real: PhantomData,
mat: PhantomData,
}
}
}
impl<R: Copy, M: Send + Sync + HasEntryPoints<R>> SimBackend<R, M> for WgpuBackend<R, M> {
fn new(volume: u64) -> Self
{
Self::new_with_vol(volume)
});
}
fn step_n(
@@ -74,11 +65,12 @@ impl<R: Copy, M: Send + Sync + HasEntryPoints<R>> SimBackend<R, M> for WgpuBacke
) {
let field_bytes = meta.dim.product_sum() as usize * std::mem::size_of::<Vec3<f32>>();
let device = &self.device;
let queue = &self.queue;
let step_bind_group_layout = &self.step_bind_group_layout;
let step_e_pipeline = &self.step_e_pipeline;
let step_h_pipeline = &self.step_h_pipeline;
let handles = self.handles.as_ref().unwrap();
let device = &handles.device;
let queue = &handles.queue;
let step_bind_group_layout = &handles.step_bind_group_layout;
let step_e_pipeline = &handles.step_e_pipeline;
let step_h_pipeline = &handles.step_h_pipeline;
let sim_meta_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("gpu-side simulation metadata"),
contents: to_bytes(&[meta][..]),

View File

@@ -1,6 +1,5 @@
use ndarray::{self, Array3};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use log::{info, trace, warn};
use crate::geom::{Coord, Index, Meters};
@@ -23,7 +22,7 @@ pub use gpu::WgpuBackend;
/// it could advance the state locally, on the CPU,
/// or it could dispatch to a SPIR-V device, like a GPU.
pub trait SimBackend<R, M> {
fn new(volume: u64) -> Self;
fn set_meta(&mut self, meta: SimMeta<R>);
fn step_n(
&self,
meta: SimMeta<R>,
@@ -49,15 +48,12 @@ where M: 'static
m: Vec<Vec3<R>>,
mat: Vec<M>,
step_no: u64,
// hidden behind an Arc to allow for cheap clones.
// XXX not confident that wgpu is actually properly synchronized for us to omit a Mutex here
// though.
#[serde(skip)]
backend: Option<Arc<B>>,
backend: B,
}
// B isn't always clonable (e.g. gpu backend) so rust can't auto-derive this.
impl<R: Clone, M: Clone, B> Clone for SpirvSim<R, M, B> {
impl<R: Clone, M: Clone, B: Default> Clone for SpirvSim<R, M, B> {
fn clone(&self) -> Self {
Self {
meta: self.meta.clone(),
@@ -66,7 +62,10 @@ impl<R: Clone, M: Clone, B> Clone for SpirvSim<R, M, B> {
m: self.m.clone(),
mat: self.mat.clone(),
step_no: self.step_no.clone(),
backend: self.backend.clone(),
// backends can be expensive to clone.
// we require that the caller explicitly init the backend if they need this.
// TODO: this probably shouldn't be a `Clone` method, but like "read_only_clone()".
backend: Default::default(),
}
}
}
@@ -137,12 +136,10 @@ where
B: Send + Sync + SimBackend<R, M>,
{
fn step_multiple<S: AbstractStimulus>(&mut self, num_steps: u32, stim: &S) {
let vol = self.size().volume();
self.backend.get_or_insert_with(|| Arc::new(B::new(vol)));
let (stim_e, stim_h) = self.eval_stimulus(stim);
let backend = &**self.backend.as_ref().unwrap();
backend.step_n(
// TODO: we probably need to call Backend::set_meta here.
// particularly in case `self` were just deserialized.
self.backend.step_n(
self.meta,
self.mat.as_slice(),
stim_e.as_slice().unwrap(),
@@ -177,15 +174,13 @@ where
M: Default + Send + Sync + 'static,
B: SimBackend<R, M>,
{
pub fn new(size: Index, feature_size: f32) -> Self {
Self::new_with_backend(size, feature_size, Some(Arc::new(B::new(size.volume()))))
pub fn new(size: Index, feature_size: f32, backend: B) -> Self {
let mut me = Self::new_lazy_backend(size, feature_size, backend);
me.backend.set_meta(me.meta);
me
}
pub fn new_no_backend(size: Index, feature_size: f32) -> Self {
Self::new_with_backend(size, feature_size, None)
}
fn new_with_backend(size: Index, feature_size: f32, backend: Option<Arc<B>>) -> Self {
pub fn new_lazy_backend(size: Index, feature_size: f32, backend: B) -> Self {
info!("SpirvSim::new({:?}, {})", size, feature_size);
let flat_size = size.volume() as usize;
if flat_size * std::mem::size_of::<M>() >= 0x40000000 {
@@ -364,55 +359,55 @@ mod test {
mod backend_agnostic {
use super::*;
use crate::stim::{NoopStimulus, RngStimulus, UniformStimulus};
pub fn do_smoke_small<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_smoke_small<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
let mut state: SpirvSim<f32, FullyGenericMaterial<f32>, B> =
SpirvSim::new(Index::new(8, 8, 8), 1e-3);
SpirvSim::new(Index::new(8, 8, 8), 1e-3, B::default());
state.step();
}
pub fn do_smoke_med_7bit<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_smoke_med_7bit<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
let mut state: SpirvSim<f32, FullyGenericMaterial<f32>, B> =
SpirvSim::new(Index::new(124, 124, 124), 1e-3);
SpirvSim::new(Index::new(124, 124, 124), 1e-3, B::default());
state.step();
}
pub fn do_smoke_med128<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_smoke_med128<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
let mut state: SpirvSim<f32, FullyGenericMaterial<f32>, B> =
SpirvSim::new(Index::new(128, 128, 128), 1e-3);
SpirvSim::new(Index::new(128, 128, 128), 1e-3, B::default());
state.step();
}
pub fn do_smoke_med_23bit<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_smoke_med_23bit<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
let mut state: SpirvSim<f32, FullyGenericMaterial<f32>, B> =
SpirvSim::new(Index::new(127, 256, 256), 1e-3);
SpirvSim::new(Index::new(127, 256, 256), 1e-3, B::default());
state.step();
}
pub fn do_smoke_med_0x800000_indexing<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_smoke_med_0x800000_indexing<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
let mut state: SpirvSim<f32, FullyGenericMaterial<f32>, B> =
SpirvSim::new(Index::new(128, 256, 256), 1e-3);
SpirvSim::new(Index::new(128, 256, 256), 1e-3, B::default());
state.step();
}
pub fn do_smoke_med_0x800000_address_space<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_smoke_med_0x800000_address_space<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
let mut state: SpirvSim<f32, FullyGenericMaterial<f32>, B> =
SpirvSim::new(Index::new(170, 256, 256), 1e-3);
SpirvSim::new(Index::new(170, 256, 256), 1e-3, B::default());
state.step();
}
pub fn do_smoke_large<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_smoke_large<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
let mut state: SpirvSim<f32, FullyGenericMaterial<f32>, B> =
SpirvSim::new(Index::new(326, 252, 160), 1e-3);
SpirvSim::new(Index::new(326, 252, 160), 1e-3, B::default());
state.step();
}
pub fn do_smoke_not_multiple_of_4<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_smoke_not_multiple_of_4<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
let mut state: SpirvSim<f32, FullyGenericMaterial<f32>, B> =
SpirvSim::new(Index::new(3, 2, 5), 1e-3);
SpirvSim::new(Index::new(3, 2, 5), 1e-3, B::default());
state.step();
}
fn test_same_explicit<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>(
fn test_same_explicit<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>(
mut ref_state: SimState,
mut dut_state: SpirvSim<f32, FullyGenericMaterial<f32>, B>,
step_iters: u64,
@@ -425,49 +420,49 @@ mod test {
}
}
fn test_same<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>(seed: u64, step_iters: u64, steps_per_iter: u32, size: Index) {
fn test_same<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>(seed: u64, step_iters: u64, steps_per_iter: u32, size: Index) {
let mut cpu_state = SimState::new(size, 1e-3);
cpu_state.apply_stimulus(&RngStimulus::new(seed));
let mut spirv_state = SpirvSim::<f32, FullyGenericMaterial<f32>, B>::new(size, 1e-3);
let mut spirv_state = SpirvSim::<f32, FullyGenericMaterial<f32>, B>::new(size, 1e-3, B::default());
spirv_state.apply_stimulus(&RngStimulus::new(seed));
test_same_explicit(cpu_state, spirv_state, step_iters, steps_per_iter);
}
pub fn do_same_1_step_no_stim<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_same_no_step_no_stim<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
let ref_state = SimState::new(Index::new(8, 8, 8), 1e-3);
let dut_state = SpirvSim::<f32, FullyGenericMaterial<f32>, B>::new_no_backend(Index::new(8, 8, 8), 1e-3);
let dut_state = SpirvSim::new_lazy_backend(Index::new(8, 8, 8), 1e-3, B::default());
test_same_explicit(ref_state, dut_state, 1, 1);
}
pub fn do_same_1_step<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_same_1_step<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
test_same::<B>(0x1234, 1, 1, Index::new(4, 4, 4));
}
pub fn do_same_1000_step<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_same_1000_step<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
test_same::<B>(0x1234, 1000, 1, Index::new(4, 4, 4));
}
pub fn do_same_not_multiple_of_4<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_same_not_multiple_of_4<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
test_same::<B>(0x1234, 100, 1, Index::new(3, 2, 5));
}
pub fn do_same_100_steps_larger<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_same_100_steps_larger<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
test_same::<B>(0x1234, 100, 1, Index::new(24, 20, 44));
}
pub fn do_same_100_steps_of_10<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_same_100_steps_of_10<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
test_same::<B>(0x1234, 100, 10, Index::new(24, 20, 44));
}
fn test_same_conductor<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>(
fn test_same_conductor<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>(
seed: u64, step_iters: u64, steps_per_iter: u32, size: Index
) {
use rand::{Rng as _, SeedableRng as _};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let mut ref_state = SimState::new(size, 1e-3);
let mut dut_state = SpirvSim::<f32, FullyGenericMaterial<f32>, B>::new(size, 1e-3);
let mut dut_state = SpirvSim::<f32, FullyGenericMaterial<f32>, B>::new(size, 1e-3, B::default());
for z in 0..size.z() {
for y in 0..size.y() {
@@ -489,21 +484,21 @@ mod test {
test_same_explicit(ref_state, dut_state, step_iters, steps_per_iter);
}
pub fn do_conductor_1_step<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_conductor_1_step<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
test_same_conductor::<B>(0x1234, 1, 1, Index::new(4, 4, 4));
}
pub fn do_conductor_many_steps_larger<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_conductor_many_steps_larger<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
test_same_conductor::<B>(0x1234, 100, 10, Index::new(96, 16, 8));
}
fn test_same_mb_ferromagnet<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>(
fn test_same_mb_ferromagnet<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>(
seed: u64, step_iters: u64, steps_per_iter: u32, size: Index
) {
use rand::{Rng as _, SeedableRng as _};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let mut ref_state = SimState::new(size, 1e-3);
let mut dut_state = SpirvSim::<f32, FullyGenericMaterial<f32>, B>::new(size, 1e-3);
let mut dut_state = SpirvSim::<f32, FullyGenericMaterial<f32>, B>::new(size, 1e-3, B::default());
for z in 0..size.z() {
for y in 0..size.y() {
@@ -537,11 +532,11 @@ mod test {
test_same_explicit(ref_state, dut_state, step_iters, steps_per_iter);
}
pub fn do_mb_ferromagnet_1_step<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_mb_ferromagnet_1_step<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
test_same_mb_ferromagnet::<B>(0x1234, 1, 1, Index::new(4, 4, 4));
}
pub fn do_mb_ferromagnet_100_steps_larger<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_mb_ferromagnet_100_steps_larger<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
test_same_mb_ferromagnet::<B>(0x1234, 10, 10, Index::new(96, 16, 8));
}
@@ -582,13 +577,13 @@ mod test {
// test_same_explicit(ref_state, dut_state, steps);
// }
fn test_smoke_mh_ferromagnet<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>(
fn test_smoke_mh_ferromagnet<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>(
seed: u64, steps: u64, size: Index
) {
// XXX This doesn't do anything, except make sure we don't crash!
use rand::{Rng as _, SeedableRng as _};
let mut rng = rand::rngs::StdRng::seed_from_u64(seed);
let mut dut_state = SpirvSim::<f32, FullyGenericMaterial<f32>, B>::new(size, 1e-3);
let mut dut_state = SpirvSim::<f32, FullyGenericMaterial<f32>, B>::new(size, 1e-3, B::default());
for z in 0..size.z() {
for y in 0..size.y() {
@@ -612,18 +607,18 @@ mod test {
}
}
pub fn do_mh_ferromagnet_smoke_1_step<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_mh_ferromagnet_smoke_1_step<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
test_smoke_mh_ferromagnet::<B>(0x1234, 1, Index::new(4, 4, 4));
}
pub fn do_mh_ferromagnet_smoke_100_steps_larger<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_mh_ferromagnet_smoke_100_steps_larger<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
test_smoke_mh_ferromagnet::<B>(0x1234, 100, Index::new(328, 252, 160));
}
pub fn do_step_multiple_with_stim<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync>() {
pub fn do_step_multiple_with_stim<B: SimBackend<f32, FullyGenericMaterial<f32>> + Send + Sync + Default>() {
let size = Index::new(4, 12, 8);
let mut ref_state = SimState::new(size, 1e-3);
let mut dut_state = SpirvSim::<f32, FullyGenericMaterial<f32>, B>::new(size, 1e-3);
let mut dut_state = SpirvSim::<f32, FullyGenericMaterial<f32>, B>::new(size, 1e-3, B::default());
let stim = UniformStimulus::new_e(Vec3::new(1.0e15, 2.0e15, -3.0e15));
for _ in 0..5 {
ref_state.step_multiple(100, &stim);
@@ -672,8 +667,8 @@ mod test {
do_smoke_not_multiple_of_4::<$backend>();
}
#[test]
fn same_1_step_no_stim() {
do_same_1_step_no_stim::<$backend>();
fn same_no_step_no_stim() {
do_same_no_step_no_stim::<$backend>();
}
#[test]
fn same_1_step() {