diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index b0a8d5e90fc9..ea37606aefe9 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -675,12 +675,12 @@ static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch, return b.create(loc, valuesTy, values, outDimsList); } -class ConvertAten_IndexPutImplOp - : public OpConversionPattern { +class ConvertAtenIndexPutHackedTwinOp + : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(Aten_IndexPutImplOp op, OpAdaptor adaptor, + matchAndRewrite(AtenIndexPutHackedTwinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); @@ -699,17 +699,6 @@ class ConvertAten_IndexPutImplOp return rewriter.notifyMatchFailure( op, "unimplemented: the values tensor type must have sizes."); - // The unsafe should be either `False` or `none`. - if (!op.getUnsafe().getType().isa()) { - bool unsafe; - if (!matchPattern(op.getUnsafe(), m_TorchConstantBool(&unsafe))) - return rewriter.notifyMatchFailure( - op, "unimplemented: unsafe must be a constant"); - else if (unsafe) - return rewriter.notifyMatchFailure( - op, "unimplemented: unsafe is expected to be false"); - } - // The accumulate should be a torch constant of boolean type. bool accumulate; if (!matchPattern(op.getAccumulate(), m_TorchConstantBool(&accumulate))) @@ -1624,8 +1613,8 @@ class ConvertTorchToTMTensor RewritePatternSet patterns(context); target.addIllegalOp(); patterns.add(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 467cd227ad92..a570d4a93e0e 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3970,8 +3970,8 @@ class SimplifyAten_IndexPutImplOp // Handle Aten_IndexPutImplOp on 1d tensors template <> -LogicalResult ConvertAtenOp::matchAndRewrite( - Aten_IndexPutImplOp op, OpAdaptor adaptor, +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenIndexPutHackedTwinOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // TOSA scatter: // // Copy the values_in tensor to the values_out tensor. @@ -6227,7 +6227,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenSliceTensorOp); INSERT_ATENOP_PATTERN(AtenBroadcastToOp); INSERT_ATENOP_PATTERN(AtenGatherOp); - INSERT_ATENOP_PATTERN(Aten_IndexPutImplOp); + INSERT_ATENOP_PATTERN(AtenIndexPutHackedTwinOp); INSERT_ATENOP_PATTERN(AtenIndexTensorHackedTwinOp); INSERT_ATENOP_PATTERN(AtenAbsOp); INSERT_ATENOP_PATTERN(AtenWhereSelfOp); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 83ad93c5e879..315264333b3a 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -5523,23 +5523,6 @@ class DecomposeAtenNewFullOp : public OpRewritePattern { }; } // namespace -namespace { -// Decompose `aten.indexPut` op into `valsem.aten.indexPutImpl` op. -class DecomposeAtenIndexPutOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenIndexPutOp op, - PatternRewriter &rewriter) const override { - Value cstFalse = rewriter.create(op.getLoc(), false); - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), - op.getAccumulate(), - /*unsafe=*/cstFalse); - return success(); - } -}; -} // namespace - namespace { class DecomposeAtenExpandAsOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -5635,61 +5618,6 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern { }; } // namespace -namespace { -// Decompose `aten.indexPut.hackedTwin` op into `valsem.aten.indexPutImpl` -// op. -class DecomposeAtenIndexPutHackedTwinOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenIndexPutHackedTwinOp op, - PatternRewriter &rewriter) const override { - Value cstFalse = rewriter.create(op.getLoc(), false); - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), - op.getAccumulate(), - /*unsafe=*/cstFalse); - return success(); - } -}; -} // namespace - -namespace { -// Decompose `aten._index_put_impl_.hacked_twin` op into `aten._index_put_impl` -// op. -class DecomposeAten_IndexPutImpl_HackedTwinOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(Aten_IndexPutImpl_HackedTwinOp op, - PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), - op.getAccumulate(), op.getUnsafe()); - return success(); - } -}; -} // namespace - -namespace { -// Decompose `aten._unsafe_indexPut.hackedTwin` op into `aten._index_put_impl` -// op. -class DecomposeAten_UnsafeIndexPutHackedTwinOp - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(Aten_UnsafeIndexPutHackedTwinOp op, - PatternRewriter &rewriter) const override { - Value cstFalse = rewriter.create(op.getLoc(), false); - rewriter.replaceOpWithNewOp( - op, op.getType(), op.getSelf(), op.getIndices(), op.getValues(), - op.getAccumulate(), - /*unsafe=*/cstFalse); - return success(); - } -}; -} // namespace - namespace { // Decompose `aten.pad` op into `aten.constantPadNd` op. class DecomposeAtenPadOp : public OpRewritePattern { @@ -7514,65 +7442,138 @@ class DecomposeAtenMaxPool2dWithIndicesOp }; } // namespace -// AtenIndexTensorOp +// Torch ops related to indexing tensors, e.g., AtenIndexTensor, AtenIndexPut. namespace { -// The goal of this pattern is to eliminate none index in aten.Index.Tensor's -// `indices` param for the ease of various backend. The detailed steps are: -// 1. reorder input tensor so that the non-none index appears at adjacent -// positions. -// 2. manually generate index tensor with some ops like iota, to replace the -// none index in `indices` -// 3. replace the old aten.Index.Tensor with a new -// aten.Index.Tensor_hacked_twin. -class DecomposeAtenIndexTensorOp : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - // TODO: It might be better to use aten.view op instead of mulitple - // aten.unsqueeze. But currently, torch-to-linalg pass has limited support for - // view on dynamic shapes, such as [?] -> [?,1,1,1]. Using aten.view op will - // cause relevant e2e tests fail. - static FailureOr - unsqueezeTensorAtTrailingDim(Operation *op, PatternRewriter &rewriter, - Value input, int count) { - Location loc = op->getLoc(); - Value constMinusOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(-1)); - Value result = input; - while (count--) { - auto unsqzTensorInfo = - unsqueezeTensor(rewriter, op, result, /*dim=*/constMinusOne); - if (failed(unsqzTensorInfo)) { - return failure(); - } - result = *unsqzTensorInfo; +// unsqueeze is more easily optimized than a generic view, and we prefer to +// enjoy ops with more structure than less in compositions. +static FailureOr unsqueezeTensorAtTrailingDim(Operation *op, + PatternRewriter &rewriter, + Value input, int count) { + Location loc = op->getLoc(); + Value constMinusOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(-1)); + Value result = input; + while (count--) { + auto unsqzTensorInfo = + unsqueezeTensor(rewriter, op, result, /*dim=*/constMinusOne); + if (failed(unsqzTensorInfo)) { + return failure(); } - return result; + + result = *unsqzTensorInfo; } + return result; +} - static Value createIndexToReplaceNone(Operation *op, - PatternRewriter &rewriter, Value input, - int dimInt, int64_t dimSize) { - Location loc = op->getLoc(); - MLIRContext *context = op->getContext(); - Value none = rewriter.create(loc); - auto int64Dtype = getDtypeIntValueForType( - rewriter, loc, - rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); +static Value createIndexToReplaceNone(Operation *op, PatternRewriter &rewriter, + Value input, int dimInt, + int64_t dimSize) { + Location loc = op->getLoc(); + MLIRContext *context = op->getContext(); + Value none = rewriter.create(loc); + auto int64Dtype = getDtypeIntValueForType( + rewriter, loc, rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); + + auto resultType = ValueTensorType::get( + context, {dimSize}, + rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); + auto dim = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimInt)); + auto end = rewriter.create(loc, input, dim); + auto v = rewriter.create( + loc, resultType, /*end=*/end, /*dtype=*/int64Dtype, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); + return v; +} - auto resultType = ValueTensorType::get( - context, {dimSize}, - rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); - auto dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimInt)); - auto end = rewriter.create(loc, input, dim); - auto v = rewriter.create( - loc, resultType, /*end=*/end, /*dtype=*/int64Dtype, /*layout=*/none, - /*device=*/none, /*pin_memory=*/none); - return v; +static FailureOr createNewIndices(Operation *op, + PatternRewriter &rewriter, Value input, + llvm::ArrayRef oldIndices, + llvm::ArrayRef newToOldDimMap, + llvm::ArrayRef oldIndexUsed) { + Location loc = op->getLoc(); + MLIRContext *context = op->getContext(); + + auto inputType = input.getType().cast(); + if (!inputType.hasSizes()) { + return failure(); + } + auto inputSizes = inputType.getSizes(); + int64_t inputRank = inputSizes.size(); + + int64_t maxIndexRank = 0; + for (auto index : oldIndices) { + auto indexType = index.getType().dyn_cast(); + if (!indexType) // None index + continue; + if (!indexType.hasSizes()) + return failure(); + int64_t indexRank = indexType.getSizes().size(); + maxIndexRank = maxIndexRank > indexRank ? maxIndexRank : indexRank; + } + + // manually generate new indices. + SmallVector listElements(inputRank); + + int64_t noneIndexCnt = 0; + int64_t i; + // handle trailing none indices. + for (i = inputRank - 1; i >= 0; --i) { + int64_t oldI = newToOldDimMap[i]; + if (oldIndexUsed[oldI]) + break; + Value v = createIndexToReplaceNone(op, rewriter, input, i, inputSizes[i]); + auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, v, noneIndexCnt); + if (failed(vInfo)) { + return failure(); + } + listElements[i] = *vInfo; + noneIndexCnt++; + } + // handle non-none index in between. + for (; i >= 0; --i) { + int64_t oldI = newToOldDimMap[i]; + if (!oldIndexUsed[oldI]) + break; + auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, oldIndices[oldI], + noneIndexCnt); + if (failed(vInfo)) { + return failure(); + } + listElements[i] = *vInfo; + } + + // handle possible leading none indices. + for (; i >= 0; --i) { + int64_t oldI = newToOldDimMap[i]; + if (oldIndexUsed[oldI]) { + return failure(); + } + Value v = createIndexToReplaceNone(op, rewriter, input, i, inputSizes[i]); + auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, v, + noneIndexCnt + maxIndexRank); + if (failed(vInfo)) { + return failure(); + } + listElements[i] = *vInfo; + noneIndexCnt++; } + auto listElemType = ValueTensorType::get(context, std::nullopt, nullptr); + Value newIndexList = rewriter.create( + loc, Torch::ListType::get(listElemType), listElements); + + return newIndexList; +} + +// The goal of this pattern is to eliminate `None` index in aten.Index.Tensor's +// `indices` param and transform it to aten.index.Tensor_hacked_twin, for the +// ease of various backend. +class DecomposeAtenIndexTensorOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenIndexTensorOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); @@ -7590,12 +7591,6 @@ class DecomposeAtenIndexTensorOp : public OpRewritePattern { } auto inputSizes = inputType.getSizes(); int64_t inputRank = inputSizes.size(); - auto outputType = cast(op.getType()); - if (!outputType.hasSizes()) { - return rewriter.notifyMatchFailure( - op, "only output with shape information is supported"); - } - auto outputRank = outputType.getSizes().size(); auto isTensor = [](Value v) { return v.getType().isa(); @@ -7603,19 +7598,15 @@ class DecomposeAtenIndexTensorOp : public OpRewritePattern { // directly replace aten.Index.Tensor with aten.index.Tensor_hacked_twin if (llvm::all_of(indices, isTensor)) { - if (indices.size() == 0) { - return rewriter.notifyMatchFailure( - op, "the indices is empty, it should be folded as a nop"); - } // By default, we regard the first index type as the list element type. auto indexElemType = indices[0] .getType() .template cast() .getWithSizesAndDtype(std::nullopt, nullptr); - auto newIndex = rewriter.create( + auto newIndices = rewriter.create( loc, Torch::ListType::get(indexElemType), indices); - rewriter.replaceOpWithNewOp(op, op.getType(), - input, newIndex); + rewriter.replaceOpWithNewOp( + op, op.getType(), input, newIndices); return success(); } @@ -7623,6 +7614,7 @@ class DecomposeAtenIndexTensorOp : public OpRewritePattern { llvm::to_vector(llvm::map_range(indices, isTensor)); for (int64_t i = indices.size(); i < inputRank; ++i) indexUsed.emplace_back(false); + bool indexIsConsecutive = true; int64_t firstUsedIndex = -1; for (size_t i = 0; i < indices.size(); ++i) { @@ -7634,17 +7626,15 @@ class DecomposeAtenIndexTensorOp : public OpRewritePattern { } } - // use aten.permute to reorder the input Value newInput; - // `dims` stores the mapping from new index to the old index of input - // tensor. - SmallVector dims; + SmallVector newToOldDimMap; + // permute input to make the non-none indices consecutive. if (!indexIsConsecutive) { SmallVector dimValues; SmallVector permutedSizes; for (int i = 0; i < inputRank; i++) { if (indexUsed[i]) { - dims.emplace_back(i); + newToOldDimMap.emplace_back(i); dimValues.emplace_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); permutedSizes.emplace_back(inputSizes[i]); @@ -7652,7 +7642,7 @@ class DecomposeAtenIndexTensorOp : public OpRewritePattern { } for (int i = 0; i < inputRank; i++) { if (!indexUsed[i]) { - dims.emplace_back(i); + newToOldDimMap.emplace_back(i); dimValues.emplace_back(rewriter.create( loc, rewriter.getI64IntegerAttr(i))); permutedSizes.emplace_back(inputSizes[i]); @@ -7668,66 +7658,100 @@ class DecomposeAtenIndexTensorOp : public OpRewritePattern { } else { newInput = input; for (int i = 0; i < inputRank; i++) { - dims.emplace_back(i); + newToOldDimMap.emplace_back(i); } } - // manually generate new indices. - SmallVector listElements(inputRank); + auto newIndeicesInfo = createNewIndices(op, rewriter, newInput, indices, + newToOldDimMap, indexUsed); + if (failed(newIndeicesInfo)) { + return rewriter.notifyMatchFailure(op, "failed to replcae `None` index"); + } + rewriter.replaceOpWithNewOp( + op, op.getType(), newInput, *newIndeicesInfo); + return success(); + } +}; - int64_t trailingDimCnt = 0; - int64_t i; - // handle trailing none index. - for (i = inputRank - 1; i >= 0; --i) { - int64_t oldI = dims[i]; - if (indexUsed[oldI]) - break; - Value v = - createIndexToReplaceNone(op, rewriter, newInput, i, inputSizes[oldI]); - auto vInfo = - unsqueezeTensorAtTrailingDim(op, rewriter, v, trailingDimCnt); - if (failed(vInfo)) { - return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor"); - } - listElements[i] = *vInfo; - trailingDimCnt++; +// The goal of this pattern is to eliminate `None` index in aten.inde_put-like +// ops' `indices` param and transform it to aten.index_put.hacked_twin, for the +// ease of various backend. +template +class DecomposeAtenIndexPutLikeOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AtenIndexPutLikeOpT op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + SmallVector indices; + if (!getListConstructElements(op.getIndices(), indices)) + return rewriter.notifyMatchFailure(op, + "failed to get elements of `indices`"); + + auto input = op.getSelf(); + auto inputType = input.getType().template cast(); + if (!inputType.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "only input with shape information is supported"); + } + auto inputSizes = inputType.getSizes(); + int64_t inputRank = inputSizes.size(); + + auto isTensor = [](Value v) { + return v.getType().isa(); + }; + + // directly replace current op with aten.index_put.hacked_twin + if (llvm::all_of(indices, isTensor)) { + // By default, we regard the first index type as the list element type. + auto indexElemType = indices[0] + .getType() + .template cast() + .getWithSizesAndDtype(std::nullopt, nullptr); + auto newIndex = rewriter.create( + loc, Torch::ListType::get(indexElemType), indices); + rewriter.replaceOpWithNewOp( + op, op.getType(), input, newIndex, op.getValues(), + op.getAccumulate()); + return success(); } - // handle non-none index in between. - for (; i >= 0; --i) { - int64_t oldI = dims[i]; - if (!indexUsed[oldI]) + + SmallVector indexUsed = + llvm::to_vector(llvm::map_range(indices, isTensor)); + for (int64_t i = indices.size(); i < inputRank; ++i) + indexUsed.emplace_back(false); + + // check if non-None index is consecutive + bool indexIsConsecutive = true; + int64_t firstUsedIndex = -1; + for (size_t i = 0; i < indices.size(); ++i) { + if (indexUsed[i] && firstUsedIndex == -1) { + firstUsedIndex = i; + } else if (indexUsed[i] && !indexUsed[i - 1]) { + indexIsConsecutive = false; break; - auto vInfo = unsqueezeTensorAtTrailingDim(op, rewriter, indices[oldI], - trailingDimCnt); - if (failed(vInfo)) { - return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor"); } - listElements[i] = *vInfo; + } + if (!indexIsConsecutive) { + return rewriter.notifyMatchFailure( + op, "non consecutive indices is not supported"); } - // handle possible leading none dimensions. - for (; i >= 0; --i) { - int64_t oldI = dims[i]; - if (indexUsed[oldI]) { - return rewriter.notifyMatchFailure( - op, "the indices are still unconsecutive after reordering input " - "tensor"); - } - Value v = - createIndexToReplaceNone(op, rewriter, newInput, i, inputSizes[oldI]); - auto vInfo = - unsqueezeTensorAtTrailingDim(op, rewriter, v, outputRank - 1 - i); - if (failed(vInfo)) { - return rewriter.notifyMatchFailure(op, "failed to unsqueeze tensor"); - } - listElements[i] = *vInfo; + SmallVector newToOldDimMap; + for (int i = 0; i < inputRank; i++) { + newToOldDimMap.emplace_back(i); } - auto listElemType = ValueTensorType::get(context, std::nullopt, nullptr); - auto newIndexList = rewriter.create( - loc, Torch::ListType::get(listElemType), listElements); - rewriter.replaceOpWithNewOp( - op, op.getType(), newInput, newIndexList); + auto newIndicesInfo = createNewIndices(op, rewriter, input, indices, + newToOldDimMap, indexUsed); + if (failed(newIndicesInfo)) { + return rewriter.notifyMatchFailure(op, "failed to replace `None` index"); + } + rewriter.replaceOpWithNewOp( + op, op.getType(), input, *newIndicesInfo, op.getValues(), + op.getAccumulate()); return success(); } }; @@ -8020,18 +8044,19 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal( - patterns); - addPatternIfTargetOpIsIllegal( + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal>( patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenIndexPutLikeOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenIndexPutLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -8106,7 +8131,6 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); // More specific conv ops diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index cceee7e82dd1..7eedb2c6053e 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -466,13 +466,14 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); @@ -500,7 +501,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 70de9c28c200..dbf1b6959605 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1831,8 +1831,6 @@ "IndexPutHackedTwin1DIntNonAccumulateModule_basic", "IndexPutImpl1DFloatNonAccumulateModule_basic", "IndexPutImpl1DIntNonAccumulateModule_basic", - "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic", - "IndexPutImpl2DNoneIndexStaticModule_basic", "IndexTensorModule3dInputStatic_basic", "IndexTensorMultiIndexStaticModule_basic", "IndexTensorStaticModule_basic", @@ -2128,6 +2126,8 @@ "ElementwisePreluModule_basic", "ElementwisePreluStaticModule_basic", "ElementwiseLogSigmoidModule_basic", + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", # It appears that you're trying to get value out of a tracing tensor "PrimListUnpackNumMismatchModule_basic", # RuntimeError: shape '[2, -1, 6]' is invalid for input of size 210