spirv: remove the ArrayHandleMut artifacts in Step{H,E}Context

this will make it easier to reuse these blocks on the CPU side.
This commit is contained in:
2022-07-24 22:17:44 -07:00
parent b70cafa205
commit 05f5f75dd3
2 changed files with 49 additions and 33 deletions

View File

@@ -39,8 +39,10 @@ fn step_h<R: Real, M: Material<R>>(
) { ) {
if id.x() < meta.dim.x() && id.y() < meta.dim.y() && id.z() < meta.dim.z() { if id.x() < meta.dim.x() && id.y() < meta.dim.y() && id.z() < meta.dim.z() {
let sim_state = SerializedStepH::new(meta, stimulus_h, material, e, h, m); let sim_state = SerializedStepH::new(meta, stimulus_h, material, e, h, m);
let update_state = sim_state.index(id); let (update_state, mut out_h, mut out_m) = sim_state.index(id);
update_state.step_h(); let (new_h, new_m) = update_state.step_h();
out_h.write(new_h);
out_m.write(new_m);
} }
} }
@@ -55,8 +57,9 @@ fn step_e<R: Real, M: Material<R>>(
if id.x() < meta.dim.x() && id.y() < meta.dim.y() && id.z() < meta.dim.z() { if id.x() < meta.dim.x() && id.y() < meta.dim.y() && id.z() < meta.dim.z() {
let sim_state = SerializedStepE::new(meta, stimulus_e, material, e, h); let sim_state = SerializedStepE::new(meta, stimulus_e, material, e, h);
let update_state = sim_state.index(id); let (update_state, mut out_e) = sim_state.index(id);
update_state.step_e(); let new_e = update_state.step_e();
out_e.write(new_e);
} }
} }

View File

