diff --git a/tools/clang/lib/SPIRV/AlignmentSizeCalculator.cpp b/tools/clang/lib/SPIRV/AlignmentSizeCalculator.cpp index 318827cdb7..492640c493 100644 --- a/tools/clang/lib/SPIRV/AlignmentSizeCalculator.cpp +++ b/tools/clang/lib/SPIRV/AlignmentSizeCalculator.cpp @@ -22,6 +22,15 @@ inline uint32_t roundToPow2(uint32_t val, uint32_t pow2) { return (val + pow2 - 1) & ~(pow2 - 1); } +/// Returns the smallest value greater than or equal to |val| that is a multiple +/// of |multiple|. +inline uint32_t roundToMultiple(uint32_t val, uint32_t multiple) { + if (val == 0) + return 0; + uint32_t t = (val - 1) / multiple; + return (multiple * (t + 1)); +} + /// Returns true if the given vector type (of the given size) crosses the /// 4-component vector boundary if placed at the given offset. bool improperStraddle(clang::QualType type, int size, int offset) { @@ -411,7 +420,7 @@ std::pair AlignmentSizeCalculator::getAlignmentAndSize( if (rule == SpirvLayoutRule::FxcSBuffer || rule == SpirvLayoutRule::Scalar) { - *stride = size; + *stride = roundToMultiple(size, alignment); // Use element alignment for fxc structured buffers and // VK_EXT_scalar_block_layout return {alignment, size * elemCount}; diff --git a/tools/clang/test/CodeGenSPIRV/array.scalar.layout.hlsl b/tools/clang/test/CodeGenSPIRV/array.scalar.layout.hlsl new file mode 100644 index 0000000000..44c05bf6c1 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/array.scalar.layout.hlsl @@ -0,0 +1,21 @@ +// RUN: %dxc -T cs_6_2 -E main %s -fvk-use-scalar-layout -spirv | FileCheck %s + +// Check that the array stride and offsets are corrects. The uint64_t has alignment +// 8 and the struct has size 12. So the stride should be the smallest multiple of 8 +// greater than or equal to 12, which is 16. + +// CHECK-DAG: OpMemberDecorate %Data 0 Offset 0 +// CHECK-DAG: OpMemberDecorate %Data 1 Offset 8 +// CHECK-DAG: OpDecorate %_runtimearr_Data ArrayStride 16 +// CHECK-DAG: OpMemberDecorate %type_RWStructuredBuffer_Data 0 Offset 0 +struct Data { + uint64_t y; + uint x; +}; +RWStructuredBuffer buffer; + +[numthreads(1, 1, 1)] +void main() +{ + buffer[0].x = 5; +}