From 080d99b5bf905dd002049f131bb7a462b2f1fef2 Mon Sep 17 00:00:00 2001 From: David Ellis Date: Tue, 5 Nov 2024 07:21:33 -0600 Subject: [PATCH] Fix Issue 6467: Prevent MacOS crash on invalid workgroup size definition (#6494) --- tests/tests/regression/issue_6467.rs | 75 ++++++++++++++++++++++++++++ tests/tests/root.rs | 1 + wgpu-hal/src/metal/command.rs | 16 +++--- 3 files changed, 85 insertions(+), 7 deletions(-) create mode 100644 tests/tests/regression/issue_6467.rs diff --git a/tests/tests/regression/issue_6467.rs b/tests/tests/regression/issue_6467.rs new file mode 100644 index 0000000000..da40a791d3 --- /dev/null +++ b/tests/tests/regression/issue_6467.rs @@ -0,0 +1,75 @@ +use wgpu::util::DeviceExt; +use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters}; + +/// Running a compute shader with one or more of the workgroup sizes set to 0 implies that no work +/// should be done, and is a user error. Vulkan and DX12 accept this invalid input with grace, but +/// Metal does not guard against this and eventually the machine will crash. Since this is a public +/// API that may be given untrusted values in a browser, this must be protected again. +/// +/// The following test should successfully do nothing on all platforms. +#[gpu_test] +static ZERO_WORKGROUP_SIZE: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters(TestParameters::default().limits(wgpu::Limits::default())) + .run_async(|ctx| async move { + let module = ctx + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed( + " + @group(0) + @binding(0) + var vals: array; + + @compute + @workgroup_size(1) + fn main(@builtin(global_invocation_id) id: vec3u) { + vals[id.x] = vals[id.x] * i32(id.x); + } + ", + )), + }); + let compute_pipeline = + ctx.device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: Some("main"), + compilation_options: wgpu::PipelineCompilationOptions::default(), + cache: None, + }); + let buffer = DeviceExt::create_buffer_init( + &ctx.device, + &wgpu::util::BufferInitDescriptor { + label: None, + contents: &[1, 1, 1, 1, 1, 1, 1, 1], + usage: wgpu::BufferUsages::STORAGE + | wgpu::BufferUsages::COPY_DST + | wgpu::BufferUsages::COPY_SRC, + }, + ); + let mut encoder = ctx + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + cpass.set_pipeline(&compute_pipeline); + let bind_group_layout = compute_pipeline.get_bind_group_layout(0); + let bind_group_entries = [wgpu::BindGroupEntry { + binding: 0, + resource: buffer.as_entire_binding(), + }]; + let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &bind_group_layout, + entries: &bind_group_entries, + }); + cpass.set_bind_group(0, &bind_group, &[]); + cpass.dispatch_workgroups(1, 0, 1); + } + ctx.queue.submit(Some(encoder.finish())); + }); diff --git a/tests/tests/root.rs b/tests/tests/root.rs index 7312b3c2ec..5b70d2053f 100644 --- a/tests/tests/root.rs +++ b/tests/tests/root.rs @@ -7,6 +7,7 @@ mod regression { mod issue_4514; mod issue_5553; mod issue_6317; + mod issue_6467; } mod bgra8unorm_storage; diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 9e38cf8656..f113639a13 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -1240,13 +1240,15 @@ impl crate::CommandEncoder for super::CommandEncoder { } unsafe fn dispatch(&mut self, count: [u32; 3]) { - let encoder = self.state.compute.as_ref().unwrap(); - let raw_count = metal::MTLSize { - width: count[0] as u64, - height: count[1] as u64, - depth: count[2] as u64, - }; - encoder.dispatch_thread_groups(raw_count, self.state.raw_wg_size); + if count[0] > 0 && count[1] > 0 && count[2] > 0 { + let encoder = self.state.compute.as_ref().unwrap(); + let raw_count = metal::MTLSize { + width: count[0] as u64, + height: count[1] as u64, + depth: count[2] as u64, + }; + encoder.dispatch_thread_groups(raw_count, self.state.raw_wg_size); + } } unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) {