Skip to content

Commit

Permalink
Merge pull request #171 from Xilinx/jose.fix_avg_pool_3d
Browse files Browse the repository at this point in the history
Fixes avg pool 3d
  • Loading branch information
josel-amd authored Jun 10, 2024
2 parents 4fceed9 + 862fcca commit 0f85717
Show file tree
Hide file tree
Showing 11 changed files with 1,169 additions and 342 deletions.
4 changes: 4 additions & 0 deletions include/torch-mlir/Conversion/TorchToLinalg/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ getBackendTypeForScalarType(MLIRContext *context,

bool isUnsignedTorchType(Type type);

LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter,
Location loc, SmallVector<int64_t> dimensions,
Value input, Value &result);

} // namespace torch_to_linalg
} // namespace torch
} // namespace mlir
50 changes: 50 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7181,6 +7181,31 @@ def Torch_Aten_AdaptiveAvgPool3dBackwardOp : Torch_Op<"aten._adaptive_avg_pool3d
}];
}

def Torch_AtenAdaptiveMaxPool1dOp : Torch_Op<"aten.adaptive_max_pool1d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::adaptive_max_pool1d : (Tensor, int[]) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$output_size
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAdaptiveMaxPool1dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 2);
}
void AtenAdaptiveMaxPool1dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 2);
}
}];
}

def Torch_AtenAdaptiveMaxPool2dOp : Torch_Op<"aten.adaptive_max_pool2d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand All @@ -7206,6 +7231,31 @@ def Torch_AtenAdaptiveMaxPool2dOp : Torch_Op<"aten.adaptive_max_pool2d", [
}];
}

def Torch_AtenAdaptiveMaxPool3dOp : Torch_Op<"aten.adaptive_max_pool3d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::adaptive_max_pool3d : (Tensor, int[]) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$output_size
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAdaptiveMaxPool3dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 2);
}
void AtenAdaptiveMaxPool3dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 2);
}
}];
}

def Torch_AtenTopkOp : Torch_Op<"aten.topk", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
55 changes: 7 additions & 48 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1457,56 +1457,15 @@ class ConvertAtenPermuteOp : public OpConversionPattern<AtenPermuteOp> {
return rewriter.notifyMatchFailure(op, "all dimensions must be constant");

Value inVector = adaptor.getSelf();
auto inType = inVector.getType().cast<RankedTensorType>();
int64_t inputRank = inType.getRank();
auto outType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Type elementType = inType.getElementType();

// Check if the dimensions are a valid constants.
int64_t numDimensions = dimensions.size();
if (inputRank != numDimensions)
Value result;
if (failed(torch_to_linalg::permuteTensor(op, rewriter, op->getLoc(),
dimensions, inVector, result)))
return rewriter.notifyMatchFailure(
op, "size of `dims` must be equal to the rank of the input");
for (unsigned i = 0; i < numDimensions; i++) {
if (dimensions[i] < 0)
dimensions[i] = toPositiveDim(dimensions[i], inputRank);
if (!isValidDim(dimensions[i], inputRank))
return rewriter.notifyMatchFailure(op, "dimension out of range");
}

Location loc = op.getLoc();

SmallVector<Value> outputDims;
for (unsigned i = 0; i < inputRank; i++)
outputDims.push_back(getDimOp(rewriter, loc, inVector, dimensions[i]));
op, "failed to perform permutation of tensor");

Value outVector = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outputDims), elementType);
SmallVector<AffineExpr> idExprs;
SmallVector<AffineExpr> swapExprs;
for (unsigned i = 0; i < inputRank; i++)
idExprs.push_back(getAffineDimExpr(i, rewriter.getContext()));
for (unsigned i = 0; i < inputRank; i++)
swapExprs.push_back(idExprs[dimensions[i]]);

AffineMap inputMap =
AffineMap::get(inputRank, /*symbolCount=*/0, idExprs, op->getContext());
AffineMap outputMap = AffineMap::get(inputRank, /*symbolCount=*/0,
swapExprs, op->getContext());
SmallVector<AffineMap> indexingMaps{inputMap, outputMap};
SmallVector<utils::IteratorType> iteratorTypes(
inputRank, utils::IteratorType::parallel);
auto transpose = rewriter
.create<linalg::GenericOp>(
loc, outVector.getType(), inVector, outVector,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, transpose);
auto outType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, result);
return success();
}
};
Expand Down
Loading

0 comments on commit 0f85717

Please sign in to comment.