diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index dabc0c2d9b36..1f4bc82e7f4c 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1290,29 +1290,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return createRemainderPayload(b, loc, converter, payloadArgs, remTensor, operands); } - if (auto fmod = dyn_cast(op)) { - Type newResultType = - cast(converter->convertType(fmod.getType())) - .getElementType(); - - Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); - Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); - Value result; - - if (isa(newResultType)) { - Value n = b.create(loc, self, other); - n = b.create(loc, n); - Value n_y = b.create(loc, n, other); - result = b.create(loc, self, n_y); - } else if (isa(newResultType)) { - Value n = b.create(loc, self, other); - Value n_y = b.create(loc, n, other); - result = b.create(loc, self, n_y); - } else { - fmod.emitError("Unsupported type encountered for AtenFmodTensorOp."); - } - return result; - } if (auto reciprocal = dyn_cast(op)) { Type dtype = cast(converter->convertType(reciprocal.getType())) @@ -1620,23 +1597,23 @@ class ConvertElementwiseOp : public ConversionPattern { AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, - AtenRemainderScalarOp, AtenRemainderTensorOp, AtenFmodTensorOp, - AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, - AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, - AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp, - AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp, - Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp, - AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, - AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, - AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, - AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp, - AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, - AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, - AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, - AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, - AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp, - AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp, - AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, + AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp, + AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, + AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, + AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp, + Aten__Lshift__ScalarOp, Aten__Rshift__ScalarOp, AtenGtScalarOp, + AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, + AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, + AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, + AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, + AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, + AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, + AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, + AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, + AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, + AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp, + AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp, + AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); @@ -3393,10 +3370,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, - AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp, - AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, - AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, - AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>(); + AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp, + AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, + AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, + AtenQuantizePerTensorOp, AtenIscloseOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 1f21a1afe8d6..ab4e284f8b2d 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -516,13 +516,12 @@ class ConvertAtenCompareOp : public OpConversionPattern { if (!lhsTy) { return op.emitError("only Tensor types supported in StableHLO"); } + bool isRhsScalar = false; if (!rhsTy) { rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), rhs.getType()); - // use lhs's element type as compute type - rhs = - hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy.getElementType()); rhsTy = dyn_cast(rhs.getType()); + isRhsScalar = true; } auto outType = cast( @@ -537,16 +536,28 @@ class ConvertAtenCompareOp : public OpConversionPattern { } if (isa(lhsElemTy) && isa(rhsElemTy)) { - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy); + // torch.lt(x_int, 1.1) use fp32 as compute type + // torch.lt(x_int, y_float) use y's float type as compute type + Type promoteTo = isRhsScalar ? rewriter.getF32Type() : rhsElemTy; + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, promoteTo); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, promoteTo); } else if (isa(lhsElemTy) && isa(rhsElemTy)) { + // always use lhs's float type as compute type rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy); } else { - if (lhsElemTy.getIntOrFloatBitWidth() > - rhsElemTy.getIntOrFloatBitWidth()) { + if (isRhsScalar) { + // torch.lt(x_float, 1.1) use x's float type as compute type + // torch.lt(x_int, 1) use x's int type as compute type rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy); } else { - lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy); + // torch.lt(x_float, y_float) use higher bitwidth as compute type + Type promoteTo = lhsElemTy.getIntOrFloatBitWidth() > + rhsElemTy.getIntOrFloatBitWidth() + ? lhsElemTy + : rhsElemTy; + lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, promoteTo); + rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, promoteTo); } } lhsElemTy = dyn_cast(lhs.getType()).getElementType(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index d31fab14d149..eedb09cfbd34 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -566,7 +566,6 @@ "AtenPolarFloatModule_basic", "DiagonalWithStaticShapeModule_basic", "EinsumStaticDiagonalDimensionModule_basic", - "ElementwiseIntTensorLtFloatScalarModule_basic", "ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic", "ElementwiseRemainderScalarModule_Float_NegativeDividend_basic", "ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic",