Skip to content

Commit

Permalink
Merge pull request #405 from Xilinx/bump_to_bc70c503
Browse files Browse the repository at this point in the history
[AutoBump] Merge with bc70c50 (Sep 13) (51)
  • Loading branch information
mgehre-amd authored Dec 10, 2024
2 parents 7fac461 + 08f6ba7 commit a9c318b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 52 deletions.
65 changes: 21 additions & 44 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1290,29 +1290,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
return createRemainderPayload(b, loc, converter, payloadArgs, remTensor,
operands);
}
if (auto fmod = dyn_cast<AtenFmodTensorOp>(op)) {
Type newResultType =
cast<RankedTensorType>(converter->convertType(fmod.getType()))
.getElementType();

Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType);
Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType);
Value result;

if (isa<mlir::FloatType>(newResultType)) {
Value n = b.create<arith::DivFOp>(loc, self, other);
n = b.create<math::TruncOp>(loc, n);
Value n_y = b.create<arith::MulFOp>(loc, n, other);
result = b.create<arith::SubFOp>(loc, self, n_y);
} else if (isa<mlir::IntegerType>(newResultType)) {
Value n = b.create<arith::DivSIOp>(loc, self, other);
Value n_y = b.create<arith::MulIOp>(loc, n, other);
result = b.create<arith::SubIOp>(loc, self, n_y);
} else {
fmod.emitError("Unsupported type encountered for AtenFmodTensorOp.");
}
return result;
}
if (auto reciprocal = dyn_cast<AtenReciprocalOp>(op)) {
Type dtype =
cast<RankedTensorType>(converter->convertType(reciprocal.getType()))
Expand Down Expand Up @@ -1620,23 +1597,23 @@ class ConvertElementwiseOp : public ConversionPattern {
AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp,
AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op,
AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp,
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenFmodTensorOp,
AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp,
AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp,
AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp,
AtenBitwiseRightShiftTensorOp, Aten__Lshift__ScalarOp,
Aten__Rshift__ScalarOp, AtenGtScalarOp, AtenGeScalarOp,
AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp,
AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp,
AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp,
AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp,
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp,
AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp,
AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp,
AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp,
AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp,
AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp,
AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp,
AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp,
AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp,
AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp,
Aten__Lshift__ScalarOp, Aten__Rshift__ScalarOp, AtenGtScalarOp,
AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp,
AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp,
AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp,
AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp,
AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp,
AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp,
AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp,
AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp,
AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp, AtenIscloseOp>(op))
return rewriter.notifyMatchFailure(op, "not a supported elementwise op");

Expand Down Expand Up @@ -3393,10 +3370,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp,
AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp,
AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp,
AtenTrilOp, AtenRemainderScalarOp, AtenFmodTensorOp,
AtenRemainderTensorOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp,
AtenFillTensorOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp,
AtenDequantizeTensorOp, AtenQuantizePerTensorOp, AtenIscloseOp>();
AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp,
AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp,
AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp,
AtenQuantizePerTensorOp, AtenIscloseOp>();
patterns.add<ConvertElementwiseOp>(typeConverter, context);
target.addIllegalOp<AtenNllLossForwardOp>();
patterns.add<ConvertAtenDetachOp>(typeConverter, context);
Expand Down
25 changes: 18 additions & 7 deletions lib/Conversion/TorchToStablehlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -516,13 +516,12 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
if (!lhsTy) {
return op.emitError("only Tensor types supported in StableHLO");
}
bool isRhsScalar = false;
if (!rhsTy) {
rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(),
rhs.getType());
// use lhs's element type as compute type
rhs =
hlo::promoteType(rewriter, op.getLoc(), rhs, lhsTy.getElementType());
rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
isRhsScalar = true;
}

auto outType = cast<RankedTensorType>(
Expand All @@ -537,16 +536,28 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
}

if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) {
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
// torch.lt(x_int, 1.1) use fp32 as compute type
// torch.lt(x_int, y_float) use y's float type as compute type
Type promoteTo = isRhsScalar ? rewriter.getF32Type() : rhsElemTy;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, promoteTo);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, promoteTo);
} else if (isa<mlir::FloatType>(lhsElemTy) &&
isa<mlir::IntegerType>(rhsElemTy)) {
// always use lhs's float type as compute type
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
} else {
if (lhsElemTy.getIntOrFloatBitWidth() >
rhsElemTy.getIntOrFloatBitWidth()) {
if (isRhsScalar) {
// torch.lt(x_float, 1.1) use x's float type as compute type
// torch.lt(x_int, 1) use x's int type as compute type
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
} else {
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
// torch.lt(x_float, y_float) use higher bitwidth as compute type
Type promoteTo = lhsElemTy.getIntOrFloatBitWidth() >
rhsElemTy.getIntOrFloatBitWidth()
? lhsElemTy
: rhsElemTy;
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, promoteTo);
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, promoteTo);
}
}
lhsElemTy = dyn_cast<RankedTensorType>(lhs.getType()).getElementType();
Expand Down
1 change: 0 additions & 1 deletion projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,6 @@
"AtenPolarFloatModule_basic",
"DiagonalWithStaticShapeModule_basic",
"EinsumStaticDiagonalDimensionModule_basic",
"ElementwiseIntTensorLtFloatScalarModule_basic",
"ElementwiseRemainderScalarModule_Bool_NegativeDivisor_basic",
"ElementwiseRemainderScalarModule_Float_NegativeDividend_basic",
"ElementwiseRemainderScalarModule_Float_NegativeDivisor_basic",
Expand Down

0 comments on commit a9c318b

Please sign in to comment.