Skip to content

Commit

Permalink
Merge pull request #411 from Xilinx/matthias.fix_tosa_pow
Browse files Browse the repository at this point in the history
TorchToTosa: Correctly lower pow with broadcasting
  • Loading branch information
mgehre-amd authored Dec 12, 2024
2 parents c48183a + 0f604c8 commit 2341f20
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
20 changes: 11 additions & 9 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1084,24 +1084,26 @@ LogicalResult ConvertAtenOp<AtenPowTensorTensorOp>::matchAndRewrite(
ConversionPatternRewriter &rewriter) const {

Value self = adaptor.getSelf();
auto selfTy = cast<RankedTensorType>(self.getType());
auto selfTy = dyn_cast<RankedTensorType>(self.getType());
auto outType =
dyn_cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
Value expTensor = adaptor.getExponent();
auto expTensorTy = dyn_cast<RankedTensorType>(expTensor.getType());

if (!selfTy)
if (!selfTy || !outType || !expTensorTy) {
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types supported in TOSA Pow");
}

if (!isa<mlir::FloatType>(selfTy.getElementType()))
if (!isa<mlir::FloatType>(selfTy.getElementType())) {
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype legalization supported");
}

auto outType =
cast<TensorType>(getTypeConverter()->convertType(op.getType()));

Value expTensor = adaptor.getExponent();
if (expTensor.getType() != selfTy) {
if (expTensorTy.getElementType() != selfTy.getElementType()) {
expTensor = rewriter.createOrFold<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(outType.getShape(), selfTy.getElementType()),
RankedTensorType::get(expTensorTy.getShape(), selfTy.getElementType()),
expTensor);
}

Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1993,6 +1993,7 @@
"ElementwisePowTensorBroadcastModule_basic",
"ElementwisePowTensorBroadcastStaticModule_basic",
"ElementwisePowTensorModule_basic",
"ElementwisePowTensorStaticModule_basic",
"ElementwisePreluModule_basic",
"ElementwisePreluStaticModule_basic",
"ElementwiseRad2DegModule_basic",
Expand Down

0 comments on commit 2341f20

Please sign in to comment.