From 7b94ced39af3b43029b165b30107b8a813735717 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Fri, 13 Sep 2024 18:48:41 +0800 Subject: [PATCH 1/2] [Stablehlo] fix aten compare ops' promote rules (#3709) previous PR(https://github.com/llvm/torch-mlir/pull/3702) --- lib/Conversion/TorchToStablehlo/Basic.cpp | 25 ++++++++++++++++------- projects/pt1/e2e_testing/xfail_sets.py | 1 - 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 1f21a1afe8d6..ab4e284f8b2d 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -516,13 +516,12 @@ class ConvertAtenCompareOp : public OpConversionPattern { if (!lhsTy) { return op.emitError("only Tensor types supported in StableHLO"); } + bool isRhsScalar = false; if (!rhsTy) { rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), rhs.getType()); - // use lhs's element type as compute type - rhs = - hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy.getElementType()); rhsTy = dyn_cast(rhs.getType()); + isRhsScalar = true; } auto outType = cast( @@ -537,16 +536,28 @@ class ConvertAtenCompareOp : public OpConversionPattern { } if (isa(lhsElemTy) && isa(rhsElemTy)) { - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy); + // torch.lt(x_int, 1.1) use fp32 as compute type + // torch.lt(x_int, y_float) use y's float type as compute type + Type promoteTo = isRhsScalar ? rewriter.getF32Type() : rhsElemTy; + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, promoteTo); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, promoteTo); } else if (isa(lhsElemTy) && isa(rhsElemTy)) { + // always use lhs's float type as compute type rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy); } else { - if (lhsElemTy.getIntOrFloatBitWidth() > - rhsElemTy.getIntOrFloatBitWidth()) { + if (isRhsScalar) { + // torch.lt(x_float, 1.1) use x's float type as compute type + // torch.lt(x_int, 1) use x's int type as compute type rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy); } else { - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy); + // torch.lt(x_float, y_float) use higher bitwidth as compute type + Type promoteTo = lhsElemTy.getIntOrFloatBitWidth() > + rhsElemTy.getIntOrFloatBitWidth() + ? lhsElemTy + : rhsElemTy; + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, promoteTo); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, promoteTo); } } lhsElemTy = dyn_cast(lhs.getType()).getElementType(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 918cbae63d36..c99ef4d96874 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -528,7 +528,6 @@ "AtenPolarFloatModule_basic", "DiagonalWithStaticShapeModule_basic", "EinsumStaticDiagonalDimensionModule_basic", - "ElementwiseIntTensorLtFloatScalarModule_basic", "ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic", "ElementwiseRemainderScalarModule_Float_NegativeDividend_basic", "ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic", From bc70c503739ce1776eac86886e463a9b3dc8cd52 Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Fri, 13 Sep 2024 12:39:58 -0400 Subject: [PATCH 2/2] Delete unnecessary linalg conversion for aten.fmod (#3707) Follow up cleanup for [this PR](https://github.com/llvm/torch-mlir/pull/3689), which introduced a decomposition for `aten.fmod.Tensor`. This means that the lowering for this operator in linalg is no longer needed. Thanks to @vivekkhandelwal1 for pointing this out. --------- Co-authored-by: Srinath Avadhanula --- .../TorchToLinalg/Uncategorized.cpp | 65 ++++++------------- 1 file changed, 21 insertions(+), 44 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index cf4e2b4f07f0..4688ffc7808a 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1282,29 +1282,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return createRemainderPayload(b, loc, converter, payloadArgs, remTensor, operands); } - if (auto fmod = dyn_cast(op)) { - Type newResultType = - cast(converter->convertType(fmod.getType())) - .getElementType(); - - Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); - Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); - Value result; - - if (isa(newResultType)) { - Value n = b.create(loc, self, other); - n = b.create(loc, n); - Value n_y = b.create(loc, n, other); - result = b.create(loc, self, n_y); - } else if (isa(newResultType)) { - Value n = b.create(loc, self, other); - Value n_y = b.create(loc, n, other); - result = b.create(loc, self, n_y); - } else { - fmod.emitError("Unsupported type encountered for AtenFmodTensorOp."); - } - return result; - } if (auto reciprocal = dyn_cast(op)) { Type dtype = cast(converter->convertType(reciprocal.getType())) @@ -1612,23 +1589,23 @@ class ConvertElementwiseOp : public ConversionPattern { AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, - AtenRemainderScalarOp, AtenRemainderTensorOp, AtenFmodTensorOp, - AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, - AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, - AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp, - AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp, - Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp, - AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, - AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, - AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, - AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp, - AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, - AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, - AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, - AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, - AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp, - AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp, - AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, + AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp, + AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, + AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, + AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp, + Aten__Lshift__ScalarOp, Aten__Rshift__ScalarOp, AtenGtScalarOp, + AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, + AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, + AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, + AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, + AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, + AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, + AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, + AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, + AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, + AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp, + AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp, + AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); @@ -3385,10 +3362,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, - AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp, - AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, - AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, - AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>(); + AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp, + AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, + AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, + AtenQuantizePerTensorOp, AtenIscloseOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context);