parameterize WgpuData over the M type
This commit is contained in:
@@ -2,12 +2,12 @@ use futures::FutureExt as _;
|
|||||||
use log::{info, warn};
|
use log::{info, warn};
|
||||||
use ndarray::Array3;
|
use ndarray::Array3;
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
|
use std::marker::PhantomData;
|
||||||
use std::num::NonZeroU64;
|
use std::num::NonZeroU64;
|
||||||
use wgpu;
|
use wgpu;
|
||||||
use wgpu::util::DeviceExt as _;
|
use wgpu::util::DeviceExt as _;
|
||||||
|
|
||||||
use coremem_types::vec::Vec3;
|
use coremem_types::vec::Vec3;
|
||||||
use coremem_types::mat::{FullyGenericMaterial, Material};
|
|
||||||
use coremem_types::step::SimMeta;
|
use coremem_types::step::SimMeta;
|
||||||
|
|
||||||
pub fn entry_points<L: 'static>() -> Option<(&'static str, &'static str)>
|
pub fn entry_points<L: 'static>() -> Option<(&'static str, &'static str)>
|
||||||
@@ -15,16 +15,17 @@ pub fn entry_points<L: 'static>() -> Option<(&'static str, &'static str)>
|
|||||||
spirv_backend::entry_points::<L>().into()
|
spirv_backend::entry_points::<L>().into()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) struct WgpuData {
|
pub(super) struct WgpuData<M: 'static> {
|
||||||
step_bind_group_layout: wgpu::BindGroupLayout,
|
step_bind_group_layout: wgpu::BindGroupLayout,
|
||||||
step_e_pipeline: wgpu::ComputePipeline,
|
step_e_pipeline: wgpu::ComputePipeline,
|
||||||
step_h_pipeline: wgpu::ComputePipeline,
|
step_h_pipeline: wgpu::ComputePipeline,
|
||||||
device: wgpu::Device,
|
device: wgpu::Device,
|
||||||
queue: wgpu::Queue,
|
queue: wgpu::Queue,
|
||||||
|
mat: PhantomData<&'static M>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WgpuData {
|
impl<M: 'static> WgpuData<M> {
|
||||||
pub fn new<M: 'static>(volume: u64) -> Self
|
pub fn new(volume: u64) -> Self
|
||||||
{
|
{
|
||||||
info!("WgpuData::new({})", volume);
|
info!("WgpuData::new({})", volume);
|
||||||
use std::mem::size_of;
|
use std::mem::size_of;
|
||||||
@@ -51,19 +52,20 @@ impl WgpuData {
|
|||||||
step_e_pipeline,
|
step_e_pipeline,
|
||||||
device,
|
device,
|
||||||
queue,
|
queue,
|
||||||
|
mat: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for WgpuData {
|
impl<M: 'static> Default for WgpuData<M> {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self::new::<FullyGenericMaterial<f32>>(0)
|
Self::new(0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl WgpuData {
|
impl<M: Send + Sync + 'static> WgpuData<M> {
|
||||||
pub(super) fn step_spirv<M>(
|
pub(super) fn step_spirv(
|
||||||
&self,
|
&self,
|
||||||
num_steps: u32,
|
num_steps: u32,
|
||||||
meta: SimMeta<f32>,
|
meta: SimMeta<f32>,
|
||||||
@@ -74,9 +76,7 @@ impl WgpuData {
|
|||||||
e: &mut [Vec3<f32>],
|
e: &mut [Vec3<f32>],
|
||||||
h: &mut [Vec3<f32>],
|
h: &mut [Vec3<f32>],
|
||||||
m: &mut [Vec3<f32>],
|
m: &mut [Vec3<f32>],
|
||||||
)
|
) {
|
||||||
where M: Send + Sync + Material<f32> + 'static
|
|
||||||
{
|
|
||||||
let field_bytes = meta.dim.product_sum() as usize * std::mem::size_of::<Vec3<f32>>();
|
let field_bytes = meta.dim.product_sum() as usize * std::mem::size_of::<Vec3<f32>>();
|
||||||
|
|
||||||
let device = &self.device;
|
let device = &self.device;
|
||||||
|
@@ -17,6 +17,7 @@ use gpu::WgpuData;
|
|||||||
/// Wrapper around an inner state object which offloads stepping onto a spirv backend (e.g. GPU).
|
/// Wrapper around an inner state object which offloads stepping onto a spirv backend (e.g. GPU).
|
||||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||||
pub struct SpirvSim<M=FullyGenericMaterial<f32>>
|
pub struct SpirvSim<M=FullyGenericMaterial<f32>>
|
||||||
|
where M: 'static
|
||||||
{
|
{
|
||||||
// TODO: make this generic over R
|
// TODO: make this generic over R
|
||||||
meta: SimMeta<f32>,
|
meta: SimMeta<f32>,
|
||||||
@@ -29,7 +30,7 @@ pub struct SpirvSim<M=FullyGenericMaterial<f32>>
|
|||||||
// XXX not confident that wgpu is actually properly synchronized for us to omit a Mutex here
|
// XXX not confident that wgpu is actually properly synchronized for us to omit a Mutex here
|
||||||
// though.
|
// though.
|
||||||
#[serde(skip)]
|
#[serde(skip)]
|
||||||
wgpu: Option<Arc<WgpuData>>,
|
wgpu: Option<Arc<WgpuData<M>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M> MaterialSim for SpirvSim<M>
|
impl<M> MaterialSim for SpirvSim<M>
|
||||||
@@ -94,7 +95,7 @@ where
|
|||||||
fn step_multiple<S: AbstractStimulus>(&mut self, num_steps: u32, stim: &S) {
|
fn step_multiple<S: AbstractStimulus>(&mut self, num_steps: u32, stim: &S) {
|
||||||
let vol = self.size().volume();
|
let vol = self.size().volume();
|
||||||
self.wgpu.get_or_insert_with(
|
self.wgpu.get_or_insert_with(
|
||||||
|| Arc::new(WgpuData::new::<M>(vol))
|
|| Arc::new(WgpuData::new(vol))
|
||||||
);
|
);
|
||||||
|
|
||||||
let (stim_e, stim_h) = self.eval_stimulus(stim);
|
let (stim_e, stim_h) = self.eval_stimulus(stim);
|
||||||
@@ -132,14 +133,14 @@ impl<M> SpirvSim<M>
|
|||||||
where M: Default + 'static
|
where M: Default + 'static
|
||||||
{
|
{
|
||||||
pub fn new(size: Index, feature_size: f32) -> Self {
|
pub fn new(size: Index, feature_size: f32) -> Self {
|
||||||
Self::new_with_wgpu_handle(size, feature_size, Some(Arc::new(WgpuData::new::<M>(size.volume()))))
|
Self::new_with_wgpu_handle(size, feature_size, Some(Arc::new(WgpuData::new(size.volume()))))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_no_wgpu(size: Index, feature_size: f32) -> Self {
|
pub fn new_no_wgpu(size: Index, feature_size: f32) -> Self {
|
||||||
Self::new_with_wgpu_handle(size, feature_size, None)
|
Self::new_with_wgpu_handle(size, feature_size, None)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_with_wgpu_handle(size: Index, feature_size: f32, wgpu: Option<Arc<WgpuData>>) -> Self {
|
fn new_with_wgpu_handle(size: Index, feature_size: f32, wgpu: Option<Arc<WgpuData<M>>>) -> Self {
|
||||||
info!("SpirvSim::new({:?}, {})", size, feature_size);
|
info!("SpirvSim::new({:?}, {})", size, feature_size);
|
||||||
let flat_size = size.volume() as usize;
|
let flat_size = size.volume() as usize;
|
||||||
if flat_size * std::mem::size_of::<M>() >= 0x40000000 {
|
if flat_size * std::mem::size_of::<M>() >= 0x40000000 {
|
||||||
|
Reference in New Issue
Block a user