From e574bf71cacbe47f59f265d28f638477aaa096ac Mon Sep 17 00:00:00 2001 From: Colin Date: Sat, 29 Jan 2022 16:30:29 -0800 Subject: [PATCH] spirv: get the optimized IsoConductorOr plumbed all the way through buffer_proto5 example shows it in action (results weren't verified) --- examples/buffer_proto5.rs | 2 + src/driver.rs | 2 +- src/sim/spirv/bindings.rs | 31 ++++++++++ src/sim/spirv/mod.rs | 21 ++++--- src/sim/spirv/spirv_backend/src/lib.rs | 81 +++++++++++++++++--------- src/sim/spirv/spirv_backend/src/mat.rs | 3 +- 6 files changed, 103 insertions(+), 37 deletions(-) diff --git a/examples/buffer_proto5.rs b/examples/buffer_proto5.rs index 494579f..3e3a0af 100644 --- a/examples/buffer_proto5.rs +++ b/examples/buffer_proto5.rs @@ -719,4 +719,6 @@ fn main() { None => error!("skipping sim because no valid geometry: {:?}", params), } } + + info!("done"); } diff --git a/src/driver.rs b/src/driver.rs index 4ed82ba..015bacb 100644 --- a/src/driver.rs +++ b/src/driver.rs @@ -44,7 +44,7 @@ impl Driver> { } impl SpirvDriver - where M::Ffi: Default + where M::Ffi: Default + 'static { pub fn new_spirv(size: C, feature_size: f32) -> Self { Self::new_with_state(SpirvSim::new(size.to_index(feature_size), feature_size)) diff --git a/src/sim/spirv/bindings.rs b/src/sim/spirv/bindings.rs index d4f5402..fbcd9c0 100644 --- a/src/sim/spirv/bindings.rs +++ b/src/sim/spirv/bindings.rs @@ -7,6 +7,7 @@ use crate::geom::{Index, Vec3, Vec3u}; /// hide the actual spirv backend structures inside a submodule to make their use/boundary clear. mod ffi { + pub use spirv_backend_lib::entry_points; pub use spirv_backend_lib::sim::SerializedSimMeta; pub use spirv_backend_lib::support::{Optional, Vec3Std, UVec3Std}; pub use spirv_backend_lib::mat::{Ferroxcube3R1MH, FullyGenericMaterial, IsoConductorOr, Material, MBPgram, MHPgram}; @@ -34,6 +35,27 @@ pub trait IntoLib { fn into_lib(self) -> Self::Lib; } +macro_rules! identity { + ($($param:ident,)* => $t:ty) => { + impl<$($param: IntoFfi),*> IntoFfi for $t { + type Ffi = $t; + fn into_ffi(self) -> Self::Ffi { + self + } + } + impl<$($param: IntoLib),*> IntoLib for $t { + type Lib = $t; + fn into_lib(self) -> Self::Lib { + self + } + } + }; +} + +// XXX: should work for any other lifetime, not just 'static +identity!(=> &'static str); +identity!(T0, T1, => (T0, T1)); + impl IntoFfi for Option where L::Ffi: Default { @@ -338,3 +360,12 @@ impl<'de, F> Deserialize<'de> for Remote Ok(Remote(local.into_ffi())) } } + +// FUNCTION BINDINGS +pub fn entry_points() -> Option<(&'static str, &'static str)> +where + L: IntoFfi, + L::Ffi: 'static +{ + ffi::entry_points::().into_lib() +} diff --git a/src/sim/spirv/mod.rs b/src/sim/spirv/mod.rs index 9c80855..5da9a84 100644 --- a/src/sim/spirv/mod.rs +++ b/src/sim/spirv/mod.rs @@ -17,7 +17,7 @@ use crate::sim::{CellStateWithM, GenericSim, MaterialSim, Sample, SampleableSim} use crate::stim::AbstractStimulus; mod bindings; -pub use bindings::{IntoFfi, IntoLib, IsoConductorOr, FullyGenericMaterial, Remote, SimMeta, Material}; +pub use bindings::{entry_points, IntoFfi, IntoLib, IsoConductorOr, FullyGenericMaterial, Remote, SimMeta, Material}; /// Wrapper around an inner state object which offloads stepping onto a spirv backend (e.g. GPU). #[derive(Clone, Default, Serialize, Deserialize)] @@ -48,14 +48,17 @@ struct WgpuData { } impl WgpuData { - pub fn new(volume: u64) -> Self { + pub fn new(volume: u64) -> Self + where M::Ffi: 'static + { use std::mem::size_of; let max_elem_size = size_of::().max(size_of::>()); let max_buf_size = volume * max_elem_size as u64 + 0x1000; + let entry_names = entry_points::().unwrap_or(("invalid_mat", "invalid_mat")); let (device, queue) = futures::executor::block_on(open_device(max_buf_size)); let shader_binary = get_shader(); let shader_module = unsafe { device.create_shader_module_spirv(&shader_binary) }; - let (step_bind_group_layout, step_h_pipeline, step_e_pipeline) = make_pipelines(&device, &shader_module); + let (step_bind_group_layout, step_h_pipeline, step_e_pipeline) = make_pipelines(&device, &shader_module, entry_names); Self { step_bind_group_layout, step_h_pipeline, @@ -68,7 +71,7 @@ impl WgpuData { impl Default for WgpuData { fn default() -> Self { - Self::new::<()>(0) + Self::new::(0) } } @@ -152,10 +155,10 @@ where } impl SpirvSim - where M::Ffi: Default + where M::Ffi: Default + 'static { pub fn new(size: Index, feature_size: f32) -> Self { - Self::new_with_wgpu_handle(size, feature_size, Some(Arc::new(WgpuData::new::(size.volume())))) + Self::new_with_wgpu_handle(size, feature_size, Some(Arc::new(WgpuData::new::(size.volume())))) } pub fn new_no_wgpu(size: Index, feature_size: f32) -> Self { @@ -534,7 +537,7 @@ async fn open_device(max_buf_size: u64) -> (wgpu::Device, wgpu::Queue) { (device, queue) } -fn make_pipelines(device: &wgpu::Device, shader_module: &wgpu::ShaderModule) -> ( +fn make_pipelines(device: &wgpu::Device, shader_module: &wgpu::ShaderModule, entry_names: (&'static str, &'static str)) -> ( wgpu::BindGroupLayout, wgpu::ComputePipeline, wgpu::ComputePipeline ) { let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { @@ -619,14 +622,14 @@ fn make_pipelines(device: &wgpu::Device, shader_module: &wgpu::ShaderModule) -> label: None, layout: Some(&pipeline_layout), module: shader_module, - entry_point: "step_h_generic_material", + entry_point: entry_names.0, }); let compute_step_e_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { label: None, layout: Some(&pipeline_layout), module: shader_module, - entry_point: "step_e_generic_material", + entry_point: entry_names.1, }); (bind_group_layout, compute_step_h_pipeline, compute_step_e_pipeline) diff --git a/src/sim/spirv/spirv_backend/src/lib.rs b/src/sim/spirv/spirv_backend/src/lib.rs index afbc7ba..3aba89d 100644 --- a/src/sim/spirv/spirv_backend/src/lib.rs +++ b/src/sim/spirv/spirv_backend/src/lib.rs @@ -20,7 +20,9 @@ pub mod support; pub use sim::{SerializedSimMeta, SerializedStepE, SerializedStepH}; pub use support::{Optional, UnsizedArray, UVec3Std, Vec3Std}; -use mat::{FullyGenericMaterial, Material}; +use mat::{IsoConductorOr, Ferroxcube3R1MH, FullyGenericMaterial, Material}; + +type Iso3R1 = IsoConductorOr; fn step_h( id: UVec3, @@ -53,31 +55,58 @@ fn step_e( } } -// LocalSize/numthreads -#[spirv(compute(threads(4, 4, 4)))] -pub fn step_h_generic_material( - #[spirv(global_invocation_id)] id: UVec3, - #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] meta: &SerializedSimMeta, - // XXX: delete this input? - #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] _unused_stimulus: &UnsizedArray, - #[spirv(storage_buffer, descriptor_set = 0, binding = 2)] material: &UnsizedArray, - #[spirv(storage_buffer, descriptor_set = 0, binding = 3)] e: &UnsizedArray, - #[spirv(storage_buffer, descriptor_set = 0, binding = 4)] h: &mut UnsizedArray, - #[spirv(storage_buffer, descriptor_set = 0, binding = 5)] m: &mut UnsizedArray, -) { - step_h(id, meta, material, e, h, m) +/// Return the step_h/step_e entry point names for the provided material +pub fn entry_points() -> Optional<(&'static str, &'static str)> { + use core::any::TypeId; + let mappings = [ + (TypeId::of::(), + ("step_h_generic_material", "step_e_generic_material") + ), + (TypeId::of::(), + ("step_h_iso_3r1", "step_e_iso_3r1") + ), + ]; + + for (id, names) in mappings { + if id == TypeId::of::() { + return Optional::some(names); + } + } + Optional::none() } -#[spirv(compute(threads(4, 4, 4)))] -pub fn step_e_generic_material( - #[spirv(global_invocation_id)] id: UVec3, - #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] meta: &SerializedSimMeta, - #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] stimulus: &UnsizedArray, - #[spirv(storage_buffer, descriptor_set = 0, binding = 2)] material: &UnsizedArray, - #[spirv(storage_buffer, descriptor_set = 0, binding = 3)] e: &mut UnsizedArray, - #[spirv(storage_buffer, descriptor_set = 0, binding = 4)] h: &UnsizedArray, - // XXX: can/should this m input be deleted? - #[spirv(storage_buffer, descriptor_set = 0, binding = 5)] _unused_m: &UnsizedArray, -) { - step_e(id, meta, stimulus, material, e, h) +macro_rules! steps { + ($mat:ty, $step_h:ident, $step_e:ident) => { + // LocalSize/numthreads + #[spirv(compute(threads(4, 4, 4)))] + pub fn $step_h( + #[spirv(global_invocation_id)] id: UVec3, + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] meta: &SerializedSimMeta, + // XXX: delete this input? + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] _unused_stimulus: &UnsizedArray, + #[spirv(storage_buffer, descriptor_set = 0, binding = 2)] material: &UnsizedArray<$mat>, + #[spirv(storage_buffer, descriptor_set = 0, binding = 3)] e: &UnsizedArray, + #[spirv(storage_buffer, descriptor_set = 0, binding = 4)] h: &mut UnsizedArray, + #[spirv(storage_buffer, descriptor_set = 0, binding = 5)] m: &mut UnsizedArray, + ) { + step_h(id, meta, material, e, h, m) + } + + #[spirv(compute(threads(4, 4, 4)))] + pub fn $step_e( + #[spirv(global_invocation_id)] id: UVec3, + #[spirv(storage_buffer, descriptor_set = 0, binding = 0)] meta: &SerializedSimMeta, + #[spirv(storage_buffer, descriptor_set = 0, binding = 1)] stimulus: &UnsizedArray, + #[spirv(storage_buffer, descriptor_set = 0, binding = 2)] material: &UnsizedArray<$mat>, + #[spirv(storage_buffer, descriptor_set = 0, binding = 3)] e: &mut UnsizedArray, + #[spirv(storage_buffer, descriptor_set = 0, binding = 4)] h: &UnsizedArray, + // XXX: can/should this m input be deleted? + #[spirv(storage_buffer, descriptor_set = 0, binding = 5)] _unused_m: &UnsizedArray, + ) { + step_e(id, meta, stimulus, material, e, h) + } + }; } + +steps!(FullyGenericMaterial, step_h_generic_material, step_e_generic_material); +steps!(Iso3R1, step_h_iso_3r1, step_e_iso_3r1); diff --git a/src/sim/spirv/spirv_backend/src/mat.rs b/src/sim/spirv/spirv_backend/src/mat.rs index af74250..b1dfe67 100644 --- a/src/sim/spirv/spirv_backend/src/mat.rs +++ b/src/sim/spirv/spirv_backend/src/mat.rs @@ -246,7 +246,8 @@ impl Material for IsoConductorOr { } fn move_b_vec(&self, m: Vec3Std, target_b: Vec3Std) -> Vec3Std { if self.value < 0.0 { - self.mat.move_b_vec(m, target_b) + let mat = self.mat; //< XXX hack for ZST + mat.move_b_vec(m, target_b) } else { m }