Skip to content

Commit

Permalink
Override-expressions for Fixed Size Arrays (#6654)
Browse files Browse the repository at this point in the history
  • Loading branch information
kentslaney authored Dec 10, 2024
1 parent 8b93a71 commit 5bec461
Show file tree
Hide file tree
Showing 28 changed files with 557 additions and 103 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456), [#6148]
- Allow for override-expressions in `workgroup_size`. By @KentSlaney in [#6635](https://github.com/gfx-rs/wgpu/pull/6635).
- Add support for OpAtomicCompareExchange in SPIR-V frontend. By @schell in [#6590](https://github.com/gfx-rs/wgpu/pull/6590).
- Implement type inference for abstract arguments to user-defined functions. By @jamienicol in [#6577](https://github.com/gfx-rs/wgpu/pull/6577).
- Allow for override-expressions in array sizes. By @KentSlaney in [#6654](https://github.com/gfx-rs/wgpu/pull/6654).

#### General

Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,7 @@ impl<'a, W: Write> Writer<'a, W> {
crate::ArraySize::Constant(size) => {
write!(self.out, "{size}")?;
}
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => (),
}

Expand Down Expand Up @@ -4459,6 +4460,7 @@ impl<'a, W: Write> Writer<'a, W> {
.expect("Bad array size")
{
proc::IndexableLength::Known(count) => count,
proc::IndexableLength::Pending => unreachable!(),
proc::IndexableLength::Dynamic => return Ok(()),
};
self.write_type(base)?;
Expand Down
1 change: 1 addition & 0 deletions naga/src/back/hlsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ impl crate::TypeInner {
let count = match size {
crate::ArraySize::Constant(size) => size.get(),
// A dynamically-sized array has to have at least one element
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => 1,
};
let last_el_size = gctx.types[base].inner.size_hlsl(gctx);
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
crate::ArraySize::Constant(size) => {
write!(self.out, "{size}")?;
}
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => unreachable!(),
}

Expand Down Expand Up @@ -2634,6 +2635,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
index::IndexableLength::Known(limit) => {
write!(self.out, "{}u", limit - 1)?;
}
index::IndexableLength::Pending => unreachable!(),
index::IndexableLength::Dynamic => unreachable!(),
}
write!(self.out, ")")?;
Expand Down
6 changes: 6 additions & 0 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2407,6 +2407,7 @@ impl<W: Write> Writer<W> {
self.out.write_str(") < ")?;
match length {
index::IndexableLength::Known(value) => write!(self.out, "{value}")?,
index::IndexableLength::Pending => unreachable!(),
index::IndexableLength::Dynamic => {
let global =
context.function.originating_global(base).ok_or_else(|| {
Expand Down Expand Up @@ -2569,6 +2570,7 @@ impl<W: Write> Writer<W> {
index::IndexableLength::Known(limit) => {
write!(self.out, "{}u", limit - 1)?;
}
index::IndexableLength::Pending => unreachable!(),
index::IndexableLength::Dynamic => {
let global = context.function.originating_global(base).ok_or_else(|| {
Error::GenericValidation("Could not find originating global".into())
Expand Down Expand Up @@ -3740,6 +3742,9 @@ impl<W: Write> Writer<W> {
)?;
writeln!(self.out, "}};")?;
}
crate::ArraySize::Pending(_) => {
unreachable!()
}
crate::ArraySize::Dynamic => {
writeln!(self.out, "typedef {base_name} {name}[1];")?;
}
Expand Down Expand Up @@ -6008,6 +6013,7 @@ mod workgroup_mem_init {
let count = match size.to_indexable_length(module).expect("Bad array size")
{
proc::IndexableLength::Known(count) => count,
proc::IndexableLength::Pending => unreachable!(),
proc::IndexableLength::Dynamic => unreachable!(),
};

Expand Down
60 changes: 60 additions & 0 deletions naga/src/back/pipeline_constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ pub fn process_overrides<'a>(
}
module.entry_points = entry_points;

process_pending(&mut module, &override_map, &adjusted_global_expressions)?;

// Now that we've rewritten all the expressions, we need to
// recompute their types and other metadata. For the time being,
// do a full re-validation.
Expand All @@ -205,6 +207,64 @@ pub fn process_overrides<'a>(
Ok((Cow::Owned(module), Cow::Owned(module_info)))
}

fn process_pending(
module: &mut Module,
override_map: &HandleVec<Override, Handle<Constant>>,
adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
) -> Result<(), PipelineConstantError> {
for (handle, ty) in module.types.clone().iter() {
if let crate::TypeInner::Array {
base,
size: crate::ArraySize::Pending(size),
stride,
} = ty.inner
{
let expr = match size {
crate::PendingArraySize::Expression(size_expr) => {
adjusted_global_expressions[size_expr]
}
crate::PendingArraySize::Override(size_override) => {
module.constants[override_map[size_override]].init
}
};
let value = module
.to_ctx()
.eval_expr_to_u32(expr)
.map(|n| {
if n == 0 {
Err(PipelineConstantError::ValidationError(
WithSpan::new(ValidationError::ArraySizeError { handle: expr })
.with_span(
module.global_expressions.get_span(expr),
"evaluated to zero",
),
))
} else {
Ok(std::num::NonZeroU32::new(n).unwrap())
}
})
.map_err(|_| {
PipelineConstantError::ValidationError(
WithSpan::new(ValidationError::ArraySizeError { handle: expr })
.with_span(module.global_expressions.get_span(expr), "negative"),
)
})??;
module.types.replace(
handle,
crate::Type {
name: None,
inner: crate::TypeInner::Array {
base,
size: crate::ArraySize::Constant(value),
stride,
},
},
);
}
}
Ok(())
}

fn process_workgroup_size_override(
module: &mut Module,
adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
Expand Down
3 changes: 3 additions & 0 deletions naga/src/back/spv/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ impl BlockContext<'_> {
Ok(crate::proc::IndexableLength::Known(known_length)) => {
Ok(MaybeKnown::Known(known_length))
}
Ok(crate::proc::IndexableLength::Pending) => {
unreachable!()
}
Ok(crate::proc::IndexableLength::Dynamic) => {
let length_id = self.write_runtime_array_length(sequence, block)?;
Ok(MaybeKnown::Computed(length_id))
Expand Down
2 changes: 2 additions & 0 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,7 @@ impl Writer {
let length_id = self.get_index_constant(length.get());
Instruction::type_array(id, type_id, length_id)
}
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id),
}
}
Expand All @@ -981,6 +982,7 @@ impl Writer {
let length_id = self.get_index_constant(length.get());
Instruction::type_array(id, type_id, length_id)
}
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => Instruction::type_runtime_array(id, type_id),
}
}
Expand Down
6 changes: 6 additions & 0 deletions naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,9 @@ impl<W: Write> Writer<W> {
self.write_type(module, base)?;
write!(self.out, ", {len}")?;
}
crate::ArraySize::Pending(_) => {
unreachable!();
}
crate::ArraySize::Dynamic => {
self.write_type(module, base)?;
}
Expand All @@ -534,6 +537,9 @@ impl<W: Write> Writer<W> {
self.write_type(module, base)?;
write!(self.out, ", {len}")?;
}
crate::ArraySize::Pending(_) => {
unreachable!();
}
crate::ArraySize::Dynamic => {
self.write_type(module, base)?;
}
Expand Down
34 changes: 34 additions & 0 deletions naga/src/compact/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ pub fn compact(module: &mut crate::Module) {
}
}

for (_, ty) in module.types.iter() {
if let crate::TypeInner::Array {
size: crate::ArraySize::Pending(crate::PendingArraySize::Expression(size_expr)),
..
} = ty.inner
{
module_tracer.global_expressions_used.insert(size_expr);
}
}

for e in module.entry_points.iter() {
if let Some(sizes) = e.workgroup_size_overrides {
for size in sizes.iter().filter_map(|x| *x) {
Expand Down Expand Up @@ -206,6 +216,30 @@ pub fn compact(module: &mut crate::Module) {
}
}

for (handle, ty) in module.types.clone().iter() {
if let crate::TypeInner::Array {
base,
size: crate::ArraySize::Pending(crate::PendingArraySize::Expression(mut size_expr)),
stride,
} = ty.inner
{
module_map.global_expressions.adjust(&mut size_expr);
module.types.replace(
handle,
crate::Type {
name: None,
inner: crate::TypeInner::Array {
base,
size: crate::ArraySize::Pending(crate::PendingArraySize::Expression(
size_expr,
)),
stride,
},
},
);
}
}

// Temporary storage to help us reuse allocations of existing
// named expression tables.
let mut reused_named_expressions = crate::NamedExpressions::default();
Expand Down
1 change: 1 addition & 0 deletions naga/src/front/glsl/offset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ pub fn calculate_offset(

let span = match size {
crate::ArraySize::Constant(size) => size.get() * stride,
crate::ArraySize::Pending(_) => unreachable!(),
crate::ArraySize::Dynamic => stride,
};

Expand Down
3 changes: 3 additions & 0 deletions naga/src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
crate::TypeInner::Array { size, .. } => {
let size = match size {
crate::ArraySize::Constant(size) => size.get(),
crate::ArraySize::Pending(_) => {
unreachable!();
}
// A runtime sized array is not a composite type
crate::ArraySize::Dynamic => {
return Err(Error::InvalidAccessType(root_type_id))
Expand Down
71 changes: 57 additions & 14 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3056,26 +3056,69 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
Ok(match size {
ast::ArraySize::Constant(expr) => {
let span = ctx.ast_expressions.get_span(expr);
let const_expr = self.expression(expr, &mut ctx.as_const())?;
let len =
ctx.module
.to_ctx()
.eval_expr_to_u32(const_expr)
.map_err(|err| match err {
crate::proc::U32EvalError::NonConst => {
Error::ExpectedConstExprConcreteIntegerScalar(span)
}
crate::proc::U32EvalError::Negative => {
Error::ExpectedPositiveArrayLength(span)
let const_expr = self.expression(expr, &mut ctx.as_const());
match const_expr {
Ok(value) => {
let len =
ctx.module.to_ctx().eval_expr_to_u32(value).map_err(
|err| match err {
crate::proc::U32EvalError::NonConst => {
Error::ExpectedConstExprConcreteIntegerScalar(span)
}
crate::proc::U32EvalError::Negative => {
Error::ExpectedPositiveArrayLength(span)
}
},
)?;
let size =
NonZeroU32::new(len).ok_or(Error::ExpectedPositiveArrayLength(span))?;
crate::ArraySize::Constant(size)
}
err => {
if let Err(Error::ConstantEvaluatorError(ref ty, _)) = err {
match **ty {
crate::proc::ConstantEvaluatorError::OverrideExpr => {
crate::ArraySize::Pending(self.array_size_override(
expr,
&mut ctx.as_override(),
span,
)?)
}
_ => {
err?;
unreachable!()
}
}
})?;
let size = NonZeroU32::new(len).ok_or(Error::ExpectedPositiveArrayLength(span))?;
crate::ArraySize::Constant(size)
} else {
err?;
unreachable!()
}
}
}
}
ast::ArraySize::Dynamic => crate::ArraySize::Dynamic,
})
}

fn array_size_override(
&mut self,
size_expr: Handle<ast::Expression<'source>>,
ctx: &mut ExpressionContext<'source, '_, '_>,
span: Span,
) -> Result<crate::PendingArraySize, Error<'source>> {
let expr = self.expression(size_expr, ctx)?;
match resolve_inner!(ctx, expr).scalar_kind().ok_or(0) {
Ok(crate::ScalarKind::Sint) | Ok(crate::ScalarKind::Uint) => Ok({
if let crate::Expression::Override(handle) = ctx.module.global_expressions[expr] {
crate::PendingArraySize::Override(handle)
} else {
crate::PendingArraySize::Expression(expr)
}
}),
_ => Err(Error::ExpectedConstExprConcreteIntegerScalar(span)),
}
}

/// Build the Naga equivalent of a named AST type.
///
/// Return a Naga `Handle<Type>` representing the front-end type
Expand Down
Loading

0 comments on commit 5bec461

Please sign in to comment.