diff --git a/crates/spirv_backend/src/lib.rs b/crates/spirv_backend/src/lib.rs index 6fcd8f7..8e223f2 100644 --- a/crates/spirv_backend/src/lib.rs +++ b/crates/spirv_backend/src/lib.rs @@ -39,10 +39,9 @@ fn step_h>( ) { if id.x() < meta.dim.x() && id.y() < meta.dim.y() && id.z() < meta.dim.z() { let sim_state = SerializedStepH::new(meta, stimulus_h, material, e, h, m); - let (update_state, mut out_h, mut out_m) = sim_state.index(id); + let update_state = sim_state.index(id); let (new_h, new_m) = update_state.step_h(); - out_h.write(new_h); - out_m.write(new_m); + sim_state.write_output(id, new_h, new_m); } } @@ -57,9 +56,9 @@ fn step_e>( if id.x() < meta.dim.x() && id.y() < meta.dim.y() && id.z() < meta.dim.z() { let sim_state = SerializedStepE::new(meta, stimulus_e, material, e, h); - let (update_state, mut out_e) = sim_state.index(id); + let update_state = sim_state.index(id); let new_e = update_state.step_e(); - out_e.write(new_e); + sim_state.write_output(id, new_e); } } diff --git a/crates/spirv_backend/src/sim.rs b/crates/spirv_backend/src/sim.rs index f5d3be4..c3c6d48 100644 --- a/crates/spirv_backend/src/sim.rs +++ b/crates/spirv_backend/src/sim.rs @@ -1,6 +1,6 @@ // use spirv_std::RuntimeArray; use crate::support::{ - Array3, Array3Mut, ArrayHandle, ArrayHandleMut, Optional, UnsizedArray + Array3, Array3Mut, Optional, UnsizedArray }; use coremem_types::mat::Material; use coremem_types::real::Real; @@ -41,19 +41,14 @@ impl<'a, R, M> SerializedStepH<'a, R, M> { } impl<'a, R: Real, M> SerializedStepH<'a, R, M> { - /// returns a context which the user can call `step_h` on, - /// plus output handles which the the results can be written to. - pub fn index(self, idx: Vec3u) -> ( - StepHContext<'a, R, M>, - ArrayHandleMut<'a, Vec3>, // out_h - ArrayHandleMut<'a, Vec3>, // out_m - ){ + /// returns a context which the user can call `step_h` on + pub fn index(&self, idx: Vec3u) -> StepHContext<'a, R, M> { let dim = self.meta.dim; let stim_h_matrix = Array3::new(self.stimulus_h, dim); let mat_matrix = Array3::new(self.material, dim); let e = Array3::new(self.e, dim); - let h = Array3Mut::new(self.h, dim); - let m = Array3Mut::new(self.m, dim); + let h = Array3::new(self.h, dim); + let m = Array3::new(self.m, dim); let in_e = VolumeSamplePos { mid: e.get(idx).unwrap(), @@ -61,24 +56,30 @@ impl<'a, R: Real, M> SerializedStepH<'a, R, M> { yp1: e.get(idx + Vec3u::unit_y()), zp1: e.get(idx + Vec3u::unit_z()), }; - let out_h = h.into_mut_handle(idx); - let out_m = m.into_mut_handle(idx); + let in_h = h.get(idx).unwrap(); + let in_m = m.get(idx).unwrap(); - let mat = mat_matrix.into_handle(idx); + let mat = mat_matrix.into_ref(idx); - ( - StepHContext { - inv_feature_size: self.meta.inv_feature_size, - time_step: self.meta.time_step, - stim_h: stim_h_matrix.get(idx).unwrap(), - mat, - in_e, - in_h: out_h.get(), - in_m: out_m.get(), - }, - out_h, - out_m, - ) + StepHContext { + inv_feature_size: self.meta.inv_feature_size, + time_step: self.meta.time_step, + stim_h: stim_h_matrix.get(idx).unwrap(), + mat, + in_e, + in_h, + in_m, + } + } + + pub fn write_output(self, idx: Vec3u, h: Vec3, m: Vec3) { + let dim = self.meta.dim; + let arr_h = Array3Mut::new(self.h, dim); + let arr_m = Array3Mut::new(self.m, dim); + let mut out_h = arr_h.into_mut_handle(idx); + let mut out_m = arr_m.into_mut_handle(idx); + out_h.write(h); + out_m.write(m); } } @@ -106,14 +107,11 @@ impl<'a, R, M> SerializedStepE<'a, R, M> { } impl<'a, R: Real, M> SerializedStepE<'a, R, M> { - pub fn index(self, idx: Vec3u) -> ( - StepEContext<'a, R, M>, - ArrayHandleMut<'a, Vec3> // out_e - ){ + pub fn index(&self, idx: Vec3u) -> StepEContext<'a, R, M> { let dim = self.meta.dim; let stim_e_matrix = Array3::new(self.stimulus_e, dim); let mat_matrix = Array3::new(self.material, dim); - let e = Array3Mut::new(self.e, dim); + let e = Array3::new(self.e, dim); let h = Array3::new(self.h, dim); let xm1 = if idx.x() == 0 { @@ -138,21 +136,25 @@ impl<'a, R: Real, M> SerializedStepE<'a, R, M> { ym1, zm1, }; - let out_e = e.into_mut_handle(idx); + let in_e = e.get(idx).unwrap(); - let mat = mat_matrix.into_handle(idx); + let mat = mat_matrix.into_ref(idx); - ( - StepEContext { - inv_feature_size: self.meta.inv_feature_size, - time_step: self.meta.time_step, - stim_e: stim_e_matrix.get(idx).unwrap(), - mat, - in_h, - in_e: out_e.get(), - }, - out_e, - ) + StepEContext { + inv_feature_size: self.meta.inv_feature_size, + time_step: self.meta.time_step, + stim_e: stim_e_matrix.get(idx).unwrap(), + mat, + in_h, + in_e, + } + } + + pub fn write_output(self, idx: Vec3u, e: Vec3) { + let dim = self.meta.dim; + let arr_e = Array3Mut::new(self.e, dim); + let mut out_e = arr_e.into_mut_handle(idx); + out_e.write(e); } } @@ -292,7 +294,7 @@ pub struct StepEContext<'a, R, M> { inv_feature_size: R, time_step: R, stim_e: Vec3, - mat: ArrayHandle<'a, M>, + mat: &'a M, /// Input field sampled near this location in_h: VolumeSampleNeg, in_e: Vec3, @@ -307,7 +309,7 @@ impl<'a, R: Real, M: Material> StepEContext<'a, R, M> { // $\nabla x H = \epsilon_0 dE/dt + \sigma E$ // no-conductivity version: // let delta_e = nabla_h * (self.time_step * EPS0_INV); - let sigma = self.mat.get_ref().conductivity(); + let sigma = self.mat.conductivity(); let e_prev = self.in_e; let delta_e = (nabla_h - e_prev.elem_mul(sigma)).elem_div( sigma*self.time_step + Vec3::uniform(twice_eps0) @@ -322,7 +324,7 @@ pub struct StepHContext<'a, R, M> { inv_feature_size: R, time_step: R, stim_h: Vec3, - mat: ArrayHandle<'a, M>, + mat: &'a M, /// Input field sampled near this location in_e: VolumeSamplePos, in_h: Vec3, @@ -346,7 +348,7 @@ impl<'a, R: Real, M: Material> StepHContext<'a, R, M> { let old_b = (old_h + old_m) * mu0; let new_b = old_b + delta_b + self.stim_h * mu0; - let mat = self.mat.get_ref(); + let mat = self.mat; let new_m = mat.move_b_vec(old_m, new_b); let new_h = new_b * mu0_inv - new_m; // println!("spirv-step_h delta_h: {:?}", delta_h); diff --git a/crates/spirv_backend/src/support.rs b/crates/spirv_backend/src/support.rs index 34eb203..fb07c13 100644 --- a/crates/spirv_backend/src/support.rs +++ b/crates/spirv_backend/src/support.rs @@ -211,6 +211,13 @@ impl<'a, T> Array3<'a, T> { self.data.get_handle(idx) } } + + pub fn into_ref(self, idx: Vec3u) -> &'a T { + let idx = checked_index(idx, self.dim).unwrap(); + unsafe { + self.data.index_ref(idx) + } + } } impl<'a, T: Copy + Default> Array3<'a, T> {