spirv: port all backends to use R for the stimulus

particularly, this patches over a difference where the gpu backend
expected the stimulus to be R, while the CPU thought it should be f32.
that would likely have revealed a crash if we had tested it with f64
(TODO).
This commit is contained in:
2022-07-26 18:16:10 -07:00
parent 00dcfb170a
commit 7d16e87b6e
3 changed files with 26 additions and 21 deletions

View File

@@ -16,8 +16,8 @@ impl<R: Real, M: Material<R>> SimBackend<R, M> for CpuBackend {
&self, &self,
num_steps: u32, num_steps: u32,
meta: SimMeta<R>, meta: SimMeta<R>,
stim_e: &[Vec3<f32>], stim_e: &[Vec3<R>],
stim_h: &[Vec3<f32>], stim_h: &[Vec3<R>],
mat: &[M], mat: &[M],
e: &mut [Vec3<R>], e: &mut [Vec3<R>],
h: &mut [Vec3<R>], h: &mut [Vec3<R>],
@@ -32,7 +32,7 @@ impl<R: Real, M: Material<R>> SimBackend<R, M> for CpuBackend {
fn step_e<R: Real, M: Material<R>>( fn step_e<R: Real, M: Material<R>>(
meta: SimMeta<R>, meta: SimMeta<R>,
stim_e: &[Vec3<f32>], stim_e: &[Vec3<R>],
mat: &[M], mat: &[M],
e: &mut [Vec3<R>], e: &mut [Vec3<R>],
h: &[Vec3<R>], h: &[Vec3<R>],
@@ -44,7 +44,7 @@ fn step_e<R: Real, M: Material<R>>(
} }
fn step_h<R: Real, M: Material<R>>( fn step_h<R: Real, M: Material<R>>(
meta: SimMeta<R>, meta: SimMeta<R>,
stim_h: &[Vec3<f32>], stim_h: &[Vec3<R>],
mat: &[M], mat: &[M],
e: &[Vec3<R>], e: &[Vec3<R>],
h: &mut [Vec3<R>], h: &mut [Vec3<R>],
@@ -58,7 +58,7 @@ fn step_h<R: Real, M: Material<R>>(
fn step_e_cell<R: Real, M: Material<R>>( fn step_e_cell<R: Real, M: Material<R>>(
idx: Vec3u, idx: Vec3u,
meta: SimMeta<R>, meta: SimMeta<R>,
stim_e: &[Vec3<f32>], stim_e: &[Vec3<R>],
mat: &[M], mat: &[M],
e: &mut [Vec3<R>], e: &mut [Vec3<R>],
h: &[Vec3<R>], h: &[Vec3<R>],
@@ -78,7 +78,7 @@ fn step_e_cell<R: Real, M: Material<R>>(
let update_state = StepEContext { let update_state = StepEContext {
inv_feature_size: meta.inv_feature_size, inv_feature_size: meta.inv_feature_size,
time_step: meta.time_step, time_step: meta.time_step,
stim_e: stim_e.cast(), stim_e,
mat, mat,
in_h, in_h,
in_e, in_e,
@@ -89,7 +89,7 @@ fn step_e_cell<R: Real, M: Material<R>>(
fn step_h_cell<R: Real, M: Material<R>>( fn step_h_cell<R: Real, M: Material<R>>(
idx: Vec3u, idx: Vec3u,
meta: SimMeta<R>, meta: SimMeta<R>,
stim_h: &[Vec3<f32>], stim_h: &[Vec3<R>],
mat: &[M], mat: &[M],
e: &[Vec3<R>], e: &[Vec3<R>],
h: &mut [Vec3<R>], h: &mut [Vec3<R>],
@@ -112,7 +112,7 @@ fn step_h_cell<R: Real, M: Material<R>>(
let update_state = StepHContext { let update_state = StepHContext {
inv_feature_size: meta.inv_feature_size, inv_feature_size: meta.inv_feature_size,
time_step: meta.time_step, time_step: meta.time_step,
stim_h: stim_h.cast(), stim_h,
mat, mat,
in_e, in_e,
in_h, in_h,

View File

@@ -67,8 +67,8 @@ impl<R: Copy, M: Send + Sync + 'static> SimBackend<R, M> for WgpuBackend<R, M> {
&self, &self,
num_steps: u32, num_steps: u32,
meta: SimMeta<R>, meta: SimMeta<R>,
stim_cpu_e: &[Vec3<f32>], stim_cpu_e: &[Vec3<R>],
stim_cpu_h: &[Vec3<f32>], stim_cpu_h: &[Vec3<R>],
mat: &[M], mat: &[M],
e: &mut [Vec3<R>], e: &mut [Vec3<R>],
h: &mut [Vec3<R>], h: &mut [Vec3<R>],

View File

@@ -22,8 +22,8 @@ pub trait SimBackend<R, M> {
&self, &self,
num_steps: u32, num_steps: u32,
meta: SimMeta<R>, meta: SimMeta<R>,
stim_e: &[Vec3<f32>], stim_e: &[Vec3<R>],
stim_h: &[Vec3<f32>], stim_h: &[Vec3<R>],
mat: &[M], mat: &[M],
e: &mut [Vec3<R>], e: &mut [Vec3<R>],
h: &mut [Vec3<R>], h: &mut [Vec3<R>],
@@ -227,7 +227,7 @@ where
#[allow(unused)] // used for test #[allow(unused)] // used for test
fn apply_stimulus(&mut self, stim: &dyn AbstractStimulus) { fn apply_stimulus(&mut self, stim: &dyn AbstractStimulus) {
trace!("apply_stimulus begin"); trace!("apply_stimulus begin");
iterate_stim(stim, self.size(), self.feature_size(), self.time(), self.timestep(), |pos_idx, value_e, value_h| { iterate_stim(stim, self.size(), self.feature_size(), self.time(), self.meta.time_step, |pos_idx, value_e, value_h| {
let flat_idx = self.flat_index(pos_idx).unwrap(); let flat_idx = self.flat_index(pos_idx).unwrap();
self.e[flat_idx] += value_e.cast(); self.e[flat_idx] += value_e.cast();
self.h[flat_idx] += value_h.cast(); self.h[flat_idx] += value_h.cast();
@@ -236,13 +236,13 @@ where
} }
fn eval_stimulus<S: AbstractStimulus>(&self, stim: &S) fn eval_stimulus<S: AbstractStimulus>(&self, stim: &S)
-> (Array3<Vec3<f32>>, Array3<Vec3<f32>>) -> (Array3<Vec3<R>>, Array3<Vec3<R>>)
{ {
trace!("eval_stimulus begin"); trace!("eval_stimulus begin");
let dim = self.size(); let dim = self.size();
let feature_size = self.feature_size(); let feature_size = self.feature_size();
let t_sec = self.time(); let t_sec = self.time();
let timestep = self.timestep(); let timestep = self.meta.time_step;
// TODO(perf): do this in one loop! // TODO(perf): do this in one loop!
let e = ndarray::Zip::from(ndarray::indices( let e = ndarray::Zip::from(ndarray::indices(
@@ -251,7 +251,7 @@ where
let pos_idx = Index::new(x as _, y as _, z as _); let pos_idx = Index::new(x as _, y as _, z as _);
let pos_meters = pos_idx.to_meters(feature_size); let pos_meters = pos_idx.to_meters(feature_size);
let (density_e, _density_h) = stim.at(t_sec, pos_meters); let (density_e, _density_h) = stim.at(t_sec, pos_meters);
density_e * timestep density_e.cast::<R>() * timestep
}); });
let h = ndarray::Zip::from(ndarray::indices( let h = ndarray::Zip::from(ndarray::indices(
[dim.z() as usize, dim.y() as usize, dim.x() as usize] [dim.z() as usize, dim.y() as usize, dim.x() as usize]
@@ -259,16 +259,21 @@ where
let pos_idx = Index::new(x as _, y as _, z as _); let pos_idx = Index::new(x as _, y as _, z as _);
let pos_meters = pos_idx.to_meters(feature_size); let pos_meters = pos_idx.to_meters(feature_size);
let (_density_e, density_h) = stim.at(t_sec, pos_meters); let (_density_e, density_h) = stim.at(t_sec, pos_meters);
density_h * timestep density_h.cast::<R>() * timestep
}); });
trace!("eval_stimulus end"); trace!("eval_stimulus end");
(e, h) (e, h)
} }
} }
fn iterate_stim<S: AbstractStimulus + ?Sized, F: FnMut(Index, Vec3<f32>, Vec3<f32>)>( fn iterate_stim<R, S, F>(
stim: &S, dim: Index, feature_size: f32, t_sec: f32, timestep: f32, mut f: F stim: &S, dim: Index, feature_size: f32, t_sec: f32, timestep: R, mut f: F
) { )
where
R: Real,
S: AbstractStimulus + ?Sized,
F: FnMut(Index, Vec3<R>, Vec3<R>),
{
// TODO: parallelize // TODO: parallelize
for z in 0..dim.z() { for z in 0..dim.z() {
for y in 0..dim.y() { for y in 0..dim.y() {
@@ -276,7 +281,7 @@ fn iterate_stim<S: AbstractStimulus + ?Sized, F: FnMut(Index, Vec3<f32>, Vec3<f3
let pos_idx = Index::new(x, y, z); let pos_idx = Index::new(x, y, z);
let pos_meters = pos_idx.to_meters(feature_size); let pos_meters = pos_idx.to_meters(feature_size);
let (density_e, density_h) = stim.at(t_sec, pos_meters); let (density_e, density_h) = stim.at(t_sec, pos_meters);
let (value_e, value_h) = (density_e * timestep, density_h * timestep); let (value_e, value_h) = (density_e.cast::<R>() * timestep, density_h.cast::<R>() * timestep);
f(pos_idx, value_e, value_h) f(pos_idx, value_e, value_h)
} }
} }