parameterize WgpuData over the M type

This commit is contained in:
2022-07-25 13:15:41 -07:00
parent 0801a0dca3
commit ff1d9867ab
2 changed files with 16 additions and 15 deletions

View File

@@ -2,12 +2,12 @@ use futures::FutureExt as _;
use log::{info, warn};
use ndarray::Array3;
use std::borrow::Cow;
use std::marker::PhantomData;
use std::num::NonZeroU64;
use wgpu;
use wgpu::util::DeviceExt as _;
use coremem_types::vec::Vec3;
use coremem_types::mat::{FullyGenericMaterial, Material};
use coremem_types::step::SimMeta;
pub fn entry_points<L: 'static>() -> Option<(&'static str, &'static str)>
@@ -15,16 +15,17 @@ pub fn entry_points<L: 'static>() -> Option<(&'static str, &'static str)>
spirv_backend::entry_points::<L>().into()
}
pub(super) struct WgpuData {
pub(super) struct WgpuData<M: 'static> {
step_bind_group_layout: wgpu::BindGroupLayout,
step_e_pipeline: wgpu::ComputePipeline,
step_h_pipeline: wgpu::ComputePipeline,
device: wgpu::Device,
queue: wgpu::Queue,
mat: PhantomData<&'static M>,
}
impl WgpuData {
pub fn new<M: 'static>(volume: u64) -> Self
impl<M: 'static> WgpuData<M> {
pub fn new(volume: u64) -> Self
{
info!("WgpuData::new({})", volume);
use std::mem::size_of;
@@ -51,19 +52,20 @@ impl WgpuData {
step_e_pipeline,
device,
queue,
mat: PhantomData,
}
}
}
impl Default for WgpuData {
impl<M: 'static> Default for WgpuData<M> {
fn default() -> Self {
Self::new::<FullyGenericMaterial<f32>>(0)
Self::new(0)
}
}
impl WgpuData {
pub(super) fn step_spirv<M>(
impl<M: Send + Sync + 'static> WgpuData<M> {
pub(super) fn step_spirv(
&self,
num_steps: u32,
meta: SimMeta<f32>,
@@ -74,9 +76,7 @@ impl WgpuData {
e: &mut [Vec3<f32>],
h: &mut [Vec3<f32>],
m: &mut [Vec3<f32>],
)
where M: Send + Sync + Material<f32> + 'static
{
) {
let field_bytes = meta.dim.product_sum() as usize * std::mem::size_of::<Vec3<f32>>();
let device = &self.device;

View File

@@ -17,6 +17,7 @@ use gpu::WgpuData;
/// Wrapper around an inner state object which offloads stepping onto a spirv backend (e.g. GPU).
#[derive(Clone, Default, Serialize, Deserialize)]
pub struct SpirvSim<M=FullyGenericMaterial<f32>>
where M: 'static
{
// TODO: make this generic over R
meta: SimMeta<f32>,
@@ -29,7 +30,7 @@ pub struct SpirvSim<M=FullyGenericMaterial<f32>>
// XXX not confident that wgpu is actually properly synchronized for us to omit a Mutex here
// though.
#[serde(skip)]
wgpu: Option<Arc<WgpuData>>,
wgpu: Option<Arc<WgpuData<M>>>,
}
impl<M> MaterialSim for SpirvSim<M>
@@ -94,7 +95,7 @@ where
fn step_multiple<S: AbstractStimulus>(&mut self, num_steps: u32, stim: &S) {
let vol = self.size().volume();
self.wgpu.get_or_insert_with(
|| Arc::new(WgpuData::new::<M>(vol))
|| Arc::new(WgpuData::new(vol))
);
let (stim_e, stim_h) = self.eval_stimulus(stim);
@@ -132,14 +133,14 @@ impl<M> SpirvSim<M>
where M: Default + 'static
{
pub fn new(size: Index, feature_size: f32) -> Self {
Self::new_with_wgpu_handle(size, feature_size, Some(Arc::new(WgpuData::new::<M>(size.volume()))))
Self::new_with_wgpu_handle(size, feature_size, Some(Arc::new(WgpuData::new(size.volume()))))
}
pub fn new_no_wgpu(size: Index, feature_size: f32) -> Self {
Self::new_with_wgpu_handle(size, feature_size, None)
}
fn new_with_wgpu_handle(size: Index, feature_size: f32, wgpu: Option<Arc<WgpuData>>) -> Self {
fn new_with_wgpu_handle(size: Index, feature_size: f32, wgpu: Option<Arc<WgpuData<M>>>) -> Self {
info!("SpirvSim::new({:?}, {})", size, feature_size);
let flat_size = size.volume() as usize;
if flat_size * std::mem::size_of::<M>() >= 0x40000000 {