From 98759134841177c1e37864a1b84f1993a3e746c2 Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Tue, 7 Jan 2025 15:12:37 +0000 Subject: [PATCH] Skip over uninitialized DenseResourceAttrs in verifiers Elided elements are uninitialized DenseResourceAttrs, without these changes MLIR containing them can not be parsed, as the verifiers crash when encountering them. --- src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp | 4 ++ src/Dialect/ONNX/ONNXOps/OpHelper.cpp | 12 ++++++ src/Dialect/ONNX/ONNXOps/OpHelper.hpp | 8 ++++ .../ONNX/ONNXOps/Sequence/SplitToSequence.cpp | 4 ++ .../ONNX/ONNXOps/Tensor/ConstantOfShape.cpp | 4 ++ .../ONNX/ONNXOps/Tensor/GatherElements.cpp | 11 +++++- src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp | 4 ++ test/mlir/onnx/invalid.mlir | 37 +++++++++++++++++++ 8 files changed, 82 insertions(+), 2 deletions(-) diff --git a/src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp b/src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp index 189d855805..701e03721d 100644 --- a/src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp +++ b/src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp @@ -76,6 +76,10 @@ LogicalResult ONNXScatterElementsOp::verify() { if (dataDimAtAxis >= 0) { if (ElementsAttr valueAttribute = getElementAttributeFromONNXValue(indices)) { + if (isElementAttrUninitializedDenseResource(valueAttribute)) { + return success(); // Return success to allow the parsing of MLIR with + // elided attributes + } for (IntegerAttr value : valueAttribute.getValues()) { int64_t index = value.getInt(); if (index >= -dataDimAtAxis && index < dataDimAtAxis) diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp index 23e6da5490..9a73e569b5 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Path.h" @@ -896,4 +897,15 @@ std::string getNodeNameInPresenceOfOpt(Operation *op, bool useFileLine) { return "NOTSET"; } +//===----------------------------------------------------------------------===// +// Support for DenseElementsAttr. +//===----------------------------------------------------------------------===// + +bool isElementAttrUninitializedDenseResource(mlir::ElementsAttr elementsAttr) { + const auto denseResourceElementsAttr = + mlir::dyn_cast(elementsAttr); + return denseResourceElementsAttr && + !denseResourceElementsAttr.getRawHandle().getBlob(); +} + } // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp index e5ff545128..e5c146b329 100644 --- a/src/Dialect/ONNX/ONNXOps/OpHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/OpHelper.hpp @@ -363,6 +363,14 @@ bool isIdentityReshape(mlir::Value input, mlir::Value output, std::string getNodeNameInPresenceOfOpt( mlir::Operation *op, bool useFileLine = true); +//===----------------------------------------------------------------------===// +// Support for DenseElementsAttr. +//===----------------------------------------------------------------------===// + +/// Returns true if elementsAttr is a DenseResourceAttr with a blob that can not +/// be received +bool isElementAttrUninitializedDenseResource(mlir::ElementsAttr elementsAttr); + #include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc" } // namespace onnx_mlir diff --git a/src/Dialect/ONNX/ONNXOps/Sequence/SplitToSequence.cpp b/src/Dialect/ONNX/ONNXOps/Sequence/SplitToSequence.cpp index 3a17990e56..38f922f765 100644 --- a/src/Dialect/ONNX/ONNXOps/Sequence/SplitToSequence.cpp +++ b/src/Dialect/ONNX/ONNXOps/Sequence/SplitToSequence.cpp @@ -58,6 +58,10 @@ LogicalResult ONNXSplitToSequenceOp::verify() { if (splitRank > 1) return emitOpError() << ": split has rank " << splitRank << " > 1"; if (ElementsAttr entries = getElementAttributeFromONNXValue(splitValue)) { + if (isElementAttrUninitializedDenseResource(entries)) { + return success(); // Return success to allow the parsing of MLIR with + // elided attributes + } if (splitRank == 0) { auto scalar = getScalarValue(entries, splitType); if (scalar <= 0) diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp index 787fc9b75e..6058adfcdb 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp @@ -70,6 +70,10 @@ LogicalResult ONNXConstantOfShapeOp::verify() { if (auto constantOp = getONNXConstantOp(input)) { ElementsAttr valueAttribute = mlir::cast(constantOp.getValueAttr()); + if (isElementAttrUninitializedDenseResource(valueAttribute)) { + return success(); // Return success to allow the parsing of MLIR with + // elided attributes + } // Get repeat values from valueAttribute. auto valueIt = valueAttribute.getValues().begin(); for (int i = 0; i < inputShape[0]; ++i) { diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/GatherElements.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/GatherElements.cpp index ce35ad81b3..dde8029994 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/GatherElements.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/GatherElements.cpp @@ -71,8 +71,13 @@ LogicalResult ONNXGatherElementsOp::verify() { // along axis of size s. ArrayRef dataShape = dataType.getShape(); const int64_t dataDimAtAxis = dataShape[axis]; - if (dataDimAtAxis >= 0) - if (ElementsAttr valueAttribute = getElementAttributeFromONNXValue(indices)) + if (dataDimAtAxis >= 0) { + if (ElementsAttr valueAttribute = + getElementAttributeFromONNXValue(indices)) { + if (isElementAttrUninitializedDenseResource(valueAttribute)) { + return success(); // Return success to allow the parsing of MLIR with + // elided attributes + } for (IntegerAttr value : valueAttribute.getValues()) { int64_t index = value.getInt(); if (index >= -dataDimAtAxis && index < dataDimAtAxis) @@ -83,6 +88,8 @@ LogicalResult ONNXGatherElementsOp::verify() { onnx_mlir::Diagnostic::Range( -dataDimAtAxis, dataDimAtAxis - 1)); } + } + } return success(); } diff --git a/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp b/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp index f5cf329cd0..b388607c12 100644 --- a/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp +++ b/src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp @@ -144,6 +144,10 @@ LogicalResult ONNXGatherNDOp::verify() { // All values in 'indices' are expected to satisfy the inequality: // -data.shape[b + i] <= indices[...,i] <= (data.shape[b + i]-1)]. if (ElementsAttr valueAttribute = getElementAttributeFromONNXValue(indices)) { + if (isElementAttrUninitializedDenseResource(valueAttribute)) { + return success(); // Return success to allow the parsing of MLIR with + // elided attributes + } int flatIndex = 0; for (IntegerAttr value : valueAttribute.getValues()) { int64_t indexValue = value.getInt(); diff --git a/test/mlir/onnx/invalid.mlir b/test/mlir/onnx/invalid.mlir index 06dd882b6d..3fa25e883b 100644 --- a/test/mlir/onnx/invalid.mlir +++ b/test/mlir/onnx/invalid.mlir @@ -182,6 +182,15 @@ func.func @test_constantofshape_verifier_4() -> tensor<2xi64> { // ----- +func.func @test_constantofshape_elided() -> tensor<2xi64> { + // Tests that we do not crash on elided elements + %0 = onnx.Constant dense_resource<__elided__> : tensor<2xi64> + %1 = "onnx.ConstantOfShape"(%0) : (tensor<2xi64>) -> tensor<2xi64> + "onnx.Return"(%1) : (tensor<2xi64>) -> () +} + +// ----- + func.func @test_flatten_verifier_1(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> { // expected-error @+1 {{onnx.Flatten: 'axis' value is 5, accepted range is [-4, 4]}} %1 = "onnx.Flatten"(%arg0) {axis = 5 : si64} : (tensor<5x5x1x32xf32>) -> tensor<*xf32> @@ -214,6 +223,15 @@ func.func @test_gatherElements_verifier_2(%data: tensor<2x2xf32>, %indices: tens // ----- +func.func @test_gatherElements_verifier_elided(%data: tensor<12x14x1024xf32>) -> tensor<12x14x14xf32> { + // Tests that we do not crash on elided elements + %indices = onnx.Constant dense_resource<__elided__> : tensor<12x14x14xi64> + %1 = "onnx.GatherElements"(%data, %indices) {axis = -1 : si64} : (tensor<12x14x1024xf32>, tensor<12x14x14xi64>) -> tensor<12x14x14xf32> + "onnx.Return"(%1) : (tensor<12x14x14xf32>) -> () +} + +// ----- + func.func @test_hardmax_verifier_1(%arg0: tensor<2x2xf32>) -> tensor<*xf32> { // expected-error @+1 {{onnx.Hardmax: 'axis' value is 3, accepted range is [-2, 1]}} %1 = "onnx.Hardmax"(%arg0) {axis = 3: si64} : (tensor<2x2xf32>) -> tensor<*xf32> @@ -307,6 +325,16 @@ func.func @test_gatherND_verifier_6(%arg0 : tensor<3x4x4x4xf32>) -> tensor<*xf32 // expected-error @+2 {{onnx.GatherND: 'indices[0]' value is 3, accepted range is [-3, 2]}} %indices = "onnx.Constant"() {value = dense<[3,2,2]> : tensor<3xi64>} : () -> tensor<3x3x2xi64> %1 = "onnx.GatherND"(%arg0, %indices) : (tensor<3x4x4x4xf32>, tensor<3x3x2xi64>) -> tensor<*xf32> + "onnx.Return"(%1) : (tensor<*xf32>) -> () +} + +// ----- + +func.func @test_gatherND_verifier_elided(%arg0 : tensor<3x4x4x4xf32>) -> tensor<*xf32> { + // Test that we do not crash on elided elements + %indices = onnx.Constant dense_resource<__elided__> : tensor<3x3x2xi64> + %1 = "onnx.GatherND"(%arg0, %indices) : (tensor<3x4x4x4xf32>, tensor<3x3x2xi64>) -> tensor<*xf32> + "onnx.Return"(%1) : (tensor<*xf32>) -> () } // ----- @@ -580,6 +608,15 @@ func.func @test_splitToSequence_verifier_6(%arg0: tensor<2x2xf32>) -> !onnx.Seq< // ----- +func.func @test_splitToSequence_verifier_elided(%arg0: tensor<2x2xf32>) -> !onnx.Seq> { + // Tests that we do not crash on elided elements + %0 = onnx.Constant dense_resource<__elided__> : tensor + %1 = "onnx.SplitToSequence"(%arg0, %0) : (tensor<2x2xf32>, tensor) -> !onnx.Seq> + "onnx.Return"(%1) : (!onnx.Seq>) -> () +} + +// ----- + func.func @test_topK_verifier_1(%arg0: tensor<3x4xi64>, %arg1: tensor<1xi64>) -> (tensor<*xf32>, tensor<*xi64>) { // expected-error @+1 {{onnx.TopK: 'axis' value is 2, accepted range is [-2, 1]}} %1, %2 = "onnx.TopK"(%arg0, %arg1) {axis = 2 : si64, largest = 1 : si64, sorted = 1 : si64} : (tensor<3x4xi64>, tensor<1xi64>) -> (tensor<*xf32>, tensor<*xi64>)