diff --git a/docs/development.md b/docs/development.md index fe997447c319..56ae3dbf0728 100644 --- a/docs/development.md +++ b/docs/development.md @@ -53,42 +53,52 @@ Two setups are possible to build: in-tree and out-of-tree. The in-tree setup is The following command generates configuration files to build the project *in-tree*, that is, using llvm/llvm-project as the main build. This will build LLVM as well as torch-mlir and its subprojects. On Windows, use the "Developer PowerShell for Visual Studio" to ensure that the compiler and linker binaries are in the `PATH` variable. +This requires `lld`, `clang`, `ccache`, and other dependencies for building `libtorch` / `PyTorch` wheels from source. If you run into issues because of these, try the [simplified build command](#simplified-build). + ```shell cmake -GNinja -Bbuild \ + externals/llvm-project/llvm \ -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ -DPython3_FIND_VIRTUALENV=ONLY \ -DLLVM_ENABLE_PROJECTS=mlir \ -DLLVM_EXTERNAL_PROJECTS="torch-mlir" \ -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$PWD" \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ -DLLVM_TARGETS_TO_BUILD=host \ - externals/llvm-project/llvm -``` -#### Flags that can reduce build time: -* Enabling clang on Linux -```shell - -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -``` -* Enabling ccache -```shell - -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache -``` -* Enabling LLD (links in seconds compared to minutes) -```shell - -DCMAKE_EXE_LINKER_FLAGS_INIT="-fuse-ld=lld" -DCMAKE_MODULE_LINKER_FLAGS_INIT="-fuse-ld=lld" -DCMAKE_SHARED_LINKER_FLAGS_INIT="-fuse-ld=lld" -# Use --ld-path= instead of -fuse-ld=lld for clang > 13 -``` -* Enabling libtorch binary cache -By default we download the latest version of libtorch everytime you build so we are always on the latest version. Set `-DLIBTORCH_CACHE=ON` to -not download the latest version everytime. If libtorch gets out of date and you test against a newer PyTorch you may notice failures. -```shell - -DLIBTORCH_CACHE=ON -``` -* Enabling building libtorch as part of your build -By default we download the latest version of libtorch. We have an experimental path to build libtorch (and PyTorch wheels) from source. + `# use clang`\ + -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \ + `# use ccache to cache build results` \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + `# use LLD to link in seconds, rather than minutes` \ + `# if using clang <= 13, replace --ld-path=lld with -fuse-ld=lld` \ + -DCMAKE_EXE_LINKER_FLAGS_INIT="--ld-path=lld" \ + -DCMAKE_MODULE_LINKER_FLAGS_INIT="--ld-path=lld" \ + -DCMAKE_SHARED_LINKER_FLAGS_INIT="--ld-path=lld" \ + `# Enabling libtorch binary cache instead of downloading the latest libtorch everytime.` \ + `# Testing against a mismatched version of libtorch may cause failures` \ + -DLIBTORCH_CACHE=ON \ + `# Enable an experimental path to build libtorch (and PyTorch wheels) from source,` \ + `# instead of downloading them` \ + -DLIBTORCH_SRC_BUILD=ON \ + `# Set the variant of libtorch to build / link against. (shared|static and optionally cxxabi11)` \ + -DLIBTORCH_VARIANT=shared +``` + +# Simplified build + +If you're running into issues with the above build command, consider using the following: + ```shell - -DLIBTORCH_SRC_BUILD=ON # Build Libtorch from source - -DLIBTORCH_VARIANT=shared # Set the variant of libtorch to build / link against. (`shared`|`static` and optionally `cxxabi11`) +cmake -GNinja -Bbuild \ + -DCMAKE_BUILD_TYPE=Release \ + -DPython3_FIND_VIRTUALENV=ONLY \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_EXTERNAL_PROJECTS="torch-mlir" \ + -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$PWD" \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DLLVM_TARGETS_TO_BUILD=host \ + externals/llvm-project/llvm ``` #### Flags to enable MLIR debugging: diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 3230cc8b46a0..0de85f4eebe5 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -97,6 +97,31 @@ struct OpBinder { return success(); } + // Operand matches of different arities. + ParseResult tensorListOperand(Value &value0) { + if (op->getNumOperands() != 1) + return failure(); + value0 = op->getOperand(0); + auto tt = dyn_cast(value0.getType()); + if (!tt) + return failure(); + if (!toValidTensorType(tt.getContainedType())) + return failure(); + return success(); + } + + ParseResult tensorListResultType(Torch::ListType &type0) { + if (op->getNumResults() != 1) + return failure(); + auto tt = dyn_cast(op->getResult(0).getType()); + if (!tt) + return failure(); + if (!toValidTensorType(tt.getContainedType())) + return failure(); + type0 = tt; + return success(); + } + ParseResult tensorResultTypes(llvm::SmallVector &typeList) { for (auto result : op->getResults()) { auto t = toValidTensorType(result.getType()); diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h index c8d1c5051f28..163ed6300878 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h @@ -53,6 +53,9 @@ class BaseTensorType : public Type { /// convenient API. Type getOptionalDtype() const; + /// Get the raw optional sparse tensor encoding. + Attribute getOptionalSparsity() const; + /// Return true if this type has a list of sizes. bool hasSizes() const { return getOptionalSizes().has_value(); } @@ -93,6 +96,10 @@ class BaseTensorType : public Type { Type getWithSizesAndDtype(std::optional> optionalSizes, Type optionalDtype) const; + Type getWithSizesAndDtypeAndSparsity( + std::optional> optionalSizes, Type optionalDtype, + Attribute optionalSparsity) const; + /// Return a type with the same shape and dtype as this one, but with /// value semantics. ValueTensorType getWithValueSemantics() const; @@ -129,23 +136,31 @@ namespace Torch { inline std::optional> BaseTensorType::getOptionalSizes() const { - if (auto tensor = dyn_cast()) + if (auto tensor = mlir::dyn_cast(*this)) return tensor.getOptionalSizes(); - if (auto tensor = dyn_cast()) + if (auto tensor = mlir::dyn_cast(*this)) return tensor.getOptionalSizes(); llvm_unreachable("not a BaseTensorType!"); } inline Type BaseTensorType::getOptionalDtype() const { - if (auto tensor = dyn_cast()) + if (auto tensor = mlir::dyn_cast(*this)) return tensor.getOptionalDtype(); - if (auto tensor = dyn_cast()) + if (auto tensor = mlir::dyn_cast(*this)) return tensor.getOptionalDtype(); llvm_unreachable("not a BaseTensorType!"); } +inline Attribute BaseTensorType::getOptionalSparsity() const { + if (auto tensor = mlir::dyn_cast(*this)) + return tensor.getOptionalSparsity(); + if (auto tensor = mlir::dyn_cast(*this)) + return tensor.getOptionalSparsity(); + llvm_unreachable("not a BaseTensorType!"); +} + inline bool BaseTensorType::classof(Type type) { - return type.isa(); + return mlir::isa(type); } } // namespace Torch diff --git a/include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h b/include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h new file mode 100644 index 000000000000..e29054790e5c --- /dev/null +++ b/include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// +#ifndef TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H +#define TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Value.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace torch { +namespace Torch { + +// Create a new SparseTensorEncodingAttr based on the provided `attr`, but with +// a new dense level inserted at `dim`. +FailureOr getSparsityWithDenseLTAtDim(Attribute attr, Value dim); + +} // namespace Torch +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_DIALECT_TORCH_SPARSITY_UTILS_H diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 44f977d5d0ed..829834959692 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -86,6 +86,9 @@ bool isBuiltInType(Type type); // std::nullopt is returned if the tensorRank can't be determined. std::optional getTensorRank(Value tensor); +// Helper function to get the number of elements in a tensor. +std::optional getTensorNumel(Value tensor); + bool isViewLikeOp(Operation *op); Value getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, Location loc, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 2615f7b7a36a..f22e05fc11bd 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1266,6 +1266,83 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } return failure(); }); + patterns.onOp( + "GlobalMaxPool", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + auto inputTensorType = operand.getType().cast(); + if (!inputTensorType || !inputTensorType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected input type having sizes"); + } + ArrayRef inputShape = inputTensorType.getSizes(); + unsigned inputRank = inputShape.size(); + if (!resultType || !resultType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected result type having sizes"); + } + SmallVector cstKernel, cstPadding, cstStrides, cstDilations; + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + for (unsigned i = 2; i < inputRank; i++) { + if (inputShape[i] == Torch::kUnknownSize) { + Value dim = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value inputDimSize = rewriter.create( + binder.getLoc(), operand, dim); + cstKernel.push_back(inputDimSize); + } else { + cstKernel.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(inputShape[i]))); + } + cstPadding.push_back(cstZero); + cstDilations.push_back(cstOne); + cstStrides.push_back(cstOne); + } + Value kernelSizeList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstKernel); + Value paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstPadding); + Value dilationsList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstDilations); + Value stridesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstStrides); + Value cstCeilMode = + rewriter.create(binder.getLoc(), false); + + if (inputRank == 3) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, dilationsList, cstCeilMode); + return success(); + } else if (inputRank == 4) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, dilationsList, cstCeilMode); + return success(); + } else if (inputRank == 5) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, dilationsList, cstCeilMode); + return success(); + } + return failure(); + }); patterns.onOp( "LayerNormalization", 17, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 4852a397e7fe..03de04acffdd 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -518,6 +518,44 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( cstStrReduction); return success(); }); + patterns.onOp( + "SequenceConstruct", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + SmallVector operands; + Torch::ListType resultType; + + if (binder.tensorOperands(operands, binder.getNumOperands()) || + binder.tensorListResultType(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, operands); + return success(); + }); + patterns.onOp( + "SequenceLength", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // onnx.SequenceLength takes a sequence(list) of tensors, and returns + // a zero rank tensor with the length. + Torch::ValueTensorType resultType; + Value x; + if (binder.tensorListOperand(x) || binder.tensorResultType(resultType)) + return failure(); + + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value none = rewriter.create(binder.getLoc()); + + Value len = rewriter.create( + binder.getLoc(), rewriter.getType(), x); + + // AtenLenTOp returns a torch.int, so we have to + // put that in a tensor. + rewriter.replaceOpWithNewOp( + binder.op, resultType, len, none, none, cstFalse); + + return success(); + }); patterns.onOp( "Sigmoid", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -966,6 +1004,55 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, data); return success(); }); + patterns.onOp( + "ReduceLogSumExp", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + + // out = Log(reducesum(exp(data))) + Value castDType = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(/*Float64Type*/ 7)); + Value noneVal = rewriter.create(binder.getLoc()); + Value constFalse = + rewriter.create(binder.getLoc(), false); + auto size = data.getType() + .dyn_cast() + .getOptionalSizes(); + auto f64ResultType = rewriter.getType( + size, rewriter.getF64Type()); + Value dataCast = rewriter.create( + binder.getLoc(), f64ResultType, data, castDType, + /*non_blocking=*/constFalse, /*copy=*/constFalse, + /*memory_format=*/noneVal); + Value dataExp = rewriter.create( + binder.getLoc(), f64ResultType, dataCast); + auto f64ReduceType = rewriter.getType( + resultType.getOptionalSizes(), rewriter.getF64Type()); + auto reducedSumBool = reducedSumImpl( + binder, rewriter, dataExp, f64ReduceType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, true); + if (failed(reducedSumBool)) + return rewriter.notifyMatchFailure( + binder.op, + "Failed to perform sum operation on square of operand"); + Value finalResult = rewriter.create( + binder.getLoc(), f64ReduceType, data); + Value resultDtype = Torch::getDtypeIntValueForType( + rewriter, binder.getLoc(), resultType.getDtype()); + rewriter.replaceOpWithNewOp( + binder.op, resultType, finalResult, resultDtype, + /*non_blocking=*/constFalse, /*copy=*/constFalse, + /*memory_format=*/noneVal); + return success(); + }); patterns.onOp("ReduceSum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -1621,8 +1708,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( }); patterns.onOp( - "Transpose", 13, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Transpose", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { auto loc = binder.getLoc(); Torch::ValueTensorType resultType; Value operand; diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 63fc43bd7bb5..0fc27aed808d 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1840,9 +1840,11 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern { op, adaptor, rewriter, resultShape, offsets, strides))) { return failure(); } - + SmallVector dynShape(resultType.getRank(), ShapedType::kDynamic); + auto sliceType = RankedTensorType::get( + dynShape, resultType.getElementType(), resultType.getEncoding()); Value result = rewriter.create( - loc, input, offsets, resultShape, strides); + loc, sliceType, input, offsets, resultShape, strides); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 00c022cc1067..5854b1b7d7fd 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -247,8 +247,7 @@ FailureOr broadcastAndConcatIndices(Operation *op, concatShape.push_back(indexTensors.size()); SmallVector broadcastedIndices; - Type indexElemTy = - cast(indexTensors[0].getType()).getElementType(); + Type indexElemTy = rewriter.getI64Type(); RankedTensorType bcastIndexType = RankedTensorType::get(indicesShape, indexElemTy); for (auto indexTensor : indexTensors) { diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index 81a1a1f564d1..502a837ea0a0 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -53,7 +53,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } } - if (isa(op)) { + if (isa(op)) { if (isa(elementTy)) { auto constAttr = DenseElementsAttr::get( constType, @@ -121,6 +121,46 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, return nullptr; } +static Value createReduceOpWithSingleRegionOp(Operation *op, Value input, + Type outTy, + ArrayRef dims, + PatternRewriter &rewriter) { + auto inputTy = dyn_cast(input.getType()); + if (!inputTy) + return nullptr; + Value initValue = + createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); + if (!initValue) + return nullptr; + + stablehlo::ReduceOp reduce = rewriter.create( + op->getLoc(), outTy, input, initValue, + rewriter.getDenseI64ArrayAttr(dims)); + + Block &block = reduce.getBody().emplaceBlock(); + auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); + block.addArgument(blockArgumentTy, op->getLoc()); + block.addArgument(blockArgumentTy, op->getLoc()); + auto *firstArgument = block.args_begin(); + auto secondArgument = block.args_rbegin(); + + { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&block); + Value result; + if (isa(op)) { + result = rewriter.create( + op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + } else { + op->emitError("unimplemented lowering in " + "createReduceOpWithSingleRegionOp"); + return nullptr; + } + rewriter.create(op->getLoc(), result); + } + return reduce.getResults()[0]; +} + // Util for converting AtenArgmaxOp and AtenMaxDimOp static std::optional getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, @@ -371,35 +411,64 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( op, "failed to get dimension sizes of the input"); } auto inputShapeVec = *inputShapeInfo; - auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, - dim, options.dimSizeIndexBits) - .value(); - if (keepDim) { - auto outShapeVec = inputShapeVec; - outShapeVec[dim] = rewriter.create( - op->getLoc(), - rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); - auto outShapeTensor = rewriter.create( - op->getLoc(), outShapeVec); - - auto stablehloReduceValueResult = - rewriter.create( - op->getLoc(), valResultType, stablehloReduceResults[0], - outShapeTensor); - auto stablehloReduceIndexResult = - rewriter.create( - op->getLoc(), idxResultType, stablehloReduceResults[1], - outShapeTensor); - rewriter.replaceOp( - op, {stablehloReduceValueResult, stablehloReduceIndexResult}); + if (op.getResult(1).use_empty()) { + llvm::SmallVector outputShape(inputTy.getShape()); + outputShape.erase(outputShape.begin() + dim); + Value reduceResult = createReduceOpWithSingleRegionOp( + op, input, RankedTensorType::get(outputShape, inputElemTy), + ArrayRef{dim}, rewriter); + if (!reduceResult) + return failure(); + + if (keepDim) { + auto outShapeVec = inputShapeVec; + outShapeVec[dim] = rewriter.create( + op->getLoc(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(options.dimSizeIndexBits), 1)); + auto outShapeTensor = rewriter.create( + op->getLoc(), outShapeVec); + + auto stablehloReduceValueResult = + rewriter.create( + op->getLoc(), valResultType, reduceResult, outShapeTensor); + rewriter.replaceOp(op, {stablehloReduceValueResult, Value()}); + return success(); + } + rewriter.replaceOp(op, {reduceResult, Value()}); + return success(); + } else { + auto stablehloReduceResults = + getMaxInDim(rewriter, op, input, inputShapeVec, dim, + options.dimSizeIndexBits) + .value(); + + if (keepDim) { + auto outShapeVec = inputShapeVec; + outShapeVec[dim] = rewriter.create( + op->getLoc(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(options.dimSizeIndexBits), 1)); + auto outShapeTensor = rewriter.create( + op->getLoc(), outShapeVec); + + auto stablehloReduceValueResult = + rewriter.create( + op->getLoc(), valResultType, stablehloReduceResults[0], + outShapeTensor); + auto stablehloReduceIndexResult = + rewriter.create( + op->getLoc(), idxResultType, stablehloReduceResults[1], + outShapeTensor); + rewriter.replaceOp( + op, {stablehloReduceValueResult, stablehloReduceIndexResult}); + return success(); + } + rewriter.replaceOp(op, + {stablehloReduceResults[0], stablehloReduceResults[1]}); return success(); } - - rewriter.replaceOp(op, - {stablehloReduceResults[0], stablehloReduceResults[1]}); - return success(); } } // namespace @@ -692,11 +761,11 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } // namespace -// AtenMaxOp +// AtenAmaxOp namespace { template <> -LogicalResult ConvertAtenReductionOp::matchAndRewrite( - AtenMaxOp op, OpAdaptor adaptor, +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenAmaxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = dyn_cast(input.getType()); @@ -717,40 +786,102 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( "AtenMaxOp to StableHLO"); } + bool keepDim = false; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { + return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); + } + + SmallVector inputDims; SmallVector dims; + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(inputDims))) { + return rewriter.notifyMatchFailure( + op, "non-const integer `dim` is not supported"); + } + for (auto d : inputDims) { + d = toPositiveDim(d, inputTy.getRank()); + // Drop invalid dims + if (isValidDim(d, inputTy.getRank())) { + dims.push_back(d); + } + } + llvm::sort(dims.begin(), dims.end()); + std::unordered_set dimsSet(dims.begin(), dims.end()); + SmallVector reduceResultShape; for (int64_t i = 0; i < inputTy.getRank(); i++) { - dims.push_back(i); + if (dimsSet.find(i) == dimsSet.end()) { + reduceResultShape.push_back(inputTy.getDimSize(i)); + } } - Value initValue = - createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) + Value reduceResult = createReduceOpWithSingleRegionOp( + op, input, RankedTensorType::get(reduceResultShape, inputElemTy), dims, + rewriter); + if (!reduceResult) return failure(); - llvm::sort(dims.begin(), dims.end()); - auto stablehloReduceOp = rewriter.create( - op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, - rewriter.getDenseI64ArrayAttr(dims)); - Block &block = stablehloReduceOp.getBody().emplaceBlock(); - auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); + if (keepDim) { + const auto &options = getOptions(); + auto outShapeInfo = + hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + if (failed(outShapeInfo)) { + return rewriter.notifyMatchFailure( + op, "failed to get dimension sizes of the input"); + } + auto outShapeVec = *outShapeInfo; + auto one = rewriter.create( + op->getLoc(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(options.dimSizeIndexBits), 1)); + for (int64_t i : dims) { + outShapeVec[i] = one; + } + auto outShapeTensor = rewriter.create( + op->getLoc(), outShapeVec); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), reduceResult, + outShapeTensor); + return success(); + } + rewriter.replaceOp(op, reduceResult); + return success(); +} +} // namespace - block.addArgument(blockArgumentTy, op->getLoc()); - block.addArgument(blockArgumentTy, op->getLoc()); +// AtenMaxOp +namespace { +template <> +LogicalResult ConvertAtenReductionOp::matchAndRewrite( + AtenMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value input = adaptor.getSelf(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy) { + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); + } + auto inputElemTy = inputTy.getElementType(); + if (!inputElemTy.isIntOrFloat()) { + return op.emitError( + "only floating-point or integer datatype legalization supported"); + } + // Currently, (u)int8 dtype is not supported + if (isa(inputElemTy) && + inputElemTy.getIntOrFloatBitWidth() == 8) { + return rewriter.notifyMatchFailure( + op, "IntegerType with bitwidth 8 unsupported in convertion from " + "AtenMaxOp to StableHLO"); + } - auto *firstArgument = block.args_begin(); - auto secondArgument = block.args_rbegin(); + SmallVector dims = + llvm::to_vector(llvm::seq(0, inputTy.getRank())); - { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&block); - Value maxResult = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), maxResult); - } + Value reduceResult = createReduceOpWithSingleRegionOp( + op, input, RankedTensorType::get({}, inputElemTy), dims, rewriter); + if (!reduceResult) + return failure(); rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), - stablehloReduceOp.getResults()); + op, getTypeConverter()->convertType(op.getType()), reduceResult); return success(); } } // namespace @@ -1205,6 +1336,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( patterns.add>(typeConverter, context, options) INSERT_ATEN_REDUCTION_OP_PATTERN(AtenArgmaxOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenMaxDimOp); + INSERT_ATEN_REDUCTION_OP_PATTERN(AtenAmaxOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumDimIntListOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenSumOp); INSERT_ATEN_REDUCTION_OP_PATTERN(AtenProdOp); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 9ee6ca6a33c1..92c19363a60d 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2596,7 +2596,8 @@ void AtenMaskedFillTensorOp::getCanonicalizationPatterns( OpFoldResult AtenCloneOp::fold(FoldAdaptor adaptor) { // note: memory_format would be ignored - if (llvm::dyn_cast(getSelf().getType())) { + if (getSelf().getType() == getResult().getType() && + llvm::dyn_cast(getSelf().getType())) { // self should have value semantics return getSelf(); } @@ -3584,17 +3585,17 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { auto inType = dyn_cast(getOperand(0).getType()); auto outType = dyn_cast(getResult().getType()); + if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() || + !inType.hasDtype() || !outType.hasDtype() || + inType.getDtype() != outType.getDtype()) + return nullptr; + if (start && end && step && step.getValue().getSExtValue() == 1 && start.getValue().getSExtValue() == 0 && end.getValue().getSExtValue() == std::numeric_limits::max() && inType == outType) return getOperand(0); - if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() || - !inType.hasDtype() || !outType.hasDtype() || - inType.getDtype() != outType.getDtype()) - return nullptr; - if (inType.getSizes().size() != outType.getSizes().size() || !inType.areAllSizesKnown() || !outType.areAllSizesKnown()) return nullptr; @@ -4555,10 +4556,10 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) { OpFoldResult PrimNumToTensorScalarOp::fold(FoldAdaptor adaptor) { Attribute a = adaptor.getA(); - auto resultTy = cast(getType()); + auto resultTy = dyn_cast(getType()); if (!a) return {}; - if (!resultTy.hasDtype() || !resultTy.hasSizes()) + if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes()) return {}; auto dty = resultTy.getDtype(); diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index c162166cdd13..d1906d6989af 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -235,6 +235,18 @@ Type BaseTensorType::getWithSizesAndDtype( llvm_unreachable("not a BaseTensorType!"); } +Type BaseTensorType::getWithSizesAndDtypeAndSparsity( + std::optional> optionalSizes, Type optionalDtype, + Attribute optionalSparsity) const { + if (mlir::isa(*this)) + return NonValueTensorType::get(getContext(), optionalSizes, optionalDtype, + optionalSparsity); + if (mlir::isa(*this)) + return ValueTensorType::get(getContext(), optionalSizes, optionalDtype, + optionalSparsity); + llvm_unreachable("not a BaseTensorType!"); +} + ValueTensorType BaseTensorType::getWithValueSemantics() const { if (auto tensor = dyn_cast()) return tensor.getWithValueSemantics(); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 315264333b3a..7de57fd9dd9d 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -71,10 +71,10 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op, } } - Type resultType = tensorType.getWithSizesAndDtype( + Type resultType = tensorType.getWithSizesAndDtypeAndSparsity( !tensorType.hasSizes() ? std::optional>() : llvm::ArrayRef(sizes), - tensorType.getOptionalDtype()); + tensorType.getOptionalDtype(), tensorType.getOptionalSparsity()); return resultType; } @@ -3371,6 +3371,104 @@ class DecomposeAtenMaskedFillScalarOp }; } // namespace +// Decompose aten.masked_scatter: +// def masked_scatter(self: Tensor, mask: Tensor, source: Tensor) -> Tensor: +// mask_int = mask + torch.zeros_like(self) +// prefix_sum = torch.cumsum(mask_int.flatten(), dim=0) +// mask_prefix = torch.clamp(prefix_sum - 1, min=0) +// mask = mask.to(torch.bool) +// source = source.flatten()[mask_prefix].reshape(mask.shape) +// return torch.where(mask, source, self) +namespace { +class DecomposeAtenMaskedScatterOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenMaskedScatterOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto context = op.getContext(); + Value mask = op.getMask(); + Value source = op.getSource(); + Value self = op.getSelf(); + + auto selfTy = cast(self.getType()); + auto resTy = cast(op.getType()); + auto sourceTy = cast(source.getType()); + + if (!resTy || !resTy.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + if (!selfTy || !selfTy.areAllSizesKnown()) + return rewriter.notifyMatchFailure( + op, "Unimplemented: no implementation for rankless tensor"); + if (!sourceTy || !sourceTy.areAllSizesKnown() || !sourceTy.hasDtype()) + return rewriter.notifyMatchFailure( + op, "Unimplemented: no implementation for rankless tensor"); + + int64_t selfNumel = getTensorNumel(self).value(); // as selfTy has sizes + int64_t sourceNumel = + getTensorNumel(source).value(); // as sourceTy has sizes + int64_t selfRank = selfTy.getSizes().size(); + int64_t sourceRank = sourceTy.getSizes().size(); + + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value constOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value constNone = rewriter.create(loc); + Value selfLastDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(selfRank - 1)); + Value sourceLastDim = rewriter.create( + loc, rewriter.getI64IntegerAttr(sourceRank - 1)); + + auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); + auto int64Dtype = getDtypeIntValueForType( + rewriter, loc, + rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); + auto selfIntType = selfTy.getWithSizesAndDtype(selfTy.getSizes(), si64Type); + + Value zerosLike = rewriter.create( + loc, selfIntType, self, int64Dtype, constNone, constNone, constNone, + constNone); + Value maskInt = rewriter.create( + loc, selfIntType, mask, zerosLike, constOne); + + auto flattenMaskedType = selfTy.getWithSizesAndDtype( + /*optionalSizes=*/{selfNumel}, si64Type); + Value maskIntFlatten = rewriter.create( + loc, flattenMaskedType, maskInt, constZero, selfLastDim); + Value prefixSum = rewriter.create( + loc, flattenMaskedType, maskIntFlatten, + /*dim=*/constZero, constNone); + Value prefixSumMinusOne = rewriter.create( + loc, flattenMaskedType, prefixSum, constOne, constOne); + Value maskPrefix = rewriter.create( + loc, flattenMaskedType, prefixSumMinusOne, /*min=*/constZero, + /*max=*/constNone); + + auto sourceFlattenType = sourceTy.getWithSizesAndDtype( + /*optionalSizes=*/{sourceNumel}, sourceTy.getDtype()); + Value sourceFlatten = rewriter.create( + loc, sourceFlattenType, source, constZero, sourceLastDim); + + auto selectSourceType = sourceTy.getWithSizesAndDtype( + /*optionalSizes=*/{selfNumel}, sourceTy.getDtype()); + Value selectSource = rewriter.create( + loc, selectSourceType, sourceFlatten, constZero, maskPrefix); + + // Reshape normalized output back to the original input shape + auto selfShape = rewriter.create( + loc, Torch::ListType::get(IntType::get(context)), self); + Value sourceReshape = rewriter.create( + loc, selfTy, selectSource, selfShape); + rewriter.replaceOpWithNewOp(op, resTy, mask, + sourceReshape, self); + return success(); + } +}; +} // namespace + // Decompose aten._convolution-like to aten.convolution namespace { template @@ -7961,6 +8059,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index dbf60614c33a..b2cdc74e13ff 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -390,6 +390,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Utils/CMakeLists.txt b/lib/Dialect/Torch/Utils/CMakeLists.txt index 91088078891d..45b3e1b987aa 100644 --- a/lib/Dialect/Torch/Utils/CMakeLists.txt +++ b/lib/Dialect/Torch/Utils/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(TorchMLIRTorchUtils Utils.cpp + SparsityUtils.cpp TorchUpstream.cpp ADDITIONAL_HEADER_DIRS diff --git a/lib/Dialect/Torch/Utils/SparsityUtils.cpp b/lib/Dialect/Torch/Utils/SparsityUtils.cpp new file mode 100644 index 000000000000..b2f1ef2d5280 --- /dev/null +++ b/lib/Dialect/Torch/Utils/SparsityUtils.cpp @@ -0,0 +1,55 @@ +//===----------------------------------------------------------------------===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Dialect/Torch/Utils/SparsityUtils.h" +#include "mlir/Dialect/SparseTensor/IR/Enums.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/ADT/SmallVector.h" +#include + +using namespace mlir; +using namespace mlir::sparse_tensor; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +FailureOr Torch::getSparsityWithDenseLTAtDim(Attribute attr, + Value dim) { + if (!attr) + return Attribute(); + + auto enc = cast(attr); + int64_t dimInt = 0; + int64_t rank = enc.getDimRank() + 1; + if (matchPattern(dim, m_TorchConstantInt(&dimInt))) { + dimInt = toPositiveDim(dimInt, rank); + if (!isValidDim(dimInt, rank)) { + return failure(); + } + if (!enc.isIdentity()) { + // TODO: support block sparsity and permutation (CSC). + return failure(); + } + auto denseLT = *LevelType::buildLvlType(LevelFormat::Dense, true, true); + SmallVector lvlTps = llvm::to_vector(enc.getLvlTypes()); + lvlTps.insert(lvlTps.begin() + dimInt, denseLT); + auto dim2Lvl = AffineMap::getMultiDimIdentityMap(rank, attr.getContext()); + return SparseTensorEncodingAttr::get( + enc.getContext(), lvlTps, dim2Lvl, AffineMap(), enc.getPosWidth(), + enc.getCrdWidth(), enc.getExplicitVal(), enc.getImplicitVal()); + } + // Do not know how to handle dynamic dimension. + return failure(); +} diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 9165134573b1..005f00836421 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" +#include "torch-mlir/Dialect/Torch/Utils/SparsityUtils.h" using namespace mlir; using namespace mlir::torch; @@ -251,6 +252,19 @@ std::optional Torch::getTensorRank(Value tensor) { return tensorType.getSizes().size(); } +std::optional Torch::getTensorNumel(Value tensor) { + BaseTensorType tensorType = cast(tensor.getType()); + if (!tensorType.hasSizes()) + return std::nullopt; + int64_t numel = 1; + for (auto dim : tensorType.getSizes()) { + if (dim == ShapedType::kDynamic) + return ShapedType::kDynamic; + numel *= dim; + } + return numel; +} + bool Torch::isViewLikeOp(Operation *op) { // AtenContiguousOp might return a view, so this is conservatively // correct. We could potentially be more precise and identify the cases @@ -361,6 +375,11 @@ FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, if (!inputType.hasSizes()) { return rewriter.notifyMatchFailure(op, "input tensor must have size"); } + FailureOr enc = + getSparsityWithDenseLTAtDim(inputType.getOptionalSparsity(), dim); + if (failed(enc)) { + return failure(); + } SmallVector unsqueezedShape; ArrayRef inputShape = inputType.getSizes(); @@ -377,8 +396,8 @@ FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, } else { unsqueezedShape.resize(unsqueezedRank, kUnknownSize); } - Type unsqueezedType = inputType.getWithSizesAndDtype( - unsqueezedShape, inputType.getOptionalDtype()); + Type unsqueezedType = inputType.getWithSizesAndDtypeAndSparsity( + unsqueezedShape, inputType.getOptionalDtype(), enc.value()); Value unsqueezed = rewriter.create( op->getLoc(), unsqueezedType, input, dim); return unsqueezed; diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp index 0c8cdf2fc54d..3ff6e4732db2 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp @@ -11,10 +11,9 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/OpDefinition.h" #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" @@ -47,13 +46,21 @@ class VerifyStablehloBackendContractPass // Structural operations. target.addDynamicallyLegalOp( opHasLegalTypes); - // Shape operations. - target.addDynamicallyLegalOp(opHasLegalTypes); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + auto moduleOp = getOperation(); + RewritePatternSet patterns(context); + if (failed(applyFullConversion(moduleOp, target, std::move(patterns)))) { + emitError(moduleOp.getLoc()) + << "Module does not conform to the Stablehlo backend contract. " + "See dialect conversion legality information above."; + return signalPassFailure(); + } } }; } // namespace diff --git a/projects/pt1/e2e_testing/main.py b/projects/pt1/e2e_testing/main.py index fdddbc2e443f..3c0620f4b41f 100644 --- a/projects/pt1/e2e_testing/main.py +++ b/projects/pt1/e2e_testing/main.py @@ -28,9 +28,6 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( RefBackendLinalgOnTensorsBackend, ) -from torch_mlir_e2e_test.onnx_backends.linalg_on_tensors import ( - LinalgOnTensorsOnnxBackend, -) from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import ( LinalgOnTensorsTosaBackend, ) @@ -56,6 +53,8 @@ FX_IMPORTER_CRASHING_SET, FX_IMPORTER_STABLEHLO_XFAIL_SET, FX_IMPORTER_STABLEHLO_CRASHING_SET, + FX_IMPORTER_TOSA_XFAIL_SET, + ONNX_TOSA_XFAIL_SET, ) # Import tests to register them in the global registry. @@ -75,8 +74,10 @@ def _get_argparse(): "lazy_tensor_core", "torchdynamo", "onnx", + "onnx_tosa", "fx_importer", "fx_importer_stablehlo", + "fx_importer_tosa", ] parser = argparse.ArgumentParser(description="Run torchscript e2e tests.") parser.add_argument( @@ -96,6 +97,8 @@ def _get_argparse(): "onnx": export to the model via onnx and reimport using the torch-onnx-to-torch path. "fx_importer": run the model through the fx importer frontend and execute the graph using Linalg-on-Tensors. "fx_importer_stablehlo": run the model through the fx importer frontend and execute the graph using Stablehlo backend. +"fx_importer_tosa": run the model through the fx importer frontend and execute the graph using the TOSA backend. +"onnx_tosa": Import ONNX to Torch via the torch-onnx-to-torch path and execute the graph using the TOSA backend. """, ) parser.add_argument( @@ -180,14 +183,22 @@ def main(): config = FxImporterTestConfig(LinalgOnTensorsStablehloBackend(), "stablehlo") xfail_set = FX_IMPORTER_STABLEHLO_XFAIL_SET crashing_set = FX_IMPORTER_STABLEHLO_CRASHING_SET + elif args.config == "fx_importer_tosa": + config = FxImporterTestConfig(LinalgOnTensorsTosaBackend(), "tosa") + xfail_set = FX_IMPORTER_TOSA_XFAIL_SET + crashing_set = set() elif args.config == "torchdynamo": config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend()) xfail_set = TORCHDYNAMO_XFAIL_SET crashing_set = TORCHDYNAMO_CRASHING_SET elif args.config == "onnx": - config = OnnxBackendTestConfig(LinalgOnTensorsOnnxBackend()) + config = OnnxBackendTestConfig(RefBackendLinalgOnTensorsBackend()) xfail_set = ONNX_XFAIL_SET crashing_set = ONNX_CRASHING_SET + elif args.config == "onnx_tosa": + config = OnnxBackendTestConfig(LinalgOnTensorsTosaBackend(), output_type="tosa") + xfail_set = ONNX_TOSA_XFAIL_SET + crashing_set = set() do_not_attempt = set( args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or [] diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 159f4e6c3604..454df6dee6f4 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1121,6 +1121,7 @@ "LinspaceTwoSizeModule_basic", "MaskedFillScalarFloatValueStaticModule_basic", "MaskedFillScalarIntValueStaticModule_basic", + "MaskedScatterStaticBasic_basic", "Matmul4dStatic_basic", "Matmul4dStaticBroadcast_basic", "Matmul_2d", @@ -2561,6 +2562,7 @@ "LinalgNormKeepDimComplexModule_basic", "LinalgVectorNormComplexModule_basic", "LogSoftmaxBackwardModule_basic", + "MaskedScatterStaticBasic_basic", "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dEmptyStrideStaticModule_basic", "MaxPool1dModule_basic", @@ -2911,3 +2913,1728 @@ # For now, we are removing the test until this issue has been debugged. "QuantizedMLP_basic", } + +FX_IMPORTER_TOSA_XFAIL_SET = { + "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", + "AdaptiveAvgPool1dGeneralDynamic_basic", + "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool2dDynamicNoBatch_basic", + "AdaptiveAvgPool2dDynamic_basic", + "AdaptiveAvgPool3dDynamicNoBatch_basic", + "AdaptiveAvgPool3dDynamic_basic", + "AdaptiveMaxPool1dDynamicNoBatch_basic", + "AdaptiveMaxPool1dDynamic_basic", + "AdaptiveMaxPool1dStatic_basic", + "AdaptiveMaxPool2dDynamicNoBatch_basic", + "AdaptiveMaxPool2dDynamicWithIndices_basic", + "AdaptiveMaxPool2dDynamic_basic", + "AdaptiveMaxPool2dStaticWithIndices_basic", + "AdaptiveMaxPool2dStatic_basic", + "AdaptiveMaxPool3dDynamicNoBatch_basic", + "AdaptiveMaxPool3dDynamicWithIndices_basic", + "AdaptiveMaxPool3dDynamic_basic", + "AdaptiveMaxPool3dStaticWithIndices_basic", + "AdaptiveMaxPool3dStatic_basic", + "AddIntModule_basic", + "Add_MixPModule_basic", + "AllBoolFalseModule_basic", + "AllBoolTrueModule_basic", + "AnyBoolFalseModule_basic", + "AnyBoolTrueModule_basic", + "ArangeStartOutViewModule_basic", + "ArgminIntModule_basic", + "ArgminIntModule_multiple_mins", + "ArgminModule_basic", + "ArgminModule_keepDim", + "ArgminModule_with_dim", + "AtenComplexImagModule_basic", + "AtenComplexRealModule_basic", + "AtenComplexViewModule_basic", + "AtenDiagEmbedDefaultDiag_basic", + "AtenDiagEmbedDimDiag_basic", + "AtenDiagEmbedNegOffsetDiag_basic", + "AtenDiagEmbedNonDefault4DDiag_basic", + "AtenDiagEmbedOffsetDiag_basic", + "AtenDiagEmbedRevDimDiag_basic", + "AtenEmbeddingBagStaticModule_basic", + "AtenEmbeddingBagSumExample_basic", + "AtenEyeMModuleInt2D_basic", + "AtenEyeModuleInt2D_basic", + "AtenFloatScalarModule_basic", + "AtenInstanceNormModule_basic", + "AtenIntBoolOpConstFalseModule_basic", + "AtenIntBoolOpConstTrueModule_basic", + "AtenIntBoolOpModule_basic", + "AtenIntTensorByteDtypeModule_basic", + "AtenIntTensorCharDtypeModule_basic", + "AtenItemFpOpModule_basic", + "AtenItemIntOpModule_basic", + "AtenLinalgCrossBroadcast_basic", + "AtenLinalgCrossCustomDim_basic", + "AtenLinalgCrossDynamic_basic", + "AtenLinalgCrossFloat_basic", + "AtenLinalgCrossInt_basic", + "AtenLinalgCrossNegativeDim_basic", + "AtenMatmulQMixedSigni8Transpose_basic", + "AtenMatmulQMixedSigni8_basic", + "AtenMatmulQint8MV_basic", + "AtenMatmulQint8VM_basic", + "AtenMatmulQint8VV_basic", + "AtenMatmulQint8_basic", + "AtenMmIntTypes_basic", + "AtenMmQMixedSigni8_basic", + "AtenMmQint8_basic", + "AtenMmQuint8_basic", + "AtenRealView128Module_basic", + "AtenRealView64Module_basic", + "AtenRoundFloatHalfToEvenModule_basic", + "AtenRoundFloatModule_basic", + "AtenSubFloatModule_basic", + "AtenTopKModule_basic", + "AtenTopKSmallestModule_basic", + "AtenTrilModule_basic", + "AtenTrilWithNegDiagonalModule_basic", + "AtenTrilWithPosDiagonalModule_basic", + "Aten_CastLongModule_basic", + "Aten_EmbeddingBagExample_basic", + "AvgPool1dFloatModule_basic", + "AvgPool1dIntModule_basic", + "AvgPool1dStaticModule_basic", + "AvgPool2dCeilModeTrueModule_basic", + "AvgPool2dDivisorOverrideModule_basic", + "AvgPool2dFloatModule_basic", + "AvgPool2dIntModule_basic", + "AvgPool2dStaticModule_basic", + "BernoulliFloatModule_basic", + "BernoulliModule_basic", + "BernoulliOnesModule_basic", + "BernoulliPModule_basic", + "BernoulliTensorModule_basic", + "BernoulliZerosModule_basic", + "BincountMinlengthModule_basic", + "BincountModule_basic", + "BincountStaticSizeModule_basic", + "BmmIntModule_basic", + "BoolFloatConstantModule_basic", + "BoolFloatFalseModule_basic", + "BoolFloatTrueModule_basic", + "BoolIntConstantModule_basic", + "BoolIntFalseModule_basic", + "BoolIntTrueModule_basic", + "BroadcastDynamicDimModule_basic", + "BroadcastToModule_basic", + "CeilFloatModule_basic", + "CollapseAllDimensionsModule_basic", + "CollapseFullDynamicModule_basic", + "CollapsePartialDynamicModule_basic", + "CollapseRank1DynamicModule_basic", + "CollapseStaticModule_basic", + "ConstantBoolParameterModule_basic", + "ContainsIntList_False", + "ContainsIntList_True", + "Conv1dModule_basic", + "Conv2dQInt8Module_basic", + "Conv2dWithPaddingDilationStrideStaticModule_grouped", + "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", + "Conv3dModule_basic", + "ConvTbcModule_basic", + "ConvTranspose2DQInt8_basic", + "Conv_Transpose2dModule_basic", + "ConvolutionBackwardModule2DPadded_basic", + "ConvolutionBackwardModule2DStatic_basic", + "ConvolutionBackwardModule2DStrided_basic", + "ConvolutionBackwardModule2D_basic", + "ConvolutionModule2DGroups_basic", + "ConvolutionModule2DTransposeNonUnitOutputPadding_basic", + "ConvolutionModule2DTransposeStridedStatic_basic", + "ConvolutionModule2DTransposeStrided_basic", + "ConvolutionModule2DTranspose_basic", + "CopyWithDifferentDTypesModule_basic", + "CosineSimilarityStaticBroadcastModule_basic", + "CrossEntropyLossModule_basic", + "CumsumInputDtypeInt32Module_basic", + "CumsumModule_basic", + "CumsumStaticModule_basic", + "CumsumStaticNegativeDimModule_basic", + "DiagonalModule_basic", + "DiagonalModule_nonsquare", + "DiagonalModule_transposed", + "DiagonalModule_with_dims", + "DiagonalModule_with_dims_and_offset", + "DiagonalModule_with_negative_dims", + "DiagonalModule_with_offset", + "DiagonalWithStaticShapeModule_basic", + "DivFloatModule_basic", + "DivIntModule_basic", + "DropoutTrainModule_basic", + "DropoutTrainStaticShapeModule_basic", + "ElementwiseAcosIntModule_basic", + "ElementwiseAcosModule_basic", + "ElementwiseAcoshIntModule_basic", + "ElementwiseAcoshModule_basic", + "ElementwiseAddScalar_NumToTensorFloat_Module_basic", + "ElementwiseAndScalarModule_basic", + "ElementwiseAndScalarStaticShapeModule_basic", + "ElementwiseAsinIntModule_basic", + "ElementwiseAsinModule_basic", + "ElementwiseAsinhIntModule_basic", + "ElementwiseAsinhModule_basic", + "ElementwiseAtan2FloatIntModule_basic", + "ElementwiseAtan2FloatIntStaticModule_basic", + "ElementwiseAtan2TensorFloatModule_basic", + "ElementwiseAtan2TensorFloatStaticModule_basic", + "ElementwiseAtan2TensorIntModule_basic", + "ElementwiseAtan2TensorIntStaticModule_basic", + "ElementwiseAtanTensorFloatModule_basic", + "ElementwiseAtanTensorIntModule_basic", + "ElementwiseAtanhIntModule_basic", + "ElementwiseAtanhModule_basic", + "ElementwiseAtenFloorDivideBroadcastModule_basic", + "ElementwiseAtenFloorDivideScalarModule_basic", + "ElementwiseAtenFloorDivideScalarNegativeModule_basic", + "ElementwiseAtenFloorDivideTensorNegativeModule_basic", + "ElementwiseAtenFloorDivideTensorPositiveModule_basic", + "ElementwiseAtenLogicalAndOpModule_basic", + "ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic", + "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", + "ElementwiseAtenLogicalNotOpModule_basic", + "ElementwiseAtenLogicalNotOpPromoteModule_basic", + "ElementwiseAtenLogicalXorOpModule_basic", + "ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic", + "ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic", + "ElementwiseBitwiseAndScalarInt32Module_basic", + "ElementwiseBitwiseAndScalarInt64Module_basic", + "ElementwiseBitwiseAndScalarInt8Module_basic", + "ElementwiseBitwiseLeftShiftInt32Module_basic", + "ElementwiseBitwiseLeftShiftInt64Module_basic", + "ElementwiseBitwiseLeftShiftInt8Module_basic", + "ElementwiseBitwiseRightShiftInt32Module_basic", + "ElementwiseBitwiseRightShiftInt64Module_basic", + "ElementwiseBitwiseRightShiftInt8Module_basic", + "ElementwiseClampMinTensorFloatModule_basic", + "ElementwiseClampMinTensorIntModule_basic", + "ElementwiseClampTensorFloatModule_basic", + "ElementwiseClampTensorIntModule_basic", + "ElementwiseCosIntModule_basic", + "ElementwiseCosModule_basic", + "ElementwiseCoshIntModule_basic", + "ElementwiseCoshModule_basic", + "ElementwiseDequantizePerChannelModule_basic", + "ElementwiseDequantizePerTensorModule_basic", + "ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic", + "ElementwiseDivScalarRoundingModeFloorModule_basic", + "ElementwiseDivScalarRoundingModeFloorStaticModule_basic", + "ElementwiseDivScalarRoundingModeTruncIntStaticModule_basic", + "ElementwiseDivScalarRoundingModeTruncModule_basic", + "ElementwiseDivScalarRoundingModeTruncStaticModule_basic", + "ElementwiseDivTensorFloatModule_basic", + "ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic", + "ElementwiseDivTensorRoundingModeFloorModule_basic", + "ElementwiseDivTensorRoundingModeFloorStaticModule_basic", + "ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic", + "ElementwiseDivTensorRoundingModeTruncModule_basic", + "ElementwiseDivTensorRoundingModeTruncStaticModule_basic", + "ElementwiseErfIntModule_basic", + "ElementwiseErfModule_basic", + "ElementwiseExpIntModule_basic", + "ElementwiseExpm1IntModule_basic", + "ElementwiseExpm1Module_basic", + "ElementwiseFmodTensor_Float_basic", + "ElementwiseFmodTensor_Int_Float_basic", + "ElementwiseFmodTensor_Int_basic", + "ElementwiseGeFloatTensorModule_basic", + "ElementwiseGeIntTensorModule_basic", + "ElementwiseGeluApproximateTanhModule_basic", + "ElementwiseHardshrinkModule_basic", + "ElementwiseHardshrinkStaticModule_basic", + "ElementwiseIntTensorLtFloatScalarModule_basic", + "ElementwiseLeFloatIntScalarModule_basic", + "ElementwiseLeFloatScalarModule_basic", + "ElementwiseLeFloatTensorNanModule_basic", + "ElementwiseLeIntScalarModule_basic", + "ElementwiseLeMixedIntScalarModule_basic", + "ElementwiseLog10IntModule_basic", + "ElementwiseLog10Module_basic", + "ElementwiseLog1pModule_basic", + "ElementwiseLog2IntModule_basic", + "ElementwiseLogIntModule_basic", + "ElementwiseLogSigmoidModule_basic", + "ElementwiseLogitModule_basic", + "ElementwiseMishModule_basic", + "ElementwiseMulTensorComplexDiffModule_basic", + "ElementwiseMulTensorComplexModule_basic", + "ElementwiseMulTensorFloatModule_basic", + "ElementwisePowScalarModule_basic", + "ElementwisePowTensorBroadcastModule_basic", + "ElementwisePowTensorBroadcastStaticModule_basic", + "ElementwisePowTensorModule_basic", + "ElementwisePowTensorStaticModule_basic", + "ElementwiseQuantizePerTensorModule_basic", + "ElementwiseQuantizePerTensorUIntModule_basic", + "ElementwiseReciprocalIntModule_basic", + "ElementwiseRemainderScalarModule_Bool_basic", + "ElementwiseRemainderTensorModule_Float_basic", + "ElementwiseRemainderTensorModule_Int_Float_basic", + "ElementwiseRemainderTensorModule_Int_basic", + "ElementwiseRsqrtIntModule_basic", + "ElementwiseSigmoidIntModule_basic", + "ElementwiseSinIntModule_basic", + "ElementwiseSinModule_basic", + "ElementwiseSinhIntModule_basic", + "ElementwiseSinhModule_basic", + "ElementwiseTanIntModule_basic", + "ElementwiseTanModule_basic", + "ElementwiseTernaryModule_basic", + "ElementwiseToDtypeF32ToI64Module_basic", + "ElementwiseToDtypeI64ToUI8Module_basic", + "ElementwiseUnaryIntModule_basic", + "ElementwiseWhereScalarOtherModule_basic", + "ElementwiseWhereScalarOtherStaticModule_basic", + "ElementwiseWhereScalarSelfModule_basic", + "ElementwiseWhereScalarSelfStaticModule_basic", + "EmptyLikeMemoryFormatModule_basic", + "EmptyLikeModule_defaultDtype", + "EmptyLikeModule_falsePinMemory", + "EmptyLikeModule_float", + "EmptyLikeModule_int", + "EmptyModule_contiguous", + "EmptyModule_defaultDtype", + "EmptyModule_falsePinMemory", + "EmptyModule_float", + "EmptyModule_int", + "EmptyModule_uint8", + "EmptyStridedModule_basic", + "EmptyStridedSizeIntStrideModule_basic", + "EqIntModule_basic", + "ExpandModule_basic", + "ExponentialModule_basic", + "FakeQuantizePerTensorAffineDynamicShapeModule_basic", + "FakeQuantizePerTensorAffineModule_basic", + "FakeQuantizePerTensorAffineRoundToEvenModule_basic", + "Fill_TensorFloat32WithFloat32_basic", + "Fill_TensorFloat32WithFloat64_basic", + "Fill_TensorFloat32WithInt64_basic", + "Fill_TensorFloat64WithFloat32Static_basic", + "Fill_TensorFloat64WithFloat32_basic", + "Fill_TensorFloat64WithFloat64_basic", + "Fill_TensorFloat64WithInt64Static_basic", + "Fill_TensorFloat64WithInt64_basic", + "FlipModuleStaticShape_basic", + "FlipModule_basic", + "FlipNegativeIndexModule_basic", + "FloatImplicitModule_basic", + "FullLikeModuleInt2D_basic", + "FullLikeModuleInt3D_basic", + "FullModuleDefaultDtype_basic", + "FullModuleFalsePinMemory_basic", + "FullModuleFloat2D_basic", + "FullModuleFloat3D_basic", + "FullModuleInt2D_basic", + "FullModuleInt3D_basic", + "GeFloatIntModule_basic", + "GeFloatModule_basic", + "GeIntModule_basic", + "GridSamplerBasic1_basic", + "GridSamplerBasic2_basic", + "GridSamplerBasic3_basic", + "GridSamplerBasic4_basic", + "GroupNormModule_basic", + "GroupNormNoWeightAndBiasModule_basic", + "GtFloatIntModule_basic", + "GtIntModule_basic", + "HBC_basic", + "IndexPut1DFloatAccumulateModule_basic", + "IndexPut1DFloatNonAccumulateModule_basic", + "IndexPut1DIntAccumulateModule_basic", + "IndexPut1DIntNonAccumulateModule_basic", + "IndexPut2DFloatAccumulateModule_basic", + "IndexPut2DFloatNonAccumulateModule_basic", + "IndexPut2DIntAccumulateModule_basic", + "IndexPut2DIntNonAccumulateModule_basic", + "IndexPut3DFloatAccumulateModule_basic", + "IndexPut3DFloatNonAccumulateModule_basic", + "IndexPut3DIntAccumulateModule_basic", + "IndexPut3DIntNonAccumulateModule_basic", + "IndexPutHackedTwin1DFloatAccumulateModule_basic", + "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin1DIntAccumulateModule_basic", + "IndexPutHackedTwin1DIntNonAccumulateModule_basic", + "IndexPutHackedTwin2DFloatAccumulateModule_basic", + "IndexPutHackedTwin2DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin2DIntAccumulateModule_basic", + "IndexPutHackedTwin2DIntNonAccumulateModule_basic", + "IndexPutHackedTwin3DFloatAccumulateModule_basic", + "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin3DIntAccumulateModule_basic", + "IndexPutHackedTwin3DIntNonAccumulateModule_basic", + "IndexPutImpl1DFloatAccumulateModule_basic", + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", + "IndexPutImpl2DFloatAccumulateModule_basic", + "IndexPutImpl2DFloatNonAccumulateModule_basic", + "IndexPutImpl2DImplicitModule_basic", + "IndexPutImpl2DIndexModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", + "IndexPutImpl3DFloatAccumulateModule_basic", + "IndexPutImpl3DFloatNonAccumulateModule_basic", + "IndexPutImplIndexWithNoneModule_basic", + "IndexSelectDynamicIndexSizeModule_basic", + "IndexSelectDynamicInputSizeModule_basic", + "IndexSelectDynamicModulebasic", + "IndexSelectNegativeDimModule_basic", + "IndexSelectRank0IdxModule_basic", + "IndexSelectSingleIdxModule_basic", + "IndexSelectTwoIdxModule_basic", + "IndexSelectWholeDimensionModule_basic", + "IndexSelectWholeTensorModule_basic", + "IndexTensorDyanmicInputContiguousWithNoneModule_basic", + "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", + "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", + "IndexTensorMultiInputContiguousCenter_basic", + "IndexTensorMultiInputContiguousOneDimDynamic_basic", + "IndexTensorMultiInputNonContiguousDynamic_basic", + "IndexTensorMultiInputNonContiguousMultipleStaticDims_basic", + "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", + "IndexTensorMultiInputNonContiguous_basic", + "IndexTensorMultiInputOneDim_basic", + "IndexTensorMultiInputThreeIndexers_basic", + "IndexTensorMultiInput_basic", + "IndexTensorNegativeIndexModule_basic", + "IndexTensorSelectDimModule_basic", + "IndexTensorStaticContiguousWithNoneModule_basic", + "IndexTensorStaticNonContiguousWithNoneModule_basic", + "IntFloatModule_basic", + "IntImplicitModule_basic", + "IsFloatingPointFloat_True", + "IsFloatingPointInt_False", + "LayerNormLastDimModule_basic", + "LayerNormModule_basic", + "LayerNormNormalizeOverAllDimsModule_basic", + "LenStrModule_basic", + "LinalgNormKeepDimComplexModule_basic", + "LinalgVectorNormComplexModule_basic", + "LinspaceDtypeModule_basic", + "LinspaceEmptyModule_basic", + "LinspaceModule_basic", + "LinspaceOneSizeModule_basic", + "LinspaceTwoSizeModule_basic", + "LogSoftmaxIntModule_basic", + "MaskedFillTensorFloatValueModule_basic", + "MatmulBroadcastBatchDim_basic", + "MatmulStaticBroadcast_basic", + "MaxPool1dCeilModeTrueModule_basic", + "MaxPool1dModule_basic", + "MaxPool2dCeilModeTrueModule_basic", + "MaxPool2dModule_basic", + "MaxPool2dWithIndicesAllNegativeValuesModule_basic", + "MaxPool2dWithIndicesAllOnesModule_basic", + "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", + "MaxPool2dWithIndicesBackwardDynamic4DModule_basic", + "MaxPool2dWithIndicesBackwardStatic3DModule_basic", + "MaxPool2dWithIndicesBackwardStatic4DModule_basic", + "MaxPool2dWithIndicesCeilModeTrueModule_basic", + "MaxPool2dWithIndicesFullSizeKernelModule_basic", + "MaxPool2dWithIndicesModule_basic", + "MaxPool2dWithIndicesNonDefaultDilationModule_basic", + "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", + "MaxPool2dWithIndicesNonDefaultParamsModule_basic", + "MaxPool2dWithIndicesNonDefaultStrideModule_basic", + "MaxPool2dWithIndicesStaticModule_basic", + "MaxPool3dCeilModeTrueModule_basic", + "MaxPool3dEmptyStrideStaticModule_basic", + "MaxPool3dLargeDatadModule_basic", + "MaxPool3dModuleRandomSimple_basic", + "MaxPool3dModule_basic", + "MaxPool3dStaticCeilModeTrueModule_basic", + "MaxPool3dStaticModule_basic", + "MeanDimDtypeModule_basic", + "MeanDimEmptyDimModule_basic", + "MeanDimNoneDimModule_basic", + "MeanDtypeModule_basic", + "MseLossMeanReductionModule_basic", + "MseLossSumReductionWithDifferentElemTypeModule_basic", + "MulFloatModule_basic", + "MulIntModule_basic", + "NativeBatchNorm1DModule_basic", + "NativeBatchNorm2DModule_basic", + "NativeBatchNorm3DModule_basic", + "NativeBatchNormNoneWeightModule_basic", + "NativeDropoutTrainModule_basic", + "NativeDropoutTrainStaticShapeModule_basic", + "NativeGroupNormBackwardModule_basic", + "NativeGroupNormModule_basic", + "NativeLayerNormDynamicModule_basic", + "NativeLayerNormModule4D_basic", + "NativeLayerNormModule_basic", + "NeFloatIntModule_basic", + "NeIntModule_basic", + "NewEmptyModuleDefaultDtype_basic", + "NewEmptyModuleFalsePinMemory_basic", + "NewEmptyModuleFloat2D_basic", + "NewEmptyModuleFloat3D_basic", + "NewEmptyModuleInt2D_basic", + "NewEmptyModuleInt3D_basic", + "NewEmptyModuleLayoutIntDtype_basic", + "NewEmptyModuleNonDefaultFloatDtype_basic", + "NewEmptyModuleNonDefaultIntDtype_basic", + "NewEmptyStridedModuleDefaultDtype_basic", + "NewFullModuleInt2D_basic", + "NewFullModuleInt3D_basic", + "NllLossModuleBackward1DMeanWeight_basic", + "NllLossModuleBackward1DMean_basic", + "NllLossModuleBackward1DSumWeight_basic", + "NllLossModuleBackward1DSum_basic", + "NllLossModuleBackward1DWeight_basic", + "NllLossModuleBackward1D_basic", + "NllLossModuleBackwardMeanWeight_basic", + "NllLossModuleBackwardMean_basic", + "NllLossModuleBackwardSumWeight_basic", + "NllLossModuleBackwardSum_basic", + "NllLossModuleBackwardWeight_basic", + "NllLossModuleBackward_basic", + "NllLossModuleBackward_ignore_index", + "NllLossModule_1D_basic", + "NllLossModule_mean_basic", + "NllLossModule_sum_basic", + "NormScalarComplexModule_basic", + "NormScalarModule_basic", + "NormScalarOptDimKeepDimComplexModule_basic", + "NormalFunctionalModule_basic", + "NumToTensorFloatModule_basic", + "NumToTensorIntModule_basic", + "NumelModule_basic", + "NumelZeroRankModule_basic", + "OnesLikeModule_falsePinMemory", + "PixelShuffleModuleFullDynamic_basic", + "PixelShuffleModuleSpatiallyDynamic_basic", + "PixelShuffleModuleSpatiallyStatic_basic", + "PixelShuffleModuleStaticRank3Int64_basic", + "PixelShuffleModuleStaticRank4Float32_basic", + "PowIntFloatModule_basic", + "PrimMaxIntModule_basic", + "PrimMinIntDynamicModule_basic", + "PrimMinIntModule_basic", + "PrimsConvertElementTypeModule_basic", + "PrimsSqueezeEmptyDimensionsModule_basic", + "PrimsSqueezeModule_basic", + "PrimsViewOfModule_basic", + "PrimsViewOfZeroRankModule_basic", + "QuantizedBatchedInputSingleLayer_basic", + "QuantizedMLP_basic", + "QuantizedNoLayer_basic", + "QuantizedReluInt32_basic", + "QuantizedReluInt8_basic", + "QuantizedReluUint8_basic", + "QuantizedSingleLayer_basic", + "RandIntDtypeModule_basic", + "RandIntLowDtypeModule_basic", + "RandIntLowModule_basic", + "RandIntModule_basic", + "RandIntPinMemoryModule_basic", + "RandLikeDtypeModule_basic", + "RandLikeModule_basic", + "RandModule_basic", + "RandnDtypeDeviceModule_basic", + "RandnGeneratorF64Module_basic", + "RandnGeneratorModule_basic", + "RandnLikeDtypeModule_basic", + "RandnLikeModule_basic", + "RandnModule_basic", + "ReduceAllDimBool_basic", + "ReduceAllDimEmpty_basic", + "ReduceAllDimFloat_basic", + "ReduceAllDimInt_basic", + "ReduceAllFloatModule_basic", + "ReduceAllIntModule_basic", + "ReduceAnyFloatModule_basic", + "ReduceAnyIntModule_basic", + "ReduceFrobeniusNormComplexModule_basic", + "ReduceL1NormComplexModule_basic", + "ReduceL1NormWithDTypeModule_basic", + "ReduceL2NormComplexModule_basic", + "ReduceL3NormAllDimsModule_basic", + "ReduceL3NormKeepDimComplexModule_basic", + "ReduceL3NormKeepDimModule_basic", + "ReduceMaxAllDims_basic", + "ReduceMaxAlongDimNegative_basic", + "ReduceMaxAlongDimUnsignedInt_basic", + "ReduceMaxAlongDim_basic", + "ReduceMaxFloatModule_basic", + "ReduceMaxKeepDim_basic", + "ReduceMaxSignedIntModule_basic", + "ReduceMaxUnsignedIntModule_basic", + "ReduceMinAlongDimNegative_basic", + "ReduceMinAlongDimSignedInt_basic", + "ReduceMinAlongDimUnsignedInt_basic", + "ReduceMinAlongDim_basic", + "ReduceMinFloatModule_basic", + "ReduceMinKeepDimReturnBoth_basic", + "ReduceMinKeepDim_basic", + "ReduceMinSignedIntModule_basic", + "ReduceMinUnsignedIntModule_basic", + "ReduceProdDimIntFloatModule_basic", + "ReduceProdDtypeFloatModule_basic", + "ReduceProdDtypeIntModule_basic", + "ReduceProdElementTypeBoolModule_basic", + "ReduceProdFloatModule_basic", + "ReduceProdSignedIntModule_basic", + "ReduceProdUnsignedIntModule_basic", + "ReduceSumDimIntListDtypeFloatModule_basic", + "ReduceSumDimIntListDtypeIntModule_basic", + "ReduceSumDimIntListElementTypeBoolModule_basic", + "ReduceSumDimIntListEmptyDimModule_basic", + "ReduceSumDtypeFloatModule_basic", + "ReduceSumDtypeIntModule_basic", + "ReduceSumElementTypeBoolModule_basic", + "ReflectionPad1dModule2dInput_Right", + "ReflectionPad1dModule2dInput_basic", + "ReflectionPad1dModule3dInput_Left", + "ReflectionPad1dModule3dInput_basic", + "ReflectionPad2dModule_Bottom", + "ReflectionPad2dModule_Left", + "ReflectionPad2dModule_Right", + "ReflectionPad2dModule_Top", + "ReflectionPad2dModule_basic", + "ReplicationPad2dModule_basic", + "ReplicationPad2dModule_bottom0", + "ReplicationPad2dModule_left0", + "ReplicationPad2dModule_right0", + "ReplicationPad2dModule_top0", + "RollModule_basic", + "RsubInt0d_NumToTensor_Module_basic", + "RsubIntModule_basic", + "RsubIntModule_noalpha_basic", + "ScalarConstantTupleModule_basic", + "ScalarImplicitFloatModule_basic", + "ScalarImplicitIntModule_basic", + "ScaledDotProductAttentionDifferentModule_basic", + "ScatterReduceFloatMaxModule", + "ScatterReduceFloatMaxModuleIncludeSelf", + "ScatterReduceFloatMeanModule", + "ScatterReduceFloatMeanModuleIncludeSelf", + "ScatterReduceFloatMinModule", + "ScatterReduceFloatMinModuleIncludeSelf", + "ScatterReduceFloatProdModule", + "ScatterReduceFloatProdModuleIncludeSelf", + "ScatterReduceFloatSumModule", + "ScatterReduceFloatSumModuleIncludeSelf", + "ScatterReduceIntMaxModule", + "ScatterReduceIntMaxModuleIncludeSelf", + "ScatterReduceIntMeanModule", + "ScatterReduceIntMeanModuleIncludeSelf", + "ScatterReduceIntMinModule", + "ScatterReduceIntMinModuleIncludeSelf", + "ScatterReduceIntProdModule", + "ScatterReduceIntProdModuleIncludeSelf", + "ScatterReduceIntSumModule", + "ScatterReduceIntSumModuleIncludeSelf", + "ScatterSrcModule_basic", + "ScatterSrcStaticModule_basic", + "ScatterValueFloatModule_basic", + "ScatterValueIntModule_basic", + "SelectScattertModule_basic", + "SelectScattertStaticModule_basic", + "SliceCopyEndGreaterThanDimSize_Module_basic", + "SliceCopyNegative_Module_basic", + "SliceCopyNonZeroDim_Module_basic", + "SliceCopyStartGreaterThanDimSize_Module_basic", + "SliceCopy_Module_basic", + "SliceEndSleStartModule_basic", + "SliceOutOfLowerBoundEndIndexModule_basic", + "SliceOutOfLowerBoundStartIndexModule_basic", + "SliceScatterModule_basic", + "SliceScatterNegativeDimModule_basic", + "SliceScatterNegativeEndModule_basic", + "SliceScatterStaticModule_basic", + "SliceScatterStepVariationModule_basic", + "SliceScatterZeroDimModule_basic", + "SliceSizeTwoStepModule_basic", + "SoftmaxIntArgTypeF64Module_basic", + "SoftmaxIntNonNoneDtypeModule_basic", + "SoftplusModule_basic", + "SortIntListReverse_basic", + "SortIntList_basic", + "SortTensorDescending_basic", + "SortTensorInteger_basic", + "SortTensorNegativeDimension_basic", + "SortTensorSpecificDimension_basic", + "SortTensor_basic", + "SplitDimDynamicModule_basic", + "SplitDimStaticModule_basic", + "SqrtIntConstantModule_basic", + "SqrtIntModule_basic", + "StdBiasedModule_basic", + "StdCorrectionAllDimReduceModule_basic", + "StdCorrectionEmptyDimModule_basic", + "StdCorrectionKeepDimModule_basic", + "StdCorrectionLargeInputModule_basic", + "StdCorrectionModule_basic", + "StdCorrectionNoneModule_basic", + "StdCorrectionSingleDimReduceModule_basic", + "StdDimBiasedModule_basic", + "StdDimEmptyDimModule_basic", + "StdDimKeepDimFalseModule_basic", + "StdDimKeepDimTrueModule_basic", + "StdDimNoneDimModule_basic", + "StdUnbiasedModule_basic", + "SubFloatModule_basic", + "SubIntModule_basic", + "TModuleRank0_basic", + "TensorToBoolZeroRank_basic", + "TensorToBool_basic", + "TensorToFloatZeroRank_basic", + "TensorToFloat_basic", + "TensorToIntZeroRank_basic", + "TensorToInt_basic", + "TensorsConcatPromoteDTypeModule_basic", + "TensorsStackPromoteDTypeModule_basic", + "TestMultipleTensorAndPrimitiveTypesReturn_basic", + "Threshold1dIntModule_basic", + "Threshold2dIntModule_basic", + "Threshold3dIntModule_basic", + "ThresholdBackward1dFloatModule_basic", + "ThresholdBackward1dIntModule_basic", + "ThresholdBackward1dMixedModule_basic", + "ThresholdBackward2dFloatModule_basic", + "ThresholdBackward2dIntModule_basic", + "ThresholdBackward2dMixedModule_basic", + "ThresholdBackward3dFloatModule_basic", + "ThresholdBackward3dIntModule_basic", + "ThresholdBackward3dMixedModule_basic", + "ToCopyWithDTypeFalsePinMemoryModule_basic", + "ToCopyWithDTypeModule_basic", + "TorchPrimLoopForLikeModule_basic", + "TorchPrimLoopWhileLikeModule_basic", + "TraceModule_basic", + "TraceModule_empty", + "TraceModule_nonsquare", + "TraceSignedIntModule_basic", + "TraceUnsignedIntModule_basic", + "TraceUnsignedIntModule_empty", + "TypeConversionI1ToF64Module_basic", + "TypeConversionI1ToI32Module_basic", + "UnbindIntGetItem_Module_basic", + "UnbindIntListUnpack_Module_basic", + "UniformModule_basic", + "UniformNoCorrelationModule_basic", + "UniformStaticShapeModule_basic", + "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", + "UpSampleNearest2dBackwardScalesNone_basic", + "UpSampleNearest2dBackward_basic", + "UpSampleNearest2dDynamicFactor_basic", + "UpSampleNearest2dDynamicSize_basic", + "UpSampleNearest2dStaticFactor_basic", + "UpSampleNearest2dStaticSize_basic", + "UpSampleNearest2d_basic", + "VarBiasedModule_basic", + "VarCorrectionAllDimReduceModule_basic", + "VarCorrectionEmptyDimModule_basic", + "VarCorrectionKeepDimModule_basic", + "VarCorrectionLargeInputModule_basic", + "VarCorrectionModule_basic", + "VarCorrectionNoneModule_basic", + "VarCorrectionSingleDimReduceModule_basic", + "VarDimAllDimReduceModule_basic", + "VarDimBiasedModule_basic", + "VarDimEmptyDimModule_basic", + "VarDimModule_basic", + "VarDimMultiDimModule_basic", + "VarDimNegativeModule_basic", + "VarDimNoneDimModule_basic", + "VarDimSingleDimModule_basic", + "VarDimUnbiasedModule_basic", + "VarMeanBiasedModule_basic", + "VarMeanCorrectionModule_basic", + "VarMeanCorrectionNoneModule_basic", + "VarMeanDimBiasedModule_basic", + "VarMeanDimModule_basic", + "VarMeanUnbiasedModule_basic", + "VarUnbiasedModule_basic", + "ViewCollapseDynamicWithAtenSizeIntModule_basic", + "ViewSizeFromOtherTensor_basic", + "ZeroFloat32Module_basic", + "ZeroInt32Module_basic", + "ZeroInt64Module_basic", + "ZerosLikeModule_falsePinMemory", +} + +ONNX_TOSA_XFAIL_SET = { + "AdaptiveAvgPool1dGeneralDynamicNoBatches_basic", + "AdaptiveAvgPool1dGeneralDynamic_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dStaticEvenMultiple_basic", + "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dDynamicNoBatch_basic", + "AdaptiveAvgPool2dDynamic_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool3dDynamicNoBatch_basic", + "AdaptiveAvgPool3dDynamic_basic", + "AdaptiveMaxPool1dDynamicNoBatch_basic", + "AdaptiveMaxPool1dDynamic_basic", + "AdaptiveMaxPool1dStatic_basic", + "AdaptiveMaxPool2dDynamicNoBatch_basic", + "AdaptiveMaxPool2dDynamicWithIndices_basic", + "AdaptiveMaxPool2dDynamic_basic", + "AdaptiveMaxPool2dStaticWithIndices_basic", + "AdaptiveMaxPool2dStatic_basic", + "AdaptiveMaxPool3dDynamicNoBatch_basic", + "AdaptiveMaxPool3dDynamicWithIndices_basic", + "AdaptiveMaxPool3dDynamic_basic", + "AdaptiveMaxPool3dStaticWithIndices_basic", + "AdaptiveMaxPool3dStatic_basic", + "AddCDivModule_basic", + "AddIntModule_basic", + "AddSizeIntModule_basic", + "AddSizeIntNegDimModule_basic", + "Add_MixPModule_basic", + "Add_Module_basic", + "AddmmModuleFloat_basic", + "AddmmModule_broadcastable", + "AddmmModule_differentRankBroadcastable", + "AllBoolFalseModule_basic", + "AllBoolTrueModule_basic", + "AnyBoolFalseModule_basic", + "AnyBoolTrueModule_basic", + "ArangeStartOutDtypeModule_basic", + "ArangeStartOutViewModule_basic", + "ArgmaxIntModule_basic", + "ArgmaxIntModule_multiple_maxs", + "ArgmaxModule_basic", + "ArgmaxModule_with_dim", + "ArgminIntModule_basic", + "ArgminIntModule_multiple_mins", + "ArgminModule_basic", + "ArgminModule_keepDim", + "ArgminModule_with_dim", + "AtenComplex64Module_basic", + "AtenComplexImagModule_basic", + "AtenComplexRealModule_basic", + "AtenComplexViewModule_basic", + "AtenDiagEmbedDefaultDiag_basic", + "AtenDiagEmbedDimDiag_basic", + "AtenDiagEmbedNegOffsetDiag_basic", + "AtenDiagEmbedNonDefault4DDiag_basic", + "AtenDiagEmbedOffsetDiag_basic", + "AtenDiagEmbedRevDimDiag_basic", + "AtenEmbeddingBagStaticModule_basic", + "AtenEmbeddingBagSumExample_basic", + "AtenFloatScalarModule_basic", + "AtenIntBoolOpConstFalseModule_basic", + "AtenIntBoolOpConstTrueModule_basic", + "AtenIntBoolOpModule_basic", + "AtenIntTensorByteDtypeModule_basic", + "AtenIntTensorCharDtypeModule_basic", + "AtenItemFpOpModule_basic", + "AtenItemIntOpModule_basic", + "AtenLinalgCrossDynamic_basic", + "AtenMatmulQMixedSigni8Transpose_basic", + "AtenMatmulQMixedSigni8_basic", + "AtenMatmulQint8MV_basic", + "AtenMatmulQint8VM_basic", + "AtenMatmulQint8VV_basic", + "AtenMatmulQint8_basic", + "AtenMmFloatTypes_basic", + "AtenMmIntTypes_basic", + "AtenMmQMixedSigni8_basic", + "AtenMmQint8_basic", + "AtenMmQuint8_basic", + "AtenRealView128Module_basic", + "AtenRealView64Module_basic", + "AtenRoundFloatHalfToEvenModule_basic", + "AtenRoundFloatModule_basic", + "AtenSubFloatModule_basic", + "AtenTopKModule_basic", + "AtenTopKSmallestModule_basic", + "AtenTrilModule_basic", + "AtenTrilWithNegDiagonalModule_basic", + "AtenTrilWithPosDiagonalModule_basic", + "AtenTriuModule_basic", + "AtenTriuWithNegDiagonalModule_basic", + "AtenTriuWithPosDiagonalModule_basic", + "Aten_CastLongModule_basic", + "Aten_EmbeddingBagExample_basic", + "AvgPool1dFloatModule_basic", + "AvgPool1dIntModule_basic", + "AvgPool1dStaticModule_basic", + "AvgPool2dCeilModeTrueModule_basic", + "AvgPool2dDivisorOverrideModule_basic", + "AvgPool2dFloatModule_basic", + "AvgPool2dIntModule_basic", + "AvgPool2dStaticModule_basic", + "AvgPool2dWithoutPadModule_basic", + "BatchMlpLayerModule_basic", + "BernoulliFloatModule_basic", + "BernoulliModule_basic", + "BernoulliOnesModule_basic", + "BernoulliPModule_basic", + "BernoulliTensorModule_basic", + "BernoulliZerosModule_basic", + "BincountMinlengthModule_basic", + "BincountModule_basic", + "BincountStaticSizeModule_basic", + "BmmIntModule_basic", + "BoolFloatConstantModule_basic", + "BoolFloatFalseModule_basic", + "BoolFloatTrueModule_basic", + "BoolIntConstantModule_basic", + "BoolIntFalseModule_basic", + "BoolIntTrueModule_basic", + "BoolTensorHandleSignless_basic", + "BroadcastDynamicDimModule_basic", + "BroadcastToModule_basic", + "BucketizeTensorFloatModule_basic", + "BucketizeTensorModule_basic", + "BucketizeTensorOutInt32RightModule_basic", + "BucketizeTensorStaticFloatModule_basic", + "BucketizeTensorStaticModule_basic", + "CeilFloatModule_basic", + "ChunkListUnpackDynamic_Module_basic", + "ChunkListUnpackUnevenDynamic_Module_basic", + "ChunkListUnpackUneven_Module_basic", + "ChunkListUnpack_Module_basic", + "CollapseAllDimensionsModule_basic", + "CollapseFullDynamicModule_basic", + "CollapsePartialDynamicModule_basic", + "CollapseRank1DynamicModule_basic", + "CollapseStaticModule_basic", + "ConstantBoolParameterModule_basic", + "ConstantPad2dStaticModule_basic", + "ConstantPadNdModule_basic", + "ConstantPadNdPartialStaticModule_basic", + "ConstantPadNdStaticModule_basic", + "ContainsIntList_False", + "ContainsIntList_True", + "Conv1dModule_basic", + "Conv2dBiasNoPaddingModule_basic", + "Conv2dModule_basic", + "Conv2dNoPaddingModule_basic", + "Conv2dQInt8Module_basic", + "Conv2dWithPaddingDilationStrideModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_grouped", + "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", + "Conv2dWithPaddingModule_basic", + "Conv3dModule_basic", + "ConvTbcModule_basic", + "ConvTranspose2DQInt8_basic", + "Conv_Transpose2dModule_basic", + "Convolution2DModule_basic", + "Convolution2DStridedModule_basic", + "ConvolutionBackwardModule2DPadded_basic", + "ConvolutionBackwardModule2DStatic_basic", + "ConvolutionBackwardModule2DStrided_basic", + "ConvolutionBackwardModule2D_basic", + "ConvolutionModule2DGroups_basic", + "ConvolutionModule2DTransposeNonUnitOutputPadding_basic", + "ConvolutionModule2DTransposeStridedStatic_basic", + "ConvolutionModule2DTransposeStrided_basic", + "ConvolutionModule2DTranspose_basic", + "CopyModule_basic", + "CopyWithDifferentDTypesAndSizesModule_basic", + "CopyWithDifferentDTypesModule_basic", + "CopyWithDifferentSizesModule_basic", + "CosineSimilarityStaticBroadcastModule_basic", + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", + "CumsumInputDtypeInt32Module_basic", + "CumsumModule_basic", + "CumsumStaticModule_basic", + "CumsumStaticNegativeDimModule_basic", + "DiagonalModule_basic", + "DiagonalModule_nonsquare", + "DiagonalModule_transposed", + "DiagonalModule_with_dims", + "DiagonalModule_with_dims_and_offset", + "DiagonalModule_with_negative_dims", + "DiagonalModule_with_offset", + "DiagonalWithStaticShapeModule_basic", + "DivFloatModule_basic", + "DivIntModule_basic", + "DropoutTrainModule_basic", + "DropoutTrainStaticShapeModule_basic", + "ElementwiseAcosIntModule_basic", + "ElementwiseAcosModule_basic", + "ElementwiseAcoshIntModule_basic", + "ElementwiseAcoshModule_basic", + "ElementwiseAddScalarInt64Module_basic", + "ElementwiseAddScalarIntModule_basic", + "ElementwiseAndScalarModule_basic", + "ElementwiseAndScalarStaticShapeModule_basic", + "ElementwiseAsinIntModule_basic", + "ElementwiseAsinModule_basic", + "ElementwiseAsinhIntModule_basic", + "ElementwiseAsinhModule_basic", + "ElementwiseAtan2FloatIntModule_basic", + "ElementwiseAtan2FloatIntStaticModule_basic", + "ElementwiseAtan2TensorFloatModule_basic", + "ElementwiseAtan2TensorFloatStaticModule_basic", + "ElementwiseAtan2TensorIntModule_basic", + "ElementwiseAtan2TensorIntStaticModule_basic", + "ElementwiseAtanTensorFloatModule_basic", + "ElementwiseAtanTensorIntModule_basic", + "ElementwiseAtanhIntModule_basic", + "ElementwiseAtanhModule_basic", + "ElementwiseAtenDivIntScalarModule_basic", + "ElementwiseAtenFloorDivideBroadcastModule_basic", + "ElementwiseAtenFloorDivideScalarModule_basic", + "ElementwiseAtenFloorDivideScalarNegativeModule_basic", + "ElementwiseAtenFloorDivideTensorNegativeModule_basic", + "ElementwiseAtenFloorDivideTensorPositiveModule_basic", + "ElementwiseAtenIsinfOpModule_basic", + "ElementwiseAtenIsneginfOpModule_basic", + "ElementwiseAtenIsposinfOpModule_basic", + "ElementwiseAtenLogicalAndOpModule_basic", + "ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic", + "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", + "ElementwiseAtenLogicalNotOpPromoteModule_basic", + "ElementwiseAtenLogicalOrOpBrodcastModule_basic", + "ElementwiseAtenLogicalOrOpDiffArgs1Module_basic", + "ElementwiseAtenLogicalOrOpDiffArgs2Module_basic", + "ElementwiseAtenLogicalOrOpDiffArgs3Module_basic", + "ElementwiseAtenLogicalOrOpNegativeModule_basic", + "ElementwiseAtenLogicalOrOpRandomFloatModule_basic", + "ElementwiseAtenLogicalOrOpRandomModule_basic", + "ElementwiseAtenLogicalXorOpModule_basic", + "ElementwiseAtenLogicalXorOpPromoteBroadcastModule_basic", + "ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic", + "ElementwiseBitwiseAndModule_basic", + "ElementwiseBitwiseAndScalarInt32Module_basic", + "ElementwiseBitwiseAndScalarInt64Module_basic", + "ElementwiseBitwiseAndScalarInt8Module_basic", + "ElementwiseBitwiseAndStaticShapeModule_basic", + "ElementwiseBitwiseLeftShiftInt32Module_basic", + "ElementwiseBitwiseLeftShiftInt64Module_basic", + "ElementwiseBitwiseLeftShiftInt8Module_basic", + "ElementwiseBitwiseNotInt32Module_basic", + "ElementwiseBitwiseNotInt64Module_basic", + "ElementwiseBitwiseOrModule_basic", + "ElementwiseBitwiseOrStaticShapeModule_basic", + "ElementwiseBitwiseRightShiftInt32Module_basic", + "ElementwiseBitwiseRightShiftInt64Module_basic", + "ElementwiseBitwiseRightShiftInt8Module_basic", + "ElementwiseBitwiseXorModule_basic", + "ElementwiseBitwiseXorStaticShapeModule_basic", + "ElementwiseClampMaxModule_basic", + "ElementwiseClampMinModule_basic", + "ElementwiseClampMinTensorFloatModule_basic", + "ElementwiseClampMinTensorIntModule_basic", + "ElementwiseClampModule_basic", + "ElementwiseClampTensorFloatModule_basic", + "ElementwiseClampTensorInt8Module_basic", + "ElementwiseClampTensorIntModule_basic", + "ElementwiseCosIntModule_basic", + "ElementwiseCosModule_basic", + "ElementwiseCoshIntModule_basic", + "ElementwiseCoshModule_basic", + "ElementwiseDequantizePerChannelModule_basic", + "ElementwiseDequantizePerTensorModule_basic", + "ElementwiseDivScalarRoundingModeFloorIntStaticModule_basic", + "ElementwiseDivScalarRoundingModeTruncModule_basic", + "ElementwiseDivScalarRoundingModeTruncStaticModule_basic", + "ElementwiseDivTensorFloatModule_basic", + "ElementwiseDivTensorIntegerModule_basic", + "ElementwiseDivTensorRoundingModeFloorIntStaticModule_basic", + "ElementwiseDivTensorRoundingModeFloorModule_basic", + "ElementwiseDivTensorRoundingModeFloorStaticModule_basic", + "ElementwiseDivTensorRoundingModeTruncIntStaticModule_basic", + "ElementwiseDivTensorRoundingModeTruncModule_basic", + "ElementwiseDivTensorRoundingModeTruncStaticModule_basic", + "ElementwiseDivTensorUnsignedIntegerModule_basic", + "ElementwiseEluNonDefaultModule_basic", + "ElementwiseEqBoolScalarModule_basic", + "ElementwiseEqDiffWidthScalarModule_basic", + "ElementwiseErfIntModule_basic", + "ElementwiseErfModule_basic", + "ElementwiseExpIntModule_basic", + "ElementwiseExpm1IntModule_basic", + "ElementwiseExpm1Module_basic", + "ElementwiseFlattenBroadcastModule_basic", + "ElementwiseFmodTensor_Float_basic", + "ElementwiseFmodTensor_Int_Float_basic", + "ElementwiseFmodTensor_Int_basic", + "ElementwiseGeFloatIntScalarModule_basic", + "ElementwiseGeFloatScalarModule_basic", + "ElementwiseGeFloatTensorModule_basic", + "ElementwiseGeIntScalarModule_basic", + "ElementwiseGeIntTensorModule_basic", + "ElementwiseGeMixedIntScalarModule_basic", + "ElementwiseGeluModule_basic", + "ElementwiseGtMixed2ScalarModule_basic", + "ElementwiseIntTensorLtFloatScalarModule_basic", + "ElementwiseIsinfModule_basic", + "ElementwiseLeFloatTensorNanModule_basic", + "ElementwiseLeMixedIntScalarModule_basic", + "ElementwiseLog10IntModule_basic", + "ElementwiseLog2IntModule_basic", + "ElementwiseLogIntModule_basic", + "ElementwiseLtDiffWidthScalarModule_basic", + "ElementwiseMishModule_basic", + "ElementwiseMulScalarModule_basic", + "ElementwiseMulTensorComplexDiffModule_basic", + "ElementwiseMulTensorComplexModule_basic", + "ElementwiseMulTensorFloatModule_basic", + "ElementwiseMulTensorIntModule_basic", + "ElementwiseNanToNumModule_Basic", + "ElementwiseOrTensorModule_basic", + "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwisePowModule_basic", + "ElementwisePowScalarModule_basic", + "ElementwisePowTensorBroadcastModule_basic", + "ElementwisePowTensorBroadcastStaticModule_basic", + "ElementwisePowTensorModule_basic", + "ElementwisePowTensorStaticModule_basic", + "ElementwiseQuantizePerTensorModule_basic", + "ElementwiseQuantizePerTensorUIntModule_basic", + "ElementwiseReciprocalIntModule_basic", + "ElementwiseRelu6Module_basic", + "ElementwiseRemainderScalarModule_Bool_basic", + "ElementwiseRemainderScalarModule_Int_Float_basic", + "ElementwiseRemainderScalarModule_Int_basic", + "ElementwiseRemainderTensorModule_Int_Float_basic", + "ElementwiseRemainderTensorModule_Int_basic", + "ElementwiseRsqrtIntModule_basic", + "ElementwiseSgnModule_basic", + "ElementwiseSigmoidIntModule_basic", + "ElementwiseSinIntModule_basic", + "ElementwiseSinModule_basic", + "ElementwiseSinhIntModule_basic", + "ElementwiseSinhModule_basic", + "ElementwiseSqrtIntModule_basic", + "ElementwiseSubScalarIntModule_basic", + "ElementwiseTanIntModule_basic", + "ElementwiseTanModule_basic", + "ElementwiseTernaryModule_basic", + "ElementwiseToDtypeF32ToI64Module_basic", + "ElementwiseToDtypeI64ToI8Module_basic", + "ElementwiseToDtypeI64ToUI8Module_basic", + "ElementwiseTruncIntModule_basic", + "ElementwiseTruncModule_basic", + "ElementwiseUnaryIntModule_basic", + "ElementwiseUnsqueezeNegDimsModule_basic", + "ElementwiseWhereScalarOtherModule_basic", + "ElementwiseWhereScalarOtherStaticModule_basic", + "ElementwiseWhereScalarSelfModule_basic", + "ElementwiseWhereScalarSelfStaticModule_basic", + "ElementwiseWhereSelfModule_basic", + "EmbeddingModule1DIndices_basic", + "EmbeddingModuleF16_basic", + "EmbeddingModuleI32Static_basic", + "EmbeddingModuleI32_basic", + "EmbeddingModuleI64_basic", + "EmptyLikeMemoryFormatModule_basic", + "EmptyLikeModule_defaultDtype", + "EmptyLikeModule_falsePinMemory", + "EmptyLikeModule_float", + "EmptyLikeModule_int", + "EmptyStridedModule_basic", + "EmptyStridedSizeIntStrideModule_basic", + "EqIntModule_basic", + "ExpandAsFloatModule_basic", + "ExpandAsIntModule_basic", + "ExpandModule_basic", + "ExponentialModule_basic", + "FakeQuantizePerTensorAffineDynamicShapeModule_basic", + "FakeQuantizePerTensorAffineModule_basic", + "FakeQuantizePerTensorAffineRoundToEvenModule_basic", + "Fill_TensorFloat32WithFloat32_basic", + "Fill_TensorFloat32WithFloat64_basic", + "Fill_TensorFloat32WithInt64_basic", + "Fill_TensorFloat64WithFloat32_basic", + "Fill_TensorFloat64WithFloat64_basic", + "Fill_TensorFloat64WithInt64_basic", + "FlattenDynamicModuleCollapseAll_basic", + "FlattenDynamicModule_basic", + "FlattenRank0Module_basic", + "FlipModuleStaticShape_basic", + "FlipModule_basic", + "FlipNegativeIndexModule_basic", + "FloatImplicitModule_basic", + "FullLikeModuleDefaultDtype_basic", + "FullLikeModuleFalsePinMemory_basic", + "FullLikeModuleFloat2D_basic", + "FullLikeModuleFloat3D_basic", + "FullLikeModuleInt2D_basic", + "FullLikeModuleInt3D_basic", + "Gather2DInputModdule_basic", + "GatherModule_basic", + "GatherNegativeDimModule_basic", + "GatherRandomIndexModule_basic", + "GeFloatIntModule_basic", + "GeFloatModule_basic", + "GeIntModule_basic", + "GeluBackwardModule_basic", + "GridSamplerBasic1_basic", + "GridSamplerBasic2_basic", + "GridSamplerBasic3_basic", + "GridSamplerBasic4_basic", + "GtFloatIntModule_basic", + "GtIntModule_basic", + "HBC_basic", + "HardTanhIntModule_basic", + "HardTanhModule_basic", + "HardsigmoidModule_basic", + "HardsigmoidRandomModule_basic", + "HardtanhBackward_basic", + "IndexPut1DFloatAccumulateModule_basic", + "IndexPut1DFloatNonAccumulateModule_basic", + "IndexPut1DIntAccumulateModule_basic", + "IndexPut1DIntNonAccumulateModule_basic", + "IndexPut2DFloatAccumulateModule_basic", + "IndexPut2DFloatNonAccumulateModule_basic", + "IndexPut2DIntAccumulateModule_basic", + "IndexPut2DIntNonAccumulateModule_basic", + "IndexPut3DFloatAccumulateModule_basic", + "IndexPut3DFloatNonAccumulateModule_basic", + "IndexPut3DIntAccumulateModule_basic", + "IndexPut3DIntNonAccumulateModule_basic", + "IndexPutHackedTwin1DFloatAccumulateModule_basic", + "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin1DIntAccumulateModule_basic", + "IndexPutHackedTwin1DIntNonAccumulateModule_basic", + "IndexPutHackedTwin2DFloatAccumulateModule_basic", + "IndexPutHackedTwin2DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin2DIntAccumulateModule_basic", + "IndexPutHackedTwin2DIntNonAccumulateModule_basic", + "IndexPutHackedTwin3DFloatAccumulateModule_basic", + "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin3DIntAccumulateModule_basic", + "IndexPutHackedTwin3DIntNonAccumulateModule_basic", + "IndexPutImpl1DFloatAccumulateModule_basic", + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", + "IndexPutImpl2DFloatAccumulateModule_basic", + "IndexPutImpl2DFloatNonAccumulateModule_basic", + "IndexPutImpl2DImplicitModule_basic", + "IndexPutImpl2DIndexModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", + "IndexPutImpl3DFloatAccumulateModule_basic", + "IndexPutImpl3DFloatNonAccumulateModule_basic", + "IndexPutImplIndexWithNoneModule_basic", + "IndexSelectDynamicIndexSizeModule_basic", + "IndexSelectDynamicInputSizeModule_basic", + "IndexSelectDynamicModulebasic", + "IndexSelectNegativeDimModule_basic", + "IndexSelectRank0IdxModule_basic", + "IndexSelectSingleIdxModule_basic", + "IndexSelectTwoIdxModule_basic", + "IndexSelectWholeDimensionModule_basic", + "IndexSelectWholeTensorModule_basic", + "IndexTensorDyanmicInputContiguousWithNoneModule_basic", + "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", + "IndexTensorHackedTwinModule3dInput_basic", + "IndexTensorHackedTwinModule_basic", + "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", + "IndexTensorModule3dInput_basic", + "IndexTensorModule_basic", + "IndexTensorMultiIndexStaticModule_basic", + "IndexTensorMultiInputContiguousCenter_basic", + "IndexTensorMultiInputContiguousOneDimDynamic_basic", + "IndexTensorMultiInputNonContiguousDynamic_basic", + "IndexTensorMultiInputNonContiguousMultipleStaticDims_basic", + "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", + "IndexTensorMultiInputNonContiguous_basic", + "IndexTensorMultiInputOneDim_basic", + "IndexTensorMultiInputThreeIndexers_basic", + "IndexTensorMultiInput_basic", + "IndexTensorNegativeIndexModule_basic", + "IndexTensorSelectDimModule_basic", + "IndexTensorStaticContiguousWithNoneModule_basic", + "IndexTensorStaticModule_basic", + "IndexTensorStaticNonContiguousWithNoneModule_basic", + "IntFloatModule_basic", + "IntImplicitModule_basic", + "IouOfModule_basic", + "IsFloatingPointFloat_True", + "IsFloatingPointInt_False", + "IscloseStaticModuleTrue_basic", + "IscloseStaticModule_basic", + "LeakyReluBackwardModule_basic", + "LeakyReluBackwardStaticModule_basic", + "LenStrModule_basic", + "LiftFreshCopyModule_basic", + "LinalgNormKeepDimComplexModule_basic", + "LinalgNormModule_basic", + "LinalgVectorNormComplexModule_basic", + "LinalgVectorNormKeepDimModule_basic", + "LinalgVectorNormModule_basic", + "LogSoftmaxBackwardModule_basic", + "LogSoftmaxIntModule_basic", + "MaskedFillTensorFloatValueModule_basic", + "MatmulBroadcastBatchDim_basic", + "MatmulSingleDynamicBatchDim_basic", + "Matmul_2d", + "Matmul_4d", + "Matmul_matvec", + "Matmul_vecmat", + "MaxPool1dCeilModeTrueModule_basic", + "MaxPool1dEmptyStrideStaticModule_basic", + "MaxPool1dModule_basic", + "MaxPool1dStaticCeilModeTrueModule_basic", + "MaxPool1dStaticModule_basic", + "MaxPool2dCeilModeTrueModule_basic", + "MaxPool2dModule_basic", + "MaxPool2dWithIndicesAllNegativeValuesModule_basic", + "MaxPool2dWithIndicesAllOnesModule_basic", + "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", + "MaxPool2dWithIndicesBackwardDynamic4DModule_basic", + "MaxPool2dWithIndicesBackwardStatic3DModule_basic", + "MaxPool2dWithIndicesBackwardStatic4DModule_basic", + "MaxPool2dWithIndicesCeilModeTrueModule_basic", + "MaxPool2dWithIndicesFullSizeKernelModule_basic", + "MaxPool2dWithIndicesModule_basic", + "MaxPool2dWithIndicesNonDefaultDilationModule_basic", + "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", + "MaxPool2dWithIndicesNonDefaultParamsModule_basic", + "MaxPool2dWithIndicesNonDefaultStrideModule_basic", + "MaxPool2dWithIndicesStaticModule_basic", + "MaxPool3dCeilModeTrueModule_basic", + "MaxPool3dEmptyStrideStaticModule_basic", + "MaxPool3dLargeDatadModule_basic", + "MaxPool3dModuleRandomSimple_basic", + "MaxPool3dModule_basic", + "MaxPool3dStaticCeilModeTrueModule_basic", + "MaxPool3dStaticModule_basic", + "MeanDimAllReduceKeepdimModule_basic", + "MeanDimAllReduceModule_basic", + "MeanDimDtypeModule_basic", + "MeanDimEmptyDimModule_basic", + "MeanDimKeepdimModule_basic", + "MeanDimModule_basic", + "MeanDimNegativeModule_basic", + "MeanDimNoneDimModule_basic", + "MeanDtypeModule_basic", + "MeanDynamicSizesModule_basic", + "MeanModule_basic", + "Mlp1LayerModule_basic", + "Mlp2LayerModuleNoBias_basic", + "Mlp2LayerModule_basic", + "MmModule_basic", + "MmModule_chained", + "MmTanhModule_basic", + "MobilenetV3Module_basic", + "MoveDimIntNegativeIndexModule_basic", + "MseLossMeanReductionModule_basic", + "MseLossSumReductionWithDifferentElemTypeModule_basic", + "MulFloatModule_basic", + "MulIntModule_basic", + "Mv_basic", + "NarrowHorizontalTest2_basic", + "NarrowHorizontalTest_basic", + "NarrowTensorHorizontalModule_basic", + "NarrowTensorVerticalModule_basic", + "NarrowVerticalTest2_basic", + "NarrowVerticalTest_basic", + "NativeBatchNorm1DModule_basic", + "NativeBatchNorm2DModule_basic", + "NativeBatchNorm3DModule_basic", + "NativeBatchNormNoneWeightModule_basic", + "NativeDropoutEvalFloatModule_basic", + "NativeDropoutTrainModule_basic", + "NativeDropoutTrainStaticShapeModule_basic", + "NativeGroupNormBackwardModule_basic", + "NativeGroupNormModule_basic", + "NativeLayerNormDynamicModule_basic", + "NativeLayerNormModule4D_basic", + "NativeLayerNormModule_basic", + "NeFloatIntModule_basic", + "NeIntModule_basic", + "NewEmptyStridedModuleDefaultDtype_basic", + "NllLossModuleBackward1DMeanWeight_basic", + "NllLossModuleBackward1DMean_basic", + "NllLossModuleBackward1DSumWeight_basic", + "NllLossModuleBackward1DSum_basic", + "NllLossModuleBackward1DWeight_basic", + "NllLossModuleBackward1D_basic", + "NllLossModuleBackwardMeanWeight_basic", + "NllLossModuleBackwardMean_basic", + "NllLossModuleBackwardSumWeight_basic", + "NllLossModuleBackwardSum_basic", + "NllLossModuleBackwardWeight_basic", + "NllLossModuleBackward_basic", + "NllLossModuleBackward_ignore_index", + "NllLossModule_1D_basic", + "NllLossModule_basic", + "NllLossModule_ignore_index_out_of_bounds_basic", + "NllLossModule_mean_basic", + "NllLossModule_sum_basic", + "NormScalarComplexModule_basic", + "NormScalarModule_basic", + "NormScalarOptDimKeepDimComplexModule_basic", + "NormScalarOptDimKeepDimModule_basic", + "NormScalarOptDimModule_basic", + "NormalFunctionalModule_basic", + "NormalizeModule_basic", + "NumToTensorFloatModule_basic", + "NumToTensorIntModule_basic", + "NumelModule_basic", + "NumelZeroRankModule_basic", + "OneHotModule_basic", + "OnesLikeModule_defaultDtype", + "OnesLikeModule_falsePinMemory", + "OnesLikeModule_float", + "OnesLikeModule_int", + "PadModule_basic", + "PadWithNoneValModule_basic", + "PermuteNegativeIndexModule_basic", + "PixelShuffleModuleFullDynamic_basic", + "PixelShuffleModuleSpatiallyDynamic_basic", + "PixelShuffleModuleSpatiallyStatic_basic", + "PixelShuffleModuleStaticRank3Int64_basic", + "PixelShuffleModuleStaticRank4Float32_basic", + "PowIntFloatModule_basic", + "PrimMaxIntModule_basic", + "PrimMinIntDynamicModule_basic", + "PrimMinIntModule_basic", + "PrimsConvertElementTypeModule_basic", + "PrimsIotaModule_basic", + "PrimsSqueezeEmptyDimensionsModule_basic", + "PrimsSqueezeModule_basic", + "PrimsViewOfModule_basic", + "PrimsViewOfZeroRankModule_basic", + "QuantizedBatchedInputSingleLayer_basic", + "QuantizedMLP_basic", + "QuantizedNoLayer_basic", + "QuantizedReluInt32_basic", + "QuantizedReluInt8_basic", + "QuantizedReluUint8_basic", + "QuantizedSingleLayer_basic", + "RandIntDtypeModule_basic", + "RandIntLowDtypeModule_basic", + "RandIntLowModule_basic", + "RandIntModule_basic", + "RandIntPinMemoryModule_basic", + "RandLikeDtypeModule_basic", + "RandLikeModule_basic", + "RandModule_basic", + "RandnDtypeDeviceModule_basic", + "RandnGeneratorF64Module_basic", + "RandnGeneratorModule_basic", + "RandnLikeDtypeModule_basic", + "RandnLikeModule_basic", + "RandnModule_basic", + "ReduceAllBoolModule_basic", + "ReduceAllDimBool_basic", + "ReduceAllDimEmpty_basic", + "ReduceAllDimFloat_basic", + "ReduceAllDimInt_basic", + "ReduceAllFloatModule_basic", + "ReduceAllIntModule_basic", + "ReduceAmaxKeepDim_basic", + "ReduceAmaxMultiDim_basic", + "ReduceAmaxOutOfOrderDim_basic", + "ReduceAmaxSingleDim_basic", + "ReduceAnyBoolModule_basic", + "ReduceAnyFloatModule_basic", + "ReduceAnyIntModule_basic", + "ReduceFrobeniusNormComplexModule_basic", + "ReduceL1NormComplexModule_basic", + "ReduceL1NormModule_basic", + "ReduceL1NormWithDTypeModule_basic", + "ReduceL2NormComplexModule_basic", + "ReduceL2NormModule_basic", + "ReduceL3NormAllDimsModule_basic", + "ReduceL3NormKeepDimComplexModule_basic", + "ReduceL3NormKeepDimModule_basic", + "ReduceLN3NormModule_basic", + "ReduceMaxAllDims_basic", + "ReduceMaxAlongDimNegative_basic", + "ReduceMaxAlongDimSignedInt_basic", + "ReduceMaxAlongDimUnsignedInt_basic", + "ReduceMaxAlongDim_basic", + "ReduceMaxFloatModule_basic", + "ReduceMaxKeepDimReturnBoth_basic", + "ReduceMaxKeepDim_basic", + "ReduceMaxNegativeDim_basic", + "ReduceMaxSignedIntModule_basic", + "ReduceMaxUnsignedIntModule_basic", + "ReduceMinAlongDimNegative_basic", + "ReduceMinAlongDimSignedInt_basic", + "ReduceMinAlongDimUnsignedInt_basic", + "ReduceMinAlongDim_basic", + "ReduceMinFloatModule_basic", + "ReduceMinKeepDimReturnBoth_basic", + "ReduceMinKeepDim_basic", + "ReduceMinSignedIntModule_basic", + "ReduceMinUnsignedIntModule_basic", + "ReduceProdDimIntFloatModule_basic", + "ReduceProdDtypeFloatModule_basic", + "ReduceProdDtypeIntModule_basic", + "ReduceProdElementTypeBoolModule_basic", + "ReduceProdFloatModule_basic", + "ReduceProdSignedIntModule_basic", + "ReduceProdUnsignedIntModule_basic", + "ReduceSumDimIntListDtypeFloatModule_basic", + "ReduceSumDimIntListDtypeIntModule_basic", + "ReduceSumDimIntListElementTypeBoolModule_basic", + "ReduceSumDimIntListEmptyDimModule_basic", + "ReduceSumDtypeFloatModule_basic", + "ReduceSumDtypeIntModule_basic", + "ReduceSumElementTypeBoolModule_basic", + "ReduceSumFloatModule_basic", + "ReduceSumSignedIntModule_basic", + "ReduceSumUnsignedIntModule_basic", + "ReflectionPad1dModule2dInput_Right", + "ReflectionPad1dModule2dInput_basic", + "ReflectionPad1dModule3dInput_Left", + "ReflectionPad1dModule3dInput_basic", + "ReflectionPad2dModule_Bottom", + "ReflectionPad2dModule_Left", + "ReflectionPad2dModule_Right", + "ReflectionPad2dModule_Top", + "ReflectionPad2dModule_basic", + "RepeatModule_basic", + "ReplicationPad2dModule_basic", + "ReplicationPad2dModule_bottom0", + "ReplicationPad2dModule_left0", + "ReplicationPad2dModule_right0", + "ReplicationPad2dModule_top0", + "ResNet18Module_basic", + "ResNet18StaticModule_basic", + "ReshapeAliasCollapseModule_basic", + "ReshapeAliasExpandModule_basic", + "ReshapeCollapseModule_basic", + "ReshapeDynamicModule_basic", + "ReshapeExpandModule_basic", + "RollModule_basic", + "RsubIntModule_noalpha_basic", + "ScalarConstantTupleModule_basic", + "ScalarImplicitFloatModule_basic", + "ScalarImplicitIntModule_basic", + "ScaledDotProductAttentionSameModule_basic", + "ScatterReduceFloatMaxModule", + "ScatterReduceFloatMaxModuleIncludeSelf", + "ScatterReduceFloatMeanModule", + "ScatterReduceFloatMeanModuleIncludeSelf", + "ScatterReduceFloatMinModule", + "ScatterReduceFloatMinModuleIncludeSelf", + "ScatterReduceFloatProdModule", + "ScatterReduceFloatProdModuleIncludeSelf", + "ScatterReduceFloatSumModule", + "ScatterReduceFloatSumModuleIncludeSelf", + "ScatterReduceIntMaxModule", + "ScatterReduceIntMaxModuleIncludeSelf", + "ScatterReduceIntMeanModule", + "ScatterReduceIntMeanModuleIncludeSelf", + "ScatterReduceIntMinModule", + "ScatterReduceIntMinModuleIncludeSelf", + "ScatterReduceIntProdModule", + "ScatterReduceIntProdModuleIncludeSelf", + "ScatterReduceIntSumModule", + "ScatterReduceIntSumModuleIncludeSelf", + "ScatterSrcModule_basic", + "ScatterSrcStaticModule_basic", + "ScatterValueFloatModule_basic", + "ScatterValueIntModule_basic", + "SelectIntModule_basic", + "SelectIntNegativeDimAndIndexStaticModule_basic", + "SelectScattertModule_basic", + "SelectScattertStaticModule_basic", + "SliceCopyEndGreaterThanDimSize_Module_basic", + "SliceCopyNegative_Module_basic", + "SliceCopyNonZeroDim_Module_basic", + "SliceCopy_Module_basic", + "SliceEndSleStartModule_basic", + "SliceModule_basic", + "SliceNegIdxModule_basic", + "SliceOutOfLowerBoundEndIndexModule_basic", + "SliceOutOfLowerBoundStartIndexModule_basic", + "SliceOutOfUpperBoundIndexModule_basic", + "SliceScatterModule_basic", + "SliceScatterNegativeDimModule_basic", + "SliceScatterNegativeEndModule_basic", + "SliceScatterStaticModule_basic", + "SliceScatterStepVariationModule_basic", + "SliceScatterZeroDimModule_basic", + "SliceSingleIdxModule_basic", + "SliceSizeTwoStepModule_basic", + "SliceStartEqEndModule_basic", + "SoftmaxBackwardModule_basic", + "SoftmaxIntArgTypeF64Module_basic", + "SoftmaxIntModule_basic", + "SoftmaxIntNegDimModule_basic", + "SoftmaxIntNonNoneDtypeModule_basic", + "SoftplusModule_basic", + "SortIntListReverse_basic", + "SortIntList_basic", + "SortTensorDescending_basic", + "SortTensorInteger_basic", + "SortTensorNegativeDimension_basic", + "SortTensorSpecificDimension_basic", + "SortTensor_basic", + "SplitDimDynamicModule_basic", + "SplitDimStaticModule_basic", + "SplitWithSizes_Module_basic", + "SqrtIntConstantModule_basic", + "SqrtIntModule_basic", + "SqueezeDimModule_dynamic", + "SqueezeDimModule_negDim", + "StdBiasedModule_basic", + "StdCorrectionAllDimReduceModule_basic", + "StdCorrectionEmptyDimModule_basic", + "StdCorrectionKeepDimModule_basic", + "StdCorrectionLargeInputModule_basic", + "StdCorrectionModule_basic", + "StdCorrectionNoneModule_basic", + "StdCorrectionSingleDimReduceModule_basic", + "StdDimBiasedModule_basic", + "StdDimEmptyDimModule_basic", + "StdDimKeepDimFalseModule_basic", + "StdDimKeepDimTrueModule_basic", + "StdDimNoneDimModule_basic", + "StdUnbiasedModule_basic", + "SubFloatModule_basic", + "SubIntModule_basic", + "TanhBackward_basic", + "TensorToBoolZeroRank_basic", + "TensorToBool_basic", + "TensorToFloatZeroRank_basic", + "TensorToFloat_basic", + "TensorToIntZeroRank_basic", + "TensorToInt_basic", + "TensorsConcatModule_basic", + "TensorsConcatNegativeDimModule_basic", + "TensorsConcatPromoteDTypeModule_basic", + "TensorsStackModule_basic", + "TensorsStackNegativeDimModule_basic", + "TensorsStackPromoteDTypeModule_basic", + "TensorsStackSingleElementListModule_basic", + "TestMultipleTensorAndPrimitiveTypesReturn_basic", + "Threshold1dFloatModule_basic", + "Threshold1dIntI32Module_basic", + "Threshold1dIntModule_basic", + "Threshold2dFloatModule_basic", + "Threshold2dIntModule_basic", + "Threshold3dFloatModule_basic", + "Threshold3dIntModule_basic", + "ThresholdBackward1dFloatModule_basic", + "ThresholdBackward1dIntModule_basic", + "ThresholdBackward1dMixedModule_basic", + "ThresholdBackward2dFloatModule_basic", + "ThresholdBackward2dIntModule_basic", + "ThresholdBackward2dMixedModule_basic", + "ThresholdBackward3dFloatModule_basic", + "ThresholdBackward3dIntModule_basic", + "ThresholdBackward3dMixedModule_basic", + "TileBigDimsSizeModule_basic", + "TileSmallDimsSizeModule_basic", + "ToCopyBoolDTypeStaticModule_basic", + "ToCopyModule_basic", + "ToCopyWithDTypeFalsePinMemoryModule_basic", + "ToCopyWithDTypeModule_basic", + "ToDtypeLayoutCPUModule_basic", + "ToDtypeLayoutNoneModule_basic", + "ToDtypeLayoutStridedModule_basic", + "TorchPrimLoopForLikeModule_basic", + "TorchPrimLoopWhileLikeModule_basic", + "TraceModule_basic", + "TraceModule_empty", + "TraceModule_nonsquare", + "TraceSignedIntModule_basic", + "TraceUnsignedIntModule_basic", + "TraceUnsignedIntModule_empty", + "TriuBroadcastModule_basic", + "TriuModule_basic", + "TupleModule_basic", + "TypeAsDifferentModule_basic", + "TypeConversionF32ToF64Module_basic", + "TypeConversionF64ToF32Module_basic", + "TypeConversionI1ToF32Module_basic", + "TypeConversionI1ToF64Module_basic", + "TypeConversionI1ToI32Module_basic", + "TypeConversionI1ToI64Module_basic", + "TypeConversionI32ToI64Module_basic", + "TypeConversionI64ToI32Module_basic", + "TypePromotionDifferentCategoryModule_basic", + "TypePromotionSameCategoryDifferentWidthModule_basic", + "TypePromotionZeroRankHigherCategoryModule_basic", + "UnflattenIntNegativeOneDimStaticModule_basic", + "UnflattenIntNegativeOneSizeStaticModule_basic", + "UnflattenIntStaticModule_basic", + "UnflattenStaticModule_basic", + "UniformModule_basic", + "UniformNoCorrelationModule_basic", + "UniformStaticShapeModule_basic", + "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "UnsafeView1DFoldModule_basic", + "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", + "UnsafeViewCollapseModule_basic", + "UnsafeViewDynamicExpandModule_basic", + "UnsafeViewDynamicExpandWithAtenSizeIntModule_basic", + "UnsafeViewExpandModule_basic", + "UpSampleNearest2dBackwardScalesNone_basic", + "UpSampleNearest2dBackward_basic", + "UpSampleNearest2dDynamicFactor_basic", + "UpSampleNearest2dDynamicSize_basic", + "UpSampleNearest2dStaticFactor_basic", + "UpSampleNearest2dStaticSize_basic", + "UpSampleNearest2d_basic", + "VarBiasedModule_basic", + "VarCorrectionAllDimReduceModule_basic", + "VarCorrectionEmptyDimModule_basic", + "VarCorrectionKeepDimModule_basic", + "VarCorrectionLargeInputModule_basic", + "VarCorrectionModule_basic", + "VarCorrectionNoneModule_basic", + "VarCorrectionSingleDimReduceModule_basic", + "VarDimAllDimReduceModule_basic", + "VarDimBiasedModule_basic", + "VarDimEmptyDimModule_basic", + "VarDimModule_basic", + "VarDimMultiDimModule_basic", + "VarDimNegativeModule_basic", + "VarDimNoneDimModule_basic", + "VarDimSingleDimModule_basic", + "VarDimUnbiasedModule_basic", + "VarMeanBiasedModule_basic", + "VarMeanCorrectionModule_basic", + "VarMeanCorrectionNoneModule_basic", + "VarMeanDimBiasedModule_basic", + "VarMeanDimModule_basic", + "VarMeanUnbiasedModule_basic", + "VarUnbiasedModule_basic", + "View1DFoldModule_basic", + "ViewCollapseDynamicWithAtenSizeIntModule_basic", + "ViewCollapseModule_basic", + "ViewDynamicExpandCollapseModule_basic", + "ViewDynamicExpandCollapseWithAtenIntModule_basic", + "ViewDynamicExpandCollapseWithParallelUnknownDimModule_basic", + "ViewDynamicExpandModule_basic", + "ViewDynamicExpandWithAtenSizeIntModule_basic", + "ViewExpandDynamicDimModule_basic", + "ViewFlattenAndExpandModule_basic", + "ViewNoChange1dModule_basic", + "ViewNoChange2dModule_basic", + "ViewNoChange3dModule_basic", + "ViewSizeDimFollowedByCollapsedOnesModule_basic", + "ViewSizeDimFollowedByExpandedOnesModule_basic", + "ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic", + "ViewSizeDimLedAndFollowedByExpandedOnesModule_basic", + "ViewSizeDimLedByCollapsedOnesModule_basic", + "ViewSizeDimLedByExpandedOnesModule_basic", + "ViewSizeFromOtherTensor_basic", + "ZeroFloat32Module_basic", + "ZeroInt32Module_basic", + "ZeroInt64Module_basic", + "ZerosLikeModule_defaultDtype", + "ZerosLikeModule_falsePinMemory", + "ZerosLikeModule_float", + "ZerosLikeModule_int", + "_Convolution2DAllFalseModule_basic", + "_Convolution2DBenchmarkModule_basic", + "_Convolution2DCudnnModule_basic", + "_Convolution2DDeterministicModule_basic", + "_Convolution2DTF32Module_basic", + "_ConvolutionDeprecated2DAllFalseModule_basic", + "_ConvolutionDeprecated2DBenchmarkModule_basic", + "_ConvolutionDeprecated2DCudnnModule_basic", + "_ConvolutionDeprecated2DDeterministicModule_basic", + "_LogSoftmaxModule_basic", + "_SoftmaxModule_basic", +} diff --git a/projects/pt1/python/torch_mlir/torchscript.py b/projects/pt1/python/torch_mlir/torchscript.py index 3a8a37348a96..8693b6f9b20f 100644 --- a/projects/pt1/python/torch_mlir/torchscript.py +++ b/projects/pt1/python/torch_mlir/torchscript.py @@ -217,7 +217,7 @@ def _get_for_tracing( "aten.adaptive_avg_pool2d", "aten.unflatten.int", ], - OutputType.STABLEHLO: [], + OutputType.STABLEHLO: ["aten.amax"], } diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py index 5402c7243e00..de39475b0dbb 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py @@ -11,7 +11,6 @@ import torch import torch_mlir -from torch_mlir_e2e_test.onnx_backends.abc import OnnxBackend from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders from .utils import ( @@ -22,6 +21,20 @@ from torch_mlir.extras import onnx_importer from torch_mlir.dialects import torch as torch_d from torch_mlir.ir import Context, Module +from torch_mlir.compiler_utils import ( + OutputType, + run_pipeline_with_repro_report, + lower_mlir_module, +) + +# The pipeline of func.func passes that lower the ONNX backend contract to the +# Linalg-on-Tensors backend contract accepted by RefBackend or another user +# defined backend. +ONNX_TO_TORCH_FUNC_PIPELINE = ",".join( + [ + "convert-torch-onnx-to-torch", + ] +) def import_onnx(contents): @@ -71,6 +84,33 @@ def convert_onnx(model, inputs): return import_onnx(buffer) +def _module_lowering( + verbose, + output_type, + torch_mod, +): + # Lower from ONNX to Torch + run_pipeline_with_repro_report( + torch_mod, + f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))", + "Lowering Onnx backend contract to Linalg-on-Tensors backend contract", + ) + + backend_legal_ops = [ + "aten.flatten.using_ints", + "aten.adaptive_avg_pool1d", + "aten.unflatten.int", + ] + option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}" + run_pipeline_with_repro_report( + torch_mod, + f"builtin.module(torch-lower-to-backend-contract{option_string})", + "Lowering TorchFX IR -> Torch Backend IR", + ) + + return lower_mlir_module(verbose, output_type, torch_mod) + + class OnnxBackendTestConfig(TestConfig): """Base class for TestConfig's that are implemented with ONNX. @@ -78,15 +118,24 @@ class OnnxBackendTestConfig(TestConfig): reaching the ONNX abstraction level. """ - def __init__(self, backend: OnnxBackend, use_make_fx: bool = False): + def __init__( + self, + backend, + use_make_fx: bool = False, + output_type="linalg-on-tensors", + ): super().__init__() self.backend = backend self.use_make_fx = use_make_fx + self.output_type = output_type - def compile(self, program: torch.nn.Module) -> Any: + def compile(self, program: torch.nn.Module, verbose: bool = False) -> Any: example_args = convert_annotations_to_placeholders(program.forward) onnx_module = convert_onnx(program, example_args) - compiled_module = self.backend.compile(onnx_module) + backend_module = _module_lowering( + verbose, OutputType.get(self.output_type), onnx_module + ) + compiled_module = self.backend.compile(backend_module) return compiled_module def run(self, artifact: Any, trace: Trace) -> Trace: diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 8935a2a060fd..0179dd369893 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -138,8 +138,6 @@ def invoke(*args): "builtin.module(" + ",".join( [ - "func.func(refback-generalize-tensor-pad)", - "func.func(refback-generalize-tensor-concat)", # Apply some optimizations. It would be great if MLIR had more useful # optimizations that worked out of the box here. # Note: When measured, this doesn't seem to actually help that much @@ -157,6 +155,10 @@ def invoke(*args): "sparse-storage-specifier-to-llvm", # Buffer deallocation pass does not know how to handle realloc. "func.func(expand-realloc)", + # Generalize pad and concat after sparse compiler, as they are handled + # differently when the operations involve sparse operand. + "func.func(refback-generalize-tensor-pad)", + "func.func(refback-generalize-tensor-concat)", # Bufferize. "func.func(scf-bufferize)", "func.func(tm-tensor-bufferize)", diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/abc.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/abc.py deleted file mode 100644 index 7e12f8b15d7d..000000000000 --- a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/abc.py +++ /dev/null @@ -1,50 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - -import abc -from typing import TypeVar - -import torch - -from torch_mlir.ir import Module - -# A type shared between the result of `OnnxBackend.compile` and the -# input to `OnnxBackend.load`. Each backend will likely have a -# different definition of this type. -CompiledArtifact = TypeVar("CompiledArtifact") - -# A wrapper around a backend-specific loaded program representation -# that uniformly translates the `x.method(...)` interface expected of -# Torch modules into appropriate lower-level operations. -Invoker = TypeVar("Invoker") - - -class OnnxBackend(abc.ABC): - """The interface to an ONNX backend. - - Backends are recommended to raise meaningful exceptions in case of error, - ideally with easy reproduction instructions. - """ - - @abc.abstractmethod - def compile(self, module: Module) -> CompiledArtifact: - """Compile the provided MLIR module into a compiled artifact. - - The module adheres to the ONNX backend contract - (see the VerifyOnnxBackendContract pass). - - The compiled artifact can be any type, but must be correctly - interpreted by the `load` method. - """ - - @abc.abstractmethod - def load(self, artifact: CompiledArtifact) -> Invoker: - """Load the compiled artifact into a uniformly invokable form. - - The compiled artifact is the result of a previous call to `compile`. - - See the description of `Invoker` for the requirements on the returned - type. - """ diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py deleted file mode 100644 index 30129c7510ef..000000000000 --- a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py +++ /dev/null @@ -1,80 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Also available under a BSD-style license. See LICENSE. - - -from torch_mlir.compiler_utils import ( - run_pipeline_with_repro_report, - lower_mlir_module, - OutputType, -) -from torch_mlir.ir import * -from torch_mlir.passmanager import * - -from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( - RefBackendLinalgOnTensorsBackend, -) - -from .abc import OnnxBackend - -__all__ = [ - "LinalgOnTensorsOnnxBackend", -] - -# The pipeline of func.func passes that lower the ONNX backend contract to the -# Linalg-on-Tensors backend contract accepted by RefBackend. -ONNX_TO_TORCH_FUNC_PIPELINE = ",".join( - [ - "convert-torch-onnx-to-torch", - ] -) - - -class LinalgOnTensorsOnnxBackend(OnnxBackend): - """Main entry-point for the linalg-on-tensors based ONNX backend. - - This currently uses the linalg-on-tensors RefBackend for actual execution. - """ - - def __init__(self): - super().__init__() - self.refbackend = RefBackendLinalgOnTensorsBackend() - - def compile(self, imported_module: Module): - """Compiles an imported module that satisfied the ONNX backend contract. - - Args: - imported_module: The MLIR module consisting of ONNX operations wrapped by - torch.operator. - Returns: - An opaque, backend specific compiled artifact object that can be - passed to `load`. - """ - run_pipeline_with_repro_report( - imported_module, - f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))", - "Lowering Onnx backend contract to Linalg-on-Tensors backend contract", - ) - - backend_legal_ops = [ - "aten.flatten.using_ints", - "aten.adaptive_avg_pool1d", - "aten.unflatten.int", - ] - option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}" - run_pipeline_with_repro_report( - imported_module, - f"builtin.module(torch-lower-to-backend-contract{option_string})", - "Lowering TorchFX IR -> Torch Backend IR", - ) - - imported_module = lower_mlir_module( - False, OutputType.LINALG_ON_TENSORS, imported_module - ) - compiled_module = self.refbackend.compile(imported_module) - return compiled_module - - def load(self, module): - """Loads a compiled artifact into the runtime.""" - return self.refbackend.load(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py index a9ce270c2533..9c473a7934cb 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -12,6 +12,31 @@ # ============================================================================== +class MaskedScatterStaticBasic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([4, 4], torch.float32, True), + ([4, 4], torch.bool, True), + ([8, 8], torch.float32, True), + ] + ) + def forward(self, x, mask, y): + return torch.masked_scatter(x, mask, y) + + +@register_test_case(module_factory=lambda: MaskedScatterStaticBasic()) +def MaskedScatterStaticBasic_basic(module, tu: TestUtils): + x = torch.rand(4, 4) + mask = torch.rand(4, 4) > 0.5 + y = torch.rand(8, 8) + module.forward(x, mask, y) + + class IndexPutImpl1DFloatNonAccumulateModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index c0f93864f9ee..4214d3f222a1 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -743,6 +743,42 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 // ----- +// CHECK-LABEL: @test_globalmaxpool +func.func @test_globalmaxpool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[C5_0:.*]] = torch.constant.int 5 + // CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C5]], %[[C5_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATION:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.max_pool2d %arg0, %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[DILATION]], %[[FALSE]] : !torch.vtensor<[1,3,5,5],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,1,1],f32> + %0 = torch.operator "onnx.GlobalMaxPool"(%arg0) : (!torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> + return %0 : !torch.vtensor<[1,3,1,1],f32> +} + +// ----- + +// CHECK-LABEL: @test_globalmaxpool_precomputed +func.func @test_globalmaxpool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C3]], %[[C3_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATION:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.max_pool2d %arg0, %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[DILATION]], %[[FALSE]] : !torch.vtensor<[1,1,3,3],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,1,1,1],f32> + %0 = torch.operator "onnx.GlobalMaxPool"(%arg0) : (!torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1,1],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_max_example func.func @test_max_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index c517ffe31dcf..1b629f4b98eb 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -911,6 +911,103 @@ func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2 // ----- +// CHECK-LABEL: func.func @test_reduce_log_sum_exp_default_axes_keepdims_example +func.func @test_reduce_log_sum_exp_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f64> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f64> -> !torch.vtensor<[1,1,1],f64> + // CHECK: %[[INT6:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[1,1,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded +func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE_0:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE_1:.+]] = torch.constant.bool false + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f64> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f64> -> !torch.vtensor<[3,2],f64> + // CHECK: %[[INT6:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[3,2],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_log_sum_exp_keep_dims_example +func.func @test_reduce_log_sum_exp_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f64> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f64> -> !torch.vtensor<[3,2,1],f64> + // CHECK: %[[INT6:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32> + %0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> + return %0 : !torch.vtensor<[3,2,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_log_sum_exp_keep_dims_int_input_example +func.func @test_reduce_log_sum_exp_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2,2],si64>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f64> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f64> -> !torch.vtensor<[3,2,1],f64> + // CHECK: %[[INT6:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32> + %0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> + return %0 : !torch.vtensor<[3,2,1],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 @@ -2012,6 +2109,45 @@ func.func @test_random_uniform_like(%arg0: !torch.vtensor<[10],f32>) -> !torch.v // ----- +// CHECK-LABEL: func.func @test_sequence_construct_3 +module { + func.func @test_sequence_construct_3(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,3,4],f32>, %arg2: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[SEQ:.+]] = torch.prim.ListConstruct %arg0, %arg1, %arg2 : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> +// CHECK: return %[[SEQ]] : !torch.list> + %0 = torch.operator "onnx.SequenceConstruct"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,3,4],f32>) -> !torch.list> + return %0 : !torch.list> + } +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_construct_1 +module { + func.func @test_sequence_construct_1(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.list> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[SEQ:.+]] = torch.prim.ListConstruct %arg0 : (!torch.vtensor<[2,3,4],f32>) -> !torch.list> +// CHECK: return %[[SEQ]] : !torch.list> + %0 = torch.operator "onnx.SequenceConstruct"(%arg0) : (!torch.vtensor<[2,3,4],f32>) -> !torch.list> + return %0 : !torch.list> + } +} + +// ----- + +// CHECK-LABEL: func.func @test_sequence_length +module { + func.func @test_sequence_length(%arg0: !torch.list>) -> !torch.vtensor<[],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[FALSE:.+]] = torch.constant.bool false +// CHECK: %[[NONE:.+]] = torch.constant.none +// CHECK: %[[LEN:.+]] = torch.aten.len.t %arg0 : !torch.list> -> !torch.int +// CHECK: %[[LEN_AS_TEN:.+]] = torch.aten.tensor.int %[[LEN]], %[[NONE]], %[[NONE]], %[[FALSE]] : !torch.int, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[],si64> +// CHECK: return %[[LEN_AS_TEN]] : !torch.vtensor<[],si64> + %0 = torch.operator "onnx.SequenceLength"(%arg0) : (!torch.list>) -> !torch.vtensor<[],si64> + return %0 : !torch.vtensor<[],si64> + } +} + +// ----- + // CHECK-LABEL: func.func @test_sce_mean_3d func.func @test_sce_mean_3d(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: !torch.vtensor<[3,2],si64>) -> !torch.vtensor<[],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index e7605f661698..180b6aac5dd3 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -3015,3 +3015,14 @@ func.func @torch.aten.max_pool2d_with_indices$canonicalize(%arg0: !torch.vtensor %result0, %result1 = torch.aten.max_pool2d_with_indices %arg0, %29, %30, %31, %31, %false : !torch.vtensor<[10,64,112,112],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[10,64,56,56],f32>, !torch.vtensor<[10,64,56,56],si64> return %result0 : !torch.vtensor<[10,64,56,56],f32> } + +// ----- + +// CHECK-LABEL: @torch.aten.clone$no_fold( +func.func @torch.aten.clone$no_fold(%arg0: !torch.vtensor<[1,2,50,4],f32>) -> (!torch.tensor) { + // CHECK: %{{.*}} = torch.aten.clone %{{.*}}, %{{.*}} : !torch.vtensor<[1,2,50,4],f32>, !torch.none -> !torch.vtensor + %none = torch.constant.none + %0 = torch.aten.clone %arg0, %none : !torch.vtensor<[1,2,50,4],f32>, !torch.none -> !torch.vtensor + %1 = torch.copy.to_tensor %0 : !torch.tensor + return %1 : !torch.tensor +} diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 474fe2bfddbc..87d2e3d96d0e 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -134,6 +134,16 @@ def sparse_export( # elif opname == "_to_dense": # # hack (assumes we never really want the to_dense for now) # node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) + elif opname == "select" and node.args[0].meta.get("sparsity", None): + dim = len(node.meta.get("val").shape) + node.meta["sparsity"] = SparsityMeta( + torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 + ) + elif opname == "stack" and node.args[0][0].meta.get("sparsity", None): + dim = len(node.meta.get("val").shape) + node.meta["sparsity"] = SparsityMeta( + torch.sparse_coo, 0, dim - 1, 1, None, torch.int64, torch.int64 + ) return prog @@ -459,6 +469,11 @@ def forward(self, x): # CHECK: values=tensor([ 0., 0., 1., 2., 3., 1000.]), # CHECK: size=(10, 20, 30), nnz=6, dtype=torch.float64, layout=torch.sparse_coo) # CHECK: torch.mlir +# CHECK: [0 6] +# CHECK: [0 1 1 4 9 9] +# CHECK: [ 0 1 1 5 19 19] +# CHECK: [ 0 1 3 6 28 29] +# CHECK: [ 0. 0. 1. 2. 3. 1000.] # def test_sparse_coo3(): class COO3Net(torch.nn.Module): @@ -481,11 +496,15 @@ def forward(self, x): # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. res1 = net(sparse_input) - # TODO: make coo3 work - # res2 = sparse_jit(net, sparse_input) + res2 = sparse_jit(net, sparse_input) print("torch.sparse") print(res1) print("torch.mlir") + print(res2[0]) + print(res2[1]) + print(res2[2]) + print(res2[3]) + print(res2[4]) @run @@ -574,8 +593,8 @@ def forward(self, X): for t in range(T): mem = mem * self.decay + X[..., t] spike = self.act(mem - self.thresh) - mem = mem * (1.0 - spike) spike = spike.to_sparse().to_dense() # prop hack + mem = mem * (1.0 - spike) spike_pot.append(spike) spike_pot = torch.stack(spike_pot, dim=-1) return spike_pot @@ -621,3 +640,47 @@ def forward(self, X): print(res1) print("torch.mlir") print(res2) + + +@run +# +# CHECK-LABEL: test_sparse_feature_scaling +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[4,4],f32>) -> !torch.vtensor<[4,4],f32> { +# ... more IR ... +# CHECK: %[[D:.*]] = torch.operator "torch.aten._to_sparse" +# CHECK: %[[R:.*]] = torch.aten.mm %[[D]], %[[A]] +# CHECK return %[[R]] : !torch.vtensor<[4,4],f32> +# CHECK: } +# +# CHECK: torch.sparse +# CHECK: tensor({{\[}}[0.3342, 0.5173, 0.0596, 0.0889], +# CHECK: [0.1321, 0.2724, 0.2105, 0.3851], +# CHECK: [0.2478, 0.3439, 0.1898, 0.2185], +# CHECK: [0.0222, 0.1683, 0.2928, 0.5167]{{\]}}) +# CHECK: torch.mlir +# +def test_sparse_feature_scaling(): + class Scale(nn.Module): + def forward(self, F): + sum_vector = torch.sum(F, dim=1) + reciprocal_vector = 1 / sum_vector + reciprocal_vector[reciprocal_vector == float("inf")] = 0 + scaling_diagonal = torch.diag(reciprocal_vector).to_sparse() + return scaling_diagonal @ F + + net = Scale() + + # Get a random (but reproducible) features input. + torch.manual_seed(0) + f = torch.rand(4, 4) + m = export_and_import(net, f) + print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + res1 = net(f) + # TODO: make this work + # res2 = sparse_jit(net, f) + print("torch.sparse") + print(res1) + print("torch.mlir") diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index e62780ff9634..d21d1acad337 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -90,6 +90,7 @@ gentbl_cc_library( cc_library( name = "TorchMLIRTorchDialectUtils", srcs = [ + "lib/Dialect/Torch/Utils/SparsityUtils.cpp", "lib/Dialect/Torch/Utils/TorchUpstream.cpp", "lib/Dialect/Torch/Utils/Utils.cpp", ], @@ -97,6 +98,7 @@ cc_library( "include/torch-mlir/Dialect/Torch/IR/TorchOps.h", "include/torch-mlir/Dialect/Torch/IR/TorchTraits.h", "include/torch-mlir/Dialect/Torch/IR/TorchTypes.h", + "include/torch-mlir/Dialect/Torch/Utils/SparsityUtils.h", "include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h", "include/torch-mlir/Dialect/Torch/Utils/Utils.h", ], @@ -108,6 +110,8 @@ cc_library( "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:SparseTensorEnums", ], )