diff --git a/docs/DXIL.rst b/docs/DXIL.rst index 6cca862d64..c3baf4e454 100644 --- a/docs/DXIL.rst +++ b/docs/DXIL.rst @@ -2337,18 +2337,18 @@ ID Name Description 223 TextureGatherRaw Gather raw elements from 4 texels with no type conversions (SRV type is constrained) 224 SampleCmpLevel samples a texture and compares a single component against the specified comparison value 225 TextureStoreSample stores texel data at specified sample index -226 WaveMatrix_Annotate Annotate a wave matrix pointer with the type information -227 WaveMatrix_Depth Returns depth (K) value for matrix of specified type -228 WaveMatrix_Fill Fill wave matrix with scalar value -229 WaveMatrix_LoadRawBuf Load wave matrix from raw buffer -230 WaveMatrix_LoadGroupShared Load wave matrix from group shared array -231 WaveMatrix_StoreRawBuf Store wave matrix to raw buffer -232 WaveMatrix_StoreGroupShared Store wave matrix to group shared array -233 WaveMatrix_Multiply Mutiply left and right wave matrix and store in accumulator -234 WaveMatrix_MultiplyAccumulate Mutiply left and right wave matrix and accumulate into accumulator -235 WaveMatrix_ScalarOp Perform scalar operation on each element of wave matrix -236 WaveMatrix_SumAccumulate Sum rows or columns of an input matrix into an existing accumulator fragment matrix -237 WaveMatrix_Add Element-wise accumulate, or broadcast add of fragment into accumulator +226 Reserved0 Reserved +227 Reserved1 Reserved +228 Reserved2 Reserved +229 Reserved3 Reserved +230 Reserved4 Reserved +231 Reserved5 Reserved +232 Reserved6 Reserved +233 Reserved7 Reserved +234 Reserved8 Reserved +235 Reserved9 Reserved +236 Reserved10 Reserved +237 Reserved11 Reserved 238 AllocateNodeOutputRecords returns a handle for the output records 239 GetNodeRecordPtr retrieve node input/output record pointer in address space 6 240 IncrementOutputCount Select the next logical output count for an EmptyNodeOutput for the whole group or per thread. diff --git a/include/dxc/DXIL/DxilConstants.h b/include/dxc/DXIL/DxilConstants.h index 85e6fc1c15..2449915754 100644 --- a/include/dxc/DXIL/DxilConstants.h +++ b/include/dxc/DXIL/DxilConstants.h @@ -474,6 +474,20 @@ inline bool IsFeedbackTexture(DXIL::ResourceKind ResourceKind) { // OPCODE-ENUM:BEGIN // Enumeration for operations specified by DXIL enum class OpCode : unsigned { + // + Reserved0 = 226, // Reserved + Reserved1 = 227, // Reserved + Reserved10 = 236, // Reserved + Reserved11 = 237, // Reserved + Reserved2 = 228, // Reserved + Reserved3 = 229, // Reserved + Reserved4 = 230, // Reserved + Reserved5 = 231, // Reserved + Reserved6 = 232, // Reserved + Reserved7 = 233, // Reserved + Reserved8 = 234, // Reserved + Reserved9 = 235, // Reserved + // Amplification shader instructions DispatchMesh = 173, // Amplification shader intrinsic DispatchMesh @@ -946,27 +960,6 @@ enum class OpCode : unsigned { WaveReadLaneAt = 117, // returns the value from the specified lane WaveReadLaneFirst = 118, // returns the value from the first lane - // WaveMatrix - WaveMatrix_Add = 237, // Element-wise accumulate, or broadcast add of fragment - // into accumulator - WaveMatrix_Annotate = - 226, // Annotate a wave matrix pointer with the type information - WaveMatrix_Depth = - 227, // Returns depth (K) value for matrix of specified type - WaveMatrix_Fill = 228, // Fill wave matrix with scalar value - WaveMatrix_LoadGroupShared = 230, // Load wave matrix from group shared array - WaveMatrix_LoadRawBuf = 229, // Load wave matrix from raw buffer - WaveMatrix_Multiply = - 233, // Mutiply left and right wave matrix and store in accumulator - WaveMatrix_MultiplyAccumulate = - 234, // Mutiply left and right wave matrix and accumulate into accumulator - WaveMatrix_ScalarOp = - 235, // Perform scalar operation on each element of wave matrix - WaveMatrix_StoreGroupShared = 232, // Store wave matrix to group shared array - WaveMatrix_StoreRawBuf = 231, // Store wave matrix to raw buffer - WaveMatrix_SumAccumulate = 236, // Sum rows or columns of an input matrix into - // an existing accumulator fragment matrix - // Work Graph intrinsics FinishedCrossGroupSharing = 243, // returns true if the current thread group // is the last to access the input @@ -1003,6 +996,9 @@ enum class OpCode : unsigned { // OPCODECLASS-ENUM:BEGIN // Groups for DXIL operations with equivalent function templates enum class OpCodeClass : unsigned { + // + Reserved, + // Amplification shader instructions DispatchMesh, @@ -1278,18 +1274,6 @@ enum class OpCodeClass : unsigned { WaveReadLaneAt, WaveReadLaneFirst, - // WaveMatrix - WaveMatrix_Accumulate, - WaveMatrix_Annotate, - WaveMatrix_Depth, - WaveMatrix_Fill, - WaveMatrix_LoadGroupShared, - WaveMatrix_LoadRawBuf, - WaveMatrix_Multiply, - WaveMatrix_ScalarOp, - WaveMatrix_StoreGroupShared, - WaveMatrix_StoreRawBuf, - // Work Graph intrinsics FinishedCrossGroupSharing, GetInputRecordCount, @@ -1306,9 +1290,9 @@ enum class OpCodeClass : unsigned { NumOpClasses_Dxil_1_5 = 143, NumOpClasses_Dxil_1_6 = 149, NumOpClasses_Dxil_1_7 = 153, - NumOpClasses_Dxil_1_8 = 183, + NumOpClasses_Dxil_1_8 = 174, - NumOpClasses = 183 // exclusive last value of enumeration + NumOpClasses = 174 // exclusive last value of enumeration }; // OPCODECLASS-ENUM:END @@ -1817,29 +1801,6 @@ enum class SamplerFeedbackType : uint8_t { LastEntry = 2 }; -enum class WaveMatrixKind : uint8_t { - Left = 0, - Right = 1, - LeftColAcc = 2, - RightRowAcc = 3, - Accumulator = 4, - NumKinds = 5, - MaskSide = 1, - MaskClass = 6, // 0 = Left/Right, 2 = Fragment, 4 = Accumulator -}; - -/* hctdb_instrhelp.get_enum_decl("WaveMatrixScalarOpCode")*/ -// WAVEMATRIXSCALAROPCODE-ENUM:BEGIN -// Operation for WaveMatrix_ScalarOp -enum class WaveMatrixScalarOpCode : unsigned { - Add = 0, - Divide = 3, - Invalid = 4, - Multiply = 2, - Subtract = 1, -}; -// WAVEMATRIXSCALAROPCODE-ENUM:END - // Corresponds to MEMORY_TYPE_FLAG enums in HLSL enum class MemoryTypeFlag : uint32_t { UavMemory = 0x00000001, // UAV_MEMORY @@ -1922,8 +1883,7 @@ const uint64_t ShaderFeatureInfo_SampleCmpGradientOrBias = 0x80000000; const uint64_t ShaderFeatureInfo_ExtendedCommandInfo = 0x100000000; // Experimental SM 6.9+ - Reserved, not yet supported. -// WaveMMA slots in between two SM 6.6 feature bits. -const uint64_t ShaderFeatureInfo_WaveMMA = 0x8000000; +const uint64_t ShaderFeatureInfo_Reserved = 0x8000000; // Maximum count without rolling over into another 64-bit field is 40, // so the last flag we can use for a feature requirement is: 0x8000000000 diff --git a/include/dxc/DXIL/DxilInstructions.h b/include/dxc/DXIL/DxilInstructions.h index 7da9e58af6..f5d8759db7 100644 --- a/include/dxc/DXIL/DxilInstructions.h +++ b/include/dxc/DXIL/DxilInstructions.h @@ -8042,500 +8042,6 @@ struct DxilInst_TextureStoreSample { void set_sampleIdx(llvm::Value *val) { Instr->setOperand(10, val); } }; -/// This instruction Annotate a wave matrix pointer with the type information -struct DxilInst_WaveMatrix_Annotate { - llvm::Instruction *Instr; - // Construction and identification - DxilInst_WaveMatrix_Annotate(llvm::Instruction *pInstr) : Instr(pInstr) {} - operator bool() const { - return hlsl::OP::IsDxilOpFuncCallInst( - Instr, hlsl::OP::OpCode::WaveMatrix_Annotate); - } - // Validation support - bool isAllowed() const { return true; } - bool isArgumentListValid() const { - if (3 != llvm::dyn_cast(Instr)->getNumArgOperands()) - return false; - return true; - } - // Metadata - bool requiresUniformInputs() const { return false; } - // Operand indexes - enum OperandIdx { - arg_waveMatrixPtr = 1, - arg_waveMatProps = 2, - }; - // Accessors - llvm::Value *get_waveMatrixPtr() const { return Instr->getOperand(1); } - void set_waveMatrixPtr(llvm::Value *val) { Instr->setOperand(1, val); } - llvm::Value *get_waveMatProps() const { return Instr->getOperand(2); } - void set_waveMatProps(llvm::Value *val) { Instr->setOperand(2, val); } -}; - -/// This instruction Returns depth (K) value for matrix of specified type -struct DxilInst_WaveMatrix_Depth { - llvm::Instruction *Instr; - // Construction and identification - DxilInst_WaveMatrix_Depth(llvm::Instruction *pInstr) : Instr(pInstr) {} - operator bool() const { - return hlsl::OP::IsDxilOpFuncCallInst(Instr, - hlsl::OP::OpCode::WaveMatrix_Depth); - } - // Validation support - bool isAllowed() const { return true; } - bool isArgumentListValid() const { - if (2 != llvm::dyn_cast(Instr)->getNumArgOperands()) - return false; - return true; - } - // Metadata - bool requiresUniformInputs() const { return false; } - // Operand indexes - enum OperandIdx { - arg_waveMatProps = 1, - }; - // Accessors - llvm::Value *get_waveMatProps() const { return Instr->getOperand(1); } - void set_waveMatProps(llvm::Value *val) { Instr->setOperand(1, val); } -}; - -/// This instruction Fill wave matrix with scalar value -struct DxilInst_WaveMatrix_Fill { - llvm::Instruction *Instr; - // Construction and identification - DxilInst_WaveMatrix_Fill(llvm::Instruction *pInstr) : Instr(pInstr) {} - operator bool() const { - return hlsl::OP::IsDxilOpFuncCallInst(Instr, - hlsl::OP::OpCode::WaveMatrix_Fill); - } - // Validation support - bool isAllowed() const { return true; } - bool isArgumentListValid() const { - if (3 != llvm::dyn_cast(Instr)->getNumArgOperands()) - return false; - return true; - } - // Metadata - bool requiresUniformInputs() const { return false; } - // Operand indexes - enum OperandIdx { - arg_waveMatrixPtr = 1, - arg_value = 2, - }; - // Accessors - llvm::Value *get_waveMatrixPtr() const { return Instr->getOperand(1); } - void set_waveMatrixPtr(llvm::Value *val) { Instr->setOperand(1, val); } - llvm::Value *get_value() const { return Instr->getOperand(2); } - void set_value(llvm::Value *val) { Instr->setOperand(2, val); } -}; - -/// This instruction Load wave matrix from raw buffer -struct DxilInst_WaveMatrix_LoadRawBuf { - llvm::Instruction *Instr; - // Construction and identification - DxilInst_WaveMatrix_LoadRawBuf(llvm::Instruction *pInstr) : Instr(pInstr) {} - operator bool() const { - return hlsl::OP::IsDxilOpFuncCallInst( - Instr, hlsl::OP::OpCode::WaveMatrix_LoadRawBuf); - } - // Validation support - bool isAllowed() const { return true; } - bool isArgumentListValid() const { - if (7 != llvm::dyn_cast(Instr)->getNumArgOperands()) - return false; - return true; - } - // Metadata - bool requiresUniformInputs() const { return false; } - // Operand indexes - enum OperandIdx { - arg_waveMatrixPtr = 1, - arg_rawBuf = 2, - arg_offsetInBytes = 3, - arg_strideInBytes = 4, - arg_alignmentInBytes = 5, - arg_colMajor = 6, - }; - // Accessors - llvm::Value *get_waveMatrixPtr() const { return Instr->getOperand(1); } - void set_waveMatrixPtr(llvm::Value *val) { Instr->setOperand(1, val); } - llvm::Value *get_rawBuf() const { return Instr->getOperand(2); } - void set_rawBuf(llvm::Value *val) { Instr->setOperand(2, val); } - llvm::Value *get_offsetInBytes() const { return Instr->getOperand(3); } - void set_offsetInBytes(llvm::Value *val) { Instr->setOperand(3, val); } - llvm::Value *get_strideInBytes() const { return Instr->getOperand(4); } - void set_strideInBytes(llvm::Value *val) { Instr->setOperand(4, val); } - llvm::Value *get_alignmentInBytes() const { return Instr->getOperand(5); } - void set_alignmentInBytes(llvm::Value *val) { Instr->setOperand(5, val); } - int8_t get_alignmentInBytes_val() const { - return (int8_t)(llvm::dyn_cast(Instr->getOperand(5)) - ->getZExtValue()); - } - void set_alignmentInBytes_val(int8_t val) { - Instr->setOperand(5, llvm::Constant::getIntegerValue( - llvm::IntegerType::get(Instr->getContext(), 8), - llvm::APInt(8, (uint64_t)val))); - } - llvm::Value *get_colMajor() const { return Instr->getOperand(6); } - void set_colMajor(llvm::Value *val) { Instr->setOperand(6, val); } - bool get_colMajor_val() const { - return (bool)(llvm::dyn_cast(Instr->getOperand(6)) - ->getZExtValue()); - } - void set_colMajor_val(bool val) { - Instr->setOperand(6, llvm::Constant::getIntegerValue( - llvm::IntegerType::get(Instr->getContext(), 1), - llvm::APInt(1, (uint64_t)val))); - } -}; - -/// This instruction Load wave matrix from group shared array -struct DxilInst_WaveMatrix_LoadGroupShared { - llvm::Instruction *Instr; - // Construction and identification - DxilInst_WaveMatrix_LoadGroupShared(llvm::Instruction *pInstr) - : Instr(pInstr) {} - operator bool() const { - return hlsl::OP::IsDxilOpFuncCallInst( - Instr, hlsl::OP::OpCode::WaveMatrix_LoadGroupShared); - } - // Validation support - bool isAllowed() const { return true; } - bool isArgumentListValid() const { - if (6 != llvm::dyn_cast(Instr)->getNumArgOperands()) - return false; - return true; - } - // Metadata - bool requiresUniformInputs() const { return false; } - // Operand indexes - enum OperandIdx { - arg_waveMatrixPtr = 1, - arg_groupsharedPtr = 2, - arg_startArrayIndex = 3, - arg_strideInElements = 4, - arg_colMajor = 5, - }; - // Accessors - llvm::Value *get_waveMatrixPtr() const { return Instr->getOperand(1); } - void set_waveMatrixPtr(llvm::Value *val) { Instr->setOperand(1, val); } - llvm::Value *get_groupsharedPtr() const { return Instr->getOperand(2); } - void set_groupsharedPtr(llvm::Value *val) { Instr->setOperand(2, val); } - llvm::Value *get_startArrayIndex() const { return Instr->getOperand(3); } - void set_startArrayIndex(llvm::Value *val) { Instr->setOperand(3, val); } - llvm::Value *get_strideInElements() const { return Instr->getOperand(4); } - void set_strideInElements(llvm::Value *val) { Instr->setOperand(4, val); } - llvm::Value *get_colMajor() const { return Instr->getOperand(5); } - void set_colMajor(llvm::Value *val) { Instr->setOperand(5, val); } - bool get_colMajor_val() const { - return (bool)(llvm::dyn_cast(Instr->getOperand(5)) - ->getZExtValue()); - } - void set_colMajor_val(bool val) { - Instr->setOperand(5, llvm::Constant::getIntegerValue( - llvm::IntegerType::get(Instr->getContext(), 1), - llvm::APInt(1, (uint64_t)val))); - } -}; - -/// This instruction Store wave matrix to raw buffer -struct DxilInst_WaveMatrix_StoreRawBuf { - llvm::Instruction *Instr; - // Construction and identification - DxilInst_WaveMatrix_StoreRawBuf(llvm::Instruction *pInstr) : Instr(pInstr) {} - operator bool() const { - return hlsl::OP::IsDxilOpFuncCallInst( - Instr, hlsl::OP::OpCode::WaveMatrix_StoreRawBuf); - } - // Validation support - bool isAllowed() const { return true; } - bool isArgumentListValid() const { - if (7 != llvm::dyn_cast(Instr)->getNumArgOperands()) - return false; - return true; - } - // Metadata - bool requiresUniformInputs() const { return false; } - // Operand indexes - enum OperandIdx { - arg_waveMatrixPtr = 1, - arg_rawBuf = 2, - arg_offsetInBytes = 3, - arg_strideInBytes = 4, - arg_alignmentInBytes = 5, - arg_colMajor = 6, - }; - // Accessors - llvm::Value *get_waveMatrixPtr() const { return Instr->getOperand(1); } - void set_waveMatrixPtr(llvm::Value *val) { Instr->setOperand(1, val); } - llvm::Value *get_rawBuf() const { return Instr->getOperand(2); } - void set_rawBuf(llvm::Value *val) { Instr->setOperand(2, val); } - llvm::Value *get_offsetInBytes() const { return Instr->getOperand(3); } - void set_offsetInBytes(llvm::Value *val) { Instr->setOperand(3, val); } - llvm::Value *get_strideInBytes() const { return Instr->getOperand(4); } - void set_strideInBytes(llvm::Value *val) { Instr->setOperand(4, val); } - llvm::Value *get_alignmentInBytes() const { return Instr->getOperand(5); } - void set_alignmentInBytes(llvm::Value *val) { Instr->setOperand(5, val); } - int8_t get_alignmentInBytes_val() const { - return (int8_t)(llvm::dyn_cast(Instr->getOperand(5)) - ->getZExtValue()); - } - void set_alignmentInBytes_val(int8_t val) { - Instr->setOperand(5, llvm::Constant::getIntegerValue( - llvm::IntegerType::get(Instr->getContext(), 8), - llvm::APInt(8, (uint64_t)val))); - } - llvm::Value *get_colMajor() const { return Instr->getOperand(6); } - void set_colMajor(llvm::Value *val) { Instr->setOperand(6, val); } - bool get_colMajor_val() const { - return (bool)(llvm::dyn_cast(Instr->getOperand(6)) - ->getZExtValue()); - } - void set_colMajor_val(bool val) { - Instr->setOperand(6, llvm::Constant::getIntegerValue( - llvm::IntegerType::get(Instr->getContext(), 1), - llvm::APInt(1, (uint64_t)val))); - } -}; - -/// This instruction Store wave matrix to group shared array -struct DxilInst_WaveMatrix_StoreGroupShared { - llvm::Instruction *Instr; - // Construction and identification - DxilInst_WaveMatrix_StoreGroupShared(llvm::Instruction *pInstr) - : Instr(pInstr) {} - operator bool() const { - return hlsl::OP::IsDxilOpFuncCallInst( - Instr, hlsl::OP::OpCode::WaveMatrix_StoreGroupShared); - } - // Validation support - bool isAllowed() const { return true; } - bool isArgumentListValid() const { - if (6 != llvm::dyn_cast(Instr)->getNumArgOperands()) - return false; - return true; - } - // Metadata - bool requiresUniformInputs() const { return false; } - // Operand indexes - enum OperandIdx { - arg_waveMatrixPtr = 1, - arg_groupsharedPtr = 2, - arg_startArrayIndex = 3, - arg_strideInElements = 4, - arg_colMajor = 5, - }; - // Accessors - llvm::Value *get_waveMatrixPtr() const { return Instr->getOperand(1); } - void set_waveMatrixPtr(llvm::Value *val) { Instr->setOperand(1, val); } - llvm::Value *get_groupsharedPtr() const { return Instr->getOperand(2); } - void set_groupsharedPtr(llvm::Value *val) { Instr->setOperand(2, val); } - llvm::Value *get_startArrayIndex() const { return Instr->getOperand(3); } - void set_startArrayIndex(llvm::Value *val) { Instr->setOperand(3, val); } - llvm::Value *get_strideInElements() const { return Instr->getOperand(4); } - void set_strideInElements(llvm::Value *val) { Instr->setOperand(4, val); } - llvm::Value *get_colMajor() const { return Instr->getOperand(5); } - void set_colMajor(llvm::Value *val) { Instr->setOperand(5, val); } - bool get_colMajor_val() const { - return (bool)(llvm::dyn_cast(Instr->getOperand(5)) - ->getZExtValue()); - } - void set_colMajor_val(bool val) { - Instr->setOperand(5, llvm::Constant::getIntegerValue( - llvm::IntegerType::get(Instr->getContext(), 1), - llvm::APInt(1, (uint64_t)val))); - } -}; - -/// This instruction Mutiply left and right wave matrix and store in accumulator -struct DxilInst_WaveMatrix_Multiply { - llvm::Instruction *Instr; - // Construction and identification - DxilInst_WaveMatrix_Multiply(llvm::Instruction *pInstr) : Instr(pInstr) {} - operator bool() const { - return hlsl::OP::IsDxilOpFuncCallInst( - Instr, hlsl::OP::OpCode::WaveMatrix_Multiply); - } - // Validation support - bool isAllowed() const { return true; } - bool isArgumentListValid() const { - if (4 != llvm::dyn_cast(Instr)->getNumArgOperands()) - return false; - return true; - } - // Metadata - bool requiresUniformInputs() const { return false; } - // Operand indexes - enum OperandIdx { - arg_waveMatrixAccumulator = 1, - arg_waveMatrixLeft = 2, - arg_waveMatrixRight = 3, - }; - // Accessors - llvm::Value *get_waveMatrixAccumulator() const { - return Instr->getOperand(1); - } - void set_waveMatrixAccumulator(llvm::Value *val) { - Instr->setOperand(1, val); - } - llvm::Value *get_waveMatrixLeft() const { return Instr->getOperand(2); } - void set_waveMatrixLeft(llvm::Value *val) { Instr->setOperand(2, val); } - llvm::Value *get_waveMatrixRight() const { return Instr->getOperand(3); } - void set_waveMatrixRight(llvm::Value *val) { Instr->setOperand(3, val); } -}; - -/// This instruction Mutiply left and right wave matrix and accumulate into -/// accumulator -struct DxilInst_WaveMatrix_MultiplyAccumulate { - llvm::Instruction *Instr; - // Construction and identification - DxilInst_WaveMatrix_MultiplyAccumulate(llvm::Instruction *pInstr) - : Instr(pInstr) {} - operator bool() const { - return hlsl::OP::IsDxilOpFuncCallInst( - Instr, hlsl::OP::OpCode::WaveMatrix_MultiplyAccumulate); - } - // Validation support - bool isAllowed() const { return true; } - bool isArgumentListValid() const { - if (4 != llvm::dyn_cast(Instr)->getNumArgOperands()) - return false; - return true; - } - // Metadata - bool requiresUniformInputs() const { return false; } - // Operand indexes - enum OperandIdx { - arg_waveMatrixAccumulator = 1, - arg_waveMatrixLeft = 2, - arg_waveMatrixRight = 3, - }; - // Accessors - llvm::Value *get_waveMatrixAccumulator() const { - return Instr->getOperand(1); - } - void set_waveMatrixAccumulator(llvm::Value *val) { - Instr->setOperand(1, val); - } - llvm::Value *get_waveMatrixLeft() const { return Instr->getOperand(2); } - void set_waveMatrixLeft(llvm::Value *val) { Instr->setOperand(2, val); } - llvm::Value *get_waveMatrixRight() const { return Instr->getOperand(3); } - void set_waveMatrixRight(llvm::Value *val) { Instr->setOperand(3, val); } -}; - -/// This instruction Perform scalar operation on each element of wave matrix -struct DxilInst_WaveMatrix_ScalarOp { - llvm::Instruction *Instr; - // Construction and identification - DxilInst_WaveMatrix_ScalarOp(llvm::Instruction *pInstr) : Instr(pInstr) {} - operator bool() const { - return hlsl::OP::IsDxilOpFuncCallInst( - Instr, hlsl::OP::OpCode::WaveMatrix_ScalarOp); - } - // Validation support - bool isAllowed() const { return true; } - bool isArgumentListValid() const { - if (4 != llvm::dyn_cast(Instr)->getNumArgOperands()) - return false; - return true; - } - // Metadata - bool requiresUniformInputs() const { return false; } - // Operand indexes - enum OperandIdx { - arg_waveMatrixPtr = 1, - arg_op = 2, - arg_value = 3, - }; - // Accessors - llvm::Value *get_waveMatrixPtr() const { return Instr->getOperand(1); } - void set_waveMatrixPtr(llvm::Value *val) { Instr->setOperand(1, val); } - llvm::Value *get_op() const { return Instr->getOperand(2); } - void set_op(llvm::Value *val) { Instr->setOperand(2, val); } - int8_t get_op_val() const { - return (int8_t)(llvm::dyn_cast(Instr->getOperand(2)) - ->getZExtValue()); - } - void set_op_val(int8_t val) { - Instr->setOperand(2, llvm::Constant::getIntegerValue( - llvm::IntegerType::get(Instr->getContext(), 8), - llvm::APInt(8, (uint64_t)val))); - } - llvm::Value *get_value() const { return Instr->getOperand(3); } - void set_value(llvm::Value *val) { Instr->setOperand(3, val); } -}; - -/// This instruction Sum rows or columns of an input matrix into an existing -/// accumulator fragment matrix -struct DxilInst_WaveMatrix_SumAccumulate { - llvm::Instruction *Instr; - // Construction and identification - DxilInst_WaveMatrix_SumAccumulate(llvm::Instruction *pInstr) - : Instr(pInstr) {} - operator bool() const { - return hlsl::OP::IsDxilOpFuncCallInst( - Instr, hlsl::OP::OpCode::WaveMatrix_SumAccumulate); - } - // Validation support - bool isAllowed() const { return true; } - bool isArgumentListValid() const { - if (3 != llvm::dyn_cast(Instr)->getNumArgOperands()) - return false; - return true; - } - // Metadata - bool requiresUniformInputs() const { return false; } - // Operand indexes - enum OperandIdx { - arg_waveMatrixFragment = 1, - arg_waveMatrixInput = 2, - }; - // Accessors - llvm::Value *get_waveMatrixFragment() const { return Instr->getOperand(1); } - void set_waveMatrixFragment(llvm::Value *val) { Instr->setOperand(1, val); } - llvm::Value *get_waveMatrixInput() const { return Instr->getOperand(2); } - void set_waveMatrixInput(llvm::Value *val) { Instr->setOperand(2, val); } -}; - -/// This instruction Element-wise accumulate, or broadcast add of fragment into -/// accumulator -struct DxilInst_WaveMatrix_Add { - llvm::Instruction *Instr; - // Construction and identification - DxilInst_WaveMatrix_Add(llvm::Instruction *pInstr) : Instr(pInstr) {} - operator bool() const { - return hlsl::OP::IsDxilOpFuncCallInst(Instr, - hlsl::OP::OpCode::WaveMatrix_Add); - } - // Validation support - bool isAllowed() const { return true; } - bool isArgumentListValid() const { - if (3 != llvm::dyn_cast(Instr)->getNumArgOperands()) - return false; - return true; - } - // Metadata - bool requiresUniformInputs() const { return false; } - // Operand indexes - enum OperandIdx { - arg_waveMatrixAccumulator = 1, - arg_waveMatrixAccumulatorOrFragment = 2, - }; - // Accessors - llvm::Value *get_waveMatrixAccumulator() const { - return Instr->getOperand(1); - } - void set_waveMatrixAccumulator(llvm::Value *val) { - Instr->setOperand(1, val); - } - llvm::Value *get_waveMatrixAccumulatorOrFragment() const { - return Instr->getOperand(2); - } - void set_waveMatrixAccumulatorOrFragment(llvm::Value *val) { - Instr->setOperand(2, val); - } -}; - /// This instruction returns a handle for the output records struct DxilInst_AllocateNodeOutputRecords { llvm::Instruction *Instr; diff --git a/include/dxc/DXIL/DxilOperations.h b/include/dxc/DXIL/DxilOperations.h index 8a5110e6ce..3514701327 100644 --- a/include/dxc/DXIL/DxilOperations.h +++ b/include/dxc/DXIL/DxilOperations.h @@ -77,8 +77,6 @@ class OP { llvm::Type *GetSplitDoubleType() const; llvm::Type *GetFourI32Type() const; llvm::Type *GetFourI16Type() const; - llvm::StructType *GetWaveMatrixPropertiesType() const; - llvm::PointerType *GetWaveMatPtrType() const; llvm::Type *GetResRetType(llvm::Type *pOverloadType); llvm::Type *GetCBufferRetType(llvm::Type *pOverloadType); @@ -161,8 +159,6 @@ class OP { llvm::Type *m_pSplitDoubleType; llvm::Type *m_pFourI32Type; llvm::Type *m_pFourI16Type; - llvm::StructType *m_pWaveMatInfoType; - llvm::PointerType *m_pWaveMatPtrType; DXIL::LowPrecisionMode m_LowPrecisionMode; diff --git a/include/dxc/DXIL/DxilShaderFlags.h b/include/dxc/DXIL/DxilShaderFlags.h index a0d4f587d0..7b065c63fa 100644 --- a/include/dxc/DXIL/DxilShaderFlags.h +++ b/include/dxc/DXIL/DxilShaderFlags.h @@ -212,10 +212,6 @@ class ShaderFlags { void SetExtendedCommandInfo(bool flag) { m_bExtendedCommandInfo = flag; } bool GetExtendedCommandInfo() const { return m_bExtendedCommandInfo; } - // Experimental SM 6.9+ - Reserved, not yet supported. - void SetWaveMMA(bool flag) { m_bWaveMMA = flag; } - bool GetWaveMMA() const { return m_bWaveMMA; } - // Per-function flags void SetUsesDerivatives(bool flag) { m_bUsesDerivatives = flag; } bool GetUsesDerivatives() const { return m_bUsesDerivatives; } @@ -328,7 +324,21 @@ class ShaderFlags { // Experimental SM 6.9+ - Reserved, not yet supported. // Bit: 36 - unsigned m_bWaveMMA : 1; // SHADER_FEATURE_WAVE_MMA +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunused-private-field" +#elif defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-private-field" +#endif + + unsigned m_bReserved : 1; // SHADER_FEATURE_RESERVED + +#ifdef __clang__ +#pragma clang diagnostic pop +#elif defined(__GNUC__) +#pragma GCC diagnostic pop +#endif // SM 6.8+ // Bit: 37 diff --git a/include/dxc/DXIL/DxilUtil.h b/include/dxc/DXIL/DxilUtil.h index cb40837f27..490f335db5 100644 --- a/include/dxc/DXIL/DxilUtil.h +++ b/include/dxc/DXIL/DxilUtil.h @@ -162,8 +162,6 @@ GetHLSLResourceProperties(llvm::Type *Ty); bool IsHLSLResourceType(llvm::Type *Ty); bool IsHLSLObjectType(llvm::Type *Ty); bool IsHLSLRayQueryType(llvm::Type *Ty); -bool IsHLSLWaveMatrixType(llvm::Type *Ty, - DXIL::WaveMatrixKind *pKind = nullptr); bool IsHLSLResourceDescType(llvm::Type *Ty); bool IsResourceSingleComponent(llvm::Type *Ty); uint8_t GetResourceComponentCount(llvm::Type *Ty); diff --git a/include/dxc/DXIL/DxilWaveMatrix.h b/include/dxc/DXIL/DxilWaveMatrix.h deleted file mode 100644 index f70c10d8dd..0000000000 --- a/include/dxc/DXIL/DxilWaveMatrix.h +++ /dev/null @@ -1,53 +0,0 @@ -/////////////////////////////////////////////////////////////////////////////// -// // -// DxilWaveMatrix.h // -// Copyright (C) Microsoft Corporation. All rights reserved. // -// This file is distributed under the University of Illinois Open Source // -// License. See LICENSE.TXT for details. // -// // -// WaveMatrix related types and helpers. // -// // -/////////////////////////////////////////////////////////////////////////////// - -#pragma once - -#include "DxilConstants.h" - -namespace llvm { -class Value; -class Constant; -class Type; -class StructType; -} // namespace llvm - -namespace hlsl { - -struct DxilWaveMatrixProperties { - DXIL::WaveMatrixKind kind; - DXIL::ComponentType compType; - unsigned dimM, dimN; - - DxilWaveMatrixProperties() - : kind(DXIL::WaveMatrixKind::NumKinds), - compType(DXIL::ComponentType::Invalid), dimM(0), dimN(0) {} - bool isValid() const { return kind < DXIL::WaveMatrixKind::NumKinds; } - bool operator==(const DxilWaveMatrixProperties &other) { - return kind == other.kind && compType == other.compType && - dimM == other.dimM && dimN == other.dimN; - } - bool operator!=(const DxilWaveMatrixProperties &other) { - return !(*this == other); - } -}; - -namespace wavemat_helper { - -DxilWaveMatrixProperties LoadInfoFromConstant(llvm::Constant *C); -llvm::Constant *GetInfoConstantFromWaveMatPtr(llvm::Value *waveMatPtr); -DxilWaveMatrixProperties GetInfoFromWaveMatPtr(llvm::Value *waveMatPtr); -llvm::Constant *GetAsConstant(const DxilWaveMatrixProperties &info, - llvm::StructType *infoTy); - -} // namespace wavemat_helper - -} // namespace hlsl diff --git a/include/dxc/DxilContainer/RDAT_LibraryTypes.inl b/include/dxc/DxilContainer/RDAT_LibraryTypes.inl index 7586684a0b..132d272a8e 100644 --- a/include/dxc/DxilContainer/RDAT_LibraryTypes.inl +++ b/include/dxc/DxilContainer/RDAT_LibraryTypes.inl @@ -72,7 +72,7 @@ RDAT_ENUM_START(DxilFeatureInfo1, uint32_t) RDAT_ENUM_VALUE(DerivativesInMeshAndAmpShaders, 0x1000000) RDAT_ENUM_VALUE(ResourceDescriptorHeapIndexing, 0x2000000) RDAT_ENUM_VALUE(SamplerDescriptorHeapIndexing, 0x4000000) - RDAT_ENUM_VALUE(WaveMMA, 0x8000000) + RDAT_ENUM_VALUE(Reserved, 0x8000000) RDAT_ENUM_VALUE(AtomicInt64OnHeapResource, 0x10000000) RDAT_ENUM_VALUE(AdvancedTextureOps, 0x20000000) RDAT_ENUM_VALUE(WriteableMSAATextures, 0x40000000) diff --git a/include/dxc/HLSL/HLOperations.h b/include/dxc/HLSL/HLOperations.h index 12e2766c21..1ccb7f04a2 100644 --- a/include/dxc/HLSL/HLOperations.h +++ b/include/dxc/HLSL/HLOperations.h @@ -49,7 +49,6 @@ enum class HLOpcodeGroup { HLIndexNodeHandle, HLCreateNodeInputRecordHandle, HLAnnotateHandle, - HLWaveMatrix_Annotate, HLAnnotateNodeHandle, HLAnnotateNodeRecordHandle, NumOfHLOps @@ -395,10 +394,6 @@ const unsigned kCreateHandleIndexOpIdx = 2; // Only for array of cbuffer. const unsigned kAnnotateHandleResourcePropertiesOpIdx = 2; const unsigned kAnnotateHandleResourceTypeOpIdx = 3; -// AnnotateWaveMatrix. -const unsigned kAnnotateWaveMatrixPtrOpIdx = 1; -const unsigned kAnnotateWaveMatrixPropertiesOpIdx = 2; - // TraceRay. const unsigned kTraceRayRayDescOpIdx = 7; const unsigned kTraceRayPayLoadOpIdx = 8; @@ -418,20 +413,6 @@ const unsigned kDispatchMeshOpThreadY = 2; const unsigned kDispatchMeshOpThreadZ = 3; const unsigned kDispatchMeshOpPayload = 4; -// WaveMatrix -const unsigned kWaveMatThisOpIdx = 1; -const unsigned kWaveMatFillScalarOpIdx = 2; -const unsigned kWaveMatScalarOpOpIdx = 2; -const unsigned kWaveMatOther1OpIdx = 2; -const unsigned kWaveMatOther2OpIdx = 3; -const unsigned kWaveMatLoadStoreBufOpIdx = 2; -const unsigned kWaveMatLoadStoreStartOpIdx = 3; -const unsigned kWaveMatLoadStoreStrideOpIdx = 4; -// Note: No ColMajor arg for fragments, so align idx is one less. -const unsigned kWaveMatLoadStoreColMajorOpIdx = 5; -const unsigned kWaveMatFragLoadStoreAlignmentOpIdx = 5; -const unsigned kWaveMatLoadStoreAlignmentOpIdx = 6; - // Work Graph const unsigned kIncrementOutputCountCountIdx = 2; diff --git a/include/dxc/HlslIntrinsicOp.h b/include/dxc/HlslIntrinsicOp.h index 9d0f0a2b28..c8ed2cbd2a 100644 --- a/include/dxc/HlslIntrinsicOp.h +++ b/include/dxc/HlslIntrinsicOp.h @@ -343,16 +343,6 @@ enum class IntrinsicOp { MOP_TraceRayInline, MOP_WorldRayDirection, MOP_WorldRayOrigin, - MOP_Fill, - MOP_MatrixDepth, - MOP_ScalarAdd, - MOP_ScalarDivide, - MOP_ScalarMultiply, - MOP_ScalarSubtract, - MOP_SumAccumulate, - MOP_Add, - MOP_Multiply, - MOP_MultiplyAccumulate, MOP_Count, MOP_FinishedCrossGroupSharing, MOP_GetGroupNodeOutputRecords, diff --git a/include/dxc/dxcapi.internal.h b/include/dxc/dxcapi.internal.h index 5c15c401eb..b0f9a467a4 100644 --- a/include/dxc/dxcapi.internal.h +++ b/include/dxc/dxcapi.internal.h @@ -121,18 +121,12 @@ enum LEGAL_INTRINSIC_COMPTYPES { LICOMPTYPE_BYTEADDRESSBUFFER = 45, LICOMPTYPE_RWBYTEADDRESSBUFFER = 46, - LICOMPTYPE_WAVE_MATRIX_LEFT = 47, - LICOMPTYPE_WAVE_MATRIX_RIGHT = 48, - LICOMPTYPE_WAVE_MATRIX_LEFT_COL_ACC = 49, - LICOMPTYPE_WAVE_MATRIX_RIGHT_ROW_ACC = 50, - LICOMPTYPE_WAVE_MATRIX_ACCUMULATOR = 51, - - LICOMPTYPE_NODE_RECORD_OR_UAV = 52, - LICOMPTYPE_ANY_NODE_OUTPUT_RECORD = 53, - LICOMPTYPE_GROUP_NODE_OUTPUT_RECORDS = 54, - LICOMPTYPE_THREAD_NODE_OUTPUT_RECORDS = 55, - - LICOMPTYPE_COUNT = 56 + LICOMPTYPE_NODE_RECORD_OR_UAV = 47, + LICOMPTYPE_ANY_NODE_OUTPUT_RECORD = 48, + LICOMPTYPE_GROUP_NODE_OUTPUT_RECORDS = 49, + LICOMPTYPE_THREAD_NODE_OUTPUT_RECORDS = 50, + + LICOMPTYPE_COUNT = 51 }; static const BYTE IA_SPECIAL_BASE = 0xf0; diff --git a/lib/DXIL/CMakeLists.txt b/lib/DXIL/CMakeLists.txt index 97d1365ce1..f5e75a9e09 100644 --- a/lib/DXIL/CMakeLists.txt +++ b/lib/DXIL/CMakeLists.txt @@ -31,7 +31,6 @@ add_llvm_library(LLVMDXIL DxilUtil.cpp DxilUtilDbgInfoAndMisc.cpp DxilPDB.cpp - DxilWaveMatrix.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/IR diff --git a/lib/DXIL/DxilModule.cpp b/lib/DXIL/DxilModule.cpp index 98531b80b1..f4abdd15aa 100644 --- a/lib/DXIL/DxilModule.cpp +++ b/lib/DXIL/DxilModule.cpp @@ -2169,9 +2169,7 @@ static void AdjustMinimumShaderModelAndFlags(const DxilFunctionProps *props, } // Adjust minimum shader model based on flags. - if (flags.GetWaveMMA()) - DXIL::UpdateToMaxOfVersions(minMajor, minMinor, 6, 9); - else if (flags.GetSampleCmpGradientOrBias() || flags.GetExtendedCommandInfo()) + if (flags.GetSampleCmpGradientOrBias() || flags.GetExtendedCommandInfo()) DXIL::UpdateToMaxOfVersions(minMajor, minMinor, 6, 8); else if (flags.GetAdvancedTextureOps() || flags.GetWriteableMSAATextures()) DXIL::UpdateToMaxOfVersions(minMajor, minMinor, 6, 7); diff --git a/lib/DXIL/DxilOperations.cpp b/lib/DXIL/DxilOperations.cpp index b3b3aea49f..dbc6201097 100644 --- a/lib/DXIL/DxilOperations.cpp +++ b/lib/DXIL/DxilOperations.cpp @@ -2289,115 +2289,114 @@ const OP::OpCodeProperty OP::m_OpCodeProps[(unsigned)OP::OpCode::NumOpCodes] = { Attribute::None, }, - // WaveMatrix void, h, f, d, i1, i8, i16, i32, i64, - // udt, obj , function attribute + // void, h, f, d, i1, i8, i16, i32, i64, udt, obj , function attribute { - OC::WaveMatrix_Annotate, - "WaveMatrix_Annotate", - OCC::WaveMatrix_Annotate, - "waveMatrix_Annotate", + OC::Reserved0, + "Reserved0", + OCC::Reserved, + "reserved", {true, false, false, false, false, false, false, false, false, false, false}, - Attribute::ArgMemOnly, + Attribute::None, }, { - OC::WaveMatrix_Depth, - "WaveMatrix_Depth", - OCC::WaveMatrix_Depth, - "waveMatrix_Depth", + OC::Reserved1, + "Reserved1", + OCC::Reserved, + "reserved", {true, false, false, false, false, false, false, false, false, false, false}, - Attribute::ReadNone, + Attribute::None, }, { - OC::WaveMatrix_Fill, - "WaveMatrix_Fill", - OCC::WaveMatrix_Fill, - "waveMatrix_Fill", - {false, true, true, false, false, false, false, true, false, false, + OC::Reserved2, + "Reserved2", + OCC::Reserved, + "reserved", + {true, false, false, false, false, false, false, false, false, false, false}, - Attribute::ArgMemOnly, + Attribute::None, }, { - OC::WaveMatrix_LoadRawBuf, - "WaveMatrix_LoadRawBuf", - OCC::WaveMatrix_LoadRawBuf, - "waveMatrix_LoadRawBuf", + OC::Reserved3, + "Reserved3", + OCC::Reserved, + "reserved", {true, false, false, false, false, false, false, false, false, false, false}, Attribute::None, }, { - OC::WaveMatrix_LoadGroupShared, - "WaveMatrix_LoadGroupShared", - OCC::WaveMatrix_LoadGroupShared, - "waveMatrix_LoadGroupShared", - {false, true, true, false, false, false, false, true, false, false, + OC::Reserved4, + "Reserved4", + OCC::Reserved, + "reserved", + {true, false, false, false, false, false, false, false, false, false, false}, - Attribute::ArgMemOnly, + Attribute::None, }, { - OC::WaveMatrix_StoreRawBuf, - "WaveMatrix_StoreRawBuf", - OCC::WaveMatrix_StoreRawBuf, - "waveMatrix_StoreRawBuf", + OC::Reserved5, + "Reserved5", + OCC::Reserved, + "reserved", {true, false, false, false, false, false, false, false, false, false, false}, Attribute::None, }, { - OC::WaveMatrix_StoreGroupShared, - "WaveMatrix_StoreGroupShared", - OCC::WaveMatrix_StoreGroupShared, - "waveMatrix_StoreGroupShared", - {false, true, true, false, false, false, false, true, false, false, + OC::Reserved6, + "Reserved6", + OCC::Reserved, + "reserved", + {true, false, false, false, false, false, false, false, false, false, false}, - Attribute::ArgMemOnly, + Attribute::None, }, { - OC::WaveMatrix_Multiply, - "WaveMatrix_Multiply", - OCC::WaveMatrix_Multiply, - "waveMatrix_Multiply", + OC::Reserved7, + "Reserved7", + OCC::Reserved, + "reserved", {true, false, false, false, false, false, false, false, false, false, false}, - Attribute::ArgMemOnly, + Attribute::None, }, { - OC::WaveMatrix_MultiplyAccumulate, - "WaveMatrix_MultiplyAccumulate", - OCC::WaveMatrix_Multiply, - "waveMatrix_Multiply", + OC::Reserved8, + "Reserved8", + OCC::Reserved, + "reserved", {true, false, false, false, false, false, false, false, false, false, false}, - Attribute::ArgMemOnly, + Attribute::None, }, { - OC::WaveMatrix_ScalarOp, - "WaveMatrix_ScalarOp", - OCC::WaveMatrix_ScalarOp, - "waveMatrix_ScalarOp", - {false, true, true, false, false, false, false, true, false, false, + OC::Reserved9, + "Reserved9", + OCC::Reserved, + "reserved", + {true, false, false, false, false, false, false, false, false, false, false}, - Attribute::ArgMemOnly, + Attribute::None, }, { - OC::WaveMatrix_SumAccumulate, - "WaveMatrix_SumAccumulate", - OCC::WaveMatrix_Accumulate, - "waveMatrix_Accumulate", + OC::Reserved10, + "Reserved10", + OCC::Reserved, + "reserved", {true, false, false, false, false, false, false, false, false, false, false}, - Attribute::ArgMemOnly, + Attribute::None, }, { - OC::WaveMatrix_Add, - "WaveMatrix_Add", - OCC::WaveMatrix_Accumulate, - "waveMatrix_Accumulate", + OC::Reserved11, + "Reserved11", + OCC::Reserved, + "reserved", {true, false, false, false, false, false, false, false, false, false, false}, - Attribute::ArgMemOnly, + Attribute::None, }, // Create/Annotate Node Handles void, h, f, d, i1, i8, @@ -2836,14 +2835,9 @@ bool OP::IsDxilOpWave(OpCode C) { // WaveReadLaneFirst=118, WaveActiveOp=119, WaveActiveBit=120, // WavePrefixOp=121, QuadReadLaneAt=122, QuadOp=123, WaveAllBitCount=135, // WavePrefixBitCount=136, WaveMatch=165, WaveMultiPrefixOp=166, - // WaveMultiPrefixBitCount=167, QuadVote=222, WaveMatrix_Annotate=226, - // WaveMatrix_Depth=227, WaveMatrix_Fill=228, WaveMatrix_LoadRawBuf=229, - // WaveMatrix_LoadGroupShared=230, WaveMatrix_StoreRawBuf=231, - // WaveMatrix_StoreGroupShared=232, WaveMatrix_Multiply=233, - // WaveMatrix_MultiplyAccumulate=234, WaveMatrix_ScalarOp=235, - // WaveMatrix_SumAccumulate=236, WaveMatrix_Add=237 + // WaveMultiPrefixBitCount=167, QuadVote=222 return (110 <= op && op <= 123) || (135 <= op && op <= 136) || - (165 <= op && op <= 167) || op == 222 || (226 <= op && op <= 237); + (165 <= op && op <= 167) || op == 222; // OPCODE-WAVE:END } @@ -3340,18 +3334,6 @@ void OP::GetMinShaderModelAndMask(OpCode C, bool bWithTranslation, } return; } - // Instructions: WaveMatrix_Annotate=226, WaveMatrix_Depth=227, - // WaveMatrix_Fill=228, WaveMatrix_LoadRawBuf=229, - // WaveMatrix_LoadGroupShared=230, WaveMatrix_StoreRawBuf=231, - // WaveMatrix_StoreGroupShared=232, WaveMatrix_Multiply=233, - // WaveMatrix_MultiplyAccumulate=234, WaveMatrix_ScalarOp=235, - // WaveMatrix_SumAccumulate=236, WaveMatrix_Add=237 - if ((226 <= op && op <= 237)) { - major = 6; - minor = 9; - mask = SFLAG(Library) | SFLAG(Compute); - return; - } // OPCODE-SMMASK:END } @@ -3506,16 +3488,6 @@ OP::OP(LLVMContext &Ctx, Module *pModule) Type::getInt16Ty(m_Ctx)}; // HiHi, HiLo, LoHi, LoLo m_pFourI16Type = GetOrCreateStructType(m_Ctx, FourI16Types, "dx.types.fouri16", pModule); - - Type *WaveMatInfoTypes[4] = {Type::getInt8Ty(m_Ctx), Type::getInt8Ty(m_Ctx), - Type::getInt32Ty(m_Ctx), - Type::getInt32Ty(m_Ctx)}; - m_pWaveMatInfoType = cast(GetOrCreateStructType( - m_Ctx, WaveMatInfoTypes, "dx.types.waveMatProps", pModule)); - m_pWaveMatPtrType = - PointerType::get(GetOrCreateStructType(m_Ctx, Type::getInt8PtrTy(m_Ctx), - "dx.types.waveMatrix", pModule), - 0); } void OP::RefreshCache() { @@ -3613,11 +3585,6 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { Type *nodeProperty = GetNodePropertiesType(); Type *nodeRecordProperty = GetNodeRecordPropertiesType(); - Type *pWaveMatProps = GetWaveMatrixPropertiesType(); - Type *pWaveMatPtr = GetWaveMatPtrType(); - Type *pGSEltPtrTy = - pETy->isVoidTy() ? nullptr : pETy->getPointerTo(DXIL::kTGSMAddrSpace); - #define A(_x) ArgTypes.emplace_back(_x) #define RRT(_y) A(GetResRetType(_y)) #define CBRT(_y) A(GetCBufferRetType(_y)) @@ -5254,94 +5221,54 @@ Function *OP::GetOpFunc(OpCode opCode, Type *pOverloadType) { A(pI32); break; - // WaveMatrix - case OpCode::WaveMatrix_Annotate: + // + case OpCode::Reserved0: A(pV); A(pI32); - A(pWaveMatPtr); - A(pWaveMatProps); break; - case OpCode::WaveMatrix_Depth: - A(pI32); + case OpCode::Reserved1: + A(pV); A(pI32); - A(pWaveMatProps); break; - case OpCode::WaveMatrix_Fill: + case OpCode::Reserved2: A(pV); A(pI32); - A(pWaveMatPtr); - A(pETy); break; - case OpCode::WaveMatrix_LoadRawBuf: + case OpCode::Reserved3: A(pV); A(pI32); - A(pWaveMatPtr); - A(pRes); - A(pI32); - A(pI32); - A(pI8); - A(pI1); break; - case OpCode::WaveMatrix_LoadGroupShared: + case OpCode::Reserved4: A(pV); A(pI32); - A(pWaveMatPtr); - A(pGSEltPtrTy); - A(pI32); - A(pI32); - A(pI1); break; - case OpCode::WaveMatrix_StoreRawBuf: + case OpCode::Reserved5: A(pV); A(pI32); - A(pWaveMatPtr); - A(pRes); - A(pI32); - A(pI32); - A(pI8); - A(pI1); break; - case OpCode::WaveMatrix_StoreGroupShared: + case OpCode::Reserved6: A(pV); A(pI32); - A(pWaveMatPtr); - A(pGSEltPtrTy); - A(pI32); - A(pI32); - A(pI1); break; - case OpCode::WaveMatrix_Multiply: + case OpCode::Reserved7: A(pV); A(pI32); - A(pWaveMatPtr); - A(pWaveMatPtr); - A(pWaveMatPtr); break; - case OpCode::WaveMatrix_MultiplyAccumulate: + case OpCode::Reserved8: A(pV); A(pI32); - A(pWaveMatPtr); - A(pWaveMatPtr); - A(pWaveMatPtr); break; - case OpCode::WaveMatrix_ScalarOp: + case OpCode::Reserved9: A(pV); A(pI32); - A(pWaveMatPtr); - A(pI8); - A(pETy); break; - case OpCode::WaveMatrix_SumAccumulate: + case OpCode::Reserved10: A(pV); A(pI32); - A(pWaveMatPtr); - A(pWaveMatPtr); break; - case OpCode::WaveMatrix_Add: + case OpCode::Reserved11: A(pV); A(pI32); - A(pWaveMatPtr); - A(pWaveMatPtr); break; // Create/Annotate Node Handles @@ -5631,7 +5558,6 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { case OpCode::TempRegStore: case OpCode::CallShader: case OpCode::Pack4x8: - case OpCode::WaveMatrix_Fill: if (FT->getNumParams() <= 2) return nullptr; return FT->getParamType(2); @@ -5675,15 +5601,9 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { return nullptr; return FT->getParamType(15); case OpCode::ReportHit: - case OpCode::WaveMatrix_ScalarOp: if (FT->getNumParams() <= 3) return nullptr; return FT->getParamType(3); - case OpCode::WaveMatrix_LoadGroupShared: - case OpCode::WaveMatrix_StoreGroupShared: - if (FT->getNumParams() <= 2) - return nullptr; - return FT->getParamType(2)->getPointerElementType(); case OpCode::CreateHandle: case OpCode::BufferUpdateCounter: case OpCode::GetDimensions: @@ -5732,14 +5652,18 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { case OpCode::AnnotateHandle: case OpCode::CreateHandleFromBinding: case OpCode::CreateHandleFromHeap: - case OpCode::WaveMatrix_Annotate: - case OpCode::WaveMatrix_Depth: - case OpCode::WaveMatrix_LoadRawBuf: - case OpCode::WaveMatrix_StoreRawBuf: - case OpCode::WaveMatrix_Multiply: - case OpCode::WaveMatrix_MultiplyAccumulate: - case OpCode::WaveMatrix_SumAccumulate: - case OpCode::WaveMatrix_Add: + case OpCode::Reserved0: + case OpCode::Reserved1: + case OpCode::Reserved2: + case OpCode::Reserved3: + case OpCode::Reserved4: + case OpCode::Reserved5: + case OpCode::Reserved6: + case OpCode::Reserved7: + case OpCode::Reserved8: + case OpCode::Reserved9: + case OpCode::Reserved10: + case OpCode::Reserved11: case OpCode::AllocateNodeOutputRecords: case OpCode::IncrementOutputCount: case OpCode::OutputComplete: @@ -5890,11 +5814,6 @@ Type *OP::GetFourI32Type() const { return m_pFourI32Type; } Type *OP::GetFourI16Type() const { return m_pFourI16Type; } -StructType *OP::GetWaveMatrixPropertiesType() const { - return m_pWaveMatInfoType; -} -PointerType *OP::GetWaveMatPtrType() const { return m_pWaveMatPtrType; } - bool OP::IsResRetType(llvm::Type *Ty) { for (Type *ResTy : m_pResRetType) { if (Ty == ResTy) diff --git a/lib/DXIL/DxilShaderFlags.cpp b/lib/DXIL/DxilShaderFlags.cpp index 4a8229d2b7..7d0799dc64 100644 --- a/lib/DXIL/DxilShaderFlags.cpp +++ b/lib/DXIL/DxilShaderFlags.cpp @@ -45,7 +45,7 @@ ShaderFlags::ShaderFlags() m_bSamplerDescriptorHeapIndexing(false), m_bAtomicInt64OnHeapResource(false), m_bResMayNotAlias(false), m_bAdvancedTextureOps(false), m_bWriteableMSAATextures(false), - m_bWaveMMA(false), m_bSampleCmpGradientOrBias(false), + m_bReserved(false), m_bSampleCmpGradientOrBias(false), m_bExtendedCommandInfo(false), m_bUsesDerivatives(false), m_bRequiresGroup(false), m_align1(0) { // Silence unused field warnings @@ -125,8 +125,6 @@ uint64_t ShaderFlags::GetFeatureInfo() const { ? hlsl::DXIL::ShaderFeatureInfo_WriteableMSAATextures : 0; - Flags |= m_bWaveMMA ? hlsl::DXIL::ShaderFeatureInfo_WaveMMA : 0; - Flags |= m_bSampleCmpGradientOrBias ? hlsl::DXIL::ShaderFeatureInfo_SampleCmpGradientOrBias : 0; @@ -198,7 +196,6 @@ uint64_t ShaderFlags::GetShaderFlagsRawForCollection() { Flags.SetResMayNotAlias(true); Flags.SetAdvancedTextureOps(true); Flags.SetWriteableMSAATextures(true); - Flags.SetWaveMMA(true); Flags.SetSampleCmpGradientOrBias(true); Flags.SetExtendedCommandInfo(true); Flags.SetUsesDerivatives(true); @@ -447,7 +444,6 @@ ShaderFlags ShaderFlags::CollectShaderFlags(const Function *F, bool hasAdvancedTextureOps = false; bool hasSampleCmpGradientOrBias = false; - bool hasWaveMMA = false; bool hasExtendedCommandInfo = false; // UsesDerivatives is used to indicate any derivative use per-function, before @@ -721,20 +717,6 @@ ShaderFlags ShaderFlags::CollectShaderFlags(const Function *F, case DXIL::OpCode::TextureGatherRaw: hasAdvancedTextureOps = true; break; - case DXIL::OpCode::WaveMatrix_Add: - case DXIL::OpCode::WaveMatrix_Annotate: - case DXIL::OpCode::WaveMatrix_Depth: - case DXIL::OpCode::WaveMatrix_Fill: - case DXIL::OpCode::WaveMatrix_LoadGroupShared: - case DXIL::OpCode::WaveMatrix_LoadRawBuf: - case DXIL::OpCode::WaveMatrix_Multiply: - case DXIL::OpCode::WaveMatrix_MultiplyAccumulate: - case DXIL::OpCode::WaveMatrix_ScalarOp: - case DXIL::OpCode::WaveMatrix_StoreGroupShared: - case DXIL::OpCode::WaveMatrix_StoreRawBuf: - case DXIL::OpCode::WaveMatrix_SumAccumulate: - hasWaveMMA = true; - break; case DXIL::OpCode::StartVertexLocation: case DXIL::OpCode::StartInstanceLocation: hasExtendedCommandInfo = true; @@ -862,7 +844,6 @@ ShaderFlags ShaderFlags::CollectShaderFlags(const Function *F, flag.SetWriteableMSAATextures(setWriteableMSAATextures_1_7 ? hasWriteableMSAATextures_1_7 : hasWriteableMSAATextures); - flag.SetWaveMMA(hasWaveMMA); // Only bother setting the flag when there are UAVs. flag.SetResMayNotAlias(canSetResMayNotAlias && hasUAVs && !M->GetResMayAlias()); diff --git a/lib/DXIL/DxilUtil.cpp b/lib/DXIL/DxilUtil.cpp index 29fa33f847..757a0bc3ee 100644 --- a/lib/DXIL/DxilUtil.cpp +++ b/lib/DXIL/DxilUtil.cpp @@ -568,9 +568,6 @@ bool IsHLSLObjectType(llvm::Type *Ty) { if (name.startswith("LineStream<")) return true; - if (IsHLSLWaveMatrixType(Ty)) - return true; - if (IsHLSLNodeIOType(Ty)) return true; } @@ -590,36 +587,6 @@ bool IsHLSLRayQueryType(llvm::Type *Ty) { return false; } -bool IsHLSLWaveMatrixType(llvm::Type *Ty, DXIL::WaveMatrixKind *pKind) { - if (Ty->isPointerTy()) - Ty = Ty->getPointerElementType(); - if (llvm::StructType *ST = dyn_cast(Ty)) { - if (!ST->hasName()) - return false; - StringRef name = ST->getName(); - // TODO: don't check names. - ConsumePrefix(name, "class."); - if (!ConsumePrefix(name, "WaveMatrix")) - return false; - DXIL::WaveMatrixKind kind = DXIL::WaveMatrixKind::NumKinds; - if (name.startswith("Left<")) - kind = DXIL::WaveMatrixKind::Left; - if (name.startswith("Right<")) - kind = DXIL::WaveMatrixKind::Right; - if (name.startswith("LeftColAcc<")) - kind = DXIL::WaveMatrixKind::LeftColAcc; - if (name.startswith("RightRowAcc<")) - kind = DXIL::WaveMatrixKind::RightRowAcc; - if (name.startswith("Accumulator<")) - kind = DXIL::WaveMatrixKind::Accumulator; - if (pKind) - *pKind = kind; - if (kind != DXIL::WaveMatrixKind::NumKinds) - return true; - } - return false; -} - bool IsHLSLResourceDescType(llvm::Type *Ty) { if (llvm::StructType *ST = dyn_cast(Ty)) { if (!ST->hasName()) diff --git a/lib/DXIL/DxilWaveMatrix.cpp b/lib/DXIL/DxilWaveMatrix.cpp deleted file mode 100644 index cdeadff439..0000000000 --- a/lib/DXIL/DxilWaveMatrix.cpp +++ /dev/null @@ -1,78 +0,0 @@ -/////////////////////////////////////////////////////////////////////////////// -// // -// DxilWaveMatrix.cpp // -// Copyright (C) Microsoft Corporation. All rights reserved. // -// This file is distributed under the University of Illinois Open Source // -// License. See LICENSE.TXT for details. // -// // -/////////////////////////////////////////////////////////////////////////////// - -#include "dxc/DXIL/DxilWaveMatrix.h" -#include "dxc/DXIL/DxilInstructions.h" -#include "dxc/DXIL/DxilModule.h" -#include "dxc/DXIL/DxilOperations.h" -#include "dxc/DXIL/DxilShaderModel.h" -#include "dxc/DXIL/DxilUtil.h" -#include "dxc/Support/Global.h" -#include "llvm/IR/Constant.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" - -using namespace llvm; - -namespace hlsl { - -DxilWaveMatrixProperties -wavemat_helper::LoadInfoFromConstant(llvm::Constant *C) { - DXASSERT(!isa(C), - "otherwise, DxilWaveMatrixProperties has invalid value"); - const ConstantStruct *CS = cast(C); - DXASSERT(CS->getType()->getNumElements() == 4, - "otherwise, struct is not expected layout"); - DxilWaveMatrixProperties info; - info.kind = (DXIL::WaveMatrixKind)cast(CS->getOperand(0)) - ->getLimitedValue(); - info.compType = (DXIL::ComponentType)cast(CS->getOperand(1)) - ->getLimitedValue(); - info.dimM = (uint32_t)cast(CS->getOperand(2))->getLimitedValue(); - info.dimN = (uint32_t)cast(CS->getOperand(3))->getLimitedValue(); - return info; -} - -Constant * -wavemat_helper::GetInfoConstantFromWaveMatPtr(llvm::Value *waveMatPtr) { - DXASSERT_NOMSG(isa(waveMatPtr)); - for (auto *U : waveMatPtr->users()) { - Instruction *I = cast(U); - DxilInst_WaveMatrix_Annotate annotate(I); - if (annotate) { - DXASSERT_NOMSG(isa(annotate.get_waveMatProps())); - return cast(annotate.get_waveMatProps()); - } - } - return nullptr; -} - -DxilWaveMatrixProperties -wavemat_helper::GetInfoFromWaveMatPtr(llvm::Value *waveMatPtr) { - Constant *infoC = wavemat_helper::GetInfoConstantFromWaveMatPtr(waveMatPtr); - DXASSERT(infoC, "otherwise, no WaveMatAnnotate call found for ptr"); - return wavemat_helper::LoadInfoFromConstant(infoC); -} - -llvm::Constant * -wavemat_helper::GetAsConstant(const DxilWaveMatrixProperties &info, - llvm::StructType *infoTy) { - LLVMContext &Ctx = infoTy->getContext(); - IntegerType *i8Ty = IntegerType::get(Ctx, 8); - IntegerType *i32Ty = IntegerType::get(Ctx, 32); - return ConstantStruct::get(cast(infoTy), - {ConstantInt::get(i8Ty, (unsigned)info.kind), - ConstantInt::get(i8Ty, (unsigned)info.compType), - ConstantInt::get(i32Ty, (unsigned)info.dimM), - ConstantInt::get(i32Ty, (unsigned)info.dimN)}); -} - -} // namespace hlsl diff --git a/lib/HLSL/DxilGenerationPass.cpp b/lib/HLSL/DxilGenerationPass.cpp index 664d5c1fa5..7d902a4ed7 100644 --- a/lib/HLSL/DxilGenerationPass.cpp +++ b/lib/HLSL/DxilGenerationPass.cpp @@ -244,8 +244,6 @@ class DxilGenerationPass : public ModulePass { } } - LowerHLAnnotateWaveMatrix(M); - std::unordered_set UpdateCounterSet; LowerRecordAccessToGetNodeRecordPtr(*m_pHLModule); @@ -319,7 +317,6 @@ class DxilGenerationPass : public ModulePass { std::unordered_set &UpdateCounterSet); void LowerHLCreateHandle( std::unordered_map &HandleToResTypeMap); - void LowerHLAnnotateWaveMatrix(Module &M); // Translate precise attribute into HL function call. void TranslatePreciseAttribute(); @@ -652,37 +649,6 @@ void DxilGenerationPass::LowerHLCreateHandle( } } -void DxilGenerationPass::LowerHLAnnotateWaveMatrix(Module &M) { - hlsl::OP &hlslOP = *m_pHLModule->GetOP(); - Value *opArg = - hlslOP.GetU32Const((unsigned)DXIL::OpCode::WaveMatrix_Annotate); - for (iplist::iterator F : M.getFunctionList()) { - if (F->user_empty()) - continue; - if (hlsl::GetHLOpcodeGroup(F) == HLOpcodeGroup::HLWaveMatrix_Annotate) { - for (auto U = F->user_begin(); U != F->user_end();) { - Value *User = *(U++); - if (!isa(User)) - continue; - // must be call inst - CallInst *CI = cast(User); - IRBuilder<> Builder(CI); - Value *waveMatPtr = - CI->getArgOperand(HLOperandIndex::kAnnotateWaveMatrixPtrOpIdx); - Value *WMP = CI->getArgOperand( - HLOperandIndex::kAnnotateWaveMatrixPropertiesOpIdx); - Function *annotateWaveMatrix = hlslOP.GetOpFunc( - DXIL::OpCode::WaveMatrix_Annotate, Builder.getVoidTy()); - CallInst *newCI = - Builder.CreateCall(annotateWaveMatrix, {opArg, waveMatPtr, WMP}); - if (!CI->user_empty()) - CI->replaceAllUsesWith(Builder.CreateBitCast(newCI, CI->getType())); - CI->eraseFromParent(); - } - } - } -} - static void MarkUavUpdateCounter(Value *LoadOrGEP, DxilResource &res, std::unordered_set &UpdateCounterSet) { diff --git a/lib/HLSL/DxilValidation.cpp b/lib/HLSL/DxilValidation.cpp index 1b64c8519e..db2915afd5 100644 --- a/lib/HLSL/DxilValidation.cpp +++ b/lib/HLSL/DxilValidation.cpp @@ -155,7 +155,6 @@ struct ValidationContext { Module *pDebugModule; DxilModule &DxilMod; const Type *HandleTy; - const Type *WaveMatrixTy; const DataLayout &DL; DebugLoc LastDebugLocEmit; ValidationRule LastRuleEmit; @@ -190,8 +189,6 @@ struct ValidationContext { slotTracker(&llvmModule, true) { DxilMod.GetDxilVersion(m_DxilMajor, m_DxilMinor); HandleTy = DxilMod.GetOP()->GetHandleType(); - WaveMatrixTy = - DxilMod.GetOP()->GetWaveMatPtrType()->getPointerElementType(); for (Function &F : llvmModule.functions()) { if (DxilMod.HasDxilEntryProps(&F)) { @@ -2687,7 +2684,7 @@ static bool ValidateType(Type *Ty, ValidationContext &ValCtx, StringRef Name = ST->getName(); if (Name.startswith("dx.")) { // Allow handle type. - if (ValCtx.HandleTy == Ty || ValCtx.WaveMatrixTy == Ty) + if (ValCtx.HandleTy == Ty) return true; hlsl::OP *hlslOP = ValCtx.DxilMod.GetOP(); if (IsDxilBuiltinStructType(ST, hlslOP)) { diff --git a/lib/HLSL/HLOperationLower.cpp b/lib/HLSL/HLOperationLower.cpp index 50ed9edb15..b07577374f 100644 --- a/lib/HLSL/HLOperationLower.cpp +++ b/lib/HLSL/HLOperationLower.cpp @@ -19,7 +19,6 @@ #include "dxc/DXIL/DxilOperations.h" #include "dxc/DXIL/DxilResourceProperties.h" #include "dxc/DXIL/DxilUtil.h" -#include "dxc/DXIL/DxilWaveMatrix.h" #include "dxc/HLSL/DxilPoisonValues.h" #include "dxc/HLSL/HLLowerUDT.h" #include "dxc/HLSL/HLMatrixLowerHelper.h" @@ -53,11 +52,7 @@ struct HLOperationLowerHelper { DxilFunctionProps *functionProps; DataLayout dataLayout; SmallDenseMap loweredTypes; - typedef std::pair WaveMatrix_Props; - typedef DenseMap WaveMatrix_PropMap; - WaveMatrix_PropMap waveMatPropMap; HLOperationLowerHelper(HLModule &HLM); - const WaveMatrix_Props &GetWaveMatInfo(Value *waveMatPtr); }; HLOperationLowerHelper::HLOperationLowerHelper(HLModule &HLM) @@ -78,19 +73,6 @@ HLOperationLowerHelper::HLOperationLowerHelper(HLModule &HLM) functionProps = &HLM.GetDxilFunctionProps(EntryFunc); } -const HLOperationLowerHelper::WaveMatrix_Props & -HLOperationLowerHelper::GetWaveMatInfo(Value *waveMatPtr) { - auto it = waveMatPropMap.find(waveMatPtr); - if (it == waveMatPropMap.end()) { - Constant *infoC = wavemat_helper::GetInfoConstantFromWaveMatPtr(waveMatPtr); - DxilWaveMatrixProperties info = wavemat_helper::LoadInfoFromConstant(infoC); - it = waveMatPropMap - .insert(std::make_pair(waveMatPtr, std::make_pair(info, infoC))) - .first; - } - return it->second; -} - struct HLObjectOperationLowerHelper { private: // For object intrinsics. @@ -4304,12 +4286,6 @@ void TranslateLoad(ResLoadHelper &helper, HLResource::Kind RK, UpdateStatus(ResRet, helper.status, Builder, OP); } -Value *TranslateWaveMatLoadStore(CallInst *CI, IntrinsicOp IOP, - OP::OpCode opcode, - HLOperationLowerHelper &helper, - HLObjectOperationLowerHelper *pObjHelper, - bool &Translated); - Value *TranslateResourceLoad(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, HLOperationLowerHelper &helper, HLObjectOperationLowerHelper *pObjHelper, @@ -4317,11 +4293,6 @@ Value *TranslateResourceLoad(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, hlsl::OP *hlslOP = &helper.hlslOP; Value *handle = CI->getArgOperand(HLOperandIndex::kHandleOpIdx); - // object.Load(...) could be WaveMatrix Load instead of resource method - if (handle->getType() == hlslOP->GetWaveMatPtrType()) - return TranslateWaveMatLoadStore(CI, IOP, opcode, helper, pObjHelper, - Translated); - IRBuilder<> Builder(CI); DXIL::ResourceClass RC = pObjHelper->GetRC(handle); @@ -4609,11 +4580,6 @@ Value *TranslateResourceStore(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, hlsl::OP *hlslOP = &helper.hlslOP; Value *handle = CI->getArgOperand(HLOperandIndex::kHandleOpIdx); - // object.Store(...) could be WaveMatrix Store instead of resource method - if (handle->getType() == hlslOP->GetWaveMatPtrType()) - return TranslateWaveMatLoadStore(CI, IOP, opcode, helper, pObjHelper, - Translated); - IRBuilder<> Builder(CI); DXIL::ResourceKind RK = pObjHelper->GetRK(handle); @@ -6095,188 +6061,6 @@ Value *TranslateUnpack(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, return ResVec; } -Value *TranslateWaveMatrixDepth(CallInst *CI, IntrinsicOp IOP, - OP::OpCode opcode, - HLOperationLowerHelper &helper, - HLObjectOperationLowerHelper *pObjHelper, - bool &Translated) { - hlsl::OP *hlslOP = &helper.hlslOP; - - Value *thisWaveMatPtr = CI->getArgOperand(HLOperandIndex::kWaveMatThisOpIdx); - const auto &props = helper.GetWaveMatInfo(thisWaveMatPtr); - - IRBuilder<> Builder(CI); - Function *dxilFunc = hlslOP->GetOpFunc(opcode, helper.voidTy); - Constant *opArg = hlslOP->GetU32Const((unsigned)opcode); - return Builder.CreateCall(dxilFunc, {opArg, props.second}); -} - -Value *TranslateWaveMatrixFill(CallInst *CI, IntrinsicOp IOP, OP::OpCode opcode, - HLOperationLowerHelper &helper, - HLObjectOperationLowerHelper *pObjHelper, - bool &Translated) { - hlsl::OP *hlslOP = &helper.hlslOP; - - Value *thisWaveMatPtr = CI->getArgOperand(HLOperandIndex::kWaveMatThisOpIdx); - Value *val = CI->getArgOperand(HLOperandIndex::kWaveMatFillScalarOpIdx); - - IRBuilder<> Builder(CI); - Function *dxilFunc = hlslOP->GetOpFunc(opcode, val->getType()); - Constant *opArg = hlslOP->GetU32Const((unsigned)opcode); - return Builder.CreateCall(dxilFunc, {opArg, thisWaveMatPtr, val}); -} - -Value *TranslateWaveMatrixScalarOp(CallInst *CI, IntrinsicOp IOP, - OP::OpCode opcode, - HLOperationLowerHelper &helper, - HLObjectOperationLowerHelper *pObjHelper, - bool &Translated) { - hlsl::OP *hlslOP = &helper.hlslOP; - - Value *thisWaveMatPtr = CI->getArgOperand(HLOperandIndex::kWaveMatThisOpIdx); - Value *val = CI->getArgOperand(HLOperandIndex::kWaveMatScalarOpOpIdx); - - DXIL::WaveMatrixScalarOpCode scalarOp = DXIL::WaveMatrixScalarOpCode::Invalid; - switch (IOP) { - case IntrinsicOp::MOP_ScalarAdd: - scalarOp = DXIL::WaveMatrixScalarOpCode::Add; - break; - case IntrinsicOp::MOP_ScalarSubtract: - scalarOp = DXIL::WaveMatrixScalarOpCode::Subtract; - break; - case IntrinsicOp::MOP_ScalarMultiply: - scalarOp = DXIL::WaveMatrixScalarOpCode::Multiply; - break; - case IntrinsicOp::MOP_ScalarDivide: - scalarOp = DXIL::WaveMatrixScalarOpCode::Divide; - break; - default: - DXASSERT(false, "Missing case for WaveMatrix scalar operation"); - } - - IRBuilder<> Builder(CI); - Function *dxilFunc = hlslOP->GetOpFunc(opcode, val->getType()); - Constant *opArg = hlslOP->GetU32Const((unsigned)opcode); - Constant *scalarOpArg = hlslOP->GetU8Const((unsigned)scalarOp); - return Builder.CreateCall(dxilFunc, - {opArg, thisWaveMatPtr, scalarOpArg, val}); -} - -Value *TranslateWaveMatrix_Accumulate(CallInst *CI, IntrinsicOp IOP, - OP::OpCode opcode, - HLOperationLowerHelper &helper, - HLObjectOperationLowerHelper *pObjHelper, - bool &Translated) { - hlsl::OP *hlslOP = &helper.hlslOP; - - Value *thisWaveMatPtr = CI->getArgOperand(HLOperandIndex::kWaveMatThisOpIdx); - Value *otherWaveMatPtr1 = - CI->getArgOperand(HLOperandIndex::kWaveMatOther1OpIdx); - - IRBuilder<> Builder(CI); - Function *dxilFunc = hlslOP->GetOpFunc(opcode, helper.voidTy); - Constant *opArg = hlslOP->GetU32Const((unsigned)opcode); - return Builder.CreateCall(dxilFunc, - {opArg, thisWaveMatPtr, otherWaveMatPtr1}); -} - -Value *TranslateWaveMatrixMultiply(CallInst *CI, IntrinsicOp IOP, - OP::OpCode opcode, - HLOperationLowerHelper &helper, - HLObjectOperationLowerHelper *pObjHelper, - bool &Translated) { - hlsl::OP *hlslOP = &helper.hlslOP; - - Value *thisWaveMatPtr = CI->getArgOperand(HLOperandIndex::kWaveMatThisOpIdx); - Value *otherWaveMatPtr1 = - CI->getArgOperand(HLOperandIndex::kWaveMatOther1OpIdx); - Value *otherWaveMatPtr2 = - CI->getArgOperand(HLOperandIndex::kWaveMatOther2OpIdx); - - IRBuilder<> Builder(CI); - Function *dxilFunc = hlslOP->GetOpFunc(opcode, helper.voidTy); - Constant *opArg = hlslOP->GetU32Const((unsigned)opcode); - return Builder.CreateCall( - dxilFunc, {opArg, thisWaveMatPtr, otherWaveMatPtr1, otherWaveMatPtr2}); -} - -Value *TranslateWaveMatLoadStore(CallInst *CI, IntrinsicOp IOP, - OP::OpCode opcode, - HLOperationLowerHelper &helper, - HLObjectOperationLowerHelper *pObjHelper, - bool &Translated) { - hlsl::OP *hlslOP = &helper.hlslOP; - - // buf is raw buffer handle or groupshared ptr: - Value *buf = CI->getArgOperand(HLOperandIndex::kWaveMatLoadStoreBufOpIdx); - Type *bufETy = buf->getType(); - bool bRawBuf = bufETy == hlslOP->GetHandleType(); - if (!bRawBuf) { - Constant *C = dyn_cast(buf); - if (auto *CE = dyn_cast(C)) - C = CE->getOperand(0)->stripPointerCasts(); - DXASSERT( - C && C->getType()->getPointerAddressSpace() == DXIL::kTGSMAddrSpace, - "otherwise, non-groupshared type passed to groupshared Load/Store"); - bufETy = dxilutil::StripArrayTypes(C->getType()->getPointerElementType()); - buf = ConstantExpr::getPointerBitCastOrAddrSpaceCast( - C, bufETy->getPointerTo(DXIL::kTGSMAddrSpace)); - } - - // Determine if fragment (LeftColAcc/RightRowAcc) - const auto &props = helper.GetWaveMatInfo( - CI->getArgOperand(HLOperandIndex::kWaveMatThisOpIdx)); - DXIL::WaveMatrixKind waveMatKind = props.first.kind; - bool bFragment = waveMatKind == DXIL::WaveMatrixKind::LeftColAcc || - waveMatKind == DXIL::WaveMatrixKind::RightRowAcc; - - if (IOP == IntrinsicOp::MOP_Load) { - opcode = bRawBuf ? OP::OpCode::WaveMatrix_LoadRawBuf - : OP::OpCode::WaveMatrix_LoadGroupShared; - } else if (IOP == IntrinsicOp::MOP_Store) { - opcode = bRawBuf ? OP::OpCode::WaveMatrix_StoreRawBuf - : OP::OpCode::WaveMatrix_StoreGroupShared; - } else { - DXASSERT(0, "otherwise, unexpected IntrinsicOp"); - } - - Function *dxilFunc = - hlslOP->GetOpFunc(opcode, bRawBuf ? helper.voidTy : bufETy); - - IRBuilder<> Builder(CI); - SmallVector args; - args.push_back(hlslOP->GetU32Const((unsigned)opcode)); - args.push_back(CI->getArgOperand(HLOperandIndex::kWaveMatThisOpIdx)); - args.push_back(buf); - args.push_back( - CI->getArgOperand(HLOperandIndex::kWaveMatLoadStoreStartOpIdx)); - - // For fragment, stride is element stride with same argument mapping. - args.push_back( - CI->getArgOperand(HLOperandIndex::kWaveMatLoadStoreStrideOpIdx)); - - // if handle, push align arg - if (bRawBuf) { - Value *align = ConstantInt::get(helper.i8Ty, (uint64_t)0); - const unsigned AlignOpIdx = - bFragment ? HLOperandIndex::kWaveMatFragLoadStoreAlignmentOpIdx - : HLOperandIndex::kWaveMatLoadStoreAlignmentOpIdx; - if (CI->getNumArgOperands() > AlignOpIdx) { - align = CI->getArgOperand(AlignOpIdx); - align = Builder.CreateTrunc(align, helper.i8Ty); - } - args.push_back(align); - } - - // No orientation for matrix fragments, just use i1 0 for unused arg. - args.push_back( - bFragment - ? ConstantInt::get(helper.i1Ty, (uint64_t)0) - : CI->getArgOperand(HLOperandIndex::kWaveMatLoadStoreColMajorOpIdx)); - - return Builder.CreateCall(dxilFunc, args); -} - } // namespace // Resource Handle. @@ -6953,26 +6737,6 @@ IntrinsicLower gLowerTable[] = { DXIL::OpCode::RayQuery_WorldRayDirection}, {IntrinsicOp::MOP_WorldRayOrigin, TranslateRayQueryFloat3Getter, DXIL::OpCode::RayQuery_WorldRayOrigin}, - {IntrinsicOp::MOP_Fill, TranslateWaveMatrixFill, - DXIL::OpCode::WaveMatrix_Fill}, - {IntrinsicOp::MOP_MatrixDepth, TranslateWaveMatrixDepth, - DXIL::OpCode::WaveMatrix_Depth}, - {IntrinsicOp::MOP_ScalarAdd, TranslateWaveMatrixScalarOp, - DXIL::OpCode::WaveMatrix_ScalarOp}, - {IntrinsicOp::MOP_ScalarDivide, TranslateWaveMatrixScalarOp, - DXIL::OpCode::WaveMatrix_ScalarOp}, - {IntrinsicOp::MOP_ScalarMultiply, TranslateWaveMatrixScalarOp, - DXIL::OpCode::WaveMatrix_ScalarOp}, - {IntrinsicOp::MOP_ScalarSubtract, TranslateWaveMatrixScalarOp, - DXIL::OpCode::WaveMatrix_ScalarOp}, - {IntrinsicOp::MOP_SumAccumulate, TranslateWaveMatrix_Accumulate, - DXIL::OpCode::WaveMatrix_SumAccumulate}, - {IntrinsicOp::MOP_Add, TranslateWaveMatrix_Accumulate, - DXIL::OpCode::WaveMatrix_Add}, - {IntrinsicOp::MOP_Multiply, TranslateWaveMatrixMultiply, - DXIL::OpCode::WaveMatrix_Multiply}, - {IntrinsicOp::MOP_MultiplyAccumulate, TranslateWaveMatrixMultiply, - DXIL::OpCode::WaveMatrix_MultiplyAccumulate}, {IntrinsicOp::MOP_Count, TranslateNodeGetInputRecordCount, DXIL::OpCode::GetInputRecordCount}, {IntrinsicOp::MOP_FinishedCrossGroupSharing, diff --git a/lib/HLSL/HLOperations.cpp b/lib/HLSL/HLOperations.cpp index 3b3ada8815..2cb3c489e8 100644 --- a/lib/HLSL/HLOperations.cpp +++ b/lib/HLSL/HLOperations.cpp @@ -46,7 +46,6 @@ static StringRef HLOpcodeGroupNames[]{ "indexnodehandle", // HLIndexNodeHandle: "createnodeinputrecordhandle", // HLCreateNodeInputRecordHandle "annotatehandle", // HLAnnotateHandle, - "wavematrix_annotate", // HLWaveMatrix_Annotate, "annotatenodehandle", // HLAnnotateNodeHandle "annotatenoderecordhandle", // HLAnnotateNodeRecordHandle "numOfHLDXIL", // NumOfHLOps @@ -71,7 +70,6 @@ static StringRef HLOpcodeGroupFullNames[]{ "dx.hl.indexnodehandle", // HLIndexNodeHandle "dx.hl.createnodeinputrecordhandle", // HLCreateNodeInputRecordHandle "dx.hl.annotatehandle", // HLAnnotateHandle, - "dx.hl.wavematrix_annotate", // HLWaveMatrix_Annotate, "dx.hl.annotatenodehandle", // HLAnnotateNodeHandle, "dx.hl.annotatenoderecordhandle", // HLAnnotateNodeRecordHandle "numOfHLDXIL", // NumOfHLOps @@ -96,7 +94,6 @@ static HLOpcodeGroup GetHLOpcodeGroupInternal(StringRef group) { .Case("createnodeinputrecordhandle", HLOpcodeGroup::HLCreateNodeInputRecordHandle) .Case("annotatehandle", HLOpcodeGroup::HLAnnotateHandle) - .Case("wavematrix_annotate", HLOpcodeGroup::HLWaveMatrix_Annotate) .Case("annotatenodehandle", HLOpcodeGroup::HLAnnotateNodeHandle) .Case("annotatenoderecordhandle", HLOpcodeGroup::HLAnnotateNodeRecordHandle) @@ -155,7 +152,6 @@ StringRef GetHLOpcodeGroupName(HLOpcodeGroup op) { case HLOpcodeGroup::HLIndexNodeHandle: case HLOpcodeGroup::HLCreateNodeInputRecordHandle: case HLOpcodeGroup::HLAnnotateHandle: - case HLOpcodeGroup::HLWaveMatrix_Annotate: case HLOpcodeGroup::HLAnnotateNodeHandle: case HLOpcodeGroup::HLAnnotateNodeRecordHandle: return HLOpcodeGroupNames[static_cast(op)]; @@ -180,7 +176,6 @@ StringRef GetHLOpcodeGroupFullName(HLOpcodeGroup op) { case HLOpcodeGroup::HLIndexNodeHandle: case HLOpcodeGroup::HLCreateNodeInputRecordHandle: case HLOpcodeGroup::HLAnnotateHandle: - case HLOpcodeGroup::HLWaveMatrix_Annotate: case HLOpcodeGroup::HLAnnotateNodeHandle: case HLOpcodeGroup::HLAnnotateNodeRecordHandle: return HLOpcodeGroupFullNames[static_cast(op)]; @@ -525,9 +520,6 @@ static AttributeSet GetHLFunctionAttributes(LLVMContext &C, case HLOpcodeGroup::HLAnnotateHandle: { addAttr(Attribute::ReadNone); } break; - case HLOpcodeGroup::HLWaveMatrix_Annotate: { - addAttr(Attribute::ArgMemOnly); - } break; case HLOpcodeGroup::HLIntrinsic: { IntrinsicOp intrinsicOp = static_cast(opcode); switch (intrinsicOp) { diff --git a/lib/Transforms/IPO/PassManagerBuilder.cpp b/lib/Transforms/IPO/PassManagerBuilder.cpp index feb3d974b9..46dc5508b9 100644 --- a/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -289,9 +289,6 @@ void PassManagerBuilder::addHLSLPasses(legacy::PassManagerBase &MPM) { // Verify no undef resource again after promotion MPM.add(createInvalidateUndefResourcesPass()); - // Translate HL WaveMatrix ptrs to final dxil type - MPM.add(createLowerWaveMatTypePass()); - MPM.add(createDxilGenerationPass(NoOpt, this->HLSLExtensionsCodeGen)); // Propagate precise attribute. diff --git a/lib/Transforms/Scalar/LowerTypePasses.cpp b/lib/Transforms/Scalar/LowerTypePasses.cpp index ad0da67428..feeb23a5da 100644 --- a/lib/Transforms/Scalar/LowerTypePasses.cpp +++ b/lib/Transforms/Scalar/LowerTypePasses.cpp @@ -900,145 +900,3 @@ INITIALIZE_PASS(ResourceToHandle, "resource-handle", ModulePass *llvm::createResourceToHandlePass() { return new ResourceToHandle(); } - -//===----------------------------------------------------------------------===// -// Lower WaveMatrix types to single dxil type. -//===----------------------------------------------------------------------===// - -namespace { - -class LowerWaveMatType : public LowerTypePass { -public: - explicit LowerWaveMatType() : LowerTypePass(ID) {} - static char ID; // Pass identification, replacement for typeid -protected: - bool needToLower(Value *V) override; - void lowerUseWithNewValue(Value *V, Value *NewV) override; - Type *lowerType(Type *Ty) override; - Constant *lowerInitVal(Constant *InitVal, Type *NewTy) override; - StringRef getGlobalPrefix() override { return ".res"; } - void initialize(Module &M) override; - -private: - void lowerUserWithNewValue(User *U, Value *V, Value *NewV); - - Type *m_WaveMatTy = nullptr; - HLModule *m_pHLM = nullptr; -}; - -void LowerWaveMatType::initialize(Module &M) { - DXASSERT(M.HasHLModule(), "require HLModule"); - m_pHLM = &M.GetHLModule(); - m_WaveMatTy = m_pHLM->GetOP()->GetWaveMatPtrType()->getPointerElementType(); -} - -bool LowerWaveMatType::needToLower(Value *V) { - return dxilutil::IsHLSLWaveMatrixType(dxilutil::GetArrayEltTy(V->getType())); -} - -Type *LowerWaveMatType::lowerType(Type *Ty) { - if (Ty->isPointerTy()) { - return PointerType::get(lowerType(Ty->getPointerElementType()), - Ty->getPointerAddressSpace()); - } else if (Ty->isArrayTy()) { - llvm::SmallVector OuterToInnerLengths; - Ty = dxilutil::StripArrayTypes(Ty, &OuterToInnerLengths); - DXASSERT(dxilutil::IsHLSLWaveMatrixType(Ty), - "otherwise, unexpected wave matrix type to lower"); - return dxilutil::WrapInArrayTypes(m_WaveMatTy, OuterToInnerLengths); - } else if (dxilutil::IsHLSLWaveMatrixType(Ty)) { - return m_WaveMatTy; - } - DXASSERT(0, "otherwise, unexpected wave matrix type to lower"); - return Ty; -} - -Constant *LowerWaveMatType::lowerInitVal(Constant *InitVal, Type *NewTy) { - DXASSERT(isa(InitVal), "wave matrix cannot have real init val"); - return UndefValue::get(NewTy); -} - -// Rewrite call, replacing argument with new type -static CallInst *RewriteIntrinsicCallForNewArg(CallInst *CI, Value *OldV, - Value *NewV, - Type *NewRet = nullptr) { - Function *F = CI->getCalledFunction(); - HLOpcodeGroup group = GetHLOpcodeGroupByName(F); - unsigned opcode = GetHLOpcode(CI); - SmallVector newArgTypes(CI->getFunctionType()->param_begin(), - CI->getFunctionType()->param_end()); - SmallVector newArgs(CI->arg_operands()); - - for (unsigned i = 1; i < newArgs.size(); i++) { - if (newArgs[i] == OldV) { - newArgTypes[i] = NewV->getType(); - newArgs[i] = NewV; - } - } - - if (NewRet == nullptr) - NewRet = CI->getType(); - - FunctionType *newFuncTy = FunctionType::get(NewRet, newArgTypes, false); - Function *newF = - GetOrCreateHLFunction(*F->getParent(), newFuncTy, group, opcode, - F->getAttributes().getFnAttributes()); - IRBuilder<> Builder(CI); - return Builder.CreateCall(newF, newArgs); -} - -void LowerWaveMatType::lowerUserWithNewValue(User *U, Value *V, Value *NewV) { - if (CallInst *CI = dyn_cast(U)) { - HLOpcodeGroup group = GetHLOpcodeGroupByName(CI->getCalledFunction()); - if (group == HLOpcodeGroup::HLWaveMatrix_Annotate || - group == HLOpcodeGroup::HLIntrinsic) { - Type *NewRet = needToLower(CI) ? lowerType(CI->getType()) : nullptr; - Value *NewU = RewriteIntrinsicCallForNewArg(CI, V, NewV, NewRet); - if (!U->user_empty()) { - if (NewRet) - lowerUseWithNewValue(U, NewU); - else - U->replaceAllUsesWith(NewU); - } - return; - } - } else if (BitCastInst *BI = dyn_cast(U)) { - BI->setOperand(0, NewV); - return; - } - - DXASSERT(0, "invalid operation on WaveMatrix pointer"); -} - -void LowerWaveMatType::lowerUseWithNewValue(Value *V, Value *NewV) { - SmallVector deadInsts; - for (auto it = V->user_begin(); it != V->user_end();) { - User *U = *it; - // Prevent double User iteration when multiple Uses in same User - while (it != V->user_end() && *it == U) - ++it; - if (GEPOperator *GEP = dyn_cast(U)) { - if (!GEP->user_empty()) - lowerUseWithNewValue(U, dxilutil::MirrorGEP(GEP, NewV)); - } else { - lowerUserWithNewValue(U, V, NewV); - } - if (Instruction *I = dyn_cast(U)) - if (I->user_empty()) - deadInsts.push_back(I); - } - for (auto I : deadInsts) - I->eraseFromParent(); -} - -} // namespace - -char LowerWaveMatType::ID = 0; - -INITIALIZE_PASS(LowerWaveMatType, "hlsl-lower-wavematrix-type", - "Lower WaveMatrix types to dxil type", false, false) - -// Public interface to the LowerWaveMatType pass -ModulePass *llvm::createLowerWaveMatTypePass() { - return new LowerWaveMatType(); -} diff --git a/tools/clang/include/clang/AST/HlslTypes.h b/tools/clang/include/clang/AST/HlslTypes.h index 4b0e93fdba..d11fd598e6 100644 --- a/tools/clang/include/clang/AST/HlslTypes.h +++ b/tools/clang/include/clang/AST/HlslTypes.h @@ -388,8 +388,6 @@ clang::CXXRecordDecl *DeclareUIntTemplatedTypeWithHandleInDeclContext( clang::CXXRecordDecl *DeclareConstantBufferViewType(clang::ASTContext &context, bool bTBuf); clang::CXXRecordDecl *DeclareRayQueryType(clang::ASTContext &context); -clang::CXXRecordDecl *DeclareWaveMatrixType(clang::ASTContext &context, - DXIL::WaveMatrixKind kind); clang::CXXRecordDecl *DeclareResourceType(clang::ASTContext &context, bool bSampler); @@ -524,7 +522,6 @@ bool GetIntrinsicOp(const clang::FunctionDecl *FD, unsigned &opcode, llvm::StringRef &group); bool GetIntrinsicLowering(const clang::FunctionDecl *FD, llvm::StringRef &S); -llvm::StringRef GetWaveMatrixName(DXIL::WaveMatrixKind kind); bool IsUserDefinedRecordType(clang::QualType type); bool DoesTypeDefineOverloadedOperator(clang::QualType typeWithOperator, clang::OverloadedOperatorKind opc, diff --git a/tools/clang/lib/AST/ASTContextHLSL.cpp b/tools/clang/lib/AST/ASTContextHLSL.cpp index bf048ac75d..3c058950e0 100644 --- a/tools/clang/lib/AST/ASTContextHLSL.cpp +++ b/tools/clang/lib/AST/ASTContextHLSL.cpp @@ -1161,24 +1161,6 @@ CXXRecordDecl *hlsl::DeclareRayQueryType(ASTContext &context) { return typeDeclBuilder.getRecordDecl(); } -clang::CXXRecordDecl *hlsl::DeclareWaveMatrixType(clang::ASTContext &context, - DXIL::WaveMatrixKind kind) { - StringRef Name = GetWaveMatrixName(kind); - BuiltinTypeDeclBuilder typeDeclBuilder(context.getTranslationUnitDecl(), - Name); - typeDeclBuilder.addTypeTemplateParam("element"); - typeDeclBuilder.addIntegerTemplateParam("dimM", context.UnsignedIntTy); - typeDeclBuilder.addIntegerTemplateParam("dimN", context.UnsignedIntTy); - - typeDeclBuilder.startDefinition(); - CXXRecordDecl *templateRecordDecl = typeDeclBuilder.getRecordDecl(); - - // Add an 'h' field to hold the handle. - typeDeclBuilder.addField("h", context.UnsignedIntTy); - - return templateRecordDecl; -} - CXXRecordDecl *hlsl::DeclareResourceType(ASTContext &context, bool bSampler) { // struct ResourceDescriptor { uint8 desc; } StringRef Name = bSampler ? ".Sampler" : ".Resource"; @@ -1375,15 +1357,6 @@ bool hlsl::GetIntrinsicLowering(const clang::FunctionDecl *FD, return true; } -llvm::StringRef hlsl::GetWaveMatrixName(DXIL::WaveMatrixKind kind) { - DXASSERT_NOMSG(kind < DXIL::WaveMatrixKind::NumKinds); - static const char *typeNames[(unsigned)DXIL::WaveMatrixKind::NumKinds] = { - "WaveMatrixLeft", "WaveMatrixRight", "WaveMatrixLeftColAcc", - "WaveMatrixRightRowAcc", "WaveMatrixAccumulator", - }; - return typeNames[(unsigned)kind]; -} - /// Parses a column or row digit. static bool TryParseColOrRowChar(const char digit, int *count) { if ('1' <= digit && digit <= '4') { diff --git a/tools/clang/lib/AST/HlslTypes.cpp b/tools/clang/lib/AST/HlslTypes.cpp index 34275a48f1..d83b307463 100644 --- a/tools/clang/lib/AST/HlslTypes.cpp +++ b/tools/clang/lib/AST/HlslTypes.cpp @@ -429,15 +429,6 @@ void GetRowsAndColsForAny(QualType type, uint32_t &rowCount, const TemplateArgument &arg1 = argList[1]; llvm::APSInt rowSize = arg1.getAsIntegral(); colCount = rowSize.getLimitedValue(); - } else if (templateDecl->getName().startswith("WaveMatrix")) { - auto name = templateDecl->getName(); - if (name == "WaveMatrixLeft" || name == "WaveMatrixRight" || - name == "WaveMatrixLeftColAcc" || name == "WaveMatrixRightRowAcc" || - name == "WaveMatrixAccumulator") { - const TemplateArgumentList &argList = templateDecl->getTemplateArgs(); - rowCount = argList[1].getAsIntegral().getLimitedValue(); - colCount = argList[2].getAsIntegral().getLimitedValue(); - } } } } diff --git a/tools/clang/lib/CodeGen/CGHLSLMS.cpp b/tools/clang/lib/CodeGen/CGHLSLMS.cpp index c78a018fc7..72f5a791ab 100644 --- a/tools/clang/lib/CodeGen/CGHLSLMS.cpp +++ b/tools/clang/lib/CodeGen/CGHLSLMS.cpp @@ -43,7 +43,6 @@ #include "dxc/DXIL/DxilCBuffer.h" #include "dxc/DXIL/DxilResourceProperties.h" -#include "dxc/DXIL/DxilWaveMatrix.h" #include "dxc/DxilRootSignature/DxilRootSignature.h" #include "dxc/HLSL/DxilExportMap.h" #include "dxc/HLSL/DxilGenerationPass.h" // support pause/resume passes @@ -214,7 +213,6 @@ class CGMSHLSLRuntime : public CGHLSLRuntime { unsigned AddTypeAnnotation(QualType Ty, DxilTypeSystem &dxilTypeSys, unsigned &arrayEltSize); DxilResourceProperties BuildResourceProperty(QualType resTy); - DxilWaveMatrixProperties BuildWaveMatrixProperties(QualType resTy); void ConstructFieldAttributedAnnotation(DxilFieldAnnotation &fieldAnnotation, QualType fieldTy, bool bDefaultRowMajor); @@ -764,31 +762,8 @@ DxilResourceProperties CGMSHLSLRuntime::BuildResourceProperty(QualType resTy) { return RP; } -DxilWaveMatrixProperties -CGMSHLSLRuntime::BuildWaveMatrixProperties(QualType qualTy) { - DxilWaveMatrixProperties props; - llvm::Type *Ty = CGM.getTypes().ConvertType(qualTy); - if (dxilutil::IsHLSLWaveMatrixType(Ty, &props.kind)) { - const CXXRecordDecl *CXXRD = - qualTy.getCanonicalType()->getAsCXXRecordDecl(); - if (const ClassTemplateSpecializationDecl *templateSpecializationDecl = - dyn_cast(CXXRD)) { - const clang::TemplateArgumentList &args = - templateSpecializationDecl->getTemplateInstantiationArgs(); - DXASSERT(args[0].getAsType()->isBuiltinType(), - "otherwise, wrong kind of component type"); - const BuiltinType *BTy = args[0].getAsType()->getAs(); - props.compType = BuiltinTyToCompTy(BTy, false, false); - props.dimM = (unsigned)args[1].getAsIntegral().getExtValue(); - props.dimN = (unsigned)args[2].getAsIntegral().getExtValue(); - } - } - return props; -} - bool CGMSHLSLRuntime::AddValToPropertyMap(Value *V, QualType Ty) { - return objectProperties.AddResource(V, BuildResourceProperty(Ty)) || - objectProperties.AddWaveMatrix(V, BuildWaveMatrixProperties(Ty)); + return objectProperties.AddResource(V, BuildResourceProperty(Ty)); } void CGMSHLSLRuntime::ConstructFieldAttributedAnnotation( diff --git a/tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp b/tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp index 41678d9dfa..8af96cc3cd 100644 --- a/tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp +++ b/tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp @@ -37,7 +37,6 @@ #include "dxc/DXIL/DxilResourceProperties.h" #include "dxc/DXIL/DxilTypeSystem.h" #include "dxc/DXIL/DxilUtil.h" -#include "dxc/DXIL/DxilWaveMatrix.h" #include "dxc/DxilRootSignature/DxilRootSignature.h" #include "dxc/HLSL/DxilExportMap.h" #include "dxc/HLSL/DxilGenerationPass.h" @@ -176,18 +175,6 @@ Value *CastHandleToRes(HLModule &HLM, Value *Handle, llvm::Type *ResTy, return Res; } -CallInst *CreateAnnotateWaveMatrix(HLModule &HLM, Value *WaveMatrixPtr, - DxilWaveMatrixProperties &WMP, - IRBuilder<> &Builder) { - Constant *WMPConstant = wavemat_helper::GetAsConstant( - WMP, HLM.GetOP()->GetWaveMatrixPropertiesType()); - CallInst *CI = HLM.EmitHLOperationCall( - Builder, HLOpcodeGroup::HLWaveMatrix_Annotate, - (unsigned)HLOpcodeGroup::HLWaveMatrix_Annotate, WaveMatrixPtr->getType(), - {WaveMatrixPtr, WMPConstant}, *HLM.GetModule()); - return CI; -} - // Lower CBV bitcast use to handle use. // Leave the load/store. void LowerDynamicCBVUseToHandle(HLModule &HLM, @@ -749,36 +736,6 @@ GetResourcePropsFromIntrinsicObjectArg(Value *arg, HLModule &HLM, return RP; } -void AddAnnotateWaveMatrix(HLModule &HLM, - DxilObjectProperties &objectProperties) { - for (auto it : objectProperties.waveMatMap) { - Value *V = it.first; - DxilWaveMatrixProperties &WMP = it.second; - // annotate Alloca, Param, or Global - if (AllocaInst *AI = dyn_cast(V)) { - // Insert annotation after alloca - IRBuilder<> Builder(AI->getNextNode()); - CreateAnnotateWaveMatrix(HLM, V, WMP, Builder); - } else if (GlobalVariable *GV = dyn_cast(V)) { - // Insert annotation in each function's entry block with users - SmallSetVector functions; - for (auto U : GV->users()) - if (Instruction *I = dyn_cast(U)) - functions.insert(I->getParent()->getParent()); - - for (auto F : functions) { - IRBuilder<> Builder(dxilutil::FindAllocaInsertionPt(F)); - CreateAnnotateWaveMatrix(HLM, V, WMP, Builder); - } - } else if (Argument *Arg = dyn_cast(V)) { - IRBuilder<> Builder(dxilutil::FindAllocaInsertionPt(Arg->getParent())); - CreateAnnotateWaveMatrix(HLM, V, WMP, Builder); - } else { - llvm_unreachable("WaveMatrix value is unexpected type"); - } - } -} - void AddOpcodeParamForIntrinsic(HLModule &HLM, Function *F, unsigned opcode, DxilObjectProperties &objectProperties) { llvm::Module &M = *HLM.GetModule(); @@ -3557,9 +3514,6 @@ void FinishIntrinsics( // Lower bitcast use of CBV into cbSubscript. LowerDynamicCBVUseToHandle(HLM, objectProperties); - // Add AnnotateWaveMatrix - AddAnnotateWaveMatrix(HLM, objectProperties); - // translate opcode into parameter for intrinsic functions // Do this before CloneShaderEntry and TranslateRayQueryConstructor to avoid // update valToResPropertiesMap for cloned inst. @@ -4074,27 +4028,4 @@ void DxilObjectProperties::updateGLC(llvm::Value *V) { it->second.Basic.IsGloballyCoherent ^= 1; } -bool DxilObjectProperties::AddWaveMatrix( - llvm::Value *V, const hlsl::DxilWaveMatrixProperties &WMP) { - if (WMP.isValid()) { - DXASSERT(!GetWaveMatrix(V).isValid() || GetWaveMatrix(V) == WMP, - "otherwise, property conflict"); - waveMatMap[V] = WMP; - return true; - } - return false; -} - -bool DxilObjectProperties::IsWaveMatrix(llvm::Value *V) { - return waveMatMap.count(V) != 0; -} - -hlsl::DxilWaveMatrixProperties -DxilObjectProperties::GetWaveMatrix(llvm::Value *V) { - auto it = waveMatMap.find(V); - if (it != waveMatMap.end()) - return it->second; - return DxilWaveMatrixProperties(); -} - } // namespace CGHLSLMSHelper diff --git a/tools/clang/lib/CodeGen/CGHLSLMSHelper.h b/tools/clang/lib/CodeGen/CGHLSLMSHelper.h index 74e1f06347..9058ed4f6d 100644 --- a/tools/clang/lib/CodeGen/CGHLSLMSHelper.h +++ b/tools/clang/lib/CodeGen/CGHLSLMSHelper.h @@ -37,7 +37,6 @@ template class SmallVector; namespace hlsl { class HLModule; struct DxilResourceProperties; -struct DxilWaveMatrixProperties; struct DxilFunctionProps; class DxilFieldAnnotation; enum class IntrinsicOp; @@ -162,13 +161,8 @@ struct DxilObjectProperties { hlsl::DxilResourceProperties GetResource(llvm::Value *V); void updateGLC(llvm::Value *V); - bool AddWaveMatrix(llvm::Value *V, const hlsl::DxilWaveMatrixProperties &WMP); - bool IsWaveMatrix(llvm::Value *V); - hlsl::DxilWaveMatrixProperties GetWaveMatrix(llvm::Value *V); - // MapVector for deterministic iteration order. llvm::MapVector resMap; - llvm::MapVector waveMatMap; }; void CopyAndAnnotateResourceArgument(llvm::Value *Src, llvm::Value *Dest, diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index 4df32f9a65..0d05c6825f 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -226,13 +226,6 @@ enum ArBasicKind { AR_OBJECT_RWTEXTURE2DMS, AR_OBJECT_RWTEXTURE2DMS_ARRAY, - // WaveMatrix - AR_OBJECT_WAVE_MATRIX_LEFT, - AR_OBJECT_WAVE_MATRIX_RIGHT, - AR_OBJECT_WAVE_MATRIX_LEFT_COL_ACC, - AR_OBJECT_WAVE_MATRIX_RIGHT_ROW_ACC, - AR_OBJECT_WAVE_MATRIX_ACCUMULATOR, - // Work Graphs AR_OBJECT_EMPTY_NODE_INPUT, AR_OBJECT_DISPATCH_NODE_INPUT_RECORD, @@ -285,15 +278,6 @@ enum ArBasicKind { case AR_OBJECT_DEPTHSTENCIL: \ case AR_OBJECT_STATEBLOCK -#define AR_BASIC_WAVE_MATRIX_INPUT_CASES \ - case AR_OBJECT_WAVE_MATRIX_LEFT: \ - case AR_OBJECT_WAVE_MATRIX_RIGHT - -#define AR_BASIC_WAVE_MATRIX_ACC_FRAG_CASES \ - case AR_OBJECT_WAVE_MATRIX_LEFT_COL_ACC: \ - case AR_OBJECT_WAVE_MATRIX_RIGHT_ROW_ACC: \ - case AR_OBJECT_WAVE_MATRIX_ACCUMULATOR - // // Properties of entries in the ArBasicKind enumeration. // These properties are intended to allow easy identification @@ -353,11 +337,6 @@ enum ArBasicKind { #define BPROP_FEEDBACKTEXTURE \ 0x00800000 // Whether the type is a feedback texture. #define BPROP_ENUM 0x01000000 // Whether the type is a enum -#define BPROP_WAVE_MATRIX_INPUT \ - 0x02000000 // Whether the type is a wave matrix input object (Left/Right) -#define BPROP_WAVE_MATRIX_ACC \ - 0x04000000 // Whether the type is a wave matrix accum object - // (Accumulator/LeftColAcc/RightRowAcc) #define GET_BPROP_PRIM_KIND(_Props) \ ((_Props) & (BPROP_BOOLEAN | BPROP_INTEGER | BPROP_FLOATING)) @@ -397,11 +376,6 @@ enum ArBasicKind { #define IS_BPROP_ENUM(_Props) (((_Props)&BPROP_ENUM) != 0) -#define IS_BPROP_WAVE_MATRIX_INPUT(_Props) \ - (((_Props) & BPROP_WAVE_MATRIX_INPUT) != 0) -#define IS_BPROP_WAVE_MATRIX_ACC(_Props) \ - (((_Props) & BPROP_WAVE_MATRIX_ACC) != 0) - const UINT g_uBasicKindProps[] = { BPROP_PRIMITIVE | BPROP_BOOLEAN | BPROP_INTEGER | BPROP_NUMERIC | BPROP_BITS0, // AR_BASIC_BOOL @@ -599,12 +573,6 @@ const UINT g_uBasicKindProps[] = { BPROP_OBJECT | BPROP_RWBUFFER, // AR_OBJECT_RWTEXTURE2DMS BPROP_OBJECT | BPROP_RWBUFFER, // AR_OBJECT_RWTEXTURE2DMS_ARRAY - BPROP_OBJECT | BPROP_WAVE_MATRIX_INPUT, // AR_OBJECT_WAVE_MATRIX_LEFT - BPROP_OBJECT | BPROP_WAVE_MATRIX_INPUT, // AR_OBJECT_WAVE_MATRIX_RIGHT - BPROP_OBJECT | BPROP_WAVE_MATRIX_ACC, // AR_OBJECT_WAVE_MATRIX_LEFT_COL_ACC - BPROP_OBJECT | BPROP_WAVE_MATRIX_ACC, // AR_OBJECT_WAVE_MATRIX_RIGHT_ROW_ACC - BPROP_OBJECT | BPROP_WAVE_MATRIX_ACC, // AR_OBJECT_WAVE_MATRIX_ACCUMULATOR - // WorkGraphs BPROP_OBJECT, // AR_OBJECT_EMPTY_NODE_INPUT BPROP_OBJECT, // AR_OBJECT_DISPATCH_NODE_INPUT_RECORD @@ -658,13 +626,6 @@ C_ASSERT(ARRAYSIZE(g_uBasicKindProps) == AR_BASIC_MAXIMUM_COUNT); #define IS_BASIC_ENUM(_Kind) IS_BPROP_ENUM(GetBasicKindProps(_Kind)) -#define IS_BASIC_WAVE_MATRIX_INPUT(_Kind) \ - IS_BPROP_WAVE_MATRIX_INPUT(GetBasicKindProps(_Kind)) -#define IS_BASIC_WAVE_MATRIX_ACC(_Kind) \ - IS_BPROP_WAVE_MATRIX_ACC(GetBasicKindProps(_Kind)) -#define IS_BASIC_WAVE_MATRIX(_Kind) \ - (IS_BASIC_WAVE_MATRIX_INPUT(_Kind) || IS_BASIC_WAVE_MATRIX_ACC(_Kind)) - #define BITWISE_ENUM_OPS(_Type) \ inline _Type operator|(_Type F1, _Type F2) { \ return (_Type)((UINT)F1 | (UINT)F2); \ @@ -979,34 +940,6 @@ GetOrCreateVectorSpecialization(ASTContext &context, Sema *sema, return vectorSpecializationType; } -// Gets component type, dimM, and dimN from WaveMatrix* instantiated type. -// Assumes wave matrix type, returns false if anything isn't as expected. -static bool GetWaveMatrixTemplateValues(QualType objType, QualType *compType, - unsigned *dimM, unsigned *dimN) { - const CXXRecordDecl *CXXRD = objType.getCanonicalType()->getAsCXXRecordDecl(); - if (const ClassTemplateSpecializationDecl *templateSpecializationDecl = - dyn_cast(CXXRD)) { - const clang::TemplateArgumentList &args = - templateSpecializationDecl->getTemplateInstantiationArgs(); - if (args.size() != 3) - return false; - if (args[0].getKind() != TemplateArgument::Type || - !args[0].getAsType()->isBuiltinType()) - return false; - if (args[1].getKind() != TemplateArgument::Integral || - args[2].getKind() != TemplateArgument::Integral) - return false; - if (compType) - *compType = args[0].getAsType(); - if (dimM) - *dimM = (unsigned)args[1].getAsIntegral().getExtValue(); - if (dimN) - *dimN = (unsigned)args[2].getAsIntegral().getExtValue(); - return true; - } - return false; -} - /// Instantiates a new *NodeOutputRecords type specialization or gets /// an existing one from the AST. static QualType @@ -1248,21 +1181,6 @@ static const ArBasicKind g_ByteAddressBufferCT[] = { static const ArBasicKind g_RWByteAddressBufferCT[] = { AR_OBJECT_RWBYTEADDRESS_BUFFER, AR_BASIC_UNKNOWN}; -static const ArBasicKind g_WaveMatrixLeftCT[] = {AR_OBJECT_WAVE_MATRIX_LEFT, - AR_BASIC_UNKNOWN}; - -static const ArBasicKind g_WaveMatrixRightCT[] = {AR_OBJECT_WAVE_MATRIX_RIGHT, - AR_BASIC_UNKNOWN}; - -static const ArBasicKind g_WaveMatrixLeftColAccCT[] = { - AR_OBJECT_WAVE_MATRIX_LEFT_COL_ACC, AR_BASIC_UNKNOWN}; - -static const ArBasicKind g_WaveMatrixRightRowAccCT[] = { - AR_OBJECT_WAVE_MATRIX_RIGHT_ROW_ACC, AR_BASIC_UNKNOWN}; - -static const ArBasicKind g_WaveMatrixAccumulatorCT[] = { - AR_OBJECT_WAVE_MATRIX_ACCUMULATOR, AR_BASIC_UNKNOWN}; - static const ArBasicKind g_NodeRecordOrUAVCT[] = { AR_OBJECT_DISPATCH_NODE_INPUT_RECORD, AR_OBJECT_RWDISPATCH_NODE_INPUT_RECORD, @@ -1345,11 +1263,6 @@ const ArBasicKind *g_LegalIntrinsicCompTypes[] = { g_ByteAddressBufferCT, // LICOMPTYPE_BYTEADDRESSBUFFER g_RWByteAddressBufferCT, // LICOMPTYPE_RWBYTEADDRESSBUFFER - g_WaveMatrixLeftCT, // LICOMPTYPE_WAVE_MATRIX_LEFT - g_WaveMatrixRightCT, // LICOMPTYPE_WAVE_MATRIX_RIGHT - g_WaveMatrixLeftColAccCT, // LICOMPTYPE_WAVE_MATRIX_LEFT_COL_ACC - g_WaveMatrixRightRowAccCT, // LICOMPTYPE_WAVE_MATRIX_RIGHT_ROW_ACC - g_WaveMatrixAccumulatorCT, // LICOMPTYPE_WAVE_MATRIX_ACCUMULATOR g_NodeRecordOrUAVCT, // LICOMPTYPE_NODE_RECORD_OR_UAV g_AnyOutputRecordCT, // LICOMPTYPE_ANY_NODE_OUTPUT_RECORD g_GroupNodeOutputRecordsCT, // LICOMPTYPE_GROUP_NODE_OUTPUT_RECORDS @@ -1433,10 +1346,6 @@ static const ArBasicKind g_ArBasicKindsAsTypes[] = { AR_OBJECT_RWTEXTURE2DMS, // RWTexture2DMS AR_OBJECT_RWTEXTURE2DMS_ARRAY, // RWTexture2DMSArray - AR_OBJECT_WAVE_MATRIX_LEFT, AR_OBJECT_WAVE_MATRIX_RIGHT, - AR_OBJECT_WAVE_MATRIX_LEFT_COL_ACC, AR_OBJECT_WAVE_MATRIX_RIGHT_ROW_ACC, - AR_OBJECT_WAVE_MATRIX_ACCUMULATOR, - // Work Graphs AR_OBJECT_EMPTY_NODE_INPUT, AR_OBJECT_DISPATCH_NODE_INPUT_RECORD, AR_OBJECT_RWDISPATCH_NODE_INPUT_RECORD, AR_OBJECT_GROUP_NODE_INPUT_RECORDS, @@ -1546,12 +1455,6 @@ static const uint8_t g_ArBasicKindsTemplateCount[] = { 2, // AR_OBJECT_RWTEXTURE2DMS 2, // AR_OBJECT_RWTEXTURE2DMS_ARRAY - 3, // AR_OBJECT_WAVE_MATRIX_LEFT, - 3, // AR_OBJECT_WAVE_MATRIX_RIGHT, - 3, // AR_OBJECT_WAVE_MATRIX_LEFT_COL_ACC, - 3, // AR_OBJECT_WAVE_MATRIX_RIGHT_ROW_ACC, - 3, // AR_OBJECT_WAVE_MATRIX_ACCUMULATOR, - // WorkGraphs 0, // AR_OBJECT_EMPTY_NODE_INPUT, 1, // AR_OBJECT_DISPATCH_NODE_INPUT_RECORD, @@ -1698,12 +1601,6 @@ static const SubscriptOperatorRecord g_ArBasicKindsSubscripts[] = { {3, MipsFalse, SampleTrue}, // AR_OBJECT_RWTEXTURE2DMS_ARRAY (RWTexture2DMSArray) - {0, MipsFalse, SampleFalse}, // AR_OBJECT_WAVE_MATRIX_LEFT, - {0, MipsFalse, SampleFalse}, // AR_OBJECT_WAVE_MATRIX_RIGHT, - {0, MipsFalse, SampleFalse}, // AR_OBJECT_WAVE_MATRIX_LEFT_COL_ACC, - {0, MipsFalse, SampleFalse}, // AR_OBJECT_WAVE_MATRIX_RIGHT_ROW_ACC, - {0, MipsFalse, SampleFalse}, // AR_OBJECT_WAVE_MATRIX_ACCUMULATOR, - // WorkGraphs {0, MipsFalse, SampleFalse}, // AR_OBJECT_EMPTY_NODE_INPUT {0, MipsFalse, SampleFalse}, // AR_OBJECT_DISPATCH_NODE_INPUT_RECORD @@ -1782,9 +1679,6 @@ static const char *g_ArBasicTypeNames[] = { "RWTexture2DMS", "RWTexture2DMSArray", - "WaveMatrixLeft", "WaveMatrixRight", "WaveMatrixLeftColAcc", - "WaveMatrixRightRowAcc", "WaveMatrixAccumulator", - // Workgraphs "EmptyNodeInput", "DispatchNodeInputRecord", "RWDispatchNodeInputRecord", "GroupNodeInputRecords", "RWGroupNodeInputRecords", "ThreadNodeInputRecord", @@ -2384,26 +2278,6 @@ static void GetIntrinsicMethods(ArBasicKind kind, *intrinsics = g_RWTexture2DMSArrayMethods; *intrinsicCount = _countof(g_RWTexture2DMSArrayMethods); break; - case AR_OBJECT_WAVE_MATRIX_LEFT: - *intrinsics = g_WaveMatrixLeftMethods; - *intrinsicCount = _countof(g_WaveMatrixLeftMethods); - break; - case AR_OBJECT_WAVE_MATRIX_RIGHT: - *intrinsics = g_WaveMatrixRightMethods; - *intrinsicCount = _countof(g_WaveMatrixRightMethods); - break; - case AR_OBJECT_WAVE_MATRIX_LEFT_COL_ACC: - *intrinsics = g_WaveMatrixLeftColAccMethods; - *intrinsicCount = _countof(g_WaveMatrixLeftColAccMethods); - break; - case AR_OBJECT_WAVE_MATRIX_RIGHT_ROW_ACC: - *intrinsics = g_WaveMatrixRightRowAccMethods; - *intrinsicCount = _countof(g_WaveMatrixRightRowAccMethods); - break; - case AR_OBJECT_WAVE_MATRIX_ACCUMULATOR: - *intrinsics = g_WaveMatrixAccumulatorMethods; - *intrinsicCount = _countof(g_WaveMatrixAccumulatorMethods); - break; case AR_OBJECT_EMPTY_NODE_INPUT: *intrinsics = g_EmptyNodeInputMethods; *intrinsicCount = _countof(g_EmptyNodeInputMethods); @@ -3732,11 +3606,6 @@ class HLSLExternalSource : public ExternalSemaSource { m_context->getRecordType(recordDecl), *m_context); } - - } else if (IsWaveMatrixBasicKind(kind)) { - recordDecl = DeclareWaveMatrixType( - *m_context, - (DXIL::WaveMatrixKind)(kind - AR_OBJECT_WAVE_MATRIX_LEFT)); } else if (kind == AR_OBJECT_FEEDBACKTEXTURE2D) { recordDecl = DeclareUIntTemplatedTypeWithHandle( *m_context, "FeedbackTexture2D", "kind"); @@ -4098,14 +3967,6 @@ class HLSLExternalSource : public ExternalSemaSource { return IsRayQueryBasicKind(GetTypeElementKind(type)); } - bool IsWaveMatrixBasicKind(ArBasicKind kind) { - return kind >= AR_OBJECT_WAVE_MATRIX_LEFT && - kind <= AR_OBJECT_WAVE_MATRIX_ACCUMULATOR; - } - bool IsWaveMatrixType(QualType type) { - return IsWaveMatrixBasicKind(GetTypeElementKind(type)); - } - void WarnMinPrecision(QualType Type, SourceLocation Loc) { Type = Type->getCanonicalTypeUnqualified(); if (IsVectorType(m_sema, Type) || IsMatrixType(m_sema, Type)) { @@ -4714,12 +4575,6 @@ class HLSLExternalSource : public ExternalSemaSource { case AR_OBJECT_RWTEXTURE2DMS: case AR_OBJECT_RWTEXTURE2DMS_ARRAY: - case AR_OBJECT_WAVE_MATRIX_LEFT: - case AR_OBJECT_WAVE_MATRIX_RIGHT: - case AR_OBJECT_WAVE_MATRIX_LEFT_COL_ACC: - case AR_OBJECT_WAVE_MATRIX_RIGHT_ROW_ACC: - case AR_OBJECT_WAVE_MATRIX_ACCUMULATOR: - case AR_OBJECT_EMPTY_NODE_INPUT: case AR_OBJECT_DISPATCH_NODE_INPUT_RECORD: case AR_OBJECT_RWDISPATCH_NODE_INPUT_RECORD: @@ -6226,46 +6081,6 @@ bool HLSLExternalSource::IsValidObjectElement(LPCSTR tableName, } } -// Given component type of wave matrix object on which a method is called, -// and given the component type of an argument passed by the user, -// return either the user component type, or a valid component type, -// if the user component type is not valid. -static ArBasicKind GetValidWaveMatrixComponentTypeForArg( - ArBasicKind objKind, // wave matrix type for this - ArBasicKind objEltKind, // element type for this - ArBasicKind argKind, // wave matrix type for arg - ArBasicKind argEltKind) { // element type for arg - if (IS_BASIC_WAVE_MATRIX_ACC(objKind) && - IS_BASIC_WAVE_MATRIX_INPUT(argKind)) { - switch (objEltKind) { - case AR_BASIC_FLOAT32: - switch (argEltKind) { - case AR_BASIC_FLOAT32: - case AR_BASIC_FLOAT16: - return argEltKind; - default: - break; - } - // return a valid type (this will be used for error message) - return AR_BASIC_FLOAT32; - case AR_BASIC_INT32: - switch (argEltKind) { - case AR_BASIC_INT8_4PACKED: - case AR_BASIC_UINT8_4PACKED: - return argEltKind; - default: - break; - } - // return a valid type (this will be used for error message) - return AR_BASIC_INT8_4PACKED; - default: - break; - } - } - // In other cases, we return this element kind. - return objEltKind; -} - bool HLSLExternalSource::MatchArguments( const IntrinsicDefIter &cursor, QualType objectType, QualType objectElement, QualType functionTemplateTypeArg, ArrayRef Args, @@ -6834,81 +6649,46 @@ bool HLSLExternalSource::MatchArguments( if ((0 == i) || !(pArgument->qwUsage & AR_QUAL_OUT)) qwQual |= AR_QUAL_CONST; - // If the type is WaveMatrix, construct a template specialization based - // on the template arguments of this wave matrix object in a special way. - if (IsWaveMatrixBasicKind(pEltType)) { - CXXRecordDecl *templateRecordDecl = - GetBasicKindType(pEltType)->getAsCXXRecordDecl(); - if (!templateRecordDecl->isCompleteDefinition()) { - // If template definition is not completed, no instantiations exist, - // so we can assume this candiate does not apply. - badArgIdx = std::min(badArgIdx, i); - return false; - } - - // read template args of objectType - ArTypeInfo objInfo; - CollectInfo(objectType, &objInfo); - ArTypeInfo argInfo; - CollectInfo(Args[i - 1]->getType(), &argInfo); - ArBasicKind eltKind = GetValidWaveMatrixComponentTypeForArg( - objInfo.ObjKind, objInfo.EltKind, argInfo.ObjKind, argInfo.EltKind); - QualType compType = GetBasicKindType(eltKind); - - // Now construct the expected argument specialization - TemplateArgument templateArgs[3] = { - TemplateArgument(compType), - TemplateArgument(*m_context, - llvm::APSInt(llvm::APInt(32, objInfo.uRows)), - m_context->UnsignedIntTy), - TemplateArgument(*m_context, - llvm::APSInt(llvm::APInt(32, objInfo.uCols)), - m_context->UnsignedIntTy)}; - pNewType = GetOrCreateTemplateSpecialization( - *m_context, *m_sema, - templateRecordDecl->getDescribedClassTemplate(), templateArgs); - } else { - DXASSERT_VALIDBASICKIND(pEltType); - pNewType = NewSimpleAggregateType(Template[pArgument->uTemplateId], - pEltType, qwQual, uRows, uCols); - - // If array type, wrap in the argument's array type. - if (i > 0 && Template[pArgument->uTemplateId] == AR_TOBJ_ARRAY) { - QualType arrayElt = Args[i - 1]->getType(); - SmallVector sizes; - while (arrayElt->isArrayType()) { - UINT size = 0; - if (arrayElt->isConstantArrayType()) { - const ConstantArrayType *arrayType = - (const ConstantArrayType *)arrayElt->getAsArrayTypeUnsafe(); - size = arrayType->getSize().getLimitedValue(); - } - arrayElt = QualType(arrayElt->getAsArrayTypeUnsafe() - ->getArrayElementTypeNoTypeQual(), - 0); - sizes.push_back(size); + DXASSERT_VALIDBASICKIND(pEltType); + pNewType = NewSimpleAggregateType(Template[pArgument->uTemplateId], + pEltType, qwQual, uRows, uCols); + + // If array type, wrap in the argument's array type. + if (i > 0 && Template[pArgument->uTemplateId] == AR_TOBJ_ARRAY) { + QualType arrayElt = Args[i - 1]->getType(); + SmallVector sizes; + while (arrayElt->isArrayType()) { + UINT size = 0; + if (arrayElt->isConstantArrayType()) { + const ConstantArrayType *arrayType = + (const ConstantArrayType *)arrayElt->getAsArrayTypeUnsafe(); + size = arrayType->getSize().getLimitedValue(); } - // Wrap element in matching array dimensions: - while (sizes.size()) { - uint64_t size = sizes.pop_back_val(); - if (size) { - pNewType = m_context->getConstantArrayType( - pNewType, llvm::APInt(32, size, false), - ArrayType::ArraySizeModifier::Normal, 0); - } else { - pNewType = m_context->getIncompleteArrayType( - pNewType, ArrayType::ArraySizeModifier::Normal, 0); - } + arrayElt = QualType( + arrayElt->getAsArrayTypeUnsafe()->getArrayElementTypeNoTypeQual(), + 0); + sizes.push_back(size); + } + // Wrap element in matching array dimensions: + while (sizes.size()) { + uint64_t size = sizes.pop_back_val(); + if (size) { + pNewType = m_context->getConstantArrayType( + pNewType, llvm::APInt(32, size, false), + ArrayType::ArraySizeModifier::Normal, 0); + } else { + pNewType = m_context->getIncompleteArrayType( + pNewType, ArrayType::ArraySizeModifier::Normal, 0); } - if (qwQual & AR_QUAL_CONST) - pNewType = QualType(pNewType.getTypePtr(), Qualifiers::Const); + } + if (qwQual & AR_QUAL_CONST) + pNewType = QualType(pNewType.getTypePtr(), Qualifiers::Const); - if (qwQual & AR_QUAL_GROUPSHARED) - pNewType = - m_context->getAddrSpaceQualType(pNewType, DXIL::kTGSMAddrSpace); + if (qwQual & AR_QUAL_GROUPSHARED) + pNewType = + m_context->getAddrSpaceQualType(pNewType, DXIL::kTGSMAddrSpace); - pNewType = m_context->getLValueReferenceType(pNewType); - } + pNewType = m_context->getLValueReferenceType(pNewType); } } @@ -7193,18 +6973,12 @@ void HLSLExternalSource::CollectInfo(QualType type, ArTypeInfo *pTypeInfo) { // when retrieving multiple properties. pTypeInfo->ObjKind = GetTypeElementKind(type); pTypeInfo->ShapeKind = GetTypeObjectKind(type); - if (IsWaveMatrixBasicKind(pTypeInfo->ObjKind)) { - QualType elTy; - GetWaveMatrixTemplateValues(type, &elTy, &pTypeInfo->uRows, - &pTypeInfo->uCols); - pTypeInfo->EltKind = GetTypeElementKind(elTy); - pTypeInfo->EltTy = pTypeInfo->EltTy = GetStructuralForm(elTy).getTypePtr(); - } else { - GetRowsAndColsForAny(type, pTypeInfo->uRows, pTypeInfo->uCols); - pTypeInfo->EltKind = pTypeInfo->ObjKind; - pTypeInfo->EltTy = - GetTypeElementType(type)->getCanonicalTypeUnqualified()->getTypePtr(); - } + + GetRowsAndColsForAny(type, pTypeInfo->uRows, pTypeInfo->uCols); + pTypeInfo->EltKind = pTypeInfo->ObjKind; + pTypeInfo->EltTy = + GetTypeElementType(type)->getCanonicalTypeUnqualified()->getTypePtr(); + pTypeInfo->uTotalElts = pTypeInfo->uRows * pTypeInfo->uCols; } @@ -14337,13 +14111,6 @@ bool Sema::DiagnoseHLSLDecl(Declarator &D, DeclContext *DC, Expr *BitWidth, QualType eltQt(qt->getArrayElementTypeNoTypeQual(), 0); while (eltQt->isArrayType()) eltQt = QualType(eltQt->getArrayElementTypeNoTypeQual(), 0); - if (hlslSource->IsWaveMatrixType(eltQt)) { - StringRef typeName( - g_ArBasicTypeNames[hlslSource->GetTypeElementKind(eltQt)]); - Diag(D.getLocStart(), diag::err_hlsl_array_disallowed) - << typeName << /* declaration */ 1; - result = false; - } if (hlsl::IsObjectType(this, eltQt, &bDeprecatedEffectObject)) { bIsObject = true; } diff --git a/tools/clang/test/HLSLFileCheck/d3dreflect/rdat_mintarget/sm6x_wavemma.hlsl b/tools/clang/test/HLSLFileCheck/d3dreflect/rdat_mintarget/sm6x_wavemma.hlsl deleted file mode 100644 index 63fac06ea7..0000000000 --- a/tools/clang/test/HLSLFileCheck/d3dreflect/rdat_mintarget/sm6x_wavemma.hlsl +++ /dev/null @@ -1,47 +0,0 @@ -// RUN: %dxc -T lib_6_x %s | %D3DReflect %s | %FileCheck %s -check-prefixes=RDAT - -// Ensure min shader target incorporates optional features used - -// SM 6.9+ - -/////////////////////////////////////////////////////////////////////////////// -// ShaderFeatureInfo_WaveMMA (0x8000000) = 134217728 - -RWByteAddressBuffer BAB : register(u1, space0); - -// RDAT-LABEL: UnmangledName: "use_wavematrix" -// RDAT: FeatureInfo1: (WaveOps | WaveMMA) -// RDAT: FeatureInfo2: 0 -// RDAT: ShaderStageFlag: (Compute | Library) -// RDAT: MinShaderTarget: 0x60069 - -[noinline] export -void use_wavematrix() { - // Use WaveMatrix in a minimal way that survives dead code elimination. - WaveMatrixLeft wml; - wml.Fill(0); - wml.Store(BAB, 0, 1024, false); -} - -// RDAT-LABEL: UnmangledName: "call_use_wavematrix" -// RDAT: FeatureInfo1: (WaveOps | WaveMMA) -// RDAT: FeatureInfo2: 0 -// RDAT: ShaderStageFlag: (Compute | Library) -// RDAT: MinShaderTarget: 0x60069 - -[noinline] export -void call_use_wavematrix() { - use_wavematrix(); -} - -// RDAT-LABEL: UnmangledName: "wavematrix_compute" -// RDAT: FeatureInfo1: (WaveOps | WaveMMA) -// RDAT: FeatureInfo2: 0 -// RDAT: ShaderStageFlag: (Compute) -// RDAT: MinShaderTarget: 0x50069 - -[shader("compute")] -[numthreads(1,1,1)] -void wavematrix_compute(uint tidx : SV_GroupIndex) { - call_use_wavematrix(); -} diff --git a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix67.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix67.hlsl deleted file mode 100644 index bcc61c3c1b..0000000000 --- a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix67.hlsl +++ /dev/null @@ -1,25 +0,0 @@ -// RUN: %dxc -E main -T cs_6_7 -select-validator internal %s | FileCheck %s -// RUN: %dxc -E main -T cs_6_8 -select-validator internal %s | FileCheck %s - -// CHECK-NOT: define void @main() - -// CHECK: error: validation errors -// CHECK: Function: main: error: Opcode WaveMatrix_Annotate not valid in shader model cs_6_{{[7-8]}} -// CHECK: Function: main: error: Opcode WaveMatrix_Annotate not valid in shader model cs_6_{{[7-8]}}. - -RWByteAddressBuffer rwbuf; - -[NumThreads(64,1,1)] -void main(uint3 gtid : SV_GroupThreadID, uint gidx : SV_GroupIndex) -{ - WaveMatrixLeft left; - WaveMatrixRight right; - -// CHECK: WaveMatrix67.hlsl:21:18: error: Opcode WaveMatrix_Depth not valid in shader model cs_6_{{[7-8]}}. -// CHECK: WaveMatrix67.hlsl:22:18: error: Opcode WaveMatrix_Depth not valid in shader model cs_6_{{[7-8]}}. - - rwbuf.Store(0, left.MatrixDepth()); - rwbuf.Store(4, right.MatrixDepth()); -} - -// CHECK: Validation failed. diff --git a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_Add-limited.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_Add-limited.hlsl deleted file mode 100644 index 09f4f28bea..0000000000 --- a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_Add-limited.hlsl +++ /dev/null @@ -1,56 +0,0 @@ -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DADD_TY=WMLC %s | FileCheck %s -DADD_TY=2 -DCOMP=9 -DDIMM=16 -DDIMN=16 -check-prefix=CHKIR -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DADD_TY=WMRR %s | FileCheck %s -DADD_TY=3 -DCOMP=9 -DDIMM=16 -DDIMN=16 -check-prefix=CHKIR -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DADD_TY=WMA %s | FileCheck %s -DADD_TY=4 -DCOMP=9 -DDIMM=16 -DDIMN=16 -check-prefix=CHKIR -// RUN: %dxc -enable-16bit-types -T cs_6_9 -ast-dump -DADD_TY=WMLC %s | FileCheck %s -DADD_TY=WaveMatrixLeftColAcc -DCOMP=float -DDIMM=16 -DDIMN=16 -check-prefix=CHKAST -// RUN: %dxc -enable-16bit-types -T cs_6_9 -ast-dump -DADD_TY=WMRR %s | FileCheck %s -DADD_TY=WaveMatrixRightRowAcc -DCOMP=float -DDIMM=16 -DDIMN=16 -check-prefix=CHKAST -// RUN: %dxc -enable-16bit-types -T cs_6_9 -ast-dump -DADD_TY=WMA %s | FileCheck %s -DADD_TY=WaveMatrixAccumulator -DCOMP=float -DDIMM=16 -DDIMN=16 -check-prefix=CHKAST - -// CHECK: ; Note: shader requires additional functionality: -// CHECK: ; Wave level operations -// CHECK: ; Wave Matrix - -// CHECK: define void @main() - -#ifndef COMP -#define COMP float -#endif -#ifndef DIMM -#define DIMM 16 -#endif -#ifndef DIMN -#define DIMN 16 -#endif -#ifndef WAVESIZE -#define WAVESIZE -#endif -#ifndef NUMTHREADS -#define NUMTHREADS [NumThreads(64,1,1)] -#endif - -#define WMLC WaveMatrixLeftColAcc -#define WMRR WaveMatrixRightRowAcc -#define WMA WaveMatrixAccumulator - -WAVESIZE -NUMTHREADS -void main(uint3 gtid : SV_GroupThreadID, uint gidx : SV_GroupIndex) -{ -// CHKIR: %[[wma:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHKIR: %[[wma_add:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHKIR: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.waveMatProps { i8 4, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHKIR: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wma_add]], %dx.types.waveMatProps { i8 [[ADD_TY]], i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) - WMA acc; - ADD_TY wma_add; - -// CHKAST: CXXMemberCallExpr -// CHKAST-NEXT: MemberExpr -// CHKAST-SAME: .Add -// CHKAST-NEXT: DeclRefExpr -// CHKAST-SAME: 'acc' 'WaveMatrixAccumulator<[[COMP]], [[DIMM]], [[DIMN]]>' -// CHKAST-NEXT: -// CHKAST-NEXT: DeclRefExpr -// CHKAST-SAME: 'wma_add' '[[ADD_TY]]<[[COMP]], [[DIMM]], [[DIMN]]>' - -// CHKIR: call void @dx.op.waveMatrix_Accumulate(i32 237, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.waveMatrix* nonnull %[[wma_add]]) - acc.Add(wma_add); -} diff --git a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_Depth.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_Depth.hlsl deleted file mode 100644 index 6411b2a98e..0000000000 --- a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_Depth.hlsl +++ /dev/null @@ -1,51 +0,0 @@ -// RUN: %dxc -E main -T cs_6_9 -DCOMP=float -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DDIMM=16 -DDIMN=16 -// RUN: %dxc -E main -T cs_6_9 -DCOMP=half -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DDIMM=16 -DDIMN=16 -// RUN: %dxc -E main -T cs_6_9 -DCOMP=int8_t4_packed -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=17 -DDIMM=16 -DDIMN=16 -// RUN: %dxc -E main -T cs_6_9 -DCOMP=uint8_t4_packed -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=18 -DDIMM=16 -DDIMN=16 -// RUN: %dxc -E main -T cs_6_9 -DCOMP=float -DDIMM=64 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DDIMM=64 -DDIMN=16 -// RUN: %dxc -E main -T cs_6_9 -DCOMP=float -DDIMM=16 -DDIMN=64 %s | FileCheck %s -DCOMP=9 -DDIMM=16 -DDIMN=64 -// RUN: %dxc -E main -T cs_6_9 -DCOMP=float -DDIMM=64 -DDIMN=64 %s | FileCheck %s -DCOMP=9 -DDIMM=64 -DDIMN=64 - -// CHECK: ; Note: shader requires additional functionality: -// CHECK: ; Wave level operations -// CHECK: ; Wave Matrix - -// CHECK: define void @main() - -#ifndef COMP -#define COMP float -#endif -#ifndef DIMM -#define DIMM 16 -#endif -#ifndef DIMN -#define DIMN 16 -#endif -#ifndef WAVESIZE -#define WAVESIZE -#endif -#ifndef NUMTHREADS -#define NUMTHREADS [NumThreads(64,1,1)] -#endif - -#define WML WaveMatrixLeft -#define WMR WaveMatrixRight - -RWByteAddressBuffer rwbuf; - -WAVESIZE -NUMTHREADS -void main(uint3 gtid : SV_GroupThreadID, uint gidx : SV_GroupIndex) -{ -// CHECK: %[[wml:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: %[[wmr:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wml]], %dx.types.waveMatProps { i8 0, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wmr]], %dx.types.waveMatProps { i8 1, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) - WML left; - WMR right; - -// CHECK: call i32 @dx.op.waveMatrix_Depth(i32 227, %dx.types.waveMatProps { i8 0, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) - rwbuf.Store(0, left.MatrixDepth()); -// CHECK: call i32 @dx.op.waveMatrix_Depth(i32 227, %dx.types.waveMatProps { i8 1, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) - rwbuf.Store(4, right.MatrixDepth()); -} diff --git a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_Fill-acc.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_Fill-acc.hlsl deleted file mode 100644 index 08680257c8..0000000000 --- a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_Fill-acc.hlsl +++ /dev/null @@ -1,54 +0,0 @@ -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DDIMM=16 -DDIMN=16 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float16_t -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=8 -DDIMM=16 -DDIMN=16 -DOLOAD=f16 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=int -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=4 -DDIMM=16 -DDIMN=16 -DOLOAD=i32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=64 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DDIMM=64 -DDIMN=16 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=16 -DDIMN=64 %s | FileCheck %s -DCOMP=9 -DDIMM=16 -DDIMN=64 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=64 -DDIMN=64 %s | FileCheck %s -DCOMP=9 -DDIMM=64 -DDIMN=64 -DOLOAD=f32 - -// CHECK: ; Note: shader requires additional functionality: -// CHECK: ; Wave level operations -// CHECK: ; Wave Matrix - -// CHECK: define void @main() - -#ifndef COMP -#define COMP float -#endif -#ifndef DIMM -#define DIMM 16 -#endif -#ifndef DIMN -#define DIMN 16 -#endif -#ifndef WAVESIZE -#define WAVESIZE -#endif -#ifndef NUMTHREADS -#define NUMTHREADS [NumThreads(64,1,1)] -#endif - -#define WMLC WaveMatrixLeftColAcc -#define WMRR WaveMatrixRightRowAcc -#define WMA WaveMatrixAccumulator - -WAVESIZE -NUMTHREADS -void main(uint3 gtid : SV_GroupThreadID, uint gidx : SV_GroupIndex) -{ -// CHECK: %[[wmlc:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: %[[wmrr:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: %[[wma:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wmlc]], %dx.types.waveMatProps { i8 2, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wmrr]], %dx.types.waveMatProps { i8 3, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.waveMatProps { i8 4, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) - WMLC leftcol; - WMRR rightrow; - WMA acc; - -// CHECK: call void @dx.op.waveMatrix_Fill.[[OLOAD]](i32 228, %dx.types.waveMatrix* nonnull %[[wmlc]] - leftcol.Fill(1); -// CHECK: call void @dx.op.waveMatrix_Fill.[[OLOAD]](i32 228, %dx.types.waveMatrix* nonnull %[[wmrr]] - rightrow.Fill(2); -// CHECK: call void @dx.op.waveMatrix_Fill.[[OLOAD]](i32 228, %dx.types.waveMatrix* nonnull %[[wma]] - acc.Fill(3); -} diff --git a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_Fill-in.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_Fill-in.hlsl deleted file mode 100644 index 1fb7f00733..0000000000 --- a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_Fill-in.hlsl +++ /dev/null @@ -1,49 +0,0 @@ -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DDIMM=16 -DDIMN=16 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float16_t -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=8 -DDIMM=16 -DDIMN=16 -DOLOAD=f16 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=int8_t4_packed -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=17 -DDIMM=16 -DDIMN=16 -DOLOAD=i32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=uint8_t4_packed -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=18 -DDIMM=16 -DDIMN=16 -DOLOAD=i32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=64 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DDIMM=64 -DDIMN=16 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=16 -DDIMN=64 %s | FileCheck %s -DCOMP=9 -DDIMM=16 -DDIMN=64 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=64 -DDIMN=64 %s | FileCheck %s -DCOMP=9 -DDIMM=64 -DDIMN=64 -DOLOAD=f32 - -// CHECK: ; Note: shader requires additional functionality: -// CHECK: ; Wave level operations -// CHECK: ; Wave Matrix - -// CHECK: define void @main() - -#ifndef COMP -#define COMP float -#endif -#ifndef DIMM -#define DIMM 16 -#endif -#ifndef DIMN -#define DIMN 16 -#endif -#ifndef WAVESIZE -#define WAVESIZE -#endif -#ifndef NUMTHREADS -#define NUMTHREADS [NumThreads(64,1,1)] -#endif - -#define WML WaveMatrixLeft -#define WMR WaveMatrixRight - -WAVESIZE -NUMTHREADS -void main(uint3 gtid : SV_GroupThreadID, uint gidx : SV_GroupIndex) -{ -// CHECK: %[[wml:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: %[[wmr:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wml]], %dx.types.waveMatProps { i8 0, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wmr]], %dx.types.waveMatProps { i8 1, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) - WML left; - WMR right; - -// CHECK: call void @dx.op.waveMatrix_Fill.[[OLOAD]](i32 228, %dx.types.waveMatrix* nonnull %[[wml]] -// CHECK: call void @dx.op.waveMatrix_Fill.[[OLOAD]](i32 228, %dx.types.waveMatrix* nonnull %[[wmr]] - left.Fill(1); - right.Fill(2); -} diff --git a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_LoadStore-acc.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_LoadStore-acc.hlsl deleted file mode 100644 index bf1fe71f23..0000000000 --- a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_LoadStore-acc.hlsl +++ /dev/null @@ -1,112 +0,0 @@ -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DDIMM=16 -DDIMN=16 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float16_t -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=8 -DDIMM=16 -DDIMN=16 -DOLOAD=f16 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=int -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=4 -DDIMM=16 -DDIMN=16 -DOLOAD=i32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=64 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DDIMM=64 -DDIMN=16 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=16 -DDIMN=64 %s | FileCheck %s -DCOMP=9 -DDIMM=16 -DDIMN=64 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=64 -DDIMN=64 %s | FileCheck %s -DCOMP=9 -DDIMM=64 -DDIMN=64 -DOLOAD=f32 - -// CHECK: ; Note: shader requires additional functionality: -// CHECK: ; Wave level operations -// CHECK: ; Wave Matrix - -// CHECK: define void @main() - -#ifndef COMP -#define COMP float -#endif -#ifndef DIMM -#define DIMM 16 -#endif -#ifndef DIMN -#define DIMN 16 -#endif -#ifndef STRIDE -#define STRIDE 64 -#endif -#ifndef WAVESIZE -#define WAVESIZE -#endif -#ifndef NUMTHREADS -#define NUMTHREADS [NumThreads(64,1,1)] -#endif - -#define WMLC WaveMatrixLeftColAcc -#define WMRR WaveMatrixRightRowAcc -#define WMA WaveMatrixAccumulator - -// Should be no addrspacecast from groupshared. -// CHECK-NOT: addrspacecast - -groupshared COMP ai512[512]; - -ByteAddressBuffer buf; -RWByteAddressBuffer rwbuf; - -WAVESIZE -NUMTHREADS -void main(uint3 gtid : SV_GroupThreadID, uint gidx : SV_GroupIndex) -{ -// CHECK: %[[wmlc:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: %[[wmrr:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: %[[wma:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wmlc]], %dx.types.waveMatProps { i8 2, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wmrr]], %dx.types.waveMatProps { i8 3, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.waveMatProps { i8 4, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) - WMLC leftcol; - WMRR rightrow; - WMA acc; - - uint n = 0; -#define IDX() (n++*1024) - -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wmlc]], %dx.types.Handle %{{[^,]+}}, i32 0, i32 64, i8 0, i1 false) -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wmlc]], %dx.types.Handle %{{[^,]+}}, i32 1024, i32 64, i8 16, i1 false) -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wmlc]], %dx.types.Handle %{{[^,]+}}, i32 2048, i32 64, i8 0, i1 false) -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wmlc]], %dx.types.Handle %{{[^,]+}}, i32 3072, i32 64, i8 16, i1 false) - leftcol.Load(buf, IDX(), STRIDE); - leftcol.Load(buf, IDX(), STRIDE, 16); - leftcol.Load(rwbuf, IDX(), STRIDE); - leftcol.Load(rwbuf, IDX(), STRIDE, 16); -// CHECK: call void @dx.op.waveMatrix_StoreRawBuf(i32 231, %dx.types.waveMatrix* nonnull %[[wmlc]], %dx.types.Handle %{{[^,]+}}, i32 4096, i32 64, i8 0, i1 false) -// CHECK: call void @dx.op.waveMatrix_StoreRawBuf(i32 231, %dx.types.waveMatrix* nonnull %[[wmlc]], %dx.types.Handle %{{[^,]+}}, i32 5120, i32 64, i8 16, i1 false) - leftcol.Store(rwbuf, IDX(), STRIDE); - leftcol.Store(rwbuf, IDX(), STRIDE, 16); -// CHECK: call void @dx.op.waveMatrix_LoadGroupShared.[[OLOAD]](i32 230, %dx.types.waveMatrix* nonnull %[[wmlc]], {{.+}} addrspace(3)* {{.+}}, i32 0, i32 16, i1 false) -// CHECK: call void @dx.op.waveMatrix_StoreGroupShared.[[OLOAD]](i32 232, %dx.types.waveMatrix* nonnull %[[wmlc]], {{.+}} addrspace(3)* {{.+}}, i32 32, i32 16, i1 false) - leftcol.Load(ai512, 0, 16); - leftcol.Store(ai512, 32, 16); - -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wmrr]], %dx.types.Handle %{{[^,]+}}, i32 6144, i32 64, i8 0, i1 false) -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wmrr]], %dx.types.Handle %{{[^,]+}}, i32 7168, i32 64, i8 16, i1 false) -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wmrr]], %dx.types.Handle %{{[^,]+}}, i32 8192, i32 64, i8 0, i1 false) -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wmrr]], %dx.types.Handle %{{[^,]+}}, i32 9216, i32 64, i8 16, i1 false) - rightrow.Load(buf, IDX(), STRIDE); - rightrow.Load(buf, IDX(), STRIDE, 16); - rightrow.Load(rwbuf, IDX(), STRIDE); - rightrow.Load(rwbuf, IDX(), STRIDE, 16); -// CHECK: call void @dx.op.waveMatrix_StoreRawBuf(i32 231, %dx.types.waveMatrix* nonnull %[[wmrr]], %dx.types.Handle %{{[^,]+}}, i32 10240, i32 64, i8 0, i1 false) -// CHECK: call void @dx.op.waveMatrix_StoreRawBuf(i32 231, %dx.types.waveMatrix* nonnull %[[wmrr]], %dx.types.Handle %{{[^,]+}}, i32 11264, i32 64, i8 16, i1 false) - rightrow.Store(rwbuf, IDX(), STRIDE); - rightrow.Store(rwbuf, IDX(), STRIDE, 16); -// CHECK: call void @dx.op.waveMatrix_LoadGroupShared.[[OLOAD]](i32 230, %dx.types.waveMatrix* nonnull %[[wmrr]], {{.+}} addrspace(3)* {{.+}}, i32 48, i32 16, i1 false) -// CHECK: call void @dx.op.waveMatrix_StoreGroupShared.[[OLOAD]](i32 232, %dx.types.waveMatrix* nonnull %[[wmrr]], {{.+}} addrspace(3)* {{.+}}, i32 64, i32 16, i1 false) - rightrow.Load(ai512, 48, 16); - rightrow.Store(ai512, 64, 16); - -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.Handle %{{[^,]+}}, i32 12288, i32 64, i8 0, i1 false) -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.Handle %{{[^,]+}}, i32 13312, i32 64, i8 16, i1 true) -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.Handle %{{[^,]+}}, i32 14336, i32 64, i8 0, i1 true) -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.Handle %{{[^,]+}}, i32 15360, i32 64, i8 16, i1 false) - acc.Load(buf, IDX(), STRIDE, false); - acc.Load(buf, IDX(), STRIDE, true, 16); - acc.Load(rwbuf, IDX(), STRIDE, true); - acc.Load(rwbuf, IDX(), STRIDE, false, 16); -// CHECK: call void @dx.op.waveMatrix_StoreRawBuf(i32 231, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.Handle %{{[^,]+}}, i32 16384, i32 64, i8 0, i1 false) -// CHECK: call void @dx.op.waveMatrix_StoreRawBuf(i32 231, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.Handle %{{[^,]+}}, i32 17408, i32 64, i8 16, i1 true) - acc.Store(rwbuf, IDX(), STRIDE, false); - acc.Store(rwbuf, IDX(), STRIDE, true, 16); -// CHECK: call void @dx.op.waveMatrix_LoadGroupShared.[[OLOAD]](i32 230, %dx.types.waveMatrix* nonnull %[[wma]], {{.+}} addrspace(3)* {{.+}}, i32 80, i32 16, i1 false) -// CHECK: call void @dx.op.waveMatrix_StoreGroupShared.[[OLOAD]](i32 232, %dx.types.waveMatrix* nonnull %[[wma]], {{.+}} addrspace(3)* {{.+}}, i32 96, i32 16, i1 true) - acc.Load(ai512, 80, 16, false); - acc.Store(ai512, 96, 16, true); -} diff --git a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_LoadStore-in.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_LoadStore-in.hlsl deleted file mode 100644 index 6d1968b5e0..0000000000 --- a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_LoadStore-in.hlsl +++ /dev/null @@ -1,92 +0,0 @@ -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DDIMM=16 -DDIMN=16 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float16_t -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=8 -DDIMM=16 -DDIMN=16 -DOLOAD=f16 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=int8_t4_packed -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=17 -DDIMM=16 -DDIMN=16 -DOLOAD=i32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=uint8_t4_packed -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=18 -DDIMM=16 -DDIMN=16 -DOLOAD=i32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=64 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DDIMM=64 -DDIMN=16 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=16 -DDIMN=64 %s | FileCheck %s -DCOMP=9 -DDIMM=16 -DDIMN=64 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=64 -DDIMN=64 %s | FileCheck %s -DCOMP=9 -DDIMM=64 -DDIMN=64 -DOLOAD=f32 - -// CHECK: ; Note: shader requires additional functionality: -// CHECK: ; Wave level operations -// CHECK: ; Wave Matrix - -// CHECK: define void @main() - -#ifndef COMP -#define COMP float -#endif -#ifndef DIMM -#define DIMM 16 -#endif -#ifndef DIMN -#define DIMN 16 -#endif -#ifndef STRIDE -#define STRIDE 64 -#endif -#ifndef WAVESIZE -#define WAVESIZE -#endif -#ifndef NUMTHREADS -#define NUMTHREADS [NumThreads(64,1,1)] -#endif - -#define WML WaveMatrixLeft -#define WMR WaveMatrixRight - -// Should be no addrspacecast from groupshared. -// CHECK-NOT: addrspacecast - -groupshared COMP ai512[512]; - -ByteAddressBuffer buf; -RWByteAddressBuffer rwbuf; - -WAVESIZE -NUMTHREADS -void main(uint3 gtid : SV_GroupThreadID, uint gidx : SV_GroupIndex) -{ -// CHECK: %[[wml:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: %[[wmr:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wml]], %dx.types.waveMatProps { i8 0, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wmr]], %dx.types.waveMatProps { i8 1, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) - WML left; - WMR right; - - uint n = 0; -#define IDX() (n++*1024) - -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wml]], %dx.types.Handle %{{[^,]+}}, i32 0, i32 64, i8 0, i1 false) -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wml]], %dx.types.Handle %{{[^,]+}}, i32 1024, i32 64, i8 16, i1 true) -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wml]], %dx.types.Handle %{{[^,]+}}, i32 2048, i32 64, i8 0, i1 true) -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wml]], %dx.types.Handle %{{[^,]+}}, i32 3072, i32 64, i8 16, i1 false) - left.Load(buf, IDX(), STRIDE, false); - left.Load(buf, IDX(), STRIDE, true, 16); - left.Load(rwbuf, IDX(), STRIDE, true); - left.Load(rwbuf, IDX(), STRIDE, false, 16); -// CHECK: call void @dx.op.waveMatrix_StoreRawBuf(i32 231, %dx.types.waveMatrix* nonnull %[[wml]], %dx.types.Handle %{{[^,]+}}, i32 4096, i32 64, i8 0, i1 false) -// CHECK: call void @dx.op.waveMatrix_StoreRawBuf(i32 231, %dx.types.waveMatrix* nonnull %[[wml]], %dx.types.Handle %{{[^,]+}}, i32 5120, i32 64, i8 16, i1 true) - left.Store(rwbuf, IDX(), STRIDE, false); - left.Store(rwbuf, IDX(), STRIDE, true, 16); -// CHECK: call void @dx.op.waveMatrix_LoadGroupShared.[[OLOAD]](i32 230, %dx.types.waveMatrix* nonnull %[[wml]], {{.+}} addrspace(3)* {{.+}}, i32 0, i32 16, i1 false) -// CHECK: call void @dx.op.waveMatrix_StoreGroupShared.[[OLOAD]](i32 232, %dx.types.waveMatrix* nonnull %[[wml]], {{.+}} addrspace(3)* {{.+}}, i32 32, i32 16, i1 true) - left.Load(ai512, 0, 16, false); - left.Store(ai512, 32, 16, true); - -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wmr]], %dx.types.Handle %{{[^,]+}}, i32 6144, i32 64, i8 0, i1 false) -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wmr]], %dx.types.Handle %{{[^,]+}}, i32 7168, i32 64, i8 16, i1 true) -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wmr]], %dx.types.Handle %{{[^,]+}}, i32 8192, i32 64, i8 0, i1 true) -// CHECK: call void @dx.op.waveMatrix_LoadRawBuf(i32 229, %dx.types.waveMatrix* nonnull %[[wmr]], %dx.types.Handle %{{[^,]+}}, i32 9216, i32 64, i8 16, i1 false) - right.Load(buf, IDX(), STRIDE, false); - right.Load(buf, IDX(), STRIDE, true, 16); - right.Load(rwbuf, IDX(), STRIDE, true); - right.Load(rwbuf, IDX(), STRIDE, false, 16); -// CHECK: call void @dx.op.waveMatrix_StoreRawBuf(i32 231, %dx.types.waveMatrix* nonnull %[[wmr]], %dx.types.Handle %{{[^,]+}}, i32 10240, i32 64, i8 0, i1 false) -// CHECK: call void @dx.op.waveMatrix_StoreRawBuf(i32 231, %dx.types.waveMatrix* nonnull %[[wmr]], %dx.types.Handle %{{[^,]+}}, i32 11264, i32 64, i8 16, i1 true) - right.Store(rwbuf, IDX(), STRIDE, false); - right.Store(rwbuf, IDX(), STRIDE, true, 16); -// CHECK: call void @dx.op.waveMatrix_LoadGroupShared.[[OLOAD]](i32 230, %dx.types.waveMatrix* nonnull %[[wmr]], {{.+}} addrspace(3)* {{.+}}, i32 48, i32 16, i1 false) -// CHECK: call void @dx.op.waveMatrix_StoreGroupShared.[[OLOAD]](i32 232, %dx.types.waveMatrix* nonnull %[[wmr]], {{.+}} addrspace(3)* {{.+}}, i32 64, i32 16, i1 true) - right.Load(ai512, 48, 16, false); - right.Store(ai512, 64, 16, true); -} diff --git a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_Multiply-Add-acc.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_Multiply-Add-acc.hlsl deleted file mode 100644 index 3fbdccac4d..0000000000 --- a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_Multiply-Add-acc.hlsl +++ /dev/null @@ -1,81 +0,0 @@ -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DCOMP_IN=float16_t -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DCOMP_IN=8 -DDIMM=16 -DDIMN=16 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DCOMP_IN=float -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DCOMP_IN=9 -DDIMM=16 -DDIMN=16 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float16_t -DCOMP_IN=float16_t -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=8 -DCOMP_IN=8 -DDIMM=16 -DDIMN=16 -DOLOAD=f16 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=int -DCOMP_IN=int8_t4_packed -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=4 -DCOMP_IN=17 -DDIMM=16 -DDIMN=16 -DOLOAD=i32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=int -DCOMP_IN=uint8_t4_packed -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=4 -DCOMP_IN=18 -DDIMM=16 -DDIMN=16 -DOLOAD=i32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DCOMP_IN=float -DDIMM=64 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DCOMP_IN=9 -DDIMM=64 -DDIMN=16 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DCOMP_IN=float -DDIMM=16 -DDIMN=64 %s | FileCheck %s -DCOMP=9 -DCOMP_IN=9 -DDIMM=16 -DDIMN=64 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DCOMP_IN=float -DDIMM=64 -DDIMN=64 %s | FileCheck %s -DCOMP=9 -DCOMP_IN=9 -DDIMM=64 -DDIMN=64 -DOLOAD=f32 - -// CHECK: ; Note: shader requires additional functionality: -// CHECK: ; Wave level operations -// CHECK: ; Wave Matrix - -// CHECK: define void @main() - -#ifndef COMP -#define COMP float -#endif -#ifndef COMP_IN -#define COMP float -#endif -#ifndef DIMM -#define DIMM 16 -#endif -#ifndef DIMN -#define DIMN 16 -#endif -#ifndef STRIDE -#define STRIDE 64 -#endif -#ifndef WAVESIZE -#define WAVESIZE -#endif -#ifndef NUMTHREADS -#define NUMTHREADS [NumThreads(64,1,1)] -#endif - -#define WML WaveMatrixLeft -#define WMR WaveMatrixRight -#define WMLC WaveMatrixLeftColAcc -#define WMRR WaveMatrixRightRowAcc -#define WMA WaveMatrixAccumulator - -WAVESIZE -NUMTHREADS -void main(uint3 gtid : SV_GroupThreadID, uint gidx : SV_GroupIndex) -{ -// CHECK: %[[wml:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: %[[wmr:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: %[[wmlc:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: %[[wmrr:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: %[[wma:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: %[[wma2:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wml]], %dx.types.waveMatProps { i8 0, i8 [[COMP_IN]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wmr]], %dx.types.waveMatProps { i8 1, i8 [[COMP_IN]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wmlc]], %dx.types.waveMatProps { i8 2, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wmrr]], %dx.types.waveMatProps { i8 3, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.waveMatProps { i8 4, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wma2]], %dx.types.waveMatProps { i8 4, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) - WML left; - WMR right; - WMLC leftcol; - WMRR rightrow; - WMA acc; - WMA acc2; - -// CHECK: call void @dx.op.waveMatrix_Multiply(i32 233, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.waveMatrix* nonnull %[[wml]], %dx.types.waveMatrix* nonnull %[[wmr]]) -// CHECK: call void @dx.op.waveMatrix_Multiply(i32 234, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.waveMatrix* nonnull %[[wml]], %dx.types.waveMatrix* nonnull %[[wmr]]) - acc.Multiply(left, right); - acc.MultiplyAccumulate(left, right); - -// CHECK: call void @dx.op.waveMatrix_Accumulate(i32 237, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.waveMatrix* nonnull %[[wmlc]]) -// CHECK: call void @dx.op.waveMatrix_Accumulate(i32 237, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.waveMatrix* nonnull %[[wmrr]]) - acc.Add(leftcol); - acc.Add(rightrow); - -// CHECK: call void @dx.op.waveMatrix_Accumulate(i32 237, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.waveMatrix* nonnull %[[wma2]]) -// CHECK: call void @dx.op.waveMatrix_Accumulate(i32 237, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.waveMatrix* nonnull %[[wma]]) - acc.Add(acc2); - acc.Add(acc); -} diff --git a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_ScalarOps-acc.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_ScalarOps-acc.hlsl deleted file mode 100644 index 253df66027..0000000000 --- a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_ScalarOps-acc.hlsl +++ /dev/null @@ -1,83 +0,0 @@ -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DDIMM=16 -DDIMN=16 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float16_t -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=8 -DDIMM=16 -DDIMN=16 -DOLOAD=f16 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=int -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=4 -DDIMM=16 -DDIMN=16 -DOLOAD=i32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=64 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DDIMM=64 -DDIMN=16 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=16 -DDIMN=64 %s | FileCheck %s -DCOMP=9 -DDIMM=16 -DDIMN=64 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DDIMM=64 -DDIMN=64 %s | FileCheck %s -DCOMP=9 -DDIMM=64 -DDIMN=64 -DOLOAD=f32 - -// CHECK: ; Note: shader requires additional functionality: -// CHECK: ; Wave level operations -// CHECK: ; Wave Matrix - -// CHECK: define void @main() - -#ifndef COMP -#define COMP float -#endif -#ifndef DIMM -#define DIMM 16 -#endif -#ifndef DIMN -#define DIMN 16 -#endif -#ifndef STRIDE -#define STRIDE 64 -#endif -#ifndef WAVESIZE -#define WAVESIZE -#endif -#ifndef NUMTHREADS -#define NUMTHREADS [NumThreads(64,1,1)] -#endif - -#define WMLC WaveMatrixLeftColAcc -#define WMRR WaveMatrixRightRowAcc -#define WMA WaveMatrixAccumulator - -ByteAddressBuffer buf; -RWByteAddressBuffer rwbuf; - -#define MAKE_TEST_SCALAR(typ) \ - void testScalar(typ mat) { \ - mat.ScalarMultiply(4); \ - mat.ScalarDivide(4); \ - mat.ScalarAdd(4); \ - mat.ScalarSubtract(4); \ - } - -MAKE_TEST_SCALAR(WMLC) -MAKE_TEST_SCALAR(WMRR) -MAKE_TEST_SCALAR(WMA) - -WAVESIZE -NUMTHREADS -void main(uint3 gtid : SV_GroupThreadID, uint gidx : SV_GroupIndex) -{ -// CHECK: %[[wmlc:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: %[[wmrr:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: %[[wma:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wmlc]], %dx.types.waveMatProps { i8 2, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wmrr]], %dx.types.waveMatProps { i8 3, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wma]], %dx.types.waveMatProps { i8 4, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) - WMLC leftcol; - WMRR rightrow; - WMA acc; - -// CHECK: call void @dx.op.waveMatrix_ScalarOp.[[OLOAD]](i32 235, %dx.types.waveMatrix* nonnull %[[wmlc]], i8 2, -// CHECK: call void @dx.op.waveMatrix_ScalarOp.[[OLOAD]](i32 235, %dx.types.waveMatrix* nonnull %[[wmlc]], i8 3, -// CHECK: call void @dx.op.waveMatrix_ScalarOp.[[OLOAD]](i32 235, %dx.types.waveMatrix* nonnull %[[wmlc]], i8 0, -// CHECK: call void @dx.op.waveMatrix_ScalarOp.[[OLOAD]](i32 235, %dx.types.waveMatrix* nonnull %[[wmlc]], i8 1, - testScalar(leftcol); - -// CHECK: call void @dx.op.waveMatrix_ScalarOp.[[OLOAD]](i32 235, %dx.types.waveMatrix* nonnull %[[wmrr]], i8 2, -// CHECK: call void @dx.op.waveMatrix_ScalarOp.[[OLOAD]](i32 235, %dx.types.waveMatrix* nonnull %[[wmrr]], i8 3, -// CHECK: call void @dx.op.waveMatrix_ScalarOp.[[OLOAD]](i32 235, %dx.types.waveMatrix* nonnull %[[wmrr]], i8 0, -// CHECK: call void @dx.op.waveMatrix_ScalarOp.[[OLOAD]](i32 235, %dx.types.waveMatrix* nonnull %[[wmrr]], i8 1, - testScalar(rightrow); - -// CHECK: call void @dx.op.waveMatrix_ScalarOp.[[OLOAD]](i32 235, %dx.types.waveMatrix* nonnull %[[wma]], i8 2, -// CHECK: call void @dx.op.waveMatrix_ScalarOp.[[OLOAD]](i32 235, %dx.types.waveMatrix* nonnull %[[wma]], i8 3, -// CHECK: call void @dx.op.waveMatrix_ScalarOp.[[OLOAD]](i32 235, %dx.types.waveMatrix* nonnull %[[wma]], i8 0, -// CHECK: call void @dx.op.waveMatrix_ScalarOp.[[OLOAD]](i32 235, %dx.types.waveMatrix* nonnull %[[wma]], i8 1, - testScalar(acc); -} diff --git a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_SumAccumulate-acc.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_SumAccumulate-acc.hlsl deleted file mode 100644 index c606e04f78..0000000000 --- a/tools/clang/test/HLSLFileCheck/hlsl/objects/WaveMatrix/WaveMatrix_SumAccumulate-acc.hlsl +++ /dev/null @@ -1,64 +0,0 @@ -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DCOMP_IN=float -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DCOMP_IN=9 -DDIMM=16 -DDIMN=16 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DCOMP_IN=float16_t -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DCOMP_IN=8 -DDIMM=16 -DDIMN=16 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float16_t -DCOMP_IN=float16_t -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=8 -DCOMP_IN=8 -DDIMM=16 -DDIMN=16 -DOLOAD=f16 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=int -DCOMP_IN=int8_t4_packed -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=4 -DCOMP_IN=17 -DDIMM=16 -DDIMN=16 -DOLOAD=i32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=int -DCOMP_IN=uint8_t4_packed -DDIMM=16 -DDIMN=16 %s | FileCheck %s -DCOMP=4 -DCOMP_IN=18 -DDIMM=16 -DDIMN=16 -DOLOAD=i32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DCOMP_IN=float -DDIMM=64 -DDIMN=16 %s | FileCheck %s -DCOMP=9 -DCOMP_IN=9 -DDIMM=64 -DDIMN=16 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DCOMP_IN=float -DDIMM=16 -DDIMN=64 %s | FileCheck %s -DCOMP=9 -DCOMP_IN=9 -DDIMM=16 -DDIMN=64 -DOLOAD=f32 -// RUN: %dxc -enable-16bit-types -T cs_6_9 -DCOMP=float -DCOMP_IN=float -DDIMM=64 -DDIMN=64 %s | FileCheck %s -DCOMP=9 -DCOMP_IN=9 -DDIMM=64 -DDIMN=64 -DOLOAD=f32 - -// CHECK: ; Note: shader requires additional functionality: -// CHECK: ; Wave level operations -// CHECK: ; Wave Matrix - -// CHECK: define void @main() - -#ifndef COMP -#define COMP float -#endif -#ifndef COMP_IN -#define COMP float -#endif -#ifndef DIMM -#define DIMM 16 -#endif -#ifndef DIMN -#define DIMN 16 -#endif -#ifndef STRIDE -#define STRIDE 64 -#endif -#ifndef WAVESIZE -#define WAVESIZE -#endif -#ifndef NUMTHREADS -#define NUMTHREADS [NumThreads(64,1,1)] -#endif - -#define WML WaveMatrixLeft -#define WMR WaveMatrixRight -#define WMLC WaveMatrixLeftColAcc -#define WMRR WaveMatrixRightRowAcc - -WAVESIZE -NUMTHREADS -void main(uint3 gtid : SV_GroupThreadID, uint gidx : SV_GroupIndex) -{ -// CHECK: %[[wml:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: %[[wmr:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: %[[wmlc:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: %[[wmrr:.*]] = alloca %dx.types.waveMatrix, align 4 -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wml]], %dx.types.waveMatProps { i8 0, i8 [[COMP_IN]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wmr]], %dx.types.waveMatProps { i8 1, i8 [[COMP_IN]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wmlc]], %dx.types.waveMatProps { i8 2, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) -// CHECK: call void @dx.op.waveMatrix_Annotate(i32 226, %dx.types.waveMatrix* nonnull %[[wmrr]], %dx.types.waveMatProps { i8 3, i8 [[COMP]], i32 [[DIMM]], i32 [[DIMN]] }) - WML left; - WMR right; - WMLC leftcol; - WMRR rightrow; - -// CHECK: call void @dx.op.waveMatrix_Accumulate(i32 236, %dx.types.waveMatrix* nonnull %[[wmlc]], %dx.types.waveMatrix* nonnull %[[wml]]) -// CHECK: call void @dx.op.waveMatrix_Accumulate(i32 236, %dx.types.waveMatrix* nonnull %[[wmrr]], %dx.types.waveMatrix* nonnull %[[wmr]]) - leftcol.SumAccumulate(left); - rightrow.SumAccumulate(right); -} diff --git a/tools/clang/test/HLSLFileCheck/hlsl/workgraph/globallycoherent_record.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/workgraph/globallycoherent_record.hlsl index fb95524926..e0f5a29291 100644 --- a/tools/clang/test/HLSLFileCheck/hlsl/workgraph/globallycoherent_record.hlsl +++ b/tools/clang/test/HLSLFileCheck/hlsl/workgraph/globallycoherent_record.hlsl @@ -20,13 +20,13 @@ // TBD: We should be annotating handles used inside called functions // xHLCHECK: define internal void @"\01?DoIt -// xHLCHECK: %[[FuncAnnInputRecHandle:[0-9]+]] = call %dx.types.NodeRecordHandle @"dx.hl.annotatenoderecordhandle..%dx.types.NodeRecordHandle (i32, %dx.types.NodeRecordHandle, %dx.types.NodeRecordInfo)"(i32 17, %dx.types.NodeRecordHandle %[[FuncInputRecHandle]], %dx.types.NodeRecordInfo { i32 613, i32 4 }) +// xHLCHECK: %[[FuncAnnInputRecHandle:[0-9]+]] = call %dx.types.NodeRecordHandle @"dx.hl.annotatenoderecordhandle..%dx.types.NodeRecordHandle (i32, %dx.types.NodeRecordHandle, %dx.types.NodeRecordInfo)"(i32 16, %dx.types.NodeRecordHandle %[[FuncInputRecHandle]], %dx.types.NodeRecordInfo { i32 613, i32 4 }) // HLCHECK: define void @firstNode // HLCHECK: %[[CreateInputHandle:[0-9]+]] = call %dx.types.NodeRecordHandle @"dx.hl.createnodeinputrecordhandle..%dx.types.NodeRecordHandle (i32, i32)"(i32 13, i32 0) // Check that NodeIOFlags is 613 = GloballyCoherent(512) | DispatchRecord(96) | ReadWrite(4) | Input(1) -// HLCHECK: %[[AnnInputRecHandle:[0-9]+]] = call %dx.types.NodeRecordHandle @"dx.hl.annotatenoderecordhandle..%dx.types.NodeRecordHandle (i32, %dx.types.NodeRecordHandle, %dx.types.NodeRecordInfo)"(i32 17, %dx.types.NodeRecordHandle %[[CreateInputHandle]], %dx.types.NodeRecordInfo { i32 613, i32 4 }) +// HLCHECK: %[[AnnInputRecHandle:[0-9]+]] = call %dx.types.NodeRecordHandle @"dx.hl.annotatenoderecordhandle..%dx.types.NodeRecordHandle (i32, %dx.types.NodeRecordHandle, %dx.types.NodeRecordInfo)"(i32 16, %dx.types.NodeRecordHandle %[[CreateInputHandle]], %dx.types.NodeRecordInfo { i32 613, i32 4 }) // ==== D3DReflect Checks ==== diff --git a/tools/clang/test/SemaHLSL/hlsl/objects/WaveMatrix/array_of_wave_matrix.hlsl b/tools/clang/test/SemaHLSL/hlsl/objects/WaveMatrix/array_of_wave_matrix.hlsl deleted file mode 100644 index 3e61715f74..0000000000 --- a/tools/clang/test/SemaHLSL/hlsl/objects/WaveMatrix/array_of_wave_matrix.hlsl +++ /dev/null @@ -1,30 +0,0 @@ -// RUN: %dxc -Tlib_6_8 -verify %s - -void foo() { - -WaveMatrixLeft left[2]; // expected-error {{declaration of type WaveMatrixLeft may not be an array}} -WaveMatrixRight right[2]; // expected-error {{declaration of type WaveMatrixRight may not be an array}} -WaveMatrixLeftColAcc leftCol[2]; // expected-error {{declaration of type WaveMatrixLeftColAcc may not be an array}} -WaveMatrixRightRowAcc rightRow[2]; // expected-error {{declaration of type WaveMatrixRightRowAcc may not be an array}} -WaveMatrixAccumulator acc[2]; // expected-error {{declaration of type WaveMatrixAccumulator may not be an array}} - -} - -void bar( -WaveMatrixLeft left[2], // expected-error {{declaration of type WaveMatrixLeft may not be an array}} -WaveMatrixRight right[2], // expected-error {{declaration of type WaveMatrixRight may not be an array}} -WaveMatrixLeftColAcc leftCol[2], // expected-error {{declaration of type WaveMatrixLeftColAcc may not be an array}} -WaveMatrixRightRowAcc rightRow[2], // expected-error {{declaration of type WaveMatrixRightRowAcc may not be an array}} -WaveMatrixAccumulator acc[2] // expected-error {{declaration of type WaveMatrixAccumulator may not be an array}} -) { - -} - -struct S { -WaveMatrixLeft left[2]; // expected-error {{declaration of type WaveMatrixLeft may not be an array}} -WaveMatrixRight right[2]; // expected-error {{declaration of type WaveMatrixRight may not be an array}} -WaveMatrixLeftColAcc leftCol[2]; // expected-error {{declaration of type WaveMatrixLeftColAcc may not be an array}} -WaveMatrixRightRowAcc rightRow[2]; // expected-error {{declaration of type WaveMatrixRightRowAcc may not be an array}} -WaveMatrixAccumulator acc[2]; // expected-error {{declaration of type WaveMatrixAccumulator may not be an array}} - -}; diff --git a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp index 0479926606..f22e99e467 100644 --- a/tools/clang/unittests/HLSLExec/ExecutionTest.cpp +++ b/tools/clang/unittests/HLSLExec/ExecutionTest.cpp @@ -501,10 +501,6 @@ class ExecutionTest { L"Table:ShaderOpArithTable.xml#PackUnpackOpTable") END_TEST_METHOD() - TEST_METHOD(WaveMatrixLoadStoreTests); - TEST_METHOD(WaveMatrixScalarTests); - TEST_METHOD(WaveMatrixMathTests); - dxc::DxcDllSupport m_support; bool m_D3DInitCompleted = false; @@ -716,64 +712,6 @@ class ExecutionTest { template const wchar_t *BasicShaderModelTest_GetFormatString(); - CComPtr - WaveMatrixTestCommonSetup(std::vector &dimMs, std::vector &dimNs, - std::shared_ptr &shaderOpSet) { - WEX::TestExecution::SetVerifyOutput verifySettings( - WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); - - CComPtr pDevice; - D3D_SHADER_MODEL model = D3D_SHADER_MODEL_6_9; - - if (!CreateDevice(&pDevice, model)) { - return nullptr; - } - - if (!DoesDeviceSupportWaveMatrix(pDevice)) { - LogCommentFmt(L"WaveMatrix not supported on this device."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); - return nullptr; - } - - CComPtr pStream; - ReadHlslDataIntoNewStream(L"ShaderOpArith.xml", &pStream); - - shaderOpSet = std::make_shared(); - st::ParseShaderOpSetFromStream(pStream, shaderOpSet.get()); - - dimMs = {16, 64}; - dimNs = {16, 64}; - - std::wstring split; - - // Parse DimM - WEX::Common::String dimMList; - WEX::TestExecution::RuntimeParameters::TryGetValue(L"Wmma_DimM", dimMList); - - if (!dimMList.IsEmpty()) { - dimMs.clear(); - wstringstream ss((const wchar_t *)dimMList); - - while (std::getline(ss, split, L',')) { - dimMs.emplace_back(std::stoi(split)); - } - } - - // Parse DimN - WEX::Common::String dimNList; - WEX::TestExecution::RuntimeParameters::TryGetValue(L"Wmma_DimN", dimNList); - if (!dimNList.IsEmpty()) { - dimNs.clear(); - wstringstream ss((const wchar_t *)dimNList); - - while (std::getline(ss, split, L',')) { - dimNs.emplace_back(std::stoi(split)); - } - } - - return pDevice; - } - void CompileFromText(LPCSTR pText, LPCWSTR pEntryPoint, LPCWSTR pTargetProfile, ID3DBlob **ppBlob, LPCWSTR *pOptions = nullptr, int numOptions = 0) { @@ -1607,19 +1545,6 @@ class ExecutionTest { #endif } - bool DoesDeviceSupportWaveMatrix(ID3D12Device *pDevice) { -#if defined(NTDDI_WIN10_FE) && WDK_NTDDI_VERSION >= NTDDI_WIN10_FE - D3D12_FEATURE_DATA_D3D12_OPTIONS9 O9; - if (FAILED(pDevice->CheckFeatureSupport( - (D3D12_FEATURE)D3D12_FEATURE_D3D12_OPTIONS9, &O9, sizeof(O9)))) - return false; - return O9.WaveMMATier >= D3D12_WAVE_MMA_TIER_1_0; -#else - UNREFERENCED_PARAMETER(pDevice); - return false; -#endif - } - bool DoesDeviceSupportAdvancedTexOps(ID3D12Device *pDevice) { #if defined(NTDDI_WIN10_CU) && WDK_NTDDI_VERSION >= NTDDI_WIN10_CU D3D12_FEATURE_DATA_D3D12_OPTIONS14 O14; @@ -6379,16 +6304,6 @@ static TableParameter TertiaryUint16OpParameters[] = { {L"Validation.Tolerance", TableParameter::INT32, true}, }; -static TableParameter WaveMatrixOpParameters[] = { - {L"Validation.Type", TableParameter::STRING, true}, - {L"Validation.Tolerance", TableParameter::DOUBLE, true}, - {L"ShaderOp.Target", TableParameter::STRING, true}, - {L"LoadStoreShaderOp.Text", TableParameter::STRING, true}, - {L"ScalarShaderOp.Text", TableParameter::STRING, true}, - {L"MathShaderOp.Text", TableParameter::STRING, true}, - {L"ScalarValidation.Scalar", TableParameter::STRING_TABLE, true}, -}; - static TableParameter DotOpParameters[] = { {L"ShaderOp.Target", TableParameter::STRING, true}, {L"ShaderOp.Text", TableParameter::STRING, true}, @@ -9042,1459 +8957,6 @@ void LoadStoreMat(int M, int N, bool LEFT, int MEM_TYPE, uint32_t K, uint32_t k, } } -// Define WAVE_MMA types if building with SDK that does not support it yet. -// For now: gate this on D3D12_EXPERIMENTAL_WAVE_MATRIX define until we know -// the version and the define is removed. -// #if !defined(D3D12_SDK_VERSION) || (D3D12_SDK_VERSION < 613) -#if !defined(D3D12_EXPERIMENTAL_WAVE_MATRIX) -typedef enum D3D12_WAVE_MMA_INPUT_DATATYPE { - D3D12_WAVE_MMA_INPUT_DATATYPE_INVALID = 0, - D3D12_WAVE_MMA_INPUT_DATATYPE_BYTE = - (D3D12_WAVE_MMA_INPUT_DATATYPE_INVALID + 1), - D3D12_WAVE_MMA_INPUT_DATATYPE_FLOAT16 = - (D3D12_WAVE_MMA_INPUT_DATATYPE_BYTE + 1), - D3D12_WAVE_MMA_INPUT_DATATYPE_FLOAT = - (D3D12_WAVE_MMA_INPUT_DATATYPE_FLOAT16 + 1) -} D3D12_WAVE_MMA_INPUT_DATATYPE; - -typedef enum D3D12_WAVE_MMA_DIMENSION { - D3D12_WAVE_MMA_DIMENSION_INVALID = 0, - D3D12_WAVE_MMA_DIMENSION_16 = (D3D12_WAVE_MMA_DIMENSION_INVALID + 1), - D3D12_WAVE_MMA_DIMENSION_64 = (D3D12_WAVE_MMA_DIMENSION_16 + 1) -} D3D12_WAVE_MMA_DIMENSION; - -typedef enum D3D12_WAVE_MMA_ACCUM_DATATYPE { - D3D12_WAVE_MMA_ACCUM_DATATYPE_NONE = 0, - D3D12_WAVE_MMA_ACCUM_DATATYPE_INT32 = 0x1, - D3D12_WAVE_MMA_ACCUM_DATATYPE_FLOAT16 = 0x2, - D3D12_WAVE_MMA_ACCUM_DATATYPE_FLOAT = 0x4 -} D3D12_WAVE_MMA_ACCUM_DATATYPE; - -typedef struct D3D12_FEATURE_DATA_WAVE_MMA { - D3D12_WAVE_MMA_INPUT_DATATYPE InputDataType; - D3D12_WAVE_MMA_DIMENSION M; - D3D12_WAVE_MMA_DIMENSION N; - BOOL Supported; - UINT K; - D3D12_WAVE_MMA_ACCUM_DATATYPE AccumDataTypes; - UINT RequiredWaveLaneCountMin; - UINT RequiredWaveLaneCountMax; -} D3D12_FEATURE_DATA_WAVE_MMA; -#endif //! defined(D3D12_EXPERIMENTAL_WAVE_MATRIX) - -D3D12_FEATURE_DATA_WAVE_MMA checkWaveMMASupport(CComPtr pDevice, - std::string &dataTypeInShader, - int DIM_M, int DIM_N) { - D3D12_FEATURE_DATA_WAVE_MMA waveMmaSupport = {}; - - if (dataTypeInShader == "float" || dataTypeInShader == "float32_t") { - waveMmaSupport.InputDataType = D3D12_WAVE_MMA_INPUT_DATATYPE_FLOAT; - waveMmaSupport.AccumDataTypes = D3D12_WAVE_MMA_ACCUM_DATATYPE_FLOAT; - } else if (dataTypeInShader == "half" || dataTypeInShader == "float16_t") { - waveMmaSupport.InputDataType = D3D12_WAVE_MMA_INPUT_DATATYPE_FLOAT16; - waveMmaSupport.AccumDataTypes = D3D12_WAVE_MMA_ACCUM_DATATYPE_FLOAT16; - } else if (dataTypeInShader == "uint8_t4_packed" || - dataTypeInShader == "int8_t4_packed" || - dataTypeInShader == "uint" || dataTypeInShader == "int") { - waveMmaSupport.InputDataType = D3D12_WAVE_MMA_INPUT_DATATYPE_BYTE; - waveMmaSupport.AccumDataTypes = D3D12_WAVE_MMA_ACCUM_DATATYPE_INT32; - } - - switch (DIM_M) { - case 16: - waveMmaSupport.M = D3D12_WAVE_MMA_DIMENSION_16; - break; - case 64: - waveMmaSupport.M = D3D12_WAVE_MMA_DIMENSION_64; - break; - default: - DXASSERT_NOMSG(false); - } - switch (DIM_N) { - case 16: - waveMmaSupport.N = D3D12_WAVE_MMA_DIMENSION_16; - break; - case 64: - waveMmaSupport.N = D3D12_WAVE_MMA_DIMENSION_64; - break; - default: - DXASSERT_NOMSG(false); - } - - // Defaults, to be overwritten by the call - waveMmaSupport.K = 16; - waveMmaSupport.Supported = false; - waveMmaSupport.RequiredWaveLaneCountMin = 0; - waveMmaSupport.RequiredWaveLaneCountMax = 0; - - // In preview, D3D12_FEATURE_WAVE_MMA = 38, - uint32_t D3D12_FEATURE_WAVE_MMA = 38; - pDevice->CheckFeatureSupport((D3D12_FEATURE)D3D12_FEATURE_WAVE_MMA, - &waveMmaSupport, - sizeof(D3D12_FEATURE_DATA_WAVE_MMA)); - - int waveSize; - if (SUCCEEDED(WEX::TestExecution::RuntimeParameters::TryGetValue( - L"Wmma_ForceWaveSize", waveSize))) { - waveMmaSupport.RequiredWaveLaneCountMin = waveSize; - waveMmaSupport.RequiredWaveLaneCountMax = waveSize; - } - - D3D12_FEATURE_DATA_D3D12_OPTIONS1 O; - HRESULT hr = pDevice->CheckFeatureSupport( - (D3D12_FEATURE)D3D12_FEATURE_D3D12_OPTIONS1, &O, sizeof(O)); - VERIFY_SUCCEEDED(hr); - - if (waveMmaSupport.RequiredWaveLaneCountMin == 0) { - waveMmaSupport.RequiredWaveLaneCountMin = O.WaveLaneCountMin; - } - - if (waveMmaSupport.RequiredWaveLaneCountMax == 0) { - waveMmaSupport.RequiredWaveLaneCountMax = O.WaveLaneCountMax; - } - - int forceK; - if (SUCCEEDED(WEX::TestExecution::RuntimeParameters::TryGetValue( - L"Wmma_ForceK", forceK))) { - waveMmaSupport.K = forceK; - } - - return waveMmaSupport; -} - -template std::string TypeIdToHlsl() { - if (typeid(T) == typeid(float)) - return "float32_t"; - else if (typeid(T) == typeid(DirectX::PackedVector::HALF)) - return "float16_t"; - else if (typeid(T) == typeid(uint8_t)) - return "uint8_t4_packed"; - else if (typeid(T) == typeid(int8_t)) - return "int8_t4_packed"; - else if (typeid(T) == typeid(int32_t)) - return "int32_t"; - - DXASSERT_NOMSG(false); - return ""; -} - -template bool ArgContainsDataType(std::wstring argName) { - std::string dTypeName = TypeIdToHlsl(); - std::wstring wdTypeName(dTypeName.begin(), dTypeName.end()); - WEX::Common::String dataTypeList; - WEX::TestExecution::RuntimeParameters::TryGetValue(argName.c_str(), - dataTypeList); - - if (!dataTypeList.IsEmpty()) { - dataTypeList.ToLower(); - return dataTypeList.Find(wdTypeName.c_str()) != -1; - } - - // No datatype arg arg means it contains all datatypes - return true; -} - -template -void WaveMatrixLoadStoreTest(int DIM_M, int DIM_N, int MEM_TYPE, - CComPtr pDevice, - std::shared_ptr ShaderOpSet, - dxc::DxcDllSupport &support, - PCWSTR Validation_type, double tolerance) { - using namespace DirectX::PackedVector; - using namespace WMMA; - std::string dataTypeInShader = TypeIdToHlsl(); - - string typeAcc = TypeIdToHlsl(); - D3D12_FEATURE_DATA_WAVE_MMA waveMmaSupport = - checkWaveMMASupport(pDevice, dataTypeInShader, DIM_M, DIM_N); - - std::string groupName = std::string("WMMALoadStore/") + - memTypeStrs[MEM_TYPE] + "/M" + std::to_string(DIM_M) + - "/N" + std::to_string(DIM_N) + "/" + - TypeIdToHlsl() + "/Accum" + typeAcc; - WEX::Logging::Log::StartGroup(CA2W(groupName.c_str())); - - bool accTypeSupported = - (typeid(TYPE_ACC) == typeid(HALF) && - (waveMmaSupport.AccumDataTypes & - D3D12_WAVE_MMA_ACCUM_DATATYPE_FLOAT16)) || - (typeid(TYPE_ACC) == typeid(float) && - (waveMmaSupport.AccumDataTypes & D3D12_WAVE_MMA_ACCUM_DATATYPE_FLOAT)) || - (typeid(TYPE_ACC) == typeid(int32_t) && - (waveMmaSupport.AccumDataTypes & D3D12_WAVE_MMA_ACCUM_DATATYPE_INT32)); - - // We need to predict type acc as float16/32 with a template. So we try both - // and return early from the incorrect prediction. - if (!ArgContainsDataType(L"Wmma_Type") || - !ArgContainsDataType(L"Wmma_AccumType") || !accTypeSupported || - !waveMmaSupport.Supported) { - - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); - WEX::Logging::Log::EndGroup(CA2W(groupName.c_str())); - return; - } - - int disableFragmentTests = 0; - WEX::TestExecution::RuntimeParameters::TryGetValue( - L"Wmma_DisableFragmentTests", disableFragmentTests); - - int DIM_K = waveMmaSupport.K; // Get K dim - - // We store left/right matrices in the same array so we just assume a maximum - // size. This size applies to accumulators as well. - constexpr int NUM_ELEMENTS = 64 * 64; - - // Create zeroed arrays for expected results. make room for 2x the normal size - // in each expected result so we have room to test offsetted loads - std::array, TOTAL_LOAD_STORE_OUTPUTS> - expectedMatrices{{}}; - std::array, TOTAL_LOAD_STORE_OUTPUTS> - expectedRowCols{{}}; - std::array, - TOTAL_ACCUM_LOAD_STORE_OUTPUTS> - expectedAccumulatorMatrices{{}}; - - // Specify defines in the shader for datatype and size, (and wave size later) - std::stringstream argsStream; - argsStream << " -DDIM_M=" << DIM_M << " -DDIM_N=" << DIM_N - << " -DDIM_K=" << DIM_K << " -DDATATYPE=" << dataTypeInShader - << " -DELEMENTSIZE=" << sizeof(T) << " -enable-16bit-types" - << " -DTYPE_ACC=" << typeAcc << " -DFRAGS_ENABLE=" - << static_cast(disableFragmentTests == 0) - << " -DNUM_LANES=" << waveMmaSupport.RequiredWaveLaneCountMin; - - bool doLeftRightTest = true; - bool doAccumTest = true; - std::stringstream argsStream2; - std::string initialArgsString = argsStream.str(); - std::string argsString; - - // this callback is called when the test - // is creating the resource to run the test - auto callback = [&](LPCSTR Name, std::vector &Data, - st::ShaderOp *pShaderOp) { - if (0 == _stricmp(Name, "g_bufIn") && doLeftRightTest) { - std::fill(Data.begin(), Data.end(), (BYTE)0); - GenerateMatrix((T *)Data.data(), NUM_ELEMENTS * 2, 0, - NUM_ELEMENTS * 2); - - BYTE *src = (BYTE *)Data.data(); - uint32_t a = 4; // alignment OR if MEM_TYPE is groupshared, acts as - // additional storage offset in elements. - size_t s = 16 * sizeof(T); // start - uint32_t lStride = DIM_K * sizeof(T); - uint32_t rStride = DIM_N * sizeof(T); - uint32_t ltStride = DIM_M * sizeof(T); - uint32_t rtStride = DIM_K * sizeof(T); - uint32_t lStrideP4 = lStride + 4; - uint32_t rStrideP4 = rStride + 4; - uint32_t ltStrideP4 = ltStride + 4; - uint32_t rtStrideP4 = rtStride + 4; - - // Generate expected values - - // Load - LoadStoreMat(DIM_M, DIM_N, true, MEM_TYPE, DIM_K, DIM_K, s, lStride, 0, - false, src, - (BYTE *)expectedMatrices[LOAD_LEFT_START].data()); - LoadStoreMat(DIM_M, DIM_N, false, MEM_TYPE, DIM_K, DIM_K, s, rStride, - 0, false, src, - (BYTE *)expectedMatrices[LOAD_RIGHT_START].data()); - - LoadStoreMat(DIM_M, DIM_N, true, MEM_TYPE, DIM_K, DIM_K, 0, lStrideP4, - 0, false, src, - (BYTE *)expectedMatrices[LOAD_LEFT_STRIDE_P4].data()); - LoadStoreMat(DIM_M, DIM_N, false, MEM_TYPE, DIM_K, DIM_K, 0, rStrideP4, - 0, false, src, - (BYTE *)expectedMatrices[LOAD_RIGHT_STRIDE_P4].data()); - - LoadStoreMat(DIM_M, DIM_N, true, MEM_TYPE, DIM_K, DIM_K, 0, - lStride * 2, 0, false, src, - (BYTE *)expectedMatrices[LOAD_LEFT_STRIDE_X2].data()); - LoadStoreMat(DIM_M, DIM_N, false, MEM_TYPE, DIM_K, DIM_K, 0, - rStride * 2, 0, false, src, - (BYTE *)expectedMatrices[LOAD_RIGHT_STRIDE_X2].data()); - - LoadStoreMat(DIM_M, DIM_N, true, MEM_TYPE, DIM_K, DIM_K, 0, lStride, a, - false, src, - (BYTE *)expectedMatrices[LOAD_LEFT_ALIGNMENT].data()); - LoadStoreMat(DIM_M, DIM_N, false, MEM_TYPE, DIM_K, DIM_K, 0, rStride, - a, false, src, - (BYTE *)expectedMatrices[LOAD_RIGHT_ALIGNMENT].data()); - - LoadStoreMat(DIM_M, DIM_N, true, MEM_TYPE, DIM_K, DIM_K, 0, ltStride, - 0, true, src, - (BYTE *)expectedMatrices[LOAD_LEFT_TRANSPOSE].data()); - LoadStoreMat(DIM_M, DIM_N, false, MEM_TYPE, DIM_K, DIM_K, 0, rtStride, - 0, true, src, - (BYTE *)expectedMatrices[LOAD_RIGHT_TRANSPOSE].data()); - - LoadStoreMat(DIM_M, DIM_N, true, MEM_TYPE, DIM_K, DIM_K, s, ltStrideP4, - a, true, src, - (BYTE *)expectedMatrices[LOAD_LEFT_ALLPARAMS].data()); - LoadStoreMat(DIM_M, DIM_N, false, MEM_TYPE, DIM_K, DIM_K, s, - rtStrideP4, a, true, src, - (BYTE *)expectedMatrices[LOAD_RIGHT_ALLPARAMS].data()); - - // Store - LoadStoreMat( - DIM_M, DIM_N, true, MEM_TYPE, DIM_K, DIM_K, 0, lStrideP4, 0, false, - src, (BYTE *)expectedMatrices[STORE_LEFT_STRIDE_P4].data(), true); - LoadStoreMat( - DIM_M, DIM_N, false, MEM_TYPE, DIM_K, DIM_K, 0, rStrideP4, 0, false, - src, (BYTE *)expectedMatrices[STORE_RIGHT_STRIDE_P4].data(), true); - - LoadStoreMat( - DIM_M, DIM_N, true, MEM_TYPE, DIM_K, DIM_K, 0, lStride * 2, 0, false, - src, (BYTE *)expectedMatrices[STORE_LEFT_STRIDE_X2].data(), true); - LoadStoreMat( - DIM_M, DIM_N, false, MEM_TYPE, DIM_K, DIM_K, 0, rStride * 2, 0, false, - src, (BYTE *)expectedMatrices[STORE_RIGHT_STRIDE_X2].data(), true); - - LoadStoreMat( - DIM_M, DIM_N, true, MEM_TYPE, DIM_K, DIM_K, 0, lStride, a, false, src, - (BYTE *)expectedMatrices[STORE_LEFT_ALIGNMENT].data(), true); - LoadStoreMat( - DIM_M, DIM_N, false, MEM_TYPE, DIM_K, DIM_K, 0, rStride, a, false, - src, (BYTE *)expectedMatrices[STORE_RIGHT_ALIGNMENT].data(), true); - - LoadStoreMat( - DIM_M, DIM_N, true, MEM_TYPE, DIM_K, DIM_K, 0, ltStride, 0, true, src, - (BYTE *)expectedMatrices[STORE_LEFT_TRANSPOSE].data(), true); - LoadStoreMat( - DIM_M, DIM_N, false, MEM_TYPE, DIM_K, DIM_K, 0, rtStride, 0, true, - src, (BYTE *)expectedMatrices[STORE_RIGHT_TRANSPOSE].data(), true); - - LoadStoreMat( - DIM_M, DIM_N, true, MEM_TYPE, DIM_K, DIM_K, 0, ltStrideP4, a, true, - src, (BYTE *)expectedMatrices[STORE_LEFT_ALLPARAMS].data(), true); - LoadStoreMat( - DIM_M, DIM_N, false, MEM_TYPE, DIM_K, DIM_K, 0, rtStrideP4, a, true, - src, (BYTE *)expectedMatrices[STORE_RIGHT_ALLPARAMS].data(), true); - - } else if (0 == _stricmp(Name, "g_bufInAccum") && doAccumTest) { - std::fill(Data.begin(), Data.end(), (BYTE)0); - GenerateMatrix((TYPE_ACC *)Data.data(), NUM_ELEMENTS * 2, - static_cast(0), - static_cast(NUM_ELEMENTS * 2)); - - BYTE *src = (BYTE *)Data.data(); - uint32_t a = 4; // alignment OR if MEM_TYPE is groupshared, acts as - // additional storage offset in elements. - size_t s = 16 * sizeof(TYPE_ACC); // start - uint32_t aStride = DIM_N * sizeof(TYPE_ACC); - uint32_t atStride = DIM_M * sizeof(TYPE_ACC); - uint32_t aStrideP4 = aStride + 4; - uint32_t atStrideP4 = atStride + 4; - uint32_t elemStride = sizeof(TYPE_ACC); - uint32_t elemStrideP4 = sizeof(TYPE_ACC) + 4; - - if (disableFragmentTests == 0) { - LoadStoreRowCol( - DIM_M, DIM_N, true, MEM_TYPE, s, 0, elemStride, src, - (BYTE *)expectedRowCols[LOAD_LEFT_START].data()); - LoadStoreRowCol( - DIM_M, DIM_N, false, MEM_TYPE, s, 0, elemStride, src, - (BYTE *)expectedRowCols[LOAD_RIGHT_START].data()); - - LoadStoreRowCol( - DIM_M, DIM_N, true, MEM_TYPE, 0, 0, elemStrideP4, src, - (BYTE *)expectedRowCols[LOAD_LEFT_STRIDE_P4].data()); - LoadStoreRowCol( - DIM_M, DIM_N, false, MEM_TYPE, 0, 0, elemStrideP4, src, - (BYTE *)expectedRowCols[LOAD_RIGHT_STRIDE_P4].data()); - - LoadStoreRowCol( - DIM_M, DIM_N, true, MEM_TYPE, 0, 0, elemStride * 2, src, - (BYTE *)expectedRowCols[LOAD_LEFT_STRIDE_X2].data()); - LoadStoreRowCol( - DIM_M, DIM_N, false, MEM_TYPE, 0, 0, elemStride * 2, src, - (BYTE *)expectedRowCols[LOAD_RIGHT_STRIDE_X2].data()); - - LoadStoreRowCol( - DIM_M, DIM_N, true, MEM_TYPE, 0, a, elemStride, src, - (BYTE *)expectedRowCols[LOAD_LEFT_ALIGNMENT].data()); - LoadStoreRowCol( - DIM_M, DIM_N, false, MEM_TYPE, 0, a, elemStride, src, - (BYTE *)expectedRowCols[LOAD_RIGHT_ALIGNMENT].data()); - - LoadStoreRowCol( - DIM_M, DIM_N, true, MEM_TYPE, 0, 0, elemStride, src, - (BYTE *)expectedRowCols[LOAD_LEFT_TRANSPOSE].data()); - LoadStoreRowCol( - DIM_M, DIM_N, false, MEM_TYPE, 0, 0, elemStride, src, - (BYTE *)expectedRowCols[LOAD_RIGHT_TRANSPOSE].data()); - - LoadStoreRowCol( - DIM_M, DIM_N, true, MEM_TYPE, s, a, elemStrideP4, src, - (BYTE *)expectedRowCols[LOAD_LEFT_ALLPARAMS].data()); - LoadStoreRowCol( - DIM_M, DIM_N, false, MEM_TYPE, s, a, elemStrideP4, src, - (BYTE *)expectedRowCols[LOAD_RIGHT_ALLPARAMS].data()); - - LoadStoreRowCol( - DIM_M, DIM_N, true, MEM_TYPE, 0, 0, elemStrideP4, src, - (BYTE *)expectedRowCols[STORE_LEFT_STRIDE_P4].data(), true); - LoadStoreRowCol( - DIM_M, DIM_N, false, MEM_TYPE, 0, 0, elemStrideP4, src, - (BYTE *)expectedRowCols[STORE_RIGHT_STRIDE_P4].data(), true); - - LoadStoreRowCol( - DIM_M, DIM_N, true, MEM_TYPE, 0, 0, elemStride * 2, src, - (BYTE *)expectedRowCols[STORE_LEFT_STRIDE_X2].data(), true); - LoadStoreRowCol( - DIM_M, DIM_N, false, MEM_TYPE, 0, 0, elemStride * 2, src, - (BYTE *)expectedRowCols[STORE_RIGHT_STRIDE_X2].data(), true); - - LoadStoreRowCol( - DIM_M, DIM_N, true, MEM_TYPE, 0, a, elemStride, src, - (BYTE *)expectedRowCols[STORE_LEFT_ALIGNMENT].data(), true); - LoadStoreRowCol( - DIM_M, DIM_N, false, MEM_TYPE, 0, a, elemStride, src, - (BYTE *)expectedRowCols[STORE_RIGHT_ALIGNMENT].data(), true); - - LoadStoreRowCol( - DIM_M, DIM_N, true, MEM_TYPE, 0, 0, elemStride, src, - (BYTE *)expectedRowCols[STORE_LEFT_TRANSPOSE].data(), true); - LoadStoreRowCol( - DIM_M, DIM_N, false, MEM_TYPE, 0, 0, elemStride, src, - (BYTE *)expectedRowCols[STORE_RIGHT_TRANSPOSE].data(), true); - - LoadStoreRowCol( - DIM_M, DIM_N, true, MEM_TYPE, 0, a, elemStrideP4, src, - (BYTE *)expectedRowCols[STORE_LEFT_ALLPARAMS].data(), true); - LoadStoreRowCol( - DIM_M, DIM_N, false, MEM_TYPE, 0, a, elemStrideP4, src, - (BYTE *)expectedRowCols[STORE_RIGHT_ALLPARAMS].data(), true); - } - - // Accumulator - LoadStoreMat( - DIM_M, DIM_N, true, MEM_TYPE, DIM_N, DIM_N, s, aStride, 0, false, src, - (BYTE *)expectedAccumulatorMatrices[LOAD_START].data()); - LoadStoreMat( - DIM_M, DIM_N, true, MEM_TYPE, DIM_N, DIM_N, 0, aStrideP4, 0, false, - src, (BYTE *)expectedAccumulatorMatrices[LOAD_STRIDE_P4].data()); - LoadStoreMat( - DIM_M, DIM_N, true, MEM_TYPE, DIM_N, DIM_N, 0, aStride * 2, 0, false, - src, (BYTE *)expectedAccumulatorMatrices[LOAD_STRIDE_X2].data()); - LoadStoreMat( - DIM_M, DIM_N, true, MEM_TYPE, DIM_N, DIM_N, 0, aStride, a, false, src, - (BYTE *)expectedAccumulatorMatrices[LOAD_ALIGNMENT].data()); - LoadStoreMat( - DIM_M, DIM_N, true, MEM_TYPE, DIM_N, DIM_N, 0, atStride, 0, true, src, - (BYTE *)expectedAccumulatorMatrices[LOAD_TRANSPOSE].data()); - LoadStoreMat( - DIM_M, DIM_N, true, MEM_TYPE, DIM_N, DIM_N, s, atStrideP4, a, true, - src, (BYTE *)expectedAccumulatorMatrices[LOAD_ALLPARAMS].data()); - - LoadStoreMat( - DIM_M, DIM_N, true, MEM_TYPE, DIM_N, DIM_N, 0, aStrideP4, 0, false, - src, (BYTE *)expectedAccumulatorMatrices[STORE_STRIDE_P4].data(), - true); - LoadStoreMat( - DIM_M, DIM_N, true, MEM_TYPE, DIM_N, DIM_N, 0, aStride * 2, 0, false, - src, (BYTE *)expectedAccumulatorMatrices[STORE_STRIDE_X2].data(), - true); - LoadStoreMat( - DIM_M, DIM_N, true, MEM_TYPE, DIM_N, DIM_N, 0, aStride, a, false, src, - (BYTE *)expectedAccumulatorMatrices[STORE_ALIGNMENT].data(), true); - LoadStoreMat( - DIM_M, DIM_N, true, MEM_TYPE, DIM_N, DIM_N, 0, atStride, 0, true, src, - (BYTE *)expectedAccumulatorMatrices[STORE_TRANSPOSE].data(), true); - LoadStoreMat( - DIM_M, DIM_N, true, MEM_TYPE, DIM_N, DIM_N, 0, atStrideP4, a, true, - src, (BYTE *)expectedAccumulatorMatrices[STORE_ALLPARAMS].data(), - true); - } else { - std::fill(Data.begin(), Data.end(), (BYTE)0); - } - - argsStream2.str(""); - argsStream2 << initialArgsString; - - if (MEM_TYPE == GROUPSHARED) { - argsStream2 << " -DGROUPSHARED=1"; - if (doAccumTest) { - argsStream2 << " -DMAX_NUM_ELEMENTS=" << 2 * DIM_M * DIM_N; - } else if (doLeftRightTest) { - argsStream2 << " -DMAX_NUM_ELEMENTS=" - << std::max(2 * DIM_M * DIM_K, 2 * DIM_K * DIM_N); - } - } - - if (doAccumTest) { - argsStream2 << " -DTEST_LOAD_STORE_ACCUMULATOR=1"; - } - - if (doLeftRightTest) { - argsStream2 << " -DTEST_LOAD_STORE_LR=1"; - } - - argsString = argsStream2.str(); - pShaderOp->Shaders.at(0).Arguments = argsString.c_str(); - }; - - std::shared_ptr test; - std::shared_ptr test2; - - if (MEM_TYPE == GROUPSHARED) { - doLeftRightTest = true; - doAccumTest = false; - test = RunShaderOpTestAfterParse(pDevice, support, "WaveMatrixOp", callback, - ShaderOpSet); - - doLeftRightTest = false; - doAccumTest = true; - test2 = RunShaderOpTestAfterParse(pDevice, support, "WaveMatrixOp", - callback, ShaderOpSet); - } else { - // Non groupshared can test both at once - test = RunShaderOpTestAfterParse(pDevice, support, "WaveMatrixOp", callback, - ShaderOpSet); - test2 = test; - } - - // Get read back data for wave matrix - MappedData matrixData; - test->Test->GetReadBackData("g_bufOut", &matrixData); - T *readBackMatrixData = (T *)matrixData.data(); - T *readBackMatrixData2 = - readBackMatrixData + expectedMatrices.size() * expectedMatrices[0].size(); - - // Verify matrix depth function output is equal to K - MappedData matrixDepthData; - test->Test->GetReadBackData("g_bufOutMatrixDepth", &matrixDepthData); - uint32_t *readBackMatrixDepthData = (uint32_t *)matrixDepthData.data(); - VerifyOutputWithExpectedValueUInt(readBackMatrixDepthData[0], DIM_K, 0); - VerifyOutputWithExpectedValueUInt(readBackMatrixDepthData[1], DIM_K, 0); - - WEX::TestExecution::DisableVerifyExceptions dve; - // For left/right wave matrix results - for (size_t i = 0; i < expectedMatrices.size(); ++i) { - auto &expectedMatrix = expectedMatrices[i]; - std::string comment = std::string("Matrix/") + loadStoreEnumStrs[i] + ":"; - WEX::Logging::Log::Comment(CA2W(comment.c_str())); - - VerifyArrayWithExpectedValue(readBackMatrixData, expectedMatrix.data(), - expectedMatrix.size(), Validation_type, - tolerance); - readBackMatrixData += expectedMatrix.size(); - - if (MEM_TYPE != GROUPSHARED) { - VerifyArrayWithExpectedValue(readBackMatrixData2, expectedMatrix.data(), - expectedMatrix.size(), Validation_type, - tolerance); - readBackMatrixData2 += expectedMatrix.size(); - } - } - - // Get read back data for rows / cols - MappedData rowColData; - test2->Test->GetReadBackData("g_bufOutRowCol", &rowColData); - TYPE_ACC *readBackRowColData = (TYPE_ACC *)rowColData.data(); - - // get read back data for accumulators - MappedData accumulatorData; - test2->Test->GetReadBackData("g_bufOutAccumulator", &accumulatorData); - TYPE_ACC *readBackAccumulatorData = (TYPE_ACC *)accumulatorData.data(); - - // For verifying that both waves produce the same output given the same input - TYPE_ACC *readBackRowColData2 = - readBackRowColData + expectedRowCols.size() * expectedRowCols[0].size(); - TYPE_ACC *readBackAccumulatorData2 = - readBackAccumulatorData + expectedAccumulatorMatrices.size() * - expectedAccumulatorMatrices[0].size(); - - // For LeftColAcc/RightRowAcc results - for (size_t i = 0; i < expectedRowCols.size(); ++i) { - auto &expectedRowCol = expectedRowCols[i]; - std::string comment = std::string("RowCol/") + loadStoreEnumStrs[i] + ":"; - WEX::Logging::Log::Comment(CA2W(comment.c_str())); - - VerifyArrayWithExpectedValue(readBackRowColData, expectedRowCol.data(), - expectedRowCol.size(), Validation_type, - tolerance); - readBackRowColData += expectedRowCol.size(); - - if (MEM_TYPE != GROUPSHARED) { - VerifyArrayWithExpectedValue(readBackRowColData2, expectedRowCol.data(), - expectedRowCol.size(), Validation_type, - tolerance); - readBackRowColData2 += expectedRowCol.size(); - } - } - - // For accumulator results - for (size_t i = 0; i < expectedAccumulatorMatrices.size(); ++i) { - auto &expectedAccumulatorMatrix = expectedAccumulatorMatrices[i]; - std::string comment = - std::string("Accumulator/") + loadStoreEnumStrs[i] + ":"; - WEX::Logging::Log::Comment(CA2W(comment.c_str())); - - VerifyArrayWithExpectedValue( - readBackAccumulatorData, expectedAccumulatorMatrix.data(), - expectedAccumulatorMatrix.size(), Validation_type, tolerance); - readBackAccumulatorData += expectedAccumulatorMatrix.size(); - - if (MEM_TYPE != GROUPSHARED) { - VerifyArrayWithExpectedValue( - readBackAccumulatorData2, expectedAccumulatorMatrix.data(), - expectedAccumulatorMatrix.size(), Validation_type, tolerance); - readBackAccumulatorData2 += expectedAccumulatorMatrix.size(); - } - } - - WEX::Logging::Log::EndGroup(CA2W(groupName.c_str())); -} - -template -void WaveMatrixMathTest(int DIM_M, int DIM_N, CComPtr pDevice, - std::shared_ptr ShaderOpSet, - dxc::DxcDllSupport &support, PCWSTR Validation_type, - double tolerance) { - using namespace WMMA; - using namespace DirectX::PackedVector; - DXASSERT_NOMSG(sizeof(T) == sizeof(T2)); - - std::string dataTypeInShader1 = TypeIdToHlsl(); - std::string dataTypeInShader2 = TypeIdToHlsl(); - std::string typeAcc = TypeIdToHlsl(); - - std::string groupName = "WMMAMath/M" + std::to_string(DIM_M) + "/N" + - std::to_string(DIM_N) + "/" + dataTypeInShader1 + - "/" + dataTypeInShader2 + "/Accum" + typeAcc; - WEX::Logging::Log::StartGroup(CA2W(groupName.c_str())); - - D3D12_FEATURE_DATA_WAVE_MMA waveMmaSupport = - checkWaveMMASupport(pDevice, dataTypeInShader1, DIM_M, DIM_N); - D3D12_FEATURE_DATA_WAVE_MMA waveMmaSupport2 = - checkWaveMMASupport(pDevice, dataTypeInShader2, DIM_M, DIM_N); - - bool accTypeSupported = - (typeid(TYPE_ACC) == typeid(HALF) && - (waveMmaSupport.AccumDataTypes & - D3D12_WAVE_MMA_ACCUM_DATATYPE_FLOAT16)) || - (typeid(TYPE_ACC) == typeid(float) && - (waveMmaSupport.AccumDataTypes & D3D12_WAVE_MMA_ACCUM_DATATYPE_FLOAT)) || - (typeid(TYPE_ACC) == typeid(int32_t) && - (waveMmaSupport.AccumDataTypes & D3D12_WAVE_MMA_ACCUM_DATATYPE_INT32)); - - if (!ArgContainsDataType(L"Wmma_Type") || - !ArgContainsDataType(L"Wmma_Type") || - !ArgContainsDataType(L"Wmma_AccumType") || !accTypeSupported || - !waveMmaSupport.Supported || !waveMmaSupport2.Supported) { - - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); - WEX::Logging::Log::EndGroup(CA2W(groupName.c_str())); - return; - } - - int disableFragmentTests = 0; - WEX::TestExecution::RuntimeParameters::TryGetValue( - L"Wmma_DisableFragmentTests", disableFragmentTests); - - int DIM_K = waveMmaSupport.K; // Get K dim - - // We store left/right matrices in the same array so we just assume a maximum - // size. This size applies to accumulators as well. - int matrixBufferStrideInElements = 64 * 64; - - // We store rows/cols in the same array so we just assume an easy maximum - // size. - constexpr int numRowColElements = 64; - - std::stringstream argsStream; - argsStream << "-DDIM_M=" << DIM_M << " -DDIM_N=" << DIM_N - << " -DDIM_K=" << DIM_K << " -DDATATYPE=" << dataTypeInShader1 - << " -DDATATYPE2=" << dataTypeInShader2 - << " -DELEMENTSIZE=" << sizeof(T) << " -enable-16bit-types" - << " -DTYPE_ACC=" << typeAcc << " -DFRAGS_ENABLE=" - << static_cast(disableFragmentTests == 0) - << " -DNUM_LANES=" << waveMmaSupport.RequiredWaveLaneCountMin - << " -DMATRIX_BUFFER_STRIDE_IN_ELEMENTS=" - << matrixBufferStrideInElements; - - std::string arguments = argsStream.str(); - - // Create zeroed arrays for expected results. - std::array, NUM_MATRIX_OPS> expectedMatrices; - for (std::vector &vec : expectedMatrices) { - vec.resize(matrixBufferStrideInElements); - std::fill(vec.begin(), vec.end(), (TYPE_ACC)0); - } - - std::array, NUM_ROWCOL_OPS> expectedRowCols; - for (std::vector &vec : expectedRowCols) { - vec.resize(numRowColElements); - std::fill(vec.begin(), vec.end(), (TYPE_ACC)0); - } - - std::vector leftMatrix(DIM_M * DIM_K); - std::vector rightMatrix(DIM_K * DIM_N); - std::vector accumulatorMatrix(DIM_M * DIM_N); - std::vector leftCol(DIM_M); - std::vector rightRow(DIM_N); - - // Sum needs higher tolerance - double sumTolerance = 0.08; - LPCWSTR sumValidationType = Validation_type; - - if (typeid(TYPE_ACC) == typeid(DirectX::PackedVector::HALF) || - typeid(T) == typeid(DirectX::PackedVector::HALF)) { - // Tolerance and sum tolerance is much higher for FP16 - tolerance = 0.08; - double startUlp = 5; // Default tolerance - double mulAddLossyOperations = DIM_K * 2; - sumTolerance = startUlp * mulAddLossyOperations; - sumValidationType = L"ulp"; - } - - // Generate input data - if (typeid(TYPE_ACC) == typeid(DirectX::PackedVector::HALF)) { - GenerateMatrix(accumulatorMatrix.data(), accumulatorMatrix.size(), - -1.0, 2.0); - accumulatorMatrix[0] = static_cast( - ConvertFloat32ToFloat16(std::numeric_limits::infinity())); - accumulatorMatrix[1] = static_cast( - ConvertFloat32ToFloat16(std::numeric_limits::quiet_NaN())); - accumulatorMatrix[2] = static_cast(ConvertFloat32ToFloat16(-0.0f)); - accumulatorMatrix[3] = static_cast( - ConvertFloat32ToFloat16(std::numeric_limits::denorm_min())); - } else if (typeid(TYPE_ACC) == typeid(float)) { - GenerateMatrix(accumulatorMatrix.data(), accumulatorMatrix.size(), - -(float)accumulatorMatrix.size() / 2, - (float)accumulatorMatrix.size() / 2); - accumulatorMatrix[0] = - static_cast(std::numeric_limits::infinity()); - accumulatorMatrix[1] = - static_cast(std::numeric_limits::quiet_NaN()); - accumulatorMatrix[2] = static_cast(-0.0f); - accumulatorMatrix[3] = - static_cast(std::numeric_limits::denorm_min()); - } else { - GenerateMatrix(accumulatorMatrix.data(), accumulatorMatrix.size(), - -(float)accumulatorMatrix.size() / 2, - (float)accumulatorMatrix.size() / 2); - } - - // if T is half, T2 will be half - if (typeid(T) == typeid(DirectX::PackedVector::HALF)) { - DXASSERT_NOMSG(typeid(T2) == typeid(DirectX::PackedVector::HALF)); - GenerateMatrix(leftMatrix.data(), leftMatrix.size(), -1.0, 2.0); - GenerateMatrix(rightMatrix.data(), rightMatrix.size(), 3.0, -1.0); - - leftMatrix[0] = rightMatrix[0] = static_cast( - ConvertFloat32ToFloat16(std::numeric_limits::infinity())); - leftMatrix[1] = rightMatrix[1] = static_cast( - ConvertFloat32ToFloat16(std::numeric_limits::quiet_NaN())); - leftMatrix[2] = rightMatrix[2] = - static_cast(ConvertFloat32ToFloat16(-0.0f)); - leftMatrix[3] = rightMatrix[3] = static_cast( - ConvertFloat32ToFloat16(std::numeric_limits::denorm_min())); - } else if (typeid(TYPE_ACC) == typeid(float)) { - GenerateMatrix(leftMatrix.data(), leftMatrix.size(), - -(float)leftMatrix.size() / 2, - (float)leftMatrix.size() / 2); - GenerateMatrix(rightMatrix.data(), rightMatrix.size(), - (float)rightMatrix.size() / 2, - -(float)rightMatrix.size() / 2); - - if (typeid(T) == typeid(float)) { - DXASSERT_NOMSG(typeid(T) == typeid(T2)); - leftMatrix[0] = rightMatrix[0] = - static_cast(std::numeric_limits::infinity()); - leftMatrix[1] = rightMatrix[1] = - static_cast(std::numeric_limits::quiet_NaN()); - leftMatrix[2] = rightMatrix[2] = static_cast(-0.0f); - leftMatrix[3] = rightMatrix[3] = - static_cast(std::numeric_limits::denorm_min()); - } - } else { - GenerateMatrix(leftMatrix.data(), leftMatrix.size(), - -(float)leftMatrix.size() / 2, - (float)leftMatrix.size() / 2); - GenerateMatrix(rightMatrix.data(), rightMatrix.size(), - (float)rightMatrix.size(), -(float)rightMatrix.size()); - } - - // Get row/col test data from accum matrix - memcpy(leftCol.data(), accumulatorMatrix.data(), - leftCol.size() * sizeof(leftCol[0])); - memcpy(rightRow.data(), accumulatorMatrix.data(), - rightRow.size() * sizeof(rightRow[0])); - - // Generate MULTIPLY_ACCUMULATE initial value - if (typeid(TYPE_ACC) == typeid(DirectX::PackedVector::HALF)) { - FillMatrix(expectedMatrices[MULTIPLY_ACCUMULATE].data(), - accumulatorMatrix.size(), - ConvertFloat32ToFloat16(42.0f)); - } else { - FillMatrix(expectedMatrices[MULTIPLY_ACCUMULATE].data(), - accumulatorMatrix.size(), (TYPE_ACC)42); - } - - // Generate ADD_MATRIX initial value - if (typeid(TYPE_ACC) == typeid(DirectX::PackedVector::HALF)) { - FillMatrix(expectedMatrices[ADD_MATRIX].data(), - accumulatorMatrix.size(), - ConvertFloat32ToFloat16(42.0f)); - } else { - FillMatrix(expectedMatrices[ADD_MATRIX].data(), - accumulatorMatrix.size(), (TYPE_ACC)42); - } - - // Generate expected outputs - MatrixAddMatrix(DIM_M, DIM_N, accumulatorMatrix.data(), - expectedMatrices[ADD_MATRIX].data()); - MatrixMultiplyByMatrix(DIM_M, DIM_N, DIM_K, - leftMatrix.data(), rightMatrix.data(), - expectedMatrices[MULTIPLY].data()); - MatrixMultiplyAndAddMatrix( - DIM_M, DIM_N, DIM_K, leftMatrix.data(), rightMatrix.data(), - expectedMatrices[MULTIPLY_ACCUMULATE].data()); - MatrixAddColumn(DIM_M, DIM_N, leftCol.data(), - expectedMatrices[BROADCAST_ADD_LEFT_COL].data()); - MatrixAddRow(DIM_M, DIM_N, rightRow.data(), - expectedMatrices[BROADCAST_ADD_RIGHT_ROW].data()); - - // Copy left col into expected output (Note that the array is zeroed out in - // the beginning) - memcpy(expectedRowCols[LEFT_COL_SUMACCUMULATE].data(), leftCol.data(), - leftCol.size() * sizeof(leftCol[0])); - - // Sum accumulate the left input matrix onto the left col - MatrixSumColumns(DIM_M, DIM_K, - expectedRowCols[LEFT_COL_SUMACCUMULATE].data(), - leftMatrix.data()); - - // copy right row into expected output - memcpy(expectedRowCols[RIGHT_ROW_SUMACCUMULATE].data(), rightRow.data(), - rightRow.size() * sizeof(rightRow[0])); - - // Sum accumulate the right input matrix onto the right row - MatrixSumRows(DIM_N, DIM_K, - expectedRowCols[RIGHT_ROW_SUMACCUMULATE].data(), - rightMatrix.data()); - - std::shared_ptr test = RunShaderOpTestAfterParse( - pDevice, support, "WaveMatrixOpMath", - // this callback is called when the test - // is creating the resource to run the test - [&](LPCSTR Name, std::vector &Data, st::ShaderOp *pShaderOp) { - if (0 == _stricmp(Name, "g_bufInMatrices")) { - - T *pInMatrices = (T *)Data.data(); - - // Copy left matrix to buffer - memcpy(pInMatrices, leftMatrix.data(), - leftMatrix.size() * sizeof(leftMatrix[0])); - - pInMatrices += matrixBufferStrideInElements; - - // Copy right matrix to buffer - memcpy(pInMatrices, rightMatrix.data(), - rightMatrix.size() * sizeof(rightMatrix[0])); - - pInMatrices += matrixBufferStrideInElements; - - // Copy accumulator matrix to buffer - memcpy(pInMatrices, accumulatorMatrix.data(), - accumulatorMatrix.size() * sizeof(accumulatorMatrix[0])); - } else if (0 == _stricmp(Name, "g_bufOutMatrices")) { - TYPE_ACC *outMatrices = (TYPE_ACC *)Data.data(); - memset(outMatrices, 0, - expectedMatrices.size() * expectedMatrices[0].size() * - sizeof(expectedMatrices[0][0])); - } else if (0 == _stricmp(Name, "g_bufOutRowCols")) { - TYPE_ACC *outRowCols = (TYPE_ACC *)Data.data(); - memset(outRowCols, 0, - expectedRowCols.size() * expectedRowCols[0].size() * - sizeof(expectedRowCols[0][0])); - } - - // update compilation arguments - pShaderOp->Shaders.at(0).Arguments = arguments.c_str(); - }, - ShaderOpSet); - - MappedData outMatrixData; - test->Test->GetReadBackData("g_bufOutMatrices", &outMatrixData); - TYPE_ACC *readBackMatrixData = (TYPE_ACC *)outMatrixData.data(); - - MappedData outRowColData; - test->Test->GetReadBackData("g_bufOutRowCols", &outRowColData); - TYPE_ACC *readBackRowColData = (TYPE_ACC *)outRowColData.data(); - - WEX::TestExecution::DisableVerifyExceptions dve; - - for (uint32_t i = 0u; i < 2; ++i) { - std::string comment = - std::string("Matrix/") + mathOpEnumStrs[MULTIPLY] + ":"; - WEX::Logging::Log::Comment(CA2W(comment.c_str())); - VerifyArrayWithExpectedValue( - readBackMatrixData + (matrixBufferStrideInElements * MULTIPLY), - expectedMatrices[MULTIPLY].data(), matrixBufferStrideInElements, - sumValidationType, sumTolerance); - - comment = - std::string("Matrix/") + mathOpEnumStrs[MULTIPLY_ACCUMULATE] + ":"; - WEX::Logging::Log::Comment(CA2W(comment.c_str())); - VerifyArrayWithExpectedValue( - readBackMatrixData + - (matrixBufferStrideInElements * MULTIPLY_ACCUMULATE), - expectedMatrices[MULTIPLY_ACCUMULATE].data(), - matrixBufferStrideInElements, sumValidationType, sumTolerance); - - comment = std::string("Matrix/") + mathOpEnumStrs[ADD_MATRIX] + ":"; - WEX::Logging::Log::Comment(CA2W(comment.c_str())); - VerifyArrayWithExpectedValue( - readBackMatrixData + (matrixBufferStrideInElements * ADD_MATRIX), - expectedMatrices[ADD_MATRIX].data(), matrixBufferStrideInElements, - Validation_type, tolerance); - - if (disableFragmentTests == 0) { - comment = - std::string("RowCol/") + mathOpEnumStrs[BROADCAST_ADD_LEFT_COL] + ":"; - WEX::Logging::Log::Comment(CA2W(comment.c_str())); - VerifyArrayWithExpectedValue( - readBackMatrixData + - (matrixBufferStrideInElements * BROADCAST_ADD_LEFT_COL), - expectedMatrices[BROADCAST_ADD_LEFT_COL].data(), - matrixBufferStrideInElements, Validation_type, tolerance); - - comment = std::string("RowCol/") + - mathOpEnumStrs[BROADCAST_ADD_RIGHT_ROW] + ":"; - WEX::Logging::Log::Comment(CA2W(comment.c_str())); - VerifyArrayWithExpectedValue( - readBackMatrixData + - (matrixBufferStrideInElements * BROADCAST_ADD_RIGHT_ROW), - expectedMatrices[BROADCAST_ADD_RIGHT_ROW].data(), - matrixBufferStrideInElements, Validation_type, tolerance); - - comment = - std::string("RowCol/") + mathOpEnumStrs[LEFT_COL_SUMACCUMULATE] + ":"; - WEX::Logging::Log::Comment(CA2W(comment.c_str())); - VerifyArrayWithExpectedValue( - readBackRowColData + (numRowColElements * LEFT_COL_SUMACCUMULATE), - expectedRowCols[LEFT_COL_SUMACCUMULATE].data(), numRowColElements, - sumValidationType, sumTolerance); - - comment = std::string("RowCol/") + - mathOpEnumStrs[RIGHT_ROW_SUMACCUMULATE] + ":"; - WEX::Logging::Log::Comment(CA2W(comment.c_str())); - VerifyArrayWithExpectedValue( - readBackRowColData + (numRowColElements * RIGHT_ROW_SUMACCUMULATE), - expectedRowCols[RIGHT_ROW_SUMACCUMULATE].data(), numRowColElements, - sumValidationType, sumTolerance); - } - - // For verifying that both waves produce the same output given the same - // input - readBackMatrixData += expectedMatrices.size() * expectedMatrices[0].size(); - readBackRowColData += expectedRowCols.size() * expectedRowCols[0].size(); - } - - WEX::Logging::Log::EndGroup(CA2W(groupName.c_str())); -} - -template -void WaveMatrixScalarTest(int DIM_M, int DIM_N, CComPtr pDevice, - std::shared_ptr ShaderOpSet, - dxc::DxcDllSupport &support, - std::string dataTypeInShader, PCWSTR Validation_type, - double tolerance, std::vector &floatScalars) { - using namespace DirectX::PackedVector; - using namespace WMMA; - - std::string typeAcc = TypeIdToHlsl(); - std::string groupName = "WMMAScalar/M" + std::to_string(DIM_M) + "/N" + - std::to_string(DIM_N) + "/AB" + dataTypeInShader + - "/Accum" + typeAcc; - WEX::Logging::Log::StartGroup(CA2W(groupName.c_str())); - - D3D12_FEATURE_DATA_WAVE_MMA waveMmaSupport = - checkWaveMMASupport(pDevice, dataTypeInShader, DIM_M, DIM_N); - - bool accTypeSupported = - (typeid(T) == typeid(HALF) && (waveMmaSupport.AccumDataTypes & - D3D12_WAVE_MMA_ACCUM_DATATYPE_FLOAT16)) || - (typeid(T) == typeid(float) && - (waveMmaSupport.AccumDataTypes & D3D12_WAVE_MMA_ACCUM_DATATYPE_FLOAT)) || - (typeid(T) == typeid(int32_t) && - (waveMmaSupport.AccumDataTypes & D3D12_WAVE_MMA_ACCUM_DATATYPE_INT32)); - - if (!ArgContainsDataType(L"Wmma_AccumType") || !accTypeSupported || - !waveMmaSupport.Supported) { - - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); - WEX::Logging::Log::EndGroup(CA2W(groupName.c_str())); - return; - } - - int disableFragmentTests = 0; - WEX::TestExecution::RuntimeParameters::TryGetValue( - L"Wmma_DisableFragmentTests", disableFragmentTests); - - // Convert scalars to template type (This is not used in half test). - std::vector scalars(floatScalars.size()); - - for (size_t i = 0; i < scalars.size(); ++i) { - if (typeid(T) == typeid(HALF)) { - scalars[i] = static_cast(ConvertFloat32ToFloat16(floatScalars[i])); - } else if (typeid(T) != typeid(float) && - static_cast(floatScalars[i]) == 0) { - // avoid 0 - scalars[i] = static_cast(std::signbit(floatScalars[i]) * 2 - 1); - } else { - scalars[i] = static_cast(floatScalars[i]); - } - } - - std::stringstream argsStream; - argsStream << "-DDIM_M=" << DIM_M << " -DDIM_N=" << DIM_N - << " -enable-16bit-types" - << " -DFRAGS_ENABLE=" - << static_cast(disableFragmentTests == 0) - << " -DDIM_K=" << waveMmaSupport.K - << " -DNUM_LANES=" << waveMmaSupport.RequiredWaveLaneCountMin - << " -DTYPE_ACC=" << typeAcc; - - std::string arguments = argsStream.str(); - - // We store left/right matrices in the same array so we just assume a - // maximum size. This size applies to accumulators as well. - uint32_t numElements = DIM_M * DIM_N; - std::vector> matrices(SCALAR_NUM_OUTPUTS * scalars.size(), - std::vector(numElements, (T)0)); - std::vector> leftCols(SCALAR_NUM_OUTPUTS * scalars.size(), - std::vector(DIM_M, (T)0)); - std::vector> rightRows(SCALAR_NUM_OUTPUTS * scalars.size(), - std::vector(DIM_N, (T)0)); - - // Generate inputs - for (size_t i = 0; i < scalars.size(); ++i) { - for (size_t j = 0; j < SCALAR_NUM_OUTPUTS; ++j) { - size_t curr = i * SCALAR_NUM_OUTPUTS + j; - size_t start = curr; - DXASSERT_NOMSG(start < numElements); - size_t end = numElements - start; - - if (typeid(T) == typeid(DirectX::PackedVector::HALF)) { - GenerateMatrix(matrices[curr].data(), numElements, (float)start, - (float)end); - GenerateMatrix(leftCols[curr].data(), DIM_M, (float)start, - (float)end); - GenerateMatrix(rightRows[curr].data(), DIM_N, (float)start, - (float)end); - } else { - GenerateMatrix(matrices[curr].data(), numElements, (T)start, (T)end); - GenerateMatrix(leftCols[curr].data(), DIM_M, (T)start, (T)end); - GenerateMatrix(rightRows[curr].data(), DIM_N, (T)start, (T)end); - } - } - } - - if (typeid(T) == typeid(HALF)) { - tolerance = 3; - Validation_type = L"ulp"; - matrices[0][0] = leftCols[0][0] = rightRows[0][0] = - ConvertFloat32ToFloat16(std::numeric_limits::infinity()); - matrices[1][0] = leftCols[1][0] = rightRows[1][0] = - ConvertFloat32ToFloat16(-std::numeric_limits::infinity()); - matrices[2][0] = leftCols[2][0] = rightRows[2][0] = - ConvertFloat32ToFloat16(std::numeric_limits::quiet_NaN()); - matrices[3][0] = leftCols[3][0] = rightRows[3][0] = - ConvertFloat32ToFloat16(-0.0f); - matrices[4][0] = leftCols[4][0] = rightRows[4][0] = - ConvertFloat32ToFloat16(std::numeric_limits::denorm_min()); - } else if (typeid(T) == typeid(float)) { - matrices[0][0] = leftCols[0][0] = rightRows[0][0] = - (T)std::numeric_limits::infinity(); - matrices[1][0] = leftCols[1][0] = rightRows[1][0] = - (T)-std::numeric_limits::infinity(); - matrices[2][0] = leftCols[2][0] = rightRows[2][0] = - (T)std::numeric_limits::quiet_NaN(); - matrices[3][0] = leftCols[3][0] = rightRows[3][0] = (T)-0.0f; - matrices[4][0] = leftCols[4][0] = rightRows[4][0] = - std::numeric_limits::denorm_min(); - } - - std::shared_ptr test = RunShaderOpTestAfterParse( - pDevice, support, "WaveMatrixOpScalar", - [&](LPCSTR Name, std::vector &Data, st::ShaderOp *pShaderOp) { - if (0 == _stricmp(Name, "g_bufInScalar")) { - T *bufferScalars = (T *)Data.data(); - - for (size_t i = 0; i < scalars.size(); ++i) { - bufferScalars[i] = scalars[i]; - } - } else if (0 == _stricmp(Name, "g_bufInAccumulator")) { - // Copy input values to buffer - size_t mtxSize = matrices[0].size() * sizeof(*matrices[0].data()); - for (size_t i = 0; i < matrices.size(); ++i) { - memcpy(Data.data() + mtxSize * i, matrices[i].data(), mtxSize); - } - - // Process CPU side input values in place into expected values - for (size_t i = 0; i < scalars.size(); ++i) { - MatrixMultiplyByScalar( - DIM_M, DIM_N, scalars[i], - matrices[i * SCALAR_NUM_OUTPUTS + SCALAR_MUL].data()); - MatrixDivideByScalar( - DIM_M, DIM_N, scalars[i], - matrices[i * SCALAR_NUM_OUTPUTS + SCALAR_DIV].data()); - MatrixAddScalar( - DIM_M, DIM_N, scalars[i], - matrices[i * SCALAR_NUM_OUTPUTS + SCALAR_ADD].data()); - MatrixSubtractScalar( - DIM_M, DIM_N, scalars[i], - matrices[i * SCALAR_NUM_OUTPUTS + SCALAR_SUB].data()); - FillMatrix(matrices[i * SCALAR_NUM_OUTPUTS + SCALAR_FILL].data(), - DIM_M * DIM_N, scalars[i]); - } - } else if (0 == _stricmp(Name, "g_bufInLeftColAcc")) { - // Copy input values to buffer - size_t lcSize = leftCols[0].size() * sizeof(*leftCols[0].data()); - for (size_t i = 0; i < leftCols.size(); ++i) { - memcpy(Data.data() + lcSize * i, leftCols[i].data(), lcSize); - } - - // Process CPU side input values in place into expected values - for (size_t i = 0; i < scalars.size(); ++i) { - VectorMultiplyByScalar( - DIM_M, scalars[i], - leftCols[i * SCALAR_NUM_OUTPUTS + SCALAR_MUL].data()); - VectorDivideByScalar( - DIM_M, scalars[i], - leftCols[i * SCALAR_NUM_OUTPUTS + SCALAR_DIV].data()); - VectorAddScalar( - DIM_M, scalars[i], - leftCols[i * SCALAR_NUM_OUTPUTS + SCALAR_ADD].data()); - VectorSubtractScalar( - DIM_M, scalars[i], - leftCols[i * SCALAR_NUM_OUTPUTS + SCALAR_SUB].data()); - FillMatrix(leftCols[i * SCALAR_NUM_OUTPUTS + SCALAR_FILL].data(), - DIM_M, scalars[i]); - } - } else if (0 == _stricmp(Name, "g_bufInRightRowAcc")) { - // Copy input values to buffer - size_t rrSize = rightRows[0].size() * sizeof(*rightRows[0].data()); - for (size_t i = 0; i < rightRows.size(); ++i) { - memcpy(Data.data() + rrSize * i, rightRows[i].data(), rrSize); - } - - // Process CPU side input values in place into expected values - for (size_t i = 0; i < scalars.size(); ++i) { - VectorMultiplyByScalar( - DIM_N, scalars[i], - rightRows[i * SCALAR_NUM_OUTPUTS + SCALAR_MUL].data()); - VectorDivideByScalar( - DIM_N, scalars[i], - rightRows[i * SCALAR_NUM_OUTPUTS + SCALAR_DIV].data()); - VectorAddScalar( - DIM_N, scalars[i], - rightRows[i * SCALAR_NUM_OUTPUTS + SCALAR_ADD].data()); - VectorSubtractScalar( - DIM_N, scalars[i], - rightRows[i * SCALAR_NUM_OUTPUTS + SCALAR_SUB].data()); - FillMatrix( - rightRows[i * SCALAR_NUM_OUTPUTS + SCALAR_FILL].data(), DIM_N, - scalars[i]); - } - } else { - std::fill(Data.begin(), Data.end(), (BYTE)0); - } - - // update compilation arguments - pShaderOp->Shaders.at(0).Arguments = arguments.c_str(); - }, - ShaderOpSet); - - MappedData matricesData; - test->Test->GetReadBackData("g_bufOutAccumulator", &matricesData); - T *readBackMatrixData = (T *)matricesData.data(); - - MappedData leftColData; - test->Test->GetReadBackData("g_bufOutLeftColAcc", &leftColData); - T *readBackLeftColData = (T *)leftColData.data(); - - MappedData rightRowData; - test->Test->GetReadBackData("g_bufOutRightRowAcc", &rightRowData); - T *readBackRightRowData = (T *)rightRowData.data(); - - // For verifying that both waves produce the same output - T *readBackRightRowData2 = - readBackRightRowData + rightRows.size() * rightRows[0].size(); - T *readBackLeftColData2 = - readBackLeftColData + leftCols.size() * leftCols[0].size(); - T *readBackMatrixData2 = - readBackMatrixData + matrices.size() * matrices[0].size(); - - WEX::TestExecution::DisableVerifyExceptions dve; - for (size_t i = 0; i < matrices.size(); ++i) { - auto &expectedMatrix = matrices[i]; - std::string comment = - std::string("Matrix/") + scalarEnumStrs[i % SCALAR_NUM_OUTPUTS] + ":"; - WEX::Logging::Log::Comment(CA2W(comment.c_str())); - - VerifyArrayWithExpectedValue(readBackMatrixData, expectedMatrix.data(), - expectedMatrix.size(), Validation_type, - tolerance); - - readBackMatrixData += expectedMatrix.size(); - - VerifyArrayWithExpectedValue(readBackMatrixData2, expectedMatrix.data(), - expectedMatrix.size(), Validation_type, - tolerance); - - readBackMatrixData2 += expectedMatrix.size(); - } - - if (disableFragmentTests == 0) { - for (size_t i = 0; i < leftCols.size(); ++i) { - auto &expectedLeftColAcc = leftCols[i]; - std::string comment = std::string("LeftCol/") + - scalarEnumStrs[i % SCALAR_NUM_OUTPUTS] + ":"; - WEX::Logging::Log::Comment(CA2W(comment.c_str())); - - VerifyArrayWithExpectedValue( - readBackLeftColData, expectedLeftColAcc.data(), - expectedLeftColAcc.size(), Validation_type, tolerance); - - readBackLeftColData += expectedLeftColAcc.size(); - - VerifyArrayWithExpectedValue( - readBackLeftColData2, expectedLeftColAcc.data(), - expectedLeftColAcc.size(), Validation_type, tolerance); - - readBackLeftColData2 += expectedLeftColAcc.size(); - } - - for (size_t i = 0; i < rightRows.size(); ++i) { - auto &expectedRightRowAcc = rightRows[i]; - std::string comment = std::string("RightRow/") + - scalarEnumStrs[i % SCALAR_NUM_OUTPUTS] + ":"; - WEX::Logging::Log::Comment(CA2W(comment.c_str())); - - VerifyArrayWithExpectedValue( - readBackRightRowData, expectedRightRowAcc.data(), - expectedRightRowAcc.size(), Validation_type, tolerance); - - readBackRightRowData += expectedRightRowAcc.size(); - - VerifyArrayWithExpectedValue( - readBackRightRowData2, expectedRightRowAcc.data(), - expectedRightRowAcc.size(), Validation_type, tolerance); - - readBackRightRowData2 += expectedRightRowAcc.size(); - } - } - - WEX::Logging::Log::EndGroup(CA2W(groupName.c_str())); -} - -// TAEF Params and their defaults: -// /p:Wmma_DisableFragmentTests=0 (Disable all row/col fragment tests) -// /p:Wmma_DisableLoadStoreTests=0 -// /p:Wmma_DisableScalarTests=0 -// /p:Wmma_DisableMathTests=0 -// /p:"Wmma_ForceK=16" (Override K with specified value) -// /p:"Wmma_ForceWaveSize=0" (defaults to WaveLaneCountMin using -// CheckFeatureSupport) /p:"Wmma_MemType=buffer,groupshared" -// /p:"Wmma_DimM=16,64" -// /p:"Wmma_DimN=16,64" -// /p:"Wmma_Type=float32_t,float16_t,uint8_t4_packed,int8_t4_packed" -// /p:"Wmma_AccumType=float32_t,float16_t,int32_t" -TEST_F(ExecutionTest, WaveMatrixLoadStoreTests) { - using namespace WMMA; - using namespace DirectX::PackedVector; - - std::vector dimMs; - std::vector dimNs; - std::shared_ptr ShaderOpSet; - - CComPtr pDevice = - WaveMatrixTestCommonSetup(dimMs, dimNs, ShaderOpSet); - - if (pDevice == nullptr) { - return; - } - - // Check if the tests are enabled - int disableLoadStoreTests = 0; - WEX::TestExecution::RuntimeParameters::TryGetValue( - L"Wmma_DisableLoadStoreTests", disableLoadStoreTests); - - if (disableLoadStoreTests == 1) { - LogCommentFmt(L"Wave matrix load store tests are disabled, skipping."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); - return; - } - - // Parse mem types - std::vector memTypes = {BUFFER, GROUPSHARED}; - std::wstring split; - WEX::Common::String memTypeList; - WEX::TestExecution::RuntimeParameters::TryGetValue(L"Wmma_MemType", - memTypeList); - if (!memTypeList.IsEmpty()) { - memTypeList.ToLower(); - memTypes.clear(); - wstringstream ss((const wchar_t *)memTypeList); - - while (std::getline(ss, split, L',')) { - if (split == L"buffer") { - memTypes.emplace_back(BUFFER); - } else if (split == L"groupshared") { - memTypes.emplace_back(GROUPSHARED); - } else { - throw std::exception("Incorrect args given for mem type"); - } - } - } - - // Run matrix load store tests for supported types - PCWSTR validationType = L"epsilon"; - double tolerance = 0; // 0 tolerance for load store - - for (int dimM : dimMs) { - for (int dimN : dimNs) { - for (int memType : memTypes) { - WaveMatrixLoadStoreTest(dimM, dimN, memType, pDevice, - ShaderOpSet, m_support, - validationType, tolerance); - WaveMatrixLoadStoreTest(dimM, dimN, memType, pDevice, - ShaderOpSet, m_support, - validationType, tolerance); - WaveMatrixLoadStoreTest(dimM, dimN, memType, pDevice, - ShaderOpSet, m_support, - validationType, tolerance); - WaveMatrixLoadStoreTest(dimM, dimN, memType, pDevice, - ShaderOpSet, m_support, - validationType, tolerance); - WaveMatrixLoadStoreTest(dimM, dimN, memType, pDevice, - ShaderOpSet, m_support, - validationType, tolerance); - } - } - } -} - -TEST_F(ExecutionTest, WaveMatrixScalarTests) { - using namespace WMMA; - using namespace DirectX::PackedVector; - - std::vector dimMs; - std::vector dimNs; - std::shared_ptr ShaderOpSet; - CComPtr pDevice = - WaveMatrixTestCommonSetup(dimMs, dimNs, ShaderOpSet); - - if (pDevice == nullptr) { - return; - } - - // Check if the tests are enabled - int disableScalarTests = 0; - WEX::TestExecution::RuntimeParameters::TryGetValue( - L"Wmma_DisableScalarTests", disableScalarTests); - - if (disableScalarTests == 1) { - LogCommentFmt(L"Wave matrix scalar tests are disabled, skipping."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); - return; - } - - // Run the matrix scalar tests for supported types - PCWSTR validationType = L"epsilon"; - double tolerance = 0.008; - std::vector scalars = {-100.0f, 20.0f, -50.0f, -0.0f, 0.0f, 42.0f}; - - for (uint32_t dimM : dimMs) { - for (uint32_t dimN : dimNs) { - std::string hlslType = "float32_t"; - WaveMatrixScalarTest(dimM, dimN, pDevice, ShaderOpSet, m_support, - hlslType, validationType, tolerance, scalars); - - // hlslType is used for the CheckFeatureSupport query. - // Only one of the two below scalar tests will run, depending on the - // accumulator precision returned by CheckFeatureSupport. - hlslType = "float16_t"; - WaveMatrixScalarTest(dimM, dimN, pDevice, ShaderOpSet, m_support, - hlslType, validationType, tolerance, scalars); - WaveMatrixScalarTest(dimM, dimN, pDevice, ShaderOpSet, m_support, - hlslType, validationType, tolerance, scalars); - - hlslType = "uint8_t4_packed"; - WaveMatrixScalarTest(dimM, dimN, pDevice, ShaderOpSet, m_support, - hlslType, validationType, tolerance, - scalars); - - hlslType = "int8_t4_packed"; - WaveMatrixScalarTest(dimM, dimN, pDevice, ShaderOpSet, m_support, - hlslType, validationType, tolerance, - scalars); - } - } -} - -TEST_F(ExecutionTest, WaveMatrixMathTests) { - using namespace WMMA; - using namespace DirectX::PackedVector; - - std::vector dimMs; - std::vector dimNs; - std::shared_ptr ShaderOpSet; - CComPtr pDevice = - WaveMatrixTestCommonSetup(dimMs, dimNs, ShaderOpSet); - - if (pDevice == nullptr) { - return; - } - - // Check if the tests are enabled - int disableMathTests = 0; - WEX::TestExecution::RuntimeParameters::TryGetValue( - L"Wmma_DisableMathTests", disableMathTests); - - if (disableMathTests == 1) { - LogCommentFmt(L"Wave matrix math tests are disabled, skipping."); - WEX::Logging::Log::Result(WEX::Logging::TestResults::Skipped); - return; - } - - // Run the matrix math tests for supported types - PCWSTR validationType = L"epsilon"; - double tolerance = 0.008; - - for (uint32_t dimM : dimMs) { - for (uint32_t dimN : dimNs) { - WaveMatrixMathTest(dimM, dimN, pDevice, ShaderOpSet, - m_support, validationType, - tolerance); - WaveMatrixMathTest(dimM, dimN, pDevice, ShaderOpSet, - m_support, validationType, - tolerance); - WaveMatrixMathTest(dimM, dimN, pDevice, ShaderOpSet, - m_support, validationType, - tolerance); - WaveMatrixMathTest(dimM, dimN, pDevice, - ShaderOpSet, m_support, - validationType, tolerance); - WaveMatrixMathTest(dimM, dimN, pDevice, - ShaderOpSet, m_support, - validationType, tolerance); - WaveMatrixMathTest(dimM, dimN, pDevice, - ShaderOpSet, m_support, - validationType, tolerance); - WaveMatrixMathTest(dimM, dimN, pDevice, - ShaderOpSet, m_support, - validationType, tolerance); - } - } -} - TEST_F(ExecutionTest, DotTest) { WEX::TestExecution::SetVerifyOutput verifySettings( WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures); diff --git a/tools/clang/unittests/HLSLExec/ShaderOpArith.xml b/tools/clang/unittests/HLSLExec/ShaderOpArith.xml index bf9e68b210..e768f205f1 100644 --- a/tools/clang/unittests/HLSLExec/ShaderOpArith.xml +++ b/tools/clang/unittests/HLSLExec/ShaderOpArith.xml @@ -1292,650 +1292,6 @@ - - RootFlags(0), UAV(u0), UAV(u1), UAV(u2), UAV(u3), UAV(u4), UAV(u5) - - - - - - - - - - - - - - - - (j); - } - } - GroupMemoryBarrierWithGroupSync(); - } - - - void FillDest(uint start, uint threadX) - { - GroupMemoryBarrierWithGroupSync(); - if (threadX == 0) - { - for (uint i = 0; i < MAX_NUM_ELEMENTS; ++i) - { - g_bufOut.Store(start + i * sizeof(DATATYPE), gsharedArr[i]); - // Also clear output so we don't write garbage if the whole buffer is not filled - gsharedArr[i] = 0; - } - } - GroupMemoryBarrierWithGroupSync(); - } - - #elif TEST_LOAD_STORE_ACCUMULATOR - groupshared TYPE_ACC gsharedArrAccumulator[MAX_NUM_ELEMENTS]; - - void ClearGShared(uint threadX) - { - GroupMemoryBarrierWithGroupSync(); - if (threadX == 0) - { - for (uint i = 0; i < MAX_NUM_ELEMENTS; ++i) - { - gsharedArrAccumulator[i] = (TYPE_ACC)0; - } - } - GroupMemoryBarrierWithGroupSync(); - } - - void FillSource(uint threadX) - { - GroupMemoryBarrierWithGroupSync(); - if (threadX == 0) - { - uint j = 0; - for (uint i = 0; i < MAX_NUM_ELEMENTS; ++i, j += sizeof(TYPE_ACC)) - { - gsharedArrAccumulator[i] = g_bufInAccum.Load(j); - } - } - GroupMemoryBarrierWithGroupSync(); - } - - void FillDest(uint start, uint threadX) - { - GroupMemoryBarrierWithGroupSync(); - if (threadX == 0) - { - for (uint i = 0; i < MAX_NUM_ELEMENTS; ++i) - { - g_bufOutAccumulator.Store(start + i * sizeof(TYPE_ACC), gsharedArrAccumulator[i]); - gsharedArrAccumulator[i] = 0; - } - } - GroupMemoryBarrierWithGroupSync(); - } - - void FillDestRowCol(uint start, uint threadX) - { - GroupMemoryBarrierWithGroupSync(); - if (threadX == 0) - { - for (uint i = 0; i < MAX_NUM_ELEMENTS; ++i) - { - g_bufOutRowCol.Store(start + i * sizeof(TYPE_ACC), gsharedArrAccumulator[i]); - } - } - ClearGShared(threadX); - } - #endif - - #define LOAD_SOURCE gsharedArr - #define LOAD_SOURCE_ACCUM gsharedArrAccumulator - #define STORE_DEST gsharedArr - #define STORE_DEST_ROWCOL gsharedArrAccumulator - #define STORE_DEST_ACCUM gsharedArrAccumulator - - // Start/Stride/Offset are all given in bytes, and converted to array elements in the macros. - - #define TEST_LOAD_LEFT(mat, k, start, stride, alignment, transp, dest, destOffset) \ - mat.Load(gsharedArr, (start)/sizeof(DATATYPE), (stride)/sizeof(DATATYPE), transp); \ - mat.Store(g_bufOut, destOffset, lStride, false); - - #define TEST_LOAD_RIGHT(mat, k, start, stride, alignment, transp, dest, destOffset) \ - mat.Load(gsharedArr, (start)/sizeof(DATATYPE), (stride)/sizeof(DATATYPE), transp); \ - mat.Store(g_bufOut, destOffset, rStride, false); - - #define TEST_LOAD_LEFT_COL(mat, k, start, stride, alignment, dest, destOffset) \ - mat.Load(gsharedArrAccumulator, (start)/sizeof(TYPE_ACC), (stride)/sizeof(TYPE_ACC)); \ - mat.Store(g_bufOutRowCol, destOffset, 1 * sizeof(TYPE_ACC)); - - #define TEST_LOAD_RIGHT_ROW(mat, k, start, stride, alignment, dest, destOffset) \ - mat.Load(gsharedArrAccumulator, (start)/sizeof(TYPE_ACC), (stride)/sizeof(TYPE_ACC)); \ - mat.Store(g_bufOutRowCol, destOffset, 1 * sizeof(TYPE_ACC)); - - #define TEST_LOAD_ACCUMULATOR(mata, k, start, stride, alignment, transp, dest, destOffset) \ - mata.Load(gsharedArrAccumulator, (start)/sizeof(TYPE_ACC), (stride)/sizeof(TYPE_ACC), transp); \ - mata.Store(g_bufOutAccumulator, destOffset, aStride, false); - - #define TEST_STORE_LEFT(matl, k, stride, offset, transp, dest, destOffset) \ - matl.Load(g_bufIn, 0, lStride, false); \ - matl.Store(gsharedArr, (offset)/sizeof(DATATYPE), (stride)/sizeof(DATATYPE), transp); \ - FillDest(destOffset, groupThreadID.x); - - #define TEST_STORE_RIGHT(matr, k, stride, offset, transp, dest, destOffset) \ - matr.Load(g_bufIn, 0, rStride, false); \ - matr.Store(gsharedArr, (offset)/sizeof(DATATYPE), (stride)/sizeof(DATATYPE), transp); \ - FillDest(destOffset, groupThreadID.x); - - #define TEST_STORE_LEFT_COL(mat, k, stride, offset, dest, destOffset) \ - mat.Load(g_bufInAccum, 0, 1 * sizeof(TYPE_ACC)); \ - mat.Store(gsharedArrAccumulator, (offset)/sizeof(TYPE_ACC), (stride)/sizeof(TYPE_ACC)); \ - FillDestRowCol(destOffset, groupThreadID.x); - - #define TEST_STORE_RIGHT_ROW(mat, k, stride, offset, dest, destOffset) \ - mat.Load(g_bufInAccum, 0, 1 * sizeof(TYPE_ACC)); \ - mat.Store(gsharedArrAccumulator, (offset)/sizeof(TYPE_ACC), (stride)/sizeof(TYPE_ACC)); \ - FillDestRowCol(destOffset, groupThreadID.x); - - #define TEST_STORE_ACCUMULATOR(mata, k, stride, offset, transp, dest, destOffset) \ - mata.Load(g_bufInAccum, 0, aStride, false); \ - mata.Store(gsharedArrAccumulator, (offset)/sizeof(TYPE_ACC), (stride)/sizeof(TYPE_ACC), transp); \ - FillDest(destOffset, groupThreadID.x); - - #else - #define LOAD_SOURCE g_bufIn - #define LOAD_SOURCE_ACCUM g_bufInAccum - #define STORE_DEST g_bufOut - #define STORE_DEST_ROWCOL g_bufOutRowCol - #define STORE_DEST_ACCUM g_bufOutAccumulator - - void FillSource(uint threadX) {} // no-op - void FillDest(uint start, uint threadX) {} - void FillDestRowCol(uint start, uint threadX) {} - void ClearGShared(uint threadX) {} - - #define TEST_LOAD_LEFT(mat, k, start, stride, alignment, transp, dest, destOffset) \ - mat.Load(LOAD_SOURCE, start, stride, transp, alignment); \ - mat.Store(dest, destOffset, lStride, false); - - #define TEST_LOAD_RIGHT(mat, k, start, stride, alignment, transp, dest, destOffset) \ - mat.Load(LOAD_SOURCE, start, stride, transp, alignment); \ - mat.Store(dest, destOffset, rStride, false); - - #define TEST_LOAD_LEFT_COL(mat, k, start, stride, alignment, dest, destOffset) \ - mat.Load(LOAD_SOURCE_ACCUM, start, stride, alignment); \ - mat.Store(dest, destOffset, (int)sizeof(TYPE_ACC)); - - #define TEST_LOAD_RIGHT_ROW(mat, k, start, stride, alignment, dest, destOffset) \ - mat.Load(LOAD_SOURCE_ACCUM, start, stride, alignment); \ - mat.Store(dest, destOffset, (int)sizeof(TYPE_ACC)); - - #define TEST_LOAD_ACCUMULATOR(mata, k, start, stride, alignment, transp, dest, destOffset) \ - mata.Load(LOAD_SOURCE_ACCUM, start, stride, transp, alignment); \ - mata.Store(dest, destOffset, aStride, false); - - #define TEST_STORE_LEFT(matl, k, stride, alignment, transp, dest, destOffset) \ - matl.Load(LOAD_SOURCE, 0, lStride, false); \ - matl.Store(dest, destOffset, stride, transp, alignment); - - #define TEST_STORE_RIGHT(matr, k, stride, alignment, transp, dest, destOffset) \ - matr.Load(LOAD_SOURCE, 0, rStride, false); \ - matr.Store(dest, destOffset, stride, transp, alignment); - - #define TEST_STORE_LEFT_COL(mat, k, stride, alignment, dest, destOffset) \ - mat.Load(LOAD_SOURCE_ACCUM, 0, (int)sizeof(TYPE_ACC)); \ - mat.Store(dest, destOffset, stride, alignment); - - #define TEST_STORE_RIGHT_ROW(mat, k, stride, alignment, dest, destOffset) \ - mat.Load(LOAD_SOURCE_ACCUM, 0, (int)sizeof(TYPE_ACC)); \ - mat.Store(dest, destOffset, stride, alignment); - - #define TEST_STORE_ACCUMULATOR(mata, k, stride, alignment, transp, dest, destOffset) \ - mata.Load(LOAD_SOURCE_ACCUM, 0, aStride, false); \ - mata.Store(dest, destOffset, stride, transp, alignment); - - #endif // GROUPSHARED if/else - - [WaveSize(NUM_LANES)] - #ifdef GROUPSHARED - [numthreads(NUM_LANES,1,1)] - #else - [numthreads(NUM_LANES * 2,1,1)] - #endif - void main(uint3 groupThreadID : SV_GroupThreadID) - { - uint rowColSize = 64 * 64 * sizeof(TYPE_ACC); - uint size = 2 * 64 * 64 * ELEMENTSIZE; - - // Calculate strides and offsets in bytes. - uint s = 16 * ELEMENTSIZE; // start - uint lStride = (DIM_K * ELEMENTSIZE); - uint rStride = (DIM_N * ELEMENTSIZE); - uint ltStride = (DIM_M * ELEMENTSIZE); - uint rtStride = (DIM_K * ELEMENTSIZE); - uint a = 4; // Alignment. For groupshared, tests store offset. - - // For accumulator - uint sizeAcc = 2 * 64 * 64 * sizeof(TYPE_ACC); - uint s2 = 16 * sizeof(TYPE_ACC); // start - uint aStride = (DIM_N * sizeof(TYPE_ACC)); - uint atStride = (DIM_M * sizeof(TYPE_ACC)); - uint accElemStride = sizeof(TYPE_ACC); - - uint groupOffset = (groupThreadID.x/NUM_LANES) * 22; - - uint LOAD_LEFT_START = 0 + groupOffset; - uint LOAD_RIGHT_START = 1 + groupOffset; - uint LOAD_LEFT_STRIDE_P4 = 2 + groupOffset; - uint LOAD_RIGHT_STRIDE_P4 = 3 + groupOffset; - uint LOAD_LEFT_STRIDE_X2 = 4 + groupOffset; - uint LOAD_RIGHT_STRIDE_X2 = 5 + groupOffset; - uint LOAD_LEFT_ALIGNMENT = 6 + groupOffset; - uint LOAD_RIGHT_ALIGNMENT = 7 + groupOffset; - uint LOAD_LEFT_TRANSPOSE = 8 + groupOffset; - uint LOAD_RIGHT_TRANSPOSE = 9 + groupOffset; - uint LOAD_LEFT_ALLPARAMS = 10 + groupOffset; - uint LOAD_RIGHT_ALLPARAMS = 11 + groupOffset; - uint STORE_LEFT_STRIDE_P4 = 12 + groupOffset; - uint STORE_RIGHT_STRIDE_P4 = 13 + groupOffset; - uint STORE_LEFT_STRIDE_X2 = 14 + groupOffset; - uint STORE_RIGHT_STRIDE_X2 = 15 + groupOffset; - uint STORE_LEFT_ALIGNMENT = 16 + groupOffset; - uint STORE_RIGHT_ALIGNMENT = 17 + groupOffset; - uint STORE_LEFT_TRANSPOSE = 18 + groupOffset; - uint STORE_RIGHT_TRANSPOSE = 19 + groupOffset; - uint STORE_LEFT_ALLPARAMS = 20 + groupOffset; - uint STORE_RIGHT_ALLPARAMS = 21 + groupOffset; - -#if TEST_LOAD_STORE_LR - WaveMatrixLeft matLeft; - WaveMatrixRight matRight; - - if (groupThreadID.x == 0) - { - g_bufOutMatrixDepth.Store(0, matLeft.MatrixDepth()); - g_bufOutMatrixDepth.Store(0 + sizeof(uint), matRight.MatrixDepth()); - } - - ClearGShared(groupThreadID.x); - FillSource(groupThreadID.x); - - ///////////////////////// - // Left/Right Matrices // - ///////////////////////// - TEST_LOAD_LEFT(matLeft, DIM_K, s, lStride , 0, false, STORE_DEST, LOAD_LEFT_START * size); - TEST_LOAD_RIGHT(matRight, DIM_K, s, rStride , 0, false, STORE_DEST, LOAD_RIGHT_START * size); - - TEST_LOAD_LEFT(matLeft, DIM_K, 0, lStride + 4, 0, false, STORE_DEST, LOAD_LEFT_STRIDE_P4 * size); - TEST_LOAD_RIGHT(matRight, DIM_K, 0, rStride + 4, 0, false, STORE_DEST, LOAD_RIGHT_STRIDE_P4 * size); - - TEST_LOAD_LEFT(matLeft, DIM_K, 0, lStride * 2, 0, false, STORE_DEST, LOAD_LEFT_STRIDE_X2 * size); - TEST_LOAD_RIGHT(matRight, DIM_K, 0, rStride * 2, 0, false, STORE_DEST, LOAD_RIGHT_STRIDE_X2 * size); - - TEST_LOAD_LEFT(matLeft, DIM_K, 0, lStride , a, false, STORE_DEST, LOAD_LEFT_ALIGNMENT * size); - TEST_LOAD_RIGHT(matRight, DIM_K, 0, rStride , a, false, STORE_DEST, LOAD_RIGHT_ALIGNMENT * size); - - TEST_LOAD_LEFT(matLeft, DIM_K, 0, ltStride , 0, true , STORE_DEST, LOAD_LEFT_TRANSPOSE * size); - TEST_LOAD_RIGHT(matRight, DIM_K, 0, rtStride , 0, true , STORE_DEST, LOAD_RIGHT_TRANSPOSE * size); - - TEST_LOAD_LEFT(matLeft, DIM_K, s, ltStride + 4, a, true , STORE_DEST, LOAD_LEFT_ALLPARAMS * size); - TEST_LOAD_RIGHT(matRight, DIM_K, s, rtStride + 4, a, true , STORE_DEST, LOAD_RIGHT_ALLPARAMS * size); - - ClearGShared(groupThreadID.x); - - TEST_STORE_LEFT(matLeft, DIM_K, lStride + 4, 0, false, STORE_DEST, STORE_LEFT_STRIDE_P4 * size); - TEST_STORE_RIGHT(matRight, DIM_K, rStride + 4, 0, false, STORE_DEST, STORE_RIGHT_STRIDE_P4 * size); - - TEST_STORE_LEFT(matLeft, DIM_K, lStride * 2, 0, false, STORE_DEST, STORE_LEFT_STRIDE_X2 * size); - TEST_STORE_RIGHT(matRight, DIM_K, rStride * 2, 0, false, STORE_DEST, STORE_RIGHT_STRIDE_X2 * size); - - TEST_STORE_LEFT(matLeft, DIM_K, lStride , a, false, STORE_DEST, STORE_LEFT_ALIGNMENT * size); - TEST_STORE_RIGHT(matRight, DIM_K, rStride , a, false, STORE_DEST, STORE_RIGHT_ALIGNMENT * size); - - TEST_STORE_LEFT(matLeft, DIM_K, ltStride , 0, true , STORE_DEST, STORE_LEFT_TRANSPOSE * size); - TEST_STORE_RIGHT(matRight, DIM_K, rtStride , 0, true , STORE_DEST, STORE_RIGHT_TRANSPOSE * size); - - TEST_STORE_LEFT(matLeft, DIM_K, ltStride + 4, a, true , STORE_DEST, STORE_LEFT_ALLPARAMS * size); - TEST_STORE_RIGHT(matRight, DIM_K, rtStride + 4, a, true , STORE_DEST, STORE_RIGHT_ALLPARAMS * size); - -#endif -#if TEST_LOAD_STORE_ACCUMULATOR - /////////////////////// - // Accumulator Types // - /////////////////////// - WaveMatrixLeftColAcc matLeftColAcc; - WaveMatrixRightRowAcc matRightRowAcc; - WaveMatrixAccumulator matAccum; - #if FRAGS_ENABLE - ClearGShared(groupThreadID.x); - FillSource(groupThreadID.x); - - TEST_LOAD_LEFT_COL(matLeftColAcc, 1, s2, accElemStride , 0, STORE_DEST_ROWCOL, LOAD_LEFT_START * rowColSize); - TEST_LOAD_RIGHT_ROW(matRightRowAcc, 1, s2, accElemStride , 0, STORE_DEST_ROWCOL, LOAD_RIGHT_START * rowColSize); - - TEST_LOAD_LEFT_COL(matLeftColAcc, 1, 0, accElemStride + 4, 0, STORE_DEST_ROWCOL, LOAD_LEFT_STRIDE_P4 * rowColSize); - TEST_LOAD_RIGHT_ROW(matRightRowAcc, 1, 0, accElemStride + 4, 0, STORE_DEST_ROWCOL, LOAD_RIGHT_STRIDE_P4 * rowColSize); - - TEST_LOAD_LEFT_COL(matLeftColAcc, 1, 0, accElemStride * 2, 0, STORE_DEST_ROWCOL, LOAD_LEFT_STRIDE_X2 * rowColSize); - TEST_LOAD_RIGHT_ROW(matRightRowAcc, 1, 0, accElemStride * 2, 0, STORE_DEST_ROWCOL, LOAD_RIGHT_STRIDE_X2 * rowColSize); - - TEST_LOAD_LEFT_COL(matLeftColAcc, 1, 0, accElemStride , a, STORE_DEST_ROWCOL, LOAD_LEFT_ALIGNMENT * rowColSize); - TEST_LOAD_RIGHT_ROW(matRightRowAcc, 1, 0, accElemStride , a, STORE_DEST_ROWCOL, LOAD_RIGHT_ALIGNMENT * rowColSize); - - TEST_LOAD_LEFT_COL(matLeftColAcc, 1, 0, accElemStride , 0, STORE_DEST_ROWCOL, LOAD_LEFT_TRANSPOSE * rowColSize); - TEST_LOAD_RIGHT_ROW(matRightRowAcc, 1, 0, accElemStride , 0, STORE_DEST_ROWCOL, LOAD_RIGHT_TRANSPOSE * rowColSize); - - TEST_LOAD_LEFT_COL(matLeftColAcc, 1, s2, accElemStride + 4, a, STORE_DEST_ROWCOL, LOAD_LEFT_ALLPARAMS * rowColSize); - TEST_LOAD_RIGHT_ROW(matRightRowAcc, 1, s2, accElemStride + 4, a, STORE_DEST_ROWCOL, LOAD_RIGHT_ALLPARAMS * rowColSize); - - ClearGShared(groupThreadID.x); - - TEST_STORE_LEFT_COL(matLeftColAcc, 1, accElemStride + 4, 0, STORE_DEST_ROWCOL, STORE_LEFT_STRIDE_P4 * rowColSize); - TEST_STORE_RIGHT_ROW(matRightRowAcc, 1, accElemStride + 4, 0, STORE_DEST_ROWCOL, STORE_RIGHT_STRIDE_P4 * rowColSize); - - TEST_STORE_LEFT_COL(matLeftColAcc, 1, accElemStride * 2, 0, STORE_DEST_ROWCOL, STORE_LEFT_STRIDE_X2 * rowColSize); - TEST_STORE_RIGHT_ROW(matRightRowAcc, 1, accElemStride * 2, 0, STORE_DEST_ROWCOL, STORE_RIGHT_STRIDE_X2 * rowColSize); - - TEST_STORE_LEFT_COL(matLeftColAcc, 1, accElemStride , a, STORE_DEST_ROWCOL, STORE_LEFT_ALIGNMENT * rowColSize); - TEST_STORE_RIGHT_ROW(matRightRowAcc, 1, accElemStride , a, STORE_DEST_ROWCOL, STORE_RIGHT_ALIGNMENT * rowColSize); - - TEST_STORE_LEFT_COL(matLeftColAcc, 1, accElemStride , 0, STORE_DEST_ROWCOL, STORE_LEFT_TRANSPOSE * rowColSize); - TEST_STORE_RIGHT_ROW(matRightRowAcc, 1, accElemStride , 0, STORE_DEST_ROWCOL, STORE_RIGHT_TRANSPOSE * rowColSize); - - TEST_STORE_LEFT_COL(matLeftColAcc, 1, accElemStride + 4, a, STORE_DEST_ROWCOL, STORE_LEFT_ALLPARAMS * rowColSize); - TEST_STORE_RIGHT_ROW(matRightRowAcc, 1, accElemStride + 4, a, STORE_DEST_ROWCOL, STORE_RIGHT_ALLPARAMS * rowColSize); - #endif // #if FRAGS_ENABLE - - groupOffset = (groupThreadID.x/NUM_LANES) * 11; - uint LOAD_START = 0 + groupOffset; - uint LOAD_STRIDE_P4 = 1 + groupOffset; - uint LOAD_STRIDE_X2 = 2 + groupOffset; - uint LOAD_ALIGNMENT = 3 + groupOffset; - uint LOAD_TRANSPOSE = 4 + groupOffset; - uint LOAD_ALLPARAMS = 5 + groupOffset; - uint STORE_STRIDE_P4 = 6 + groupOffset; - uint STORE_STRIDE_X2 = 7 + groupOffset; - uint STORE_ALIGNMENT = 8 + groupOffset; - uint STORE_TRANSPOSE = 9 + groupOffset; - uint STORE_ALLPARAMS = 10 + groupOffset; - - ClearGShared(groupThreadID.x); - FillSource(groupThreadID.x); - - TEST_LOAD_ACCUMULATOR (matAccum, DIM_K, s2, aStride , 0, false, g_bufOutAccumulator, LOAD_START * sizeAcc); - TEST_LOAD_ACCUMULATOR (matAccum, DIM_K, 0 , aStride + 4, 0, false, g_bufOutAccumulator, LOAD_STRIDE_P4 * sizeAcc); - TEST_LOAD_ACCUMULATOR (matAccum, DIM_K, 0 , aStride * 2, 0, false, g_bufOutAccumulator, LOAD_STRIDE_X2 * sizeAcc); - TEST_LOAD_ACCUMULATOR (matAccum, DIM_K, 0 , aStride , a, false, g_bufOutAccumulator, LOAD_ALIGNMENT * sizeAcc); - TEST_LOAD_ACCUMULATOR (matAccum, DIM_K, 0 , atStride , 0, true , g_bufOutAccumulator, LOAD_TRANSPOSE * sizeAcc); - TEST_LOAD_ACCUMULATOR (matAccum, DIM_K, s2, atStride + 4, a, true , g_bufOutAccumulator, LOAD_ALLPARAMS * sizeAcc); - - ClearGShared(groupThreadID.x); - - TEST_STORE_ACCUMULATOR(matAccum, DIM_K, aStride + 4, 0, false, STORE_DEST_ACCUM, STORE_STRIDE_P4 * sizeAcc); - TEST_STORE_ACCUMULATOR(matAccum, DIM_K, aStride * 2, 0, false, STORE_DEST_ACCUM, STORE_STRIDE_X2 * sizeAcc); - TEST_STORE_ACCUMULATOR(matAccum, DIM_K, aStride , a, false, STORE_DEST_ACCUM, STORE_ALIGNMENT * sizeAcc); - TEST_STORE_ACCUMULATOR(matAccum, DIM_K, atStride , 0, true , STORE_DEST_ACCUM, STORE_TRANSPOSE * sizeAcc); - TEST_STORE_ACCUMULATOR(matAccum, DIM_K, atStride + 4, a, true , STORE_DEST_ACCUM, STORE_ALLPARAMS * sizeAcc); -#endif // #if TEST_LOAD_STORE_ACCUMULATOR - }; - ]]> - - - - - RootFlags(0), UAV(u0), UAV(u1), UAV(u2), UAV(u3), UAV(u4), UAV(u5), UAV(u6) - - - - - - - - - - - - - - - - - - leftCol; - WaveMatrixRightRowAcc rightRow; - WaveMatrixAccumulator accumulator; - - TYPE_ACC scalar = g_bufInScalar.Load(groupID.x * sizeof(TYPE_ACC)); - - const uint lStride = (uint)(DIM_K * sizeof(TYPE_ACC)); - const uint rStride = (uint)(DIM_N * sizeof(TYPE_ACC)); - const uint aStride = (uint)(DIM_N * sizeof(TYPE_ACC)); - - /////////// - // Accumulator - /////////// - - accumulator.Load(g_bufInAccumulator, scalarMulOffset * NUM_ACCUMULATOR_ELEMENTS, aStride, false); - accumulator.ScalarMultiply(scalar); - accumulator.Store(g_bufOutAccumulator, outScalarMulOffset * NUM_ACCUMULATOR_ELEMENTS, aStride, false); - - accumulator.Load(g_bufInAccumulator, scalarDivOffset * NUM_ACCUMULATOR_ELEMENTS, aStride, false); - accumulator.ScalarDivide(scalar); - accumulator.Store(g_bufOutAccumulator, outScalarDivOffset * NUM_ACCUMULATOR_ELEMENTS, aStride, false); - - accumulator.Load(g_bufInAccumulator, scalarAddOffset * NUM_ACCUMULATOR_ELEMENTS, aStride, false); - accumulator.ScalarAdd(scalar); - accumulator.Store(g_bufOutAccumulator, outScalarAddOffset * NUM_ACCUMULATOR_ELEMENTS, aStride, false); - - accumulator.Load(g_bufInAccumulator, scalarSubOffset * NUM_ACCUMULATOR_ELEMENTS, aStride, false); - accumulator.ScalarSubtract(scalar); - accumulator.Store(g_bufOutAccumulator, outScalarSubOffset * NUM_ACCUMULATOR_ELEMENTS, aStride, false); - - accumulator.Load(g_bufInAccumulator, scalarFillOffset * NUM_ACCUMULATOR_ELEMENTS, aStride, false); - accumulator.Fill(scalar); - accumulator.Store(g_bufOutAccumulator, outScalarFillOffset * NUM_ACCUMULATOR_ELEMENTS, aStride, false); - -#if FRAGS_ENABLE - - /////////// - // Left Col - /////////// - - // We load and store the left col transposed (as a row) to save space - leftCol.Load (g_bufInLeftColAcc, scalarMulOffset * DIM_M, (int)sizeof(TYPE_ACC)); - leftCol.ScalarMultiply(scalar); - leftCol.Store(g_bufOutLeftColAcc, outScalarMulOffset * DIM_M, (int)sizeof(TYPE_ACC)); - - leftCol.Load (g_bufInLeftColAcc, scalarDivOffset * DIM_M, (int)sizeof(TYPE_ACC)); - leftCol.ScalarDivide(scalar); - leftCol.Store(g_bufOutLeftColAcc, outScalarDivOffset * DIM_M, (int)sizeof(TYPE_ACC)); - - leftCol.Load (g_bufInLeftColAcc, scalarAddOffset * DIM_M, (int)sizeof(TYPE_ACC)); - leftCol.ScalarAdd(scalar); - leftCol.Store(g_bufOutLeftColAcc, outScalarAddOffset * DIM_M, (int)sizeof(TYPE_ACC)); - - leftCol.Load (g_bufInLeftColAcc, scalarSubOffset * DIM_M, (int)sizeof(TYPE_ACC)); - leftCol.ScalarSubtract(scalar); - leftCol.Store(g_bufOutLeftColAcc, outScalarSubOffset * DIM_M, (int)sizeof(TYPE_ACC)); - - leftCol.Load (g_bufInLeftColAcc, scalarFillOffset * DIM_M, (int)sizeof(TYPE_ACC)); - leftCol.Fill(scalar); - leftCol.Store(g_bufOutLeftColAcc, outScalarFillOffset * DIM_M, (int)sizeof(TYPE_ACC)); - - /////////// - // Right Row - /////////// - - rightRow.Load (g_bufInRightRowAcc, scalarMulOffset * DIM_N, (int)sizeof(TYPE_ACC)); - rightRow.ScalarMultiply(scalar); - rightRow.Store(g_bufOutRightRowAcc, outScalarMulOffset * DIM_N, (int)sizeof(TYPE_ACC)); - - rightRow.Load (g_bufInRightRowAcc, scalarDivOffset * DIM_N, (int)sizeof(TYPE_ACC)); - rightRow.ScalarDivide(scalar); - rightRow.Store(g_bufOutRightRowAcc, outScalarDivOffset * DIM_N, (int)sizeof(TYPE_ACC)); - - rightRow.Load (g_bufInRightRowAcc, scalarAddOffset * DIM_N, (int)sizeof(TYPE_ACC)); - rightRow.ScalarAdd(scalar); - rightRow.Store(g_bufOutRightRowAcc, outScalarAddOffset * DIM_N, (int)sizeof(TYPE_ACC)); - - rightRow.Load (g_bufInRightRowAcc, scalarSubOffset * DIM_N, (int)sizeof(TYPE_ACC)); - rightRow.ScalarSubtract(scalar); - rightRow.Store(g_bufOutRightRowAcc, outScalarSubOffset * DIM_N, (int)sizeof(TYPE_ACC)); - - rightRow.Load (g_bufInRightRowAcc, scalarFillOffset * DIM_N, (int)sizeof(TYPE_ACC)); - rightRow.Fill(scalar); - rightRow.Store(g_bufOutRightRowAcc, outScalarFillOffset * DIM_N, (int)sizeof(TYPE_ACC)); -#endif // #if FRAGS_ENABLE - }; - ]]> - - - - - RootFlags(0), UAV(u0), UAV(u1), UAV(u2) - - - - - - - - - - leftMatrix; - WaveMatrixRight rightMatrix; - WaveMatrixLeftColAcc leftCol; - WaveMatrixRightRowAcc rightRow; - WaveMatrixAccumulator accumulator; - WaveMatrixAccumulator outAccumulator; - - const uint lStride = (uint)(DIM_K * ELEMENTSIZE); - const uint rStride = (uint)(DIM_N * ELEMENTSIZE); - const uint aStride = (uint)(DIM_N * sizeof(TYPE_ACC)); - - leftMatrix.Load(g_bufInMatrices, 0, lStride, false); - rightMatrix.Load(g_bufInMatrices, MATRIX_BUFFER_STRIDE_IN_ELEMENTS * ELEMENTSIZE, rStride, false); - accumulator.Load(g_bufInMatrices, 2 * MATRIX_BUFFER_STRIDE_IN_ELEMENTS * ELEMENTSIZE, aStride, false); - - outAccumulator.Multiply(leftMatrix, rightMatrix); - outAccumulator.Store(g_bufOutMatrices, outMulMatrix * MATRIX_BUFFER_STRIDE_IN_ELEMENTS * sizeof(TYPE_ACC), aStride, false); - - outAccumulator.Fill(42); - outAccumulator.MultiplyAccumulate(leftMatrix, rightMatrix); - outAccumulator.Store(g_bufOutMatrices, outMulAccumulateMatrix * MATRIX_BUFFER_STRIDE_IN_ELEMENTS * sizeof(TYPE_ACC), aStride, false); - - outAccumulator.Fill(42); - outAccumulator.Add(accumulator); - outAccumulator.Store(g_bufOutMatrices, outAddMatrix * MATRIX_BUFFER_STRIDE_IN_ELEMENTS * sizeof(TYPE_ACC), aStride, false); -#if FRAGS_ENABLE - leftCol.Load(g_bufInMatrices, 2 * MATRIX_BUFFER_STRIDE_IN_ELEMENTS * ELEMENTSIZE, (int)sizeof(TYPE_ACC)); - rightRow.Load(g_bufInMatrices, 2 * MATRIX_BUFFER_STRIDE_IN_ELEMENTS * ELEMENTSIZE, (int)sizeof(TYPE_ACC)); - - outAccumulator.Fill(0); - outAccumulator.Add(leftCol); - outAccumulator.Store(g_bufOutMatrices, outBroadcastAddColMatrix * MATRIX_BUFFER_STRIDE_IN_ELEMENTS * sizeof(TYPE_ACC), aStride, false); - - outAccumulator.Fill(0); - outAccumulator.Add(rightRow); - outAccumulator.Store(g_bufOutMatrices, outBroadcastAddRowMatrix * MATRIX_BUFFER_STRIDE_IN_ELEMENTS * sizeof(TYPE_ACC), aStride, false); - - leftCol.SumAccumulate(leftMatrix); - rightRow.SumAccumulate(rightMatrix); - leftCol.Store(g_bufOutRowCols, outRowColOffset, (int)sizeof(TYPE_ACC)); - rightRow.Store(g_bufOutRowCols, outRowColOffset + 64 * sizeof(TYPE_ACC), (int)sizeof(TYPE_ACC)); -#endif //#if FRAGS_ENABLE - }; - ]]> - - - RootFlags(0), UAV(u0) diff --git a/utils/hct/gen_intrin_main.txt b/utils/hct/gen_intrin_main.txt index ad48a06977..40b5f6d96a 100644 --- a/utils/hct/gen_intrin_main.txt +++ b/utils/hct/gen_intrin_main.txt @@ -1089,125 +1089,6 @@ uint [[ro]] CommittedInstanceContributionToHitGroupIndex(); } namespace -namespace WaveMatrixLeftMethods { - -void [[amo]] Fill(in $match<1, -1> numeric value); -void [[]] Load(in ByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor); -void [[]] Load(in ByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor, in uint alignment); -void [[]] Load(in RWByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor); -void [[]] Load(in RWByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor, in uint alignment); -void [[]] Store(in ByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor); -void [[]] Store(in ByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor, in uint alignment); -void [[]] Store(in RWByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor); -void [[]] Store(in RWByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor, in uint alignment); - -// input/output must be an array in groupshared memory -void [[amo]] Load(in groupshared $match<1, -1> numeric[] input, in uint start, in uint stride, in bool bColMajor); -void [[amo]] Store(out groupshared $match<1, -1> numeric[] output, in uint start, in uint stride, in bool bColMajor); - -uint [[rn]] MatrixDepth(); - -} namespace - -namespace WaveMatrixRightMethods { - -void [[amo]] Fill(in $match<1, -1> numeric value); -void [[]] Load(in ByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor); -void [[]] Load(in ByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor, in uint alignment); -void [[]] Load(in RWByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor); -void [[]] Load(in RWByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor, in uint alignment); -void [[]] Store(in ByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor); -void [[]] Store(in ByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor, in uint alignment); -void [[]] Store(in RWByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor); -void [[]] Store(in RWByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor, in uint alignment); - -// input/output must be an array in groupshared memory -void [[amo]] Load(in groupshared $match<1, -1> numeric[] input, in uint start, in uint stride, in bool bColMajor); -void [[amo]] Store(out groupshared $match<1, -1> numeric[] output, in uint start, in uint stride, in bool bColMajor); - -uint [[rn]] MatrixDepth(); - -} namespace - -namespace WaveMatrixLeftColAccMethods { - -void [[amo]] Fill(in $match<1, -1> numeric value); -void [[]] Load(in ByteAddressBuffer buffer, in uint start, in uint elementStride); -void [[]] Load(in ByteAddressBuffer buffer, in uint start, in uint elementStride, in uint alignment); -void [[]] Load(in RWByteAddressBuffer buffer, in uint start, in uint elementStride); -void [[]] Load(in RWByteAddressBuffer buffer, in uint start, in uint elementStride, in uint alignment); -void [[]] Store(in ByteAddressBuffer buffer, in uint start, in uint elementStride); -void [[]] Store(in ByteAddressBuffer buffer, in uint start, in uint elementStride, in uint alignment); -void [[]] Store(in RWByteAddressBuffer buffer, in uint start, in uint elementStride); -void [[]] Store(in RWByteAddressBuffer buffer, in uint start, in uint elementStride, in uint alignment); - -// input/output must be an array in groupshared memory -void [[amo]] Load(in groupshared $match<1, -1> numeric[] input, in uint start, in uint elementStride); -void [[amo]] Store(out groupshared $match<1, -1> numeric[] output, in uint start, in uint elementStride); - -void [[amo]] ScalarMultiply(in $match<1, -1> numeric value); -void [[amo]] ScalarDivide(in $match<1, -1> numeric value); -void [[amo]] ScalarAdd(in $match<1, -1> numeric value); -void [[amo]] ScalarSubtract(in $match<1, -1> numeric value); - -void [[amo]] SumAccumulate(in WaveMatrixLeft mat); - -} namespace - -namespace WaveMatrixRightRowAccMethods { - -void [[amo]] Fill(in $match<1, -1> numeric value); -void [[]] Load(in ByteAddressBuffer buffer, in uint start, in uint elementStride); -void [[]] Load(in ByteAddressBuffer buffer, in uint start, in uint elementStride, in uint alignment); -void [[]] Load(in RWByteAddressBuffer buffer, in uint start, in uint elementStride); -void [[]] Load(in RWByteAddressBuffer buffer, in uint start, in uint elementStride, in uint alignment); -void [[]] Store(in ByteAddressBuffer buffer, in uint start, in uint elementStride); -void [[]] Store(in ByteAddressBuffer buffer, in uint start, in uint elementStride, in uint alignment); -void [[]] Store(in RWByteAddressBuffer buffer, in uint start, in uint elementStride); -void [[]] Store(in RWByteAddressBuffer buffer, in uint start, in uint elementStride, in uint alignment); - -// input/output must be an array in groupshared memory -void [[amo]] Load(in groupshared $match<1, -1> numeric[] input, in uint start, in uint elementStride); -void [[amo]] Store(out groupshared $match<1, -1> numeric[] output, in uint start, in uint elementStride); - -void [[amo]] ScalarMultiply(in $match<1, -1> numeric value); -void [[amo]] ScalarDivide(in $match<1, -1> numeric value); -void [[amo]] ScalarAdd(in $match<1, -1> numeric value); -void [[amo]] ScalarSubtract(in $match<1, -1> numeric value); - -void [[amo]] SumAccumulate(in WaveMatrixRight mat); - -} namespace - -namespace WaveMatrixAccumulatorMethods { - -void [[amo]] Fill(in $match<1, -1> numeric value); -void [[]] Load(in ByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor); -void [[]] Load(in ByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor, in uint alignment); -void [[]] Load(in RWByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor); -void [[]] Load(in RWByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor, in uint alignment); -void [[]] Store(in ByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor); -void [[]] Store(in ByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor, in uint alignment); -void [[]] Store(in RWByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor); -void [[]] Store(in RWByteAddressBuffer buffer, in uint start, in uint stride, in bool bColMajor, in uint alignment); - -// input/output must be an array in groupshared memory -void [[amo]] Load(in groupshared $match<1, -1> numeric[] input, in uint start, in uint stride, in bool bColMajor); -void [[amo]] Store(out groupshared $match<1, -1> numeric[] output, in uint start, in uint stride, in bool bColMajor); - -void [[amo]] ScalarMultiply(in $match<1, -1> numeric value); -void [[amo]] ScalarDivide(in $match<1, -1> numeric value); -void [[amo]] ScalarAdd(in $match<1, -1> numeric value); -void [[amo]] ScalarSubtract(in $match<1, -1> numeric value); - -void [[amo]] Multiply(in WaveMatrixLeft matA, in WaveMatrixRight matB); -void [[amo]] MultiplyAccumulate(in WaveMatrixLeft matA, in WaveMatrixRight matB); -void [[amo]] Add(in WaveMatrixLeftColAcc broadcastedMatrix); -void [[amo]] Add(in WaveMatrixRightRowAcc broadcastedMatrix); -void [[amo]] Add(in WaveMatrixAccumulator fullMatrix); - -} namespace - // Work Graphs objects and methods // EmptyNodeInput diff --git a/utils/hct/hctdb.py b/utils/hct/hctdb.py index 8d285d6e8b..35343fe31a 100644 --- a/utils/hct/hctdb.py +++ b/utils/hct/hctdb.py @@ -638,18 +638,6 @@ def populate_categories_and_models(self): for i in "IsHelperLane".split(","): self.name_idx[i].category = "Helper Lanes" self.name_idx[i].shader_model = 6, 6 - for i in ( - "WaveMatrix_Annotate,WaveMatrix_Depth,WaveMatrix_Fill," - + "WaveMatrix_LoadRawBuf,WaveMatrix_LoadGroupShared,WaveMatrix_StoreRawBuf,WaveMatrix_StoreGroupShared," - + "WaveMatrix_Multiply,WaveMatrix_MultiplyAccumulate,WaveMatrix_ScalarOp," - + "WaveMatrix_SumAccumulate,WaveMatrix_Add" - ).split(","): - self.name_idx[i].category = "WaveMatrix" - self.name_idx[i].shader_model = 6, 9 - self.name_idx[i].shader_stages = ( - "library", - "compute", - ) for i in "QuadVote,TextureGatherRaw,SampleCmpLevel,TextureStoreSample".split( "," ): @@ -4990,280 +4978,125 @@ def UFI(name, **mappings): % next_op_idx ) - # WaveMatrix ops + # Reserved ops self.add_dxil_op( - "WaveMatrix_Annotate", + "Reserved0", next_op_idx, - "WaveMatrix_Annotate", - "Annotate a wave matrix pointer with the type information", + "Reserved", + "Reserved", "v", - "amo", - [ - db_dxil_param(0, "v", "", ""), - db_dxil_param(2, "waveMat", "waveMatrixPtr", "WaveMatrix pointer"), - db_dxil_param( - 3, - "waveMatProps", - "waveMatProps", - "constant WaveMatrix type info", - is_const=True, - ), - ], + "", + [retvoid_param], ) next_op_idx += 1 - self.add_dxil_op( - "WaveMatrix_Depth", + "Reserved1", next_op_idx, - "WaveMatrix_Depth", - "Returns depth (K) value for matrix of specified type", + "Reserved", + "Reserved", "v", - "rn", - [ - db_dxil_param(0, "i32", "", "depth (k) value"), - db_dxil_param( - 2, "waveMatProps", "waveMatProps", "constant WaveMatrix type info" - ), - ], + "", + [retvoid_param], ) next_op_idx += 1 - self.add_dxil_op( - "WaveMatrix_Fill", + "Reserved2", next_op_idx, - "WaveMatrix_Fill", - "Fill wave matrix with scalar value", - "hfi", - "amo", - [ - db_dxil_param(0, "v", "", ""), - db_dxil_param(2, "waveMat", "waveMatrixPtr", "WaveMatrix pointer"), - db_dxil_param(3, "$o", "value", "scalar value to fill matrix with"), - ], + "Reserved", + "Reserved", + "v", + "", + [retvoid_param], ) next_op_idx += 1 - self.add_dxil_op( - "WaveMatrix_LoadRawBuf", + "Reserved3", next_op_idx, - "WaveMatrix_LoadRawBuf", - "Load wave matrix from raw buffer", + "Reserved", + "Reserved", "v", "", - [ - db_dxil_param(0, "v", "", ""), - db_dxil_param(2, "waveMat", "waveMatrixPtr", "WaveMatrix pointer"), - db_dxil_param(3, "res", "rawBuf", "handle of raw buffer"), - db_dxil_param(4, "i32", "offsetInBytes", "offset in bytes"), - db_dxil_param(5, "i32", "strideInBytes", "stride in bytes"), - db_dxil_param( - 6, "i8", "alignmentInBytes", "alignment in bytes", is_const=True - ), - db_dxil_param( - 7, "i1", "colMajor", "memory is col-major", is_const=True - ), - ], + [retvoid_param], ) next_op_idx += 1 - self.add_dxil_op( - "WaveMatrix_LoadGroupShared", + "Reserved4", next_op_idx, - "WaveMatrix_LoadGroupShared", - "Load wave matrix from group shared array", - "hfi", - "amo", - [ - db_dxil_param(0, "v", "", ""), - db_dxil_param(2, "waveMat", "waveMatrixPtr", "WaveMatrix pointer"), - db_dxil_param( - 3, "$gsptr", "groupsharedPtr", "pointer to groupshared array" - ), - db_dxil_param(4, "i32", "startArrayIndex", "start array index"), - db_dxil_param(5, "i32", "strideInElements", "stride in elements"), - db_dxil_param( - 6, "i1", "colMajor", "memory is col-major", is_const=True - ), - ], + "Reserved", + "Reserved", + "v", + "", + [retvoid_param], ) next_op_idx += 1 - self.add_dxil_op( - "WaveMatrix_StoreRawBuf", + "Reserved5", next_op_idx, - "WaveMatrix_StoreRawBuf", - "Store wave matrix to raw buffer", + "Reserved", + "Reserved", "v", "", - [ - db_dxil_param(0, "v", "", ""), - db_dxil_param(2, "waveMat", "waveMatrixPtr", "WaveMatrix pointer"), - db_dxil_param(3, "res", "rawBuf", "handle of raw buffer"), - db_dxil_param(4, "i32", "offsetInBytes", "offset in bytes"), - db_dxil_param(5, "i32", "strideInBytes", "stride in bytes"), - db_dxil_param( - 6, "i8", "alignmentInBytes", "alignment in bytes", is_const=True - ), - db_dxil_param( - 7, "i1", "colMajor", "memory is col-major", is_const=True - ), - ], + [retvoid_param], ) next_op_idx += 1 - self.add_dxil_op( - "WaveMatrix_StoreGroupShared", + "Reserved6", next_op_idx, - "WaveMatrix_StoreGroupShared", - "Store wave matrix to group shared array", - "hfi", - "amo", - [ - db_dxil_param(0, "v", "", ""), - db_dxil_param(2, "waveMat", "waveMatrixPtr", "WaveMatrix pointer"), - db_dxil_param( - 3, "$gsptr", "groupsharedPtr", "pointer to groupshared array" - ), - db_dxil_param(4, "i32", "startArrayIndex", "start array index"), - db_dxil_param(5, "i32", "strideInElements", "stride in elements"), - db_dxil_param( - 6, "i1", "colMajor", "memory is col-major", is_const=True - ), - ], + "Reserved", + "Reserved", + "v", + "", + [retvoid_param], ) next_op_idx += 1 - self.add_dxil_op( - "WaveMatrix_Multiply", + "Reserved7", next_op_idx, - "WaveMatrix_Multiply", - "Mutiply left and right wave matrix and store in accumulator", + "Reserved", + "Reserved", "v", - "amo", - [ - db_dxil_param(0, "v", "", ""), - db_dxil_param( - 2, - "waveMat", - "waveMatrixAccumulator", - "pointer to WaveMatrixAccumulator", - ), - db_dxil_param( - 3, "waveMat", "waveMatrixLeft", "pointer to WaveMatrixLeft" - ), - db_dxil_param( - 4, "waveMat", "waveMatrixRight", "pointer to WaveMatrixRight" - ), - ], + "", + [retvoid_param], ) next_op_idx += 1 - self.add_dxil_op( - "WaveMatrix_MultiplyAccumulate", + "Reserved8", next_op_idx, - "WaveMatrix_Multiply", - "Mutiply left and right wave matrix and accumulate into accumulator", + "Reserved", + "Reserved", "v", - "amo", - [ - db_dxil_param(0, "v", "", ""), - db_dxil_param( - 2, - "waveMat", - "waveMatrixAccumulator", - "pointer to WaveMatrixAccumulator", - ), - db_dxil_param( - 3, "waveMat", "waveMatrixLeft", "pointer to WaveMatrixLeft" - ), - db_dxil_param( - 4, "waveMat", "waveMatrixRight", "pointer to WaveMatrixRight" - ), - ], + "", + [retvoid_param], ) next_op_idx += 1 - self.add_dxil_op( - "WaveMatrix_ScalarOp", + "Reserved9", next_op_idx, - "WaveMatrix_ScalarOp", - "Perform scalar operation on each element of wave matrix", - "hfi", - "amo", - [ - db_dxil_param(0, "v", "", ""), - db_dxil_param(2, "waveMat", "waveMatrixPtr", "WaveMatrix pointer"), - db_dxil_param( - 3, - "i8", - "op", - "operation", - enum_name="WaveMatrixScalarOpCode", - is_const=True, - ), - db_dxil_param(4, "$o", "value", "scalar value"), - ], + "Reserved", + "Reserved", + "v", + "", + [retvoid_param], ) next_op_idx += 1 - self.add_enum_type( - "WaveMatrixScalarOpCode", - "Operation for WaveMatrix_ScalarOp", - [ - (0, "Add", ""), - (1, "Subtract", ""), - (2, "Multiply", ""), - (3, "Divide", ""), - (4, "Invalid", ""), - ], - ) - self.add_dxil_op( - "WaveMatrix_SumAccumulate", + "Reserved10", next_op_idx, - "WaveMatrix_Accumulate", - "Sum rows or columns of an input matrix into an existing accumulator fragment matrix", + "Reserved", + "Reserved", "v", - "amo", - [ - db_dxil_param(0, "v", "", ""), - db_dxil_param( - 2, - "waveMat", - "waveMatrixFragment", - "pointer to WaveMatrixLeftColAcc or WaveMatrixRightRowAcc", - ), - db_dxil_param( - 3, - "waveMat", - "waveMatrixInput", - "pointer to WaveMatrixLeft or WaveMatrixRight", - ), - ], + "", + [retvoid_param], ) next_op_idx += 1 - self.add_dxil_op( - "WaveMatrix_Add", + "Reserved11", next_op_idx, - "WaveMatrix_Accumulate", - "Element-wise accumulate, or broadcast add of fragment into accumulator", + "Reserved", + "Reserved", "v", - "amo", - [ - db_dxil_param(0, "v", "", ""), - db_dxil_param( - 2, - "waveMat", - "waveMatrixAccumulator", - "pointer to WaveMatrixAccumulator", - ), - db_dxil_param( - 3, - "waveMat", - "waveMatrixAccumulatorOrFragment", - "pointer to Accumulator or WaveMatrixLeftColAcc or WaveMatrixRightRowAcc", - ), - ], + "", + [retvoid_param], ) next_op_idx += 1 @@ -6287,12 +6120,6 @@ def add_pass(name, type_name, doc, opts): add_pass( "resource-handle", "ResourceToHandle", "Lower resource into handle", [] ) - add_pass( - "hlsl-lower-wavematrix-type", - "LowerWaveMatType", - "Lower WaveMatrix types to dxil type", - [], - ) add_pass( "hlsl-passes-nopause", "NoPausePasses", @@ -8436,11 +8263,6 @@ def __init__(self, intrinsic_defs): "any_sampler": "LICOMPTYPE_ANY_SAMPLER", "ByteAddressBuffer": "LICOMPTYPE_BYTEADDRESSBUFFER", "RWByteAddressBuffer": "LICOMPTYPE_RWBYTEADDRESSBUFFER", - "WaveMatrixLeft": "LICOMPTYPE_WAVE_MATRIX_LEFT", - "WaveMatrixRight": "LICOMPTYPE_WAVE_MATRIX_RIGHT", - "WaveMatrixLeftColAcc": "LICOMPTYPE_WAVE_MATRIX_LEFT_COL_ACC", - "WaveMatrixRightRowAcc": "LICOMPTYPE_WAVE_MATRIX_RIGHT_ROW_ACC", - "WaveMatrixAccumulator": "LICOMPTYPE_WAVE_MATRIX_ACCUMULATOR", "NodeRecordOrUAV": "LICOMPTYPE_NODE_RECORD_OR_UAV", "AnyNodeOutputRecord": "LICOMPTYPE_ANY_NODE_OUTPUT_RECORD", "GroupNodeOutputRecords": "LICOMPTYPE_GROUP_NODE_OUTPUT_RECORDS", @@ -8497,7 +8319,7 @@ def load_intrinsics(self, intrinsic_defs): r"""( sampler\w* | string | (?:RW)?(?:Texture\w*|ByteAddressBuffer) | - WaveMatrix\w* | acceleration_struct | ray_desc | + acceleration_struct | ray_desc | Node\w* | RWNode\w* | EmptyNode\w* | AnyNodeOutput\w* | NodeOutputRecord\w* | GroupShared\w* $)""",