@@ -41,7 +41,13 @@ impl<'a, R, M> SerializedStepH<'a, R, M> {
} }
impl<'a, R: Real, M> SerializedStepH<'a, R, M> { impl<'a, R: Real, M> SerializedStepH<'a, R, M> {
pub fn index(self, idx: Vec3u) -> StepHContext<'a, R, M> { /// returns a context which the user can call `step_h` on,
/// plus output handles which the the results can be written to.
pub fn index(self, idx: Vec3u) -> (
StepHContext<'a, R, M>,
ArrayHandleMut<'a, Vec3<R>>, // out_h
ArrayHandleMut<'a, Vec3<R>>, // out_m
){
let dim = self.meta.dim; let dim = self.meta.dim;
let stim_h_matrix = Array3::new(self.stimulus_h, dim); let stim_h_matrix = Array3::new(self.stimulus_h, dim);
let mat_matrix = Array3::new(self.material, dim); let mat_matrix = Array3::new(self.material, dim);
@@ -60,15 +66,19 @@ impl<'a, R: Real, M> SerializedStepH<'a, R, M> {
let mat = mat_matrix.into_handle(idx); let mat = mat_matrix.into_handle(idx);
StepHContext { (
inv_feature_size: self.meta.inv_feature_size, StepHContext {
time_step: self.meta.time_step, inv_feature_size: self.meta.inv_feature_size,
stim_h: stim_h_matrix.get(idx).unwrap(), time_step: self.meta.time_step,
mat, stim_h: stim_h_matrix.get(idx).unwrap(),
in_e, mat,
in_e,
in_h: out_h.get(),
in_m: out_m.get(),
},
out_h, out_h,
out_m, out_m,
} )
} }
} }
@@ -96,7 +106,10 @@ impl<'a, R, M> SerializedStepE<'a, R, M> {
} }
impl<'a, R: Real, M> SerializedStepE<'a, R, M> { impl<'a, R: Real, M> SerializedStepE<'a, R, M> {
pub fn index(self, idx: Vec3u) -> StepEContext<'a, R, M> { pub fn index(self, idx: Vec3u) -> (
StepEContext<'a, R, M>,
ArrayHandleMut<'a, Vec3<R>> // out_e
){
let dim = self.meta.dim; let dim = self.meta.dim;
let stim_e_matrix = Array3::new(self.stimulus_e, dim); let stim_e_matrix = Array3::new(self.stimulus_e, dim);
let mat_matrix = Array3::new(self.material, dim); let mat_matrix = Array3::new(self.material, dim);
@@ -129,14 +142,17 @@ impl<'a, R: Real, M> SerializedStepE<'a, R, M> {
let mat = mat_matrix.into_handle(idx); let mat = mat_matrix.into_handle(idx);
StepEContext { (
inv_feature_size: self.meta.inv_feature_size, StepEContext {
time_step: self.meta.time_step, inv_feature_size: self.meta.inv_feature_size,
stim_e: stim_e_matrix.get(idx).unwrap(), time_step: self.meta.time_step,
mat, stim_e: stim_e_matrix.get(idx).unwrap(),
in_h, mat,
in_h,
in_e: out_e.get(),
},
out_e, out_e,
} )
} }
} }
@@ -279,12 +295,11 @@ pub struct StepEContext<'a, R, M> {
mat: ArrayHandle<'a, M>, mat: ArrayHandle<'a, M>,
/// Input field sampled near this location /// Input field sampled near this location
in_h: VolumeSampleNeg<R>, in_h: VolumeSampleNeg<R>,
/// Handle to the output field at one specific index. in_e: Vec3<R>,
out_e: ArrayHandleMut<'a, Vec3<R>>,
} }
impl<'a, R: Real, M: Material<R>> StepEContext<'a, R, M> { impl<'a, R: Real, M: Material<R>> StepEContext<'a, R, M> {
pub fn step_e(mut self) { pub fn step_e(self) -> Vec3<R> {
let twice_eps0 = R::twice_eps0(); let twice_eps0 = R::twice_eps0();
let deltas = self.in_h.delta_h(); let deltas = self.in_h.delta_h();
// \nabla x H // \nabla x H
@@ -293,12 +308,12 @@ impl<'a, R: Real, M: Material<R>> StepEContext<'a, R, M> {
// no-conductivity version: // no-conductivity version:
// let delta_e = nabla_h * (self.time_step * EPS0_INV); // let delta_e = nabla_h * (self.time_step * EPS0_INV);
let sigma = self.mat.get_ref().conductivity(); let sigma = self.mat.get_ref().conductivity();
let e_prev = self.out_e.get(); let e_prev = self.in_e;
let delta_e = (nabla_h - e_prev.elem_mul(sigma)).elem_div( let delta_e = (nabla_h - e_prev.elem_mul(sigma)).elem_div(
sigma*self.time_step + Vec3::uniform(twice_eps0) sigma*self.time_step + Vec3::uniform(twice_eps0)
)*(R::two()*self.time_step); )*(R::two()*self.time_step);
// println!("spirv-step_e delta_e: {:?}", delta_e); // println!("spirv-step_e delta_e: {:?}", delta_e);
self.out_e.write(e_prev + delta_e + self.stim_e); e_prev + delta_e + self.stim_e
} }
} }
@@ -310,13 +325,12 @@ pub struct StepHContext<'a, R, M> {
mat: ArrayHandle<'a, M>, mat: ArrayHandle<'a, M>,
/// Input field sampled near this location /// Input field sampled near this location
in_e: VolumeSamplePos<R>, in_e: VolumeSamplePos<R>,
/// Handle to the output field at one specific index. in_h: Vec3<R>,
out_h: ArrayHandleMut<'a, Vec3<R>>, in_m: Vec3<R>,
out_m: ArrayHandleMut<'a, Vec3<R>>,
} }
impl<'a, R: Real, M: Material<R>> StepHContext<'a, R, M> { impl<'a, R: Real, M: Material<R>> StepHContext<'a, R, M> {
pub fn step_h(mut self) { pub fn step_h(self) -> (Vec3<R>, Vec3<R>) {
let mu0 = R::mu0(); let mu0 = R::mu0();
let mu0_inv = R::mu0_inv(); let mu0_inv = R::mu0_inv();
let deltas = self.in_e.delta_e(); let deltas = self.in_e.delta_e();
@@ -327,8 +341,8 @@ impl<'a, R: Real, M: Material<R>> StepHContext<'a, R, M> {
let delta_b = nabla_e * (-self.time_step); let delta_b = nabla_e * (-self.time_step);
// Relation between these is: B = mu0*(H + M) // Relation between these is: B = mu0*(H + M)
let old_h = self.out_h.get(); let old_h = self.in_h;
let old_m = self.out_m.get(); let old_m = self.in_m;
let old_b = (old_h + old_m) * mu0; let old_b = (old_h + old_m) * mu0;
let new_b = old_b + delta_b + self.stim_h * mu0; let new_b = old_b + delta_b + self.stim_h * mu0;
@@ -336,7 +350,6 @@ impl<'a, R: Real, M: Material<R>> StepHContext<'a, R, M> {
let new_m = mat.move_b_vec(old_m, new_b); let new_m = mat.move_b_vec(old_m, new_b);
let new_h = new_b * mu0_inv - new_m; let new_h = new_b * mu0_inv - new_m;
// println!("spirv-step_h delta_h: {:?}", delta_h); // println!("spirv-step_h delta_h: {:?}", delta_h);
self.out_h.write(new_h); (new_h, new_m)
self.out_m.write(new_m);
} }
} }