spirv: remove the Optionality around entry points: compute them statically with traits

This commit is contained in:
2022-07-27 12:32:43 -07:00
parent baaeeb9463
commit 568d61c598
3 changed files with 27 additions and 31 deletions

View File

@@ -7,7 +7,7 @@ use crate::sim::{GenericSim, MaterialSim, SampleableSim};
use crate::sim::legacy::{self, SimState};
use crate::sim::legacy::mat::Pml;
use crate::sim::units::{Frame, Time};
use crate::sim::spirv::SpirvSim;
use crate::sim::spirv::{self, SpirvSim};
use crate::stim::AbstractStimulus;
use crate::types::vec::Vec3;
@@ -47,7 +47,9 @@ impl<R: Real, M: Default> Driver<SimState<R, M>> {
}
}
impl<R: Real, M: Default + Send + Sync + 'static> SpirvDriver<R, M>
impl<R: Real, M: Default + Send + Sync> SpirvDriver<R, M>
where
spirv::WgpuBackend<R, M>: spirv::SimBackend<R, M>
{
pub fn new_spirv<C: Coord>(size: C, feature_size: f32) -> Self {
Self::new_with_state(SpirvSim::new(size.to_index(feature_size), feature_size))

View File

@@ -9,30 +9,27 @@ use wgpu::util::DeviceExt as _;
use coremem_types::vec::Vec3;
use coremem_types::step::SimMeta;
use spirv_backend::HasEntryPoints;
use super::SimBackend;
pub fn entry_points<L: 'static>() -> Option<(&'static str, &'static str)>
{
spirv_backend::entry_points::<L>().into()
}
pub struct WgpuBackend<R, M: 'static> {
pub struct WgpuBackend<R, M> {
step_bind_group_layout: wgpu::BindGroupLayout,
step_e_pipeline: wgpu::ComputePipeline,
step_h_pipeline: wgpu::ComputePipeline,
device: wgpu::Device,
queue: wgpu::Queue,
real: PhantomData<R>,
mat: PhantomData<&'static M>,
mat: PhantomData<M>,
}
impl<R, M: 'static> Default for WgpuBackend<R, M> {
impl<R, M: HasEntryPoints<R>> Default for WgpuBackend<R, M> {
fn default() -> Self {
Self::new_with_vol(0)
}
}
impl<R, M: 'static> WgpuBackend<R, M> {
impl<R, M: HasEntryPoints<R>> WgpuBackend<R, M> {
fn new_with_vol(volume: u64) -> Self {
info!("WgpuBackend::new_with_vol({})", volume);
use std::mem::size_of;
@@ -40,7 +37,7 @@ impl<R, M: 'static> WgpuBackend<R, M> {
let max_array_size = volume * max_elem_size as u64;
let max_buf_size = max_array_size + 0x1000; // allow some overhead
let entry_names = entry_points::<M>().unwrap_or(("invalid_mat", "invalid_mat"));
let entry_names = (M::step_h(), M::step_e());
let (device, queue) = futures::executor::block_on(open_device(max_buf_size));
let shader_binary = get_shader();
let shader_module = unsafe { device.create_shader_module_spirv(&shader_binary) };
@@ -57,7 +54,7 @@ impl<R, M: 'static> WgpuBackend<R, M> {
}
}
impl<R: Copy, M: Send + Sync + 'static> SimBackend<R, M> for WgpuBackend<R, M> {
impl<R: Copy, M: Send + Sync + HasEntryPoints<R>> SimBackend<R, M> for WgpuBackend<R, M> {
fn new(volume: u64) -> Self
{
Self::new_with_vol(volume)

View File

@@ -15,7 +15,6 @@ use spirv_std::macros::spirv;
mod adapt;
mod support;
use coremem_types::compound::Optional;
use coremem_types::mat::{Ferroxcube3R1MH, FullyGenericMaterial, IsoConductorOr};
use coremem_types::step::SimMeta;
use coremem_types::vec::{Vec3, Vec3u};
@@ -26,28 +25,26 @@ fn glam_vec_to_internal(v: glam::UVec3) -> Vec3u {
Vec3u::new(v.x, v.y, v.z)
}
/// Return the step_h/step_e entry point names for the provided material
pub fn entry_points<M: 'static>() -> Optional<(&'static str, &'static str)> {
use core::any::TypeId;
let mappings = [
(TypeId::of::<FullyGenericMaterial<f32>>(),
("step_h_generic_material", "step_e_generic_material")
),
(TypeId::of::<Iso3R1<f32>>(),
("step_h_iso_3r1", "step_e_iso_3r1")
),
];
mod private {
pub trait Sealed {}
}
for (id, names) in mappings {
if id == TypeId::of::<M>() {
return Optional::some(names);
}
}
Optional::none()
pub trait HasEntryPoints<R>: private::Sealed {
fn step_h() -> &'static str;
fn step_e() -> &'static str;
}
macro_rules! steps {
($flt:ty, $mat:ty, $step_h:ident, $step_e:ident) => {
impl private::Sealed for $mat { }
impl HasEntryPoints<$flt> for $mat {
fn step_h() -> &'static str {
stringify!($step_h)
}
fn step_e() -> &'static str {
stringify!($step_e)
}
}
// LocalSize/numthreads
#[spirv(compute(threads(4, 4, 4)))]
pub fn $step_h(