parameterize WgpuData over the M type
This commit is contained in:
@@ -2,12 +2,12 @@ use futures::FutureExt as _;
|
||||
use log::{info, warn};
|
||||
use ndarray::Array3;
|
||||
use std::borrow::Cow;
|
||||
use std::marker::PhantomData;
|
||||
use std::num::NonZeroU64;
|
||||
use wgpu;
|
||||
use wgpu::util::DeviceExt as _;
|
||||
|
||||
use coremem_types::vec::Vec3;
|
||||
use coremem_types::mat::{FullyGenericMaterial, Material};
|
||||
use coremem_types::step::SimMeta;
|
||||
|
||||
pub fn entry_points<L: 'static>() -> Option<(&'static str, &'static str)>
|
||||
@@ -15,16 +15,17 @@ pub fn entry_points<L: 'static>() -> Option<(&'static str, &'static str)>
|
||||
spirv_backend::entry_points::<L>().into()
|
||||
}
|
||||
|
||||
pub(super) struct WgpuData {
|
||||
pub(super) struct WgpuData<M: 'static> {
|
||||
step_bind_group_layout: wgpu::BindGroupLayout,
|
||||
step_e_pipeline: wgpu::ComputePipeline,
|
||||
step_h_pipeline: wgpu::ComputePipeline,
|
||||
device: wgpu::Device,
|
||||
queue: wgpu::Queue,
|
||||
mat: PhantomData<&'static M>,
|
||||
}
|
||||
|
||||
impl WgpuData {
|
||||
pub fn new<M: 'static>(volume: u64) -> Self
|
||||
impl<M: 'static> WgpuData<M> {
|
||||
pub fn new(volume: u64) -> Self
|
||||
{
|
||||
info!("WgpuData::new({})", volume);
|
||||
use std::mem::size_of;
|
||||
@@ -51,19 +52,20 @@ impl WgpuData {
|
||||
step_e_pipeline,
|
||||
device,
|
||||
queue,
|
||||
mat: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WgpuData {
|
||||
impl<M: 'static> Default for WgpuData<M> {
|
||||
fn default() -> Self {
|
||||
Self::new::<FullyGenericMaterial<f32>>(0)
|
||||
Self::new(0)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl WgpuData {
|
||||
pub(super) fn step_spirv<M>(
|
||||
impl<M: Send + Sync + 'static> WgpuData<M> {
|
||||
pub(super) fn step_spirv(
|
||||
&self,
|
||||
num_steps: u32,
|
||||
meta: SimMeta<f32>,
|
||||
@@ -74,9 +76,7 @@ impl WgpuData {
|
||||
e: &mut [Vec3<f32>],
|
||||
h: &mut [Vec3<f32>],
|
||||
m: &mut [Vec3<f32>],
|
||||
)
|
||||
where M: Send + Sync + Material<f32> + 'static
|
||||
{
|
||||
) {
|
||||
let field_bytes = meta.dim.product_sum() as usize * std::mem::size_of::<Vec3<f32>>();
|
||||
|
||||
let device = &self.device;
|
||||
|
@@ -17,6 +17,7 @@ use gpu::WgpuData;
|
||||
/// Wrapper around an inner state object which offloads stepping onto a spirv backend (e.g. GPU).
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
pub struct SpirvSim<M=FullyGenericMaterial<f32>>
|
||||
where M: 'static
|
||||
{
|
||||
// TODO: make this generic over R
|
||||
meta: SimMeta<f32>,
|
||||
@@ -29,7 +30,7 @@ pub struct SpirvSim<M=FullyGenericMaterial<f32>>
|
||||
// XXX not confident that wgpu is actually properly synchronized for us to omit a Mutex here
|
||||
// though.
|
||||
#[serde(skip)]
|
||||
wgpu: Option<Arc<WgpuData>>,
|
||||
wgpu: Option<Arc<WgpuData<M>>>,
|
||||
}
|
||||
|
||||
impl<M> MaterialSim for SpirvSim<M>
|
||||
@@ -94,7 +95,7 @@ where
|
||||
fn step_multiple<S: AbstractStimulus>(&mut self, num_steps: u32, stim: &S) {
|
||||
let vol = self.size().volume();
|
||||
self.wgpu.get_or_insert_with(
|
||||
|| Arc::new(WgpuData::new::<M>(vol))
|
||||
|| Arc::new(WgpuData::new(vol))
|
||||
);
|
||||
|
||||
let (stim_e, stim_h) = self.eval_stimulus(stim);
|
||||
@@ -132,14 +133,14 @@ impl<M> SpirvSim<M>
|
||||
where M: Default + 'static
|
||||
{
|
||||
pub fn new(size: Index, feature_size: f32) -> Self {
|
||||
Self::new_with_wgpu_handle(size, feature_size, Some(Arc::new(WgpuData::new::<M>(size.volume()))))
|
||||
Self::new_with_wgpu_handle(size, feature_size, Some(Arc::new(WgpuData::new(size.volume()))))
|
||||
}
|
||||
|
||||
pub fn new_no_wgpu(size: Index, feature_size: f32) -> Self {
|
||||
Self::new_with_wgpu_handle(size, feature_size, None)
|
||||
}
|
||||
|
||||
fn new_with_wgpu_handle(size: Index, feature_size: f32, wgpu: Option<Arc<WgpuData>>) -> Self {
|
||||
fn new_with_wgpu_handle(size: Index, feature_size: f32, wgpu: Option<Arc<WgpuData<M>>>) -> Self {
|
||||
info!("SpirvSim::new({:?}, {})", size, feature_size);
|
||||
let flat_size = size.volume() as usize;
|
||||
if flat_size * std::mem::size_of::<M>() >= 0x40000000 {
|
||||
|
Reference in New Issue
Block a user