Skip to content

Commit

Permalink
Merge pull request #418 from Xilinx/bump_to_3f79a298
Browse files Browse the repository at this point in the history
[AutoBump] Merge with fixes of 3f79a29 (Sep 20) (57)
  • Loading branch information
mgehre-amd authored Dec 17, 2024
2 parents 45037b6 + 5456f2d commit 9a76fb9
Show file tree
Hide file tree
Showing 3 changed files with 324 additions and 150 deletions.
207 changes: 108 additions & 99 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,18 @@ class ConvertAtenBinaryOp : public OpConversionPattern<AtenOpT> {
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));

auto binaryOp =
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
rewriter.replaceOp(op, binaryOp.getResult());
Value binaryOp;

// TOSA ArithmeticRightShiftOp has a round parameter.
if constexpr (std::is_same<AtenOpT, AtenBitwiseRightShiftTensorOp>()) {
binaryOp = rewriter.create<TosaOpT>(op->getLoc(), outTy, lhs, rhs,
/*round=*/false);
} else {
binaryOp =
tosa::createBinaryOpAndCast<TosaOpT>(rewriter, op, outTy, lhs, rhs);
}

rewriter.replaceOp(op, binaryOp);
return success();
}
};
Expand Down Expand Up @@ -354,6 +363,7 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
// For bitwise operators, only integer datatype legalization is supported
constexpr bool isBitwiseOp =
std::is_same<AtenOpT, AtenBitwiseAndTensorOp>() ||
std::is_same<AtenOpT, AtenBitwiseAndScalarOp>() ||
std::is_same<AtenOpT, AtenBitwiseOrTensorOp>() ||
std::is_same<AtenOpT, AtenBitwiseXorTensorOp>();
if (isa<mlir::FloatType>(lhsElemTy) && isBitwiseOp) {
Expand All @@ -375,8 +385,7 @@ class ConvertAtenCompareOp : public OpConversionPattern<AtenOpT> {
constexpr auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>() ||
std::is_same<AtenOpT, AtenLeTensorOp>() ||
std::is_same<AtenOpT, AtenLeScalarOp>() ||
std::is_same<AtenOpT, AtenLeTensorOp>());
std::is_same<AtenOpT, AtenLeScalarOp>());

// Promote lhs and rhs dtypes for bitwise operators.
TensorType resultTy = cast<TensorType>(
Expand Down Expand Up @@ -692,39 +701,30 @@ class ConvertAtenOp : public OpConversionPattern<AtenOpT> {
ConversionPatternRewriter &rewriter) const override;
};

template <>
LogicalResult ConvertAtenOp<AtenTanhOp>::matchAndRewrite(
AtenTanhOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
auto selfTy = cast<TensorType>(self.getType());
if (selfTy && isa<mlir::FloatType>(selfTy.getElementType())) {
rewriter.replaceOpWithNewOp<tosa::TanhOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
}
// Sigmoid legalization in TOSA for quantized element-type uses specialized
// tosa.table construct.
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization currently supported");
}
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenActivationFunctionOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value self = adaptor.getSelf();
auto selfTy = cast<TensorType>(self.getType());

if (!selfTy)
return rewriter.notifyMatchFailure(op, "Only Tensor types supported");

if (!isa<mlir::FloatType>(selfTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization currently supported");

rewriter.replaceOpWithNewOp<TosaOpT>(
op, this->getTypeConverter()->convertType(op.getType()), self);

template <>
LogicalResult ConvertAtenOp<AtenSigmoidOp>::matchAndRewrite(
AtenSigmoidOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value self = adaptor.getSelf();
auto selfTy = cast<TensorType>(self.getType());
if (selfTy && isa<mlir::FloatType>(selfTy.getElementType())) {
rewriter.replaceOpWithNewOp<tosa::SigmoidOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
}
// Sigmoid legalization in TOSA for quantized element-type uses
// specialized tosa.table construct.
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization currently supported");
}
};

template <>
LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
Expand Down Expand Up @@ -1209,73 +1209,63 @@ class ConvertAtenSqueezeAllDimsOp : public ConvertAtenSqueezeOp<AtenOpT> {
}
};

