spirv: port all backends to use R for the stimulus
particularly, this patches over a difference where the gpu backend expected the stimulus to be R, while the CPU thought it should be f32. that would likely have revealed a crash if we had tested it with f64 (TODO).
This commit is contained in:
@@ -16,8 +16,8 @@ impl<R: Real, M: Material<R>> SimBackend<R, M> for CpuBackend {
|
|||||||
&self,
|
&self,
|
||||||
num_steps: u32,
|
num_steps: u32,
|
||||||
meta: SimMeta<R>,
|
meta: SimMeta<R>,
|
||||||
stim_e: &[Vec3<f32>],
|
stim_e: &[Vec3<R>],
|
||||||
stim_h: &[Vec3<f32>],
|
stim_h: &[Vec3<R>],
|
||||||
mat: &[M],
|
mat: &[M],
|
||||||
e: &mut [Vec3<R>],
|
e: &mut [Vec3<R>],
|
||||||
h: &mut [Vec3<R>],
|
h: &mut [Vec3<R>],
|
||||||
@@ -32,7 +32,7 @@ impl<R: Real, M: Material<R>> SimBackend<R, M> for CpuBackend {
|
|||||||
|
|
||||||
fn step_e<R: Real, M: Material<R>>(
|
fn step_e<R: Real, M: Material<R>>(
|
||||||
meta: SimMeta<R>,
|
meta: SimMeta<R>,
|
||||||
stim_e: &[Vec3<f32>],
|
stim_e: &[Vec3<R>],
|
||||||
mat: &[M],
|
mat: &[M],
|
||||||
e: &mut [Vec3<R>],
|
e: &mut [Vec3<R>],
|
||||||
h: &[Vec3<R>],
|
h: &[Vec3<R>],
|
||||||
@@ -44,7 +44,7 @@ fn step_e<R: Real, M: Material<R>>(
|
|||||||
}
|
}
|
||||||
fn step_h<R: Real, M: Material<R>>(
|
fn step_h<R: Real, M: Material<R>>(
|
||||||
meta: SimMeta<R>,
|
meta: SimMeta<R>,
|
||||||
stim_h: &[Vec3<f32>],
|
stim_h: &[Vec3<R>],
|
||||||
mat: &[M],
|
mat: &[M],
|
||||||
e: &[Vec3<R>],
|
e: &[Vec3<R>],
|
||||||
h: &mut [Vec3<R>],
|
h: &mut [Vec3<R>],
|
||||||
@@ -58,7 +58,7 @@ fn step_h<R: Real, M: Material<R>>(
|
|||||||
fn step_e_cell<R: Real, M: Material<R>>(
|
fn step_e_cell<R: Real, M: Material<R>>(
|
||||||
idx: Vec3u,
|
idx: Vec3u,
|
||||||
meta: SimMeta<R>,
|
meta: SimMeta<R>,
|
||||||
stim_e: &[Vec3<f32>],
|
stim_e: &[Vec3<R>],
|
||||||
mat: &[M],
|
mat: &[M],
|
||||||
e: &mut [Vec3<R>],
|
e: &mut [Vec3<R>],
|
||||||
h: &[Vec3<R>],
|
h: &[Vec3<R>],
|
||||||
@@ -78,7 +78,7 @@ fn step_e_cell<R: Real, M: Material<R>>(
|
|||||||
let update_state = StepEContext {
|
let update_state = StepEContext {
|
||||||
inv_feature_size: meta.inv_feature_size,
|
inv_feature_size: meta.inv_feature_size,
|
||||||
time_step: meta.time_step,
|
time_step: meta.time_step,
|
||||||
stim_e: stim_e.cast(),
|
stim_e,
|
||||||
mat,
|
mat,
|
||||||
in_h,
|
in_h,
|
||||||
in_e,
|
in_e,
|
||||||
@@ -89,7 +89,7 @@ fn step_e_cell<R: Real, M: Material<R>>(
|
|||||||
fn step_h_cell<R: Real, M: Material<R>>(
|
fn step_h_cell<R: Real, M: Material<R>>(
|
||||||
idx: Vec3u,
|
idx: Vec3u,
|
||||||
meta: SimMeta<R>,
|
meta: SimMeta<R>,
|
||||||
stim_h: &[Vec3<f32>],
|
stim_h: &[Vec3<R>],
|
||||||
mat: &[M],
|
mat: &[M],
|
||||||
e: &[Vec3<R>],
|
e: &[Vec3<R>],
|
||||||
h: &mut [Vec3<R>],
|
h: &mut [Vec3<R>],
|
||||||
@@ -112,7 +112,7 @@ fn step_h_cell<R: Real, M: Material<R>>(
|
|||||||
let update_state = StepHContext {
|
let update_state = StepHContext {
|
||||||
inv_feature_size: meta.inv_feature_size,
|
inv_feature_size: meta.inv_feature_size,
|
||||||
time_step: meta.time_step,
|
time_step: meta.time_step,
|
||||||
stim_h: stim_h.cast(),
|
stim_h,
|
||||||
mat,
|
mat,
|
||||||
in_e,
|
in_e,
|
||||||
in_h,
|
in_h,
|
||||||
|
@@ -67,8 +67,8 @@ impl<R: Copy, M: Send + Sync + 'static> SimBackend<R, M> for WgpuBackend<R, M> {
|
|||||||
&self,
|
&self,
|
||||||
num_steps: u32,
|
num_steps: u32,
|
||||||
meta: SimMeta<R>,
|
meta: SimMeta<R>,
|
||||||
stim_cpu_e: &[Vec3<f32>],
|
stim_cpu_e: &[Vec3<R>],
|
||||||
stim_cpu_h: &[Vec3<f32>],
|
stim_cpu_h: &[Vec3<R>],
|
||||||
mat: &[M],
|
mat: &[M],
|
||||||
e: &mut [Vec3<R>],
|
e: &mut [Vec3<R>],
|
||||||
h: &mut [Vec3<R>],
|
h: &mut [Vec3<R>],
|
||||||
|
@@ -22,8 +22,8 @@ pub trait SimBackend<R, M> {
|
|||||||
&self,
|
&self,
|
||||||
num_steps: u32,
|
num_steps: u32,
|
||||||
meta: SimMeta<R>,
|
meta: SimMeta<R>,
|
||||||
stim_e: &[Vec3<f32>],
|
stim_e: &[Vec3<R>],
|
||||||
stim_h: &[Vec3<f32>],
|
stim_h: &[Vec3<R>],
|
||||||
mat: &[M],
|
mat: &[M],
|
||||||
e: &mut [Vec3<R>],
|
e: &mut [Vec3<R>],
|
||||||
h: &mut [Vec3<R>],
|
h: &mut [Vec3<R>],
|
||||||
@@ -227,7 +227,7 @@ where
|
|||||||
#[allow(unused)] // used for test
|
#[allow(unused)] // used for test
|
||||||
fn apply_stimulus(&mut self, stim: &dyn AbstractStimulus) {
|
fn apply_stimulus(&mut self, stim: &dyn AbstractStimulus) {
|
||||||
trace!("apply_stimulus begin");
|
trace!("apply_stimulus begin");
|
||||||
iterate_stim(stim, self.size(), self.feature_size(), self.time(), self.timestep(), |pos_idx, value_e, value_h| {
|
iterate_stim(stim, self.size(), self.feature_size(), self.time(), self.meta.time_step, |pos_idx, value_e, value_h| {
|
||||||
let flat_idx = self.flat_index(pos_idx).unwrap();
|
let flat_idx = self.flat_index(pos_idx).unwrap();
|
||||||
self.e[flat_idx] += value_e.cast();
|
self.e[flat_idx] += value_e.cast();
|
||||||
self.h[flat_idx] += value_h.cast();
|
self.h[flat_idx] += value_h.cast();
|
||||||
@@ -236,13 +236,13 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn eval_stimulus<S: AbstractStimulus>(&self, stim: &S)
|
fn eval_stimulus<S: AbstractStimulus>(&self, stim: &S)
|
||||||
-> (Array3<Vec3<f32>>, Array3<Vec3<f32>>)
|
-> (Array3<Vec3<R>>, Array3<Vec3<R>>)
|
||||||
{
|
{
|
||||||
trace!("eval_stimulus begin");
|
trace!("eval_stimulus begin");
|
||||||
let dim = self.size();
|
let dim = self.size();
|
||||||
let feature_size = self.feature_size();
|
let feature_size = self.feature_size();
|
||||||
let t_sec = self.time();
|
let t_sec = self.time();
|
||||||
let timestep = self.timestep();
|
let timestep = self.meta.time_step;
|
||||||
|
|
||||||
// TODO(perf): do this in one loop!
|
// TODO(perf): do this in one loop!
|
||||||
let e = ndarray::Zip::from(ndarray::indices(
|
let e = ndarray::Zip::from(ndarray::indices(
|
||||||
@@ -251,7 +251,7 @@ where
|
|||||||
let pos_idx = Index::new(x as _, y as _, z as _);
|
let pos_idx = Index::new(x as _, y as _, z as _);
|
||||||
let pos_meters = pos_idx.to_meters(feature_size);
|
let pos_meters = pos_idx.to_meters(feature_size);
|
||||||
let (density_e, _density_h) = stim.at(t_sec, pos_meters);
|
let (density_e, _density_h) = stim.at(t_sec, pos_meters);
|
||||||
density_e * timestep
|
density_e.cast::<R>() * timestep
|
||||||
});
|
});
|
||||||
let h = ndarray::Zip::from(ndarray::indices(
|
let h = ndarray::Zip::from(ndarray::indices(
|
||||||
[dim.z() as usize, dim.y() as usize, dim.x() as usize]
|
[dim.z() as usize, dim.y() as usize, dim.x() as usize]
|
||||||
@@ -259,16 +259,21 @@ where
|
|||||||
let pos_idx = Index::new(x as _, y as _, z as _);
|
let pos_idx = Index::new(x as _, y as _, z as _);
|
||||||
let pos_meters = pos_idx.to_meters(feature_size);
|
let pos_meters = pos_idx.to_meters(feature_size);
|
||||||
let (_density_e, density_h) = stim.at(t_sec, pos_meters);
|
let (_density_e, density_h) = stim.at(t_sec, pos_meters);
|
||||||
density_h * timestep
|
density_h.cast::<R>() * timestep
|
||||||
});
|
});
|
||||||
trace!("eval_stimulus end");
|
trace!("eval_stimulus end");
|
||||||
(e, h)
|
(e, h)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn iterate_stim<S: AbstractStimulus + ?Sized, F: FnMut(Index, Vec3<f32>, Vec3<f32>)>(
|
fn iterate_stim<R, S, F>(
|
||||||
stim: &S, dim: Index, feature_size: f32, t_sec: f32, timestep: f32, mut f: F
|
stim: &S, dim: Index, feature_size: f32, t_sec: f32, timestep: R, mut f: F
|
||||||
) {
|
)
|
||||||
|
where
|
||||||
|
R: Real,
|
||||||
|
S: AbstractStimulus + ?Sized,
|
||||||
|
F: FnMut(Index, Vec3<R>, Vec3<R>),
|
||||||
|
{
|
||||||
// TODO: parallelize
|
// TODO: parallelize
|
||||||
for z in 0..dim.z() {
|
for z in 0..dim.z() {
|
||||||
for y in 0..dim.y() {
|
for y in 0..dim.y() {
|
||||||
@@ -276,7 +281,7 @@ fn iterate_stim<S: AbstractStimulus + ?Sized, F: FnMut(Index, Vec3<f32>, Vec3<f3
|
|||||||
let pos_idx = Index::new(x, y, z);
|
let pos_idx = Index::new(x, y, z);
|
||||||
let pos_meters = pos_idx.to_meters(feature_size);
|
let pos_meters = pos_idx.to_meters(feature_size);
|
||||||
let (density_e, density_h) = stim.at(t_sec, pos_meters);
|
let (density_e, density_h) = stim.at(t_sec, pos_meters);
|
||||||
let (value_e, value_h) = (density_e * timestep, density_h * timestep);
|
let (value_e, value_h) = (density_e.cast::<R>() * timestep, density_h.cast::<R>() * timestep);
|
||||||
f(pos_idx, value_e, value_h)
|
f(pos_idx, value_e, value_h)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user