diff --git a/crates/coremem/src/sim/spirv/gpu.rs b/crates/coremem/src/sim/spirv/gpu.rs index dcb7849..f0f2481 100644 --- a/crates/coremem/src/sim/spirv/gpu.rs +++ b/crates/coremem/src/sim/spirv/gpu.rs @@ -2,12 +2,12 @@ use futures::FutureExt as _; use log::{info, warn}; use ndarray::Array3; use std::borrow::Cow; +use std::marker::PhantomData; use std::num::NonZeroU64; use wgpu; use wgpu::util::DeviceExt as _; use coremem_types::vec::Vec3; -use coremem_types::mat::{FullyGenericMaterial, Material}; use coremem_types::step::SimMeta; pub fn entry_points() -> Option<(&'static str, &'static str)> @@ -15,16 +15,17 @@ pub fn entry_points() -> Option<(&'static str, &'static str)> spirv_backend::entry_points::().into() } -pub(super) struct WgpuData { +pub(super) struct WgpuData { step_bind_group_layout: wgpu::BindGroupLayout, step_e_pipeline: wgpu::ComputePipeline, step_h_pipeline: wgpu::ComputePipeline, device: wgpu::Device, queue: wgpu::Queue, + mat: PhantomData<&'static M>, } -impl WgpuData { - pub fn new(volume: u64) -> Self +impl WgpuData { + pub fn new(volume: u64) -> Self { info!("WgpuData::new({})", volume); use std::mem::size_of; @@ -51,19 +52,20 @@ impl WgpuData { step_e_pipeline, device, queue, + mat: PhantomData, } } } -impl Default for WgpuData { +impl Default for WgpuData { fn default() -> Self { - Self::new::>(0) + Self::new(0) } } -impl WgpuData { - pub(super) fn step_spirv( +impl WgpuData { + pub(super) fn step_spirv( &self, num_steps: u32, meta: SimMeta, @@ -74,9 +76,7 @@ impl WgpuData { e: &mut [Vec3], h: &mut [Vec3], m: &mut [Vec3], - ) - where M: Send + Sync + Material + 'static - { + ) { let field_bytes = meta.dim.product_sum() as usize * std::mem::size_of::>(); let device = &self.device; diff --git a/crates/coremem/src/sim/spirv/mod.rs b/crates/coremem/src/sim/spirv/mod.rs index f3e11d4..ec67ff6 100644 --- a/crates/coremem/src/sim/spirv/mod.rs +++ b/crates/coremem/src/sim/spirv/mod.rs @@ -17,6 +17,7 @@ use gpu::WgpuData; /// Wrapper around an inner state object which offloads stepping onto a spirv backend (e.g. GPU). #[derive(Clone, Default, Serialize, Deserialize)] pub struct SpirvSim> +where M: 'static { // TODO: make this generic over R meta: SimMeta, @@ -29,7 +30,7 @@ pub struct SpirvSim> // XXX not confident that wgpu is actually properly synchronized for us to omit a Mutex here // though. #[serde(skip)] - wgpu: Option>, + wgpu: Option>>, } impl MaterialSim for SpirvSim @@ -94,7 +95,7 @@ where fn step_multiple(&mut self, num_steps: u32, stim: &S) { let vol = self.size().volume(); self.wgpu.get_or_insert_with( - || Arc::new(WgpuData::new::(vol)) + || Arc::new(WgpuData::new(vol)) ); let (stim_e, stim_h) = self.eval_stimulus(stim); @@ -132,14 +133,14 @@ impl SpirvSim where M: Default + 'static { pub fn new(size: Index, feature_size: f32) -> Self { - Self::new_with_wgpu_handle(size, feature_size, Some(Arc::new(WgpuData::new::(size.volume())))) + Self::new_with_wgpu_handle(size, feature_size, Some(Arc::new(WgpuData::new(size.volume())))) } pub fn new_no_wgpu(size: Index, feature_size: f32) -> Self { Self::new_with_wgpu_handle(size, feature_size, None) } - fn new_with_wgpu_handle(size: Index, feature_size: f32, wgpu: Option>) -> Self { + fn new_with_wgpu_handle(size: Index, feature_size: f32, wgpu: Option>>) -> Self { info!("SpirvSim::new({:?}, {})", size, feature_size); let flat_size = size.volume() as usize; if flat_size * std::mem::size_of::() >= 0x40000000 {