template <>
LogicalResult ConvertAtenOp<AtenPowScalarOp>::matchAndRewrite(
AtenPowScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

Value exp = adaptor.getExponent();
auto expTy = dyn_cast<RankedTensorType>(exp.getType());

if (!expTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Pow");

if (!isa<mlir::FloatType>(expTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");

Value selfTensor;
Value selfScalar = op.getSelf();
if (failed(torchScalarToTosaTensor(rewriter, op, selfScalar, selfTensor,
expTy.getElementType(), {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA Pow operation");

auto outType =
cast<TensorType>(getTypeConverter()->convertType(op.getType()));

auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(rewriter, op, outType,
selfTensor, exp);
rewriter.replaceOp(op, powOp.getResult());

return success();
}
template <typename AtenOpT>
class ConvertAtenPowOp : public OpConversionPattern<AtenOpT> {
public:
using OpConversionPattern<AtenOpT>::OpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;
LogicalResult
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

template <>
LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
AtenPowTensorScalarOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto outType =
cast<TensorType>(this->getTypeConverter()->convertType(op.getType()));

Value self = adaptor.getSelf();
auto selfTy = cast<RankedTensorType>(self.getType());
Value selfTensor;
if constexpr (std::is_same<AtenOpT, AtenPowScalarOp>()) {
Value selfScalar = op.getSelf();
if (failed(torchScalarToTosaTensor(rewriter, op, selfScalar, selfTensor,
outType.getElementType(), {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA PowScalar operation");
} else {
selfTensor = adaptor.getSelf();
auto selfTy = cast<RankedTensorType>(selfTensor.getType());

if (!selfTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Pow");
if (!selfTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Pow");

if (!isa<mlir::FloatType>(selfTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");
if (!isa<mlir::FloatType>(selfTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");
}

auto outType =
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
Value expTensor;
if constexpr (std::is_same<AtenOpT, AtenPowTensorScalarOp>()) {
Value expScalar = op.getExponent();
if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor,
outType.getElementType(), {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA Pow operation");
} else {
expTensor = adaptor.getExponent();
auto expTy = cast<RankedTensorType>(expTensor.getType());

Value expTensor;
Value expScalar = op.getExponent();
if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor,
outType.getElementType(), {})))
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"conversion in TOSA Pow operation");
if (!expTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Pow");
}

auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(rewriter, op, outType,
self, expTensor);
rewriter.replaceOp(op, powOp.getResult());
auto powOp = tosa::createBinaryOpAndCast<tosa::PowOp>(
rewriter, op, outType, selfTensor, expTensor);
rewriter.replaceOp(op, powOp.getResult());

return success();
}
return success();
}
};

template <>
LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
Expand Down Expand Up @@ -6721,6 +6711,10 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_BINARY_PATTERN(AtenLogicalAndOp, tosa::LogicalAndOp)
INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp)
INSERT_BINARY_PATTERN(AtenLogicalXorOp, tosa::LogicalXorOp)
INSERT_BINARY_PATTERN(AtenBitwiseLeftShiftTensorOp,
tosa::LogicalLeftShiftOp)
INSERT_BINARY_PATTERN(AtenBitwiseRightShiftTensorOp,
tosa::ArithmeticRightShiftOp)
#undef INSERT_BINARY_PATTERN

#define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \
Expand All @@ -6744,11 +6738,14 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_BINARY_COMPARE_PATTERN(AtenGtScalarOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtTensorOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLeTensorOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenLeScalarOp, tosa::GreaterEqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp)
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndScalarOp, tosa::BitwiseAndOp)
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseOrTensorOp, tosa::BitwiseOrOp)
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseXorTensorOp, tosa::BitwiseXorOp)
#undef INSERT_BINARY_COMPARE_PATTERN
Expand Down Expand Up @@ -6889,18 +6886,30 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_MASKED_FILL_PATTERN(AtenMaskedFillTensorOp);
#undef INSERT_MASKED_FILL_PATTERN

#define INSERT_POW_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenPowOp<AtenOp>>(typeConverter, context);
INSERT_POW_OP_PATTERN(AtenPowTensorScalarOp);
INSERT_POW_OP_PATTERN(AtenPowTensorTensorOp);
INSERT_POW_OP_PATTERN(AtenPowScalarOp);
#undef INSERT_POW_OP_PATTERN

#define INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenOp, TosaOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenActivationFunctionOp<AtenOp, TosaOp>>(typeConverter, \
context);
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenTanhOp, tosa::TanhOp);
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenSigmoidOp, tosa::SigmoidOp);
INSERT_ACTIVATION_FUNCTION_OP_PATTERN(AtenErfOp, tosa::ErfOp);
#undef INSERT_ACTIVATION_FUNCITON_OP_PATTERN

#define INSERT_ATENOP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);
INSERT_ATENOP_PATTERN(AtenTanhOp);
INSERT_ATENOP_PATTERN(AtenHardtanhBackwardOp);
INSERT_ATENOP_PATTERN(AtenSigmoidOp);
INSERT_ATENOP_PATTERN(AtenReluOp);
INSERT_ATENOP_PATTERN(AtenLeakyReluOp);
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
INSERT_ATENOP_PATTERN(AtenPowScalarOp);
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
INSERT_ATENOP_PATTERN(AtenPowTensorTensorOp);
INSERT_ATENOP_PATTERN(AtenRsubScalarOp);
INSERT_ATENOP_PATTERN(AtenConvolutionOp);
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
Expand Down
Loading

0 comments on commit 9a76fb9

Please sign in to comment.