spirv: remove the Optionality around entry points: compute them statically with traits
This commit is contained in:
@@ -7,7 +7,7 @@ use crate::sim::{GenericSim, MaterialSim, SampleableSim};
|
||||
use crate::sim::legacy::{self, SimState};
|
||||
use crate::sim::legacy::mat::Pml;
|
||||
use crate::sim::units::{Frame, Time};
|
||||
use crate::sim::spirv::SpirvSim;
|
||||
use crate::sim::spirv::{self, SpirvSim};
|
||||
use crate::stim::AbstractStimulus;
|
||||
use crate::types::vec::Vec3;
|
||||
|
||||
@@ -47,7 +47,9 @@ impl<R: Real, M: Default> Driver<SimState<R, M>> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Real, M: Default + Send + Sync + 'static> SpirvDriver<R, M>
|
||||
impl<R: Real, M: Default + Send + Sync> SpirvDriver<R, M>
|
||||
where
|
||||
spirv::WgpuBackend<R, M>: spirv::SimBackend<R, M>
|
||||
{
|
||||
pub fn new_spirv<C: Coord>(size: C, feature_size: f32) -> Self {
|
||||
Self::new_with_state(SpirvSim::new(size.to_index(feature_size), feature_size))
|
||||
|
@@ -9,30 +9,27 @@ use wgpu::util::DeviceExt as _;
|
||||
use coremem_types::vec::Vec3;
|
||||
use coremem_types::step::SimMeta;
|
||||
|
||||
use spirv_backend::HasEntryPoints;
|
||||
|
||||
use super::SimBackend;
|
||||
|
||||
pub fn entry_points<L: 'static>() -> Option<(&'static str, &'static str)>
|
||||
{
|
||||
spirv_backend::entry_points::<L>().into()
|
||||
}
|
||||
|
||||
pub struct WgpuBackend<R, M: 'static> {
|
||||
pub struct WgpuBackend<R, M> {
|
||||
step_bind_group_layout: wgpu::BindGroupLayout,
|
||||
step_e_pipeline: wgpu::ComputePipeline,
|
||||
step_h_pipeline: wgpu::ComputePipeline,
|
||||
device: wgpu::Device,
|
||||
queue: wgpu::Queue,
|
||||
real: PhantomData<R>,
|
||||
mat: PhantomData<&'static M>,
|
||||
mat: PhantomData<M>,
|
||||
}
|
||||
|
||||
impl<R, M: 'static> Default for WgpuBackend<R, M> {
|
||||
impl<R, M: HasEntryPoints<R>> Default for WgpuBackend<R, M> {
|
||||
fn default() -> Self {
|
||||
Self::new_with_vol(0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, M: 'static> WgpuBackend<R, M> {
|
||||
impl<R, M: HasEntryPoints<R>> WgpuBackend<R, M> {
|
||||
fn new_with_vol(volume: u64) -> Self {
|
||||
info!("WgpuBackend::new_with_vol({})", volume);
|
||||
use std::mem::size_of;
|
||||
@@ -40,7 +37,7 @@ impl<R, M: 'static> WgpuBackend<R, M> {
|
||||
let max_array_size = volume * max_elem_size as u64;
|
||||
let max_buf_size = max_array_size + 0x1000; // allow some overhead
|
||||
|
||||
let entry_names = entry_points::<M>().unwrap_or(("invalid_mat", "invalid_mat"));
|
||||
let entry_names = (M::step_h(), M::step_e());
|
||||
let (device, queue) = futures::executor::block_on(open_device(max_buf_size));
|
||||
let shader_binary = get_shader();
|
||||
let shader_module = unsafe { device.create_shader_module_spirv(&shader_binary) };
|
||||
@@ -57,7 +54,7 @@ impl<R, M: 'static> WgpuBackend<R, M> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Copy, M: Send + Sync + 'static> SimBackend<R, M> for WgpuBackend<R, M> {
|
||||
impl<R: Copy, M: Send + Sync + HasEntryPoints<R>> SimBackend<R, M> for WgpuBackend<R, M> {
|
||||
fn new(volume: u64) -> Self
|
||||
{
|
||||
Self::new_with_vol(volume)
|
||||
|
@@ -15,7 +15,6 @@ use spirv_std::macros::spirv;
|
||||
mod adapt;
|
||||
mod support;
|
||||
|
||||
use coremem_types::compound::Optional;
|
||||
use coremem_types::mat::{Ferroxcube3R1MH, FullyGenericMaterial, IsoConductorOr};
|
||||
use coremem_types::step::SimMeta;
|
||||
use coremem_types::vec::{Vec3, Vec3u};
|
||||
@@ -26,28 +25,26 @@ fn glam_vec_to_internal(v: glam::UVec3) -> Vec3u {
|
||||
Vec3u::new(v.x, v.y, v.z)
|
||||
}
|
||||
|
||||
/// Return the step_h/step_e entry point names for the provided material
|
||||
pub fn entry_points<M: 'static>() -> Optional<(&'static str, &'static str)> {
|
||||
use core::any::TypeId;
|
||||
let mappings = [
|
||||
(TypeId::of::<FullyGenericMaterial<f32>>(),
|
||||
("step_h_generic_material", "step_e_generic_material")
|
||||
),
|
||||
(TypeId::of::<Iso3R1<f32>>(),
|
||||
("step_h_iso_3r1", "step_e_iso_3r1")
|
||||
),
|
||||
];
|
||||
mod private {
|
||||
pub trait Sealed {}
|
||||
}
|
||||
|
||||
for (id, names) in mappings {
|
||||
if id == TypeId::of::<M>() {
|
||||
return Optional::some(names);
|
||||
}
|
||||
}
|
||||
Optional::none()
|
||||
pub trait HasEntryPoints<R>: private::Sealed {
|
||||
fn step_h() -> &'static str;
|
||||
fn step_e() -> &'static str;
|
||||
}
|
||||
|
||||
macro_rules! steps {
|
||||
($flt:ty, $mat:ty, $step_h:ident, $step_e:ident) => {
|
||||
impl private::Sealed for $mat { }
|
||||
impl HasEntryPoints<$flt> for $mat {
|
||||
fn step_h() -> &'static str {
|
||||
stringify!($step_h)
|
||||
}
|
||||
fn step_e() -> &'static str {
|
||||
stringify!($step_e)
|
||||
}
|
||||
}
|
||||
// LocalSize/numthreads
|
||||
#[spirv(compute(threads(4, 4, 4)))]
|
||||
pub fn $step_h(
|
||||
|
Reference in New Issue
Block a user