diff --git a/crates/coremem/src/sim/spirv/cpu.rs b/crates/coremem/src/sim/spirv/cpu.rs index 99f7516..b8a85f2 100644 --- a/crates/coremem/src/sim/spirv/cpu.rs +++ b/crates/coremem/src/sim/spirv/cpu.rs @@ -16,8 +16,8 @@ impl> SimBackend for CpuBackend { &self, num_steps: u32, meta: SimMeta, - stim_e: &[Vec3], - stim_h: &[Vec3], + stim_e: &[Vec3], + stim_h: &[Vec3], mat: &[M], e: &mut [Vec3], h: &mut [Vec3], @@ -32,7 +32,7 @@ impl> SimBackend for CpuBackend { fn step_e>( meta: SimMeta, - stim_e: &[Vec3], + stim_e: &[Vec3], mat: &[M], e: &mut [Vec3], h: &[Vec3], @@ -44,7 +44,7 @@ fn step_e>( } fn step_h>( meta: SimMeta, - stim_h: &[Vec3], + stim_h: &[Vec3], mat: &[M], e: &[Vec3], h: &mut [Vec3], @@ -58,7 +58,7 @@ fn step_h>( fn step_e_cell>( idx: Vec3u, meta: SimMeta, - stim_e: &[Vec3], + stim_e: &[Vec3], mat: &[M], e: &mut [Vec3], h: &[Vec3], @@ -78,7 +78,7 @@ fn step_e_cell>( let update_state = StepEContext { inv_feature_size: meta.inv_feature_size, time_step: meta.time_step, - stim_e: stim_e.cast(), + stim_e, mat, in_h, in_e, @@ -89,7 +89,7 @@ fn step_e_cell>( fn step_h_cell>( idx: Vec3u, meta: SimMeta, - stim_h: &[Vec3], + stim_h: &[Vec3], mat: &[M], e: &[Vec3], h: &mut [Vec3], @@ -112,7 +112,7 @@ fn step_h_cell>( let update_state = StepHContext { inv_feature_size: meta.inv_feature_size, time_step: meta.time_step, - stim_h: stim_h.cast(), + stim_h, mat, in_e, in_h, diff --git a/crates/coremem/src/sim/spirv/gpu.rs b/crates/coremem/src/sim/spirv/gpu.rs index 5121928..ce5f46b 100644 --- a/crates/coremem/src/sim/spirv/gpu.rs +++ b/crates/coremem/src/sim/spirv/gpu.rs @@ -67,8 +67,8 @@ impl SimBackend for WgpuBackend { &self, num_steps: u32, meta: SimMeta, - stim_cpu_e: &[Vec3], - stim_cpu_h: &[Vec3], + stim_cpu_e: &[Vec3], + stim_cpu_h: &[Vec3], mat: &[M], e: &mut [Vec3], h: &mut [Vec3], diff --git a/crates/coremem/src/sim/spirv/mod.rs b/crates/coremem/src/sim/spirv/mod.rs index a69f3be..6a914bd 100644 --- a/crates/coremem/src/sim/spirv/mod.rs +++ b/crates/coremem/src/sim/spirv/mod.rs @@ -22,8 +22,8 @@ pub trait SimBackend { &self, num_steps: u32, meta: SimMeta, - stim_e: &[Vec3], - stim_h: &[Vec3], + stim_e: &[Vec3], + stim_h: &[Vec3], mat: &[M], e: &mut [Vec3], h: &mut [Vec3], @@ -227,7 +227,7 @@ where #[allow(unused)] // used for test fn apply_stimulus(&mut self, stim: &dyn AbstractStimulus) { trace!("apply_stimulus begin"); - iterate_stim(stim, self.size(), self.feature_size(), self.time(), self.timestep(), |pos_idx, value_e, value_h| { + iterate_stim(stim, self.size(), self.feature_size(), self.time(), self.meta.time_step, |pos_idx, value_e, value_h| { let flat_idx = self.flat_index(pos_idx).unwrap(); self.e[flat_idx] += value_e.cast(); self.h[flat_idx] += value_h.cast(); @@ -236,13 +236,13 @@ where } fn eval_stimulus(&self, stim: &S) - -> (Array3>, Array3>) + -> (Array3>, Array3>) { trace!("eval_stimulus begin"); let dim = self.size(); let feature_size = self.feature_size(); let t_sec = self.time(); - let timestep = self.timestep(); + let timestep = self.meta.time_step; // TODO(perf): do this in one loop! let e = ndarray::Zip::from(ndarray::indices( @@ -251,7 +251,7 @@ where let pos_idx = Index::new(x as _, y as _, z as _); let pos_meters = pos_idx.to_meters(feature_size); let (density_e, _density_h) = stim.at(t_sec, pos_meters); - density_e * timestep + density_e.cast::() * timestep }); let h = ndarray::Zip::from(ndarray::indices( [dim.z() as usize, dim.y() as usize, dim.x() as usize] @@ -259,16 +259,21 @@ where let pos_idx = Index::new(x as _, y as _, z as _); let pos_meters = pos_idx.to_meters(feature_size); let (_density_e, density_h) = stim.at(t_sec, pos_meters); - density_h * timestep + density_h.cast::() * timestep }); trace!("eval_stimulus end"); (e, h) } } -fn iterate_stim, Vec3)>( - stim: &S, dim: Index, feature_size: f32, t_sec: f32, timestep: f32, mut f: F -) { +fn iterate_stim( + stim: &S, dim: Index, feature_size: f32, t_sec: f32, timestep: R, mut f: F +) +where + R: Real, + S: AbstractStimulus + ?Sized, + F: FnMut(Index, Vec3, Vec3), +{ // TODO: parallelize for z in 0..dim.z() { for y in 0..dim.y() { @@ -276,7 +281,7 @@ fn iterate_stim, Vec3() * timestep, density_h.cast::() * timestep); f(pos_idx, value_e, value_h) } }