Skip to content

Commit

Permalink
Merge pull request #257 from Xilinx/jrickert.verifier.elided
Browse files Browse the repository at this point in the history
Skip over uninitialized DenseResourceAttrs in verifiers
  • Loading branch information
jorickert authored Jan 7, 2025
2 parents dd59f30 + 9875913 commit 03c16db
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Math/Scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ LogicalResult ONNXScatterElementsOp::verify() {
if (dataDimAtAxis >= 0) {
if (ElementsAttr valueAttribute =
getElementAttributeFromONNXValue(indices)) {
if (isElementAttrUninitializedDenseResource(valueAttribute)) {
return success(); // Return success to allow the parsing of MLIR with
// elided attributes
}
for (IntegerAttr value : valueAttribute.getValues<IntegerAttr>()) {
int64_t index = value.getInt();
if (index >= -dataDimAtAxis && index < dataDimAtAxis)
Expand Down
12 changes: 12 additions & 0 deletions src/Dialect/ONNX/ONNXOps/OpHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectResourceBlobManager.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Path.h"
Expand Down Expand Up @@ -896,4 +897,15 @@ std::string getNodeNameInPresenceOfOpt(Operation *op, bool useFileLine) {
return "NOTSET";
}

//===----------------------------------------------------------------------===//
// Support for DenseElementsAttr.
//===----------------------------------------------------------------------===//

bool isElementAttrUninitializedDenseResource(mlir::ElementsAttr elementsAttr) {
const auto denseResourceElementsAttr =
mlir::dyn_cast<DenseResourceElementsAttr>(elementsAttr);
return denseResourceElementsAttr &&
!denseResourceElementsAttr.getRawHandle().getBlob();
}

} // namespace onnx_mlir
8 changes: 8 additions & 0 deletions src/Dialect/ONNX/ONNXOps/OpHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,14 @@ bool isIdentityReshape(mlir::Value input, mlir::Value output,
std::string getNodeNameInPresenceOfOpt(
mlir::Operation *op, bool useFileLine = true);

//===----------------------------------------------------------------------===//
// Support for DenseElementsAttr.
//===----------------------------------------------------------------------===//

/// Returns true if elementsAttr is a DenseResourceAttr with a blob that can not
/// be received
bool isElementAttrUninitializedDenseResource(mlir::ElementsAttr elementsAttr);

#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp.inc"

} // namespace onnx_mlir
Expand Down
4 changes: 4 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Sequence/SplitToSequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ LogicalResult ONNXSplitToSequenceOp::verify() {
if (splitRank > 1)
return emitOpError() << ": split has rank " << splitRank << " > 1";
if (ElementsAttr entries = getElementAttributeFromONNXValue(splitValue)) {
if (isElementAttrUninitializedDenseResource(entries)) {
return success(); // Return success to allow the parsing of MLIR with
// elided attributes
}
if (splitRank == 0) {
auto scalar = getScalarValue<int64_t>(entries, splitType);
if (scalar <= 0)
Expand Down
4 changes: 4 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Tensor/ConstantOfShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ LogicalResult ONNXConstantOfShapeOp::verify() {
if (auto constantOp = getONNXConstantOp(input)) {
ElementsAttr valueAttribute =
mlir::cast<ElementsAttr>(constantOp.getValueAttr());
if (isElementAttrUninitializedDenseResource(valueAttribute)) {
return success(); // Return success to allow the parsing of MLIR with
// elided attributes
}
// Get repeat values from valueAttribute.
auto valueIt = valueAttribute.getValues<IntegerAttr>().begin();
for (int i = 0; i < inputShape[0]; ++i) {
Expand Down
11 changes: 9 additions & 2 deletions src/Dialect/ONNX/ONNXOps/Tensor/GatherElements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,13 @@ LogicalResult ONNXGatherElementsOp::verify() {
// along axis of size s.
ArrayRef<int64_t> dataShape = dataType.getShape();
const int64_t dataDimAtAxis = dataShape[axis];
if (dataDimAtAxis >= 0)
if (ElementsAttr valueAttribute = getElementAttributeFromONNXValue(indices))
if (dataDimAtAxis >= 0) {
if (ElementsAttr valueAttribute =
getElementAttributeFromONNXValue(indices)) {
if (isElementAttrUninitializedDenseResource(valueAttribute)) {
return success(); // Return success to allow the parsing of MLIR with
// elided attributes
}
for (IntegerAttr value : valueAttribute.getValues<IntegerAttr>()) {
int64_t index = value.getInt();
if (index >= -dataDimAtAxis && index < dataDimAtAxis)
Expand All @@ -83,6 +88,8 @@ LogicalResult ONNXGatherElementsOp::verify() {
onnx_mlir::Diagnostic::Range<int64_t>(
-dataDimAtAxis, dataDimAtAxis - 1));
}
}
}

return success();
}
Expand Down
4 changes: 4 additions & 0 deletions src/Dialect/ONNX/ONNXOps/Tensor/GatherND.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ LogicalResult ONNXGatherNDOp::verify() {
// All values in 'indices' are expected to satisfy the inequality:
// -data.shape[b + i] <= indices[...,i] <= (data.shape[b + i]-1)].
if (ElementsAttr valueAttribute = getElementAttributeFromONNXValue(indices)) {
if (isElementAttrUninitializedDenseResource(valueAttribute)) {
return success(); // Return success to allow the parsing of MLIR with
// elided attributes
}
int flatIndex = 0;
for (IntegerAttr value : valueAttribute.getValues<IntegerAttr>()) {
int64_t indexValue = value.getInt();
Expand Down
37 changes: 37 additions & 0 deletions test/mlir/onnx/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,15 @@ func.func @test_constantofshape_verifier_4() -> tensor<2xi64> {

// -----

func.func @test_constantofshape_elided() -> tensor<2xi64> {
// Tests that we do not crash on elided elements
%0 = onnx.Constant dense_resource<__elided__> : tensor<2xi64>
%1 = "onnx.ConstantOfShape"(%0) : (tensor<2xi64>) -> tensor<2xi64>
"onnx.Return"(%1) : (tensor<2xi64>) -> ()
}

// -----

func.func @test_flatten_verifier_1(%arg0 : tensor<5x5x1x32xf32>) -> tensor<*xf32> {
// expected-error @+1 {{onnx.Flatten: 'axis' value is 5, accepted range is [-4, 4]}}
%1 = "onnx.Flatten"(%arg0) {axis = 5 : si64} : (tensor<5x5x1x32xf32>) -> tensor<*xf32>
Expand Down Expand Up @@ -214,6 +223,15 @@ func.func @test_gatherElements_verifier_2(%data: tensor<2x2xf32>, %indices: tens

// -----

func.func @test_gatherElements_verifier_elided(%data: tensor<12x14x1024xf32>) -> tensor<12x14x14xf32> {
// Tests that we do not crash on elided elements
%indices = onnx.Constant dense_resource<__elided__> : tensor<12x14x14xi64>
%1 = "onnx.GatherElements"(%data, %indices) {axis = -1 : si64} : (tensor<12x14x1024xf32>, tensor<12x14x14xi64>) -> tensor<12x14x14xf32>
"onnx.Return"(%1) : (tensor<12x14x14xf32>) -> ()
}

// -----

func.func @test_hardmax_verifier_1(%arg0: tensor<2x2xf32>) -> tensor<*xf32> {
// expected-error @+1 {{onnx.Hardmax: 'axis' value is 3, accepted range is [-2, 1]}}
%1 = "onnx.Hardmax"(%arg0) {axis = 3: si64} : (tensor<2x2xf32>) -> tensor<*xf32>
Expand Down Expand Up @@ -307,6 +325,16 @@ func.func @test_gatherND_verifier_6(%arg0 : tensor<3x4x4x4xf32>) -> tensor<*xf32
// expected-error @+2 {{onnx.GatherND: 'indices[0]' value is 3, accepted range is [-3, 2]}}
%indices = "onnx.Constant"() {value = dense<[3,2,2]> : tensor<3xi64>} : () -> tensor<3x3x2xi64>
%1 = "onnx.GatherND"(%arg0, %indices) : (tensor<3x4x4x4xf32>, tensor<3x3x2xi64>) -> tensor<*xf32>
"onnx.Return"(%1) : (tensor<*xf32>) -> ()
}

// -----

func.func @test_gatherND_verifier_elided(%arg0 : tensor<3x4x4x4xf32>) -> tensor<*xf32> {
// Test that we do not crash on elided elements
%indices = onnx.Constant dense_resource<__elided__> : tensor<3x3x2xi64>
%1 = "onnx.GatherND"(%arg0, %indices) : (tensor<3x4x4x4xf32>, tensor<3x3x2xi64>) -> tensor<*xf32>
"onnx.Return"(%1) : (tensor<*xf32>) -> ()
}

// -----
Expand Down Expand Up @@ -580,6 +608,15 @@ func.func @test_splitToSequence_verifier_6(%arg0: tensor<2x2xf32>) -> !onnx.Seq<

// -----

func.func @test_splitToSequence_verifier_elided(%arg0: tensor<2x2xf32>) -> !onnx.Seq<tensor<*xf32>> {
// Tests that we do not crash on elided elements
%0 = onnx.Constant dense_resource<__elided__> : tensor<i64>
%1 = "onnx.SplitToSequence"(%arg0, %0) : (tensor<2x2xf32>, tensor<i64>) -> !onnx.Seq<tensor<*xf32>>
"onnx.Return"(%1) : (!onnx.Seq<tensor<*xf32>>) -> ()
}

// -----

func.func @test_topK_verifier_1(%arg0: tensor<3x4xi64>, %arg1: tensor<1xi64>) -> (tensor<*xf32>, tensor<*xi64>) {
// expected-error @+1 {{onnx.TopK: 'axis' value is 2, accepted range is [-2, 1]}}
%1, %2 = "onnx.TopK"(%arg0, %arg1) {axis = 2 : si64, largest = 1 : si64, sorted = 1 : si64} : (tensor<3x4xi64>, tensor<1xi64>) -> (tensor<*xf32>, tensor<*xi64>)
Expand Down

0 comments on commit 03c16db

Please sign in to comment.