Skip to content

Commit

Permalink
Merge pull request #182 from Xilinx/feature/backport_ea1_ops
Browse files Browse the repository at this point in the history
Merge main to release
  • Loading branch information
mgehre-amd authored Jun 13, 2024
2 parents 03c5d50 + d7a881c commit f00e686
Show file tree
Hide file tree
Showing 17 changed files with 1,259 additions and 367 deletions.
4 changes: 4 additions & 0 deletions include/torch-mlir/Conversion/TorchToLinalg/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ getBackendTypeForScalarType(MLIRContext *context,

bool isUnsignedTorchType(Type type);

LogicalResult permuteTensor(Operation *op, PatternRewriter &rewriter,
Location loc, SmallVector<int64_t> dimensions,
Value input, Value &result);

} // namespace torch_to_linalg
} // namespace torch
} // namespace mlir
50 changes: 50 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7181,6 +7181,31 @@ def Torch_Aten_AdaptiveAvgPool3dBackwardOp : Torch_Op<"aten._adaptive_avg_pool3d
}];
}

def Torch_AtenAdaptiveMaxPool1dOp : Torch_Op<"aten.adaptive_max_pool1d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::adaptive_max_pool1d : (Tensor, int[]) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$output_size
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAdaptiveMaxPool1dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 2);
}
void AtenAdaptiveMaxPool1dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 2);
}
}];
}

def Torch_AtenAdaptiveMaxPool2dOp : Torch_Op<"aten.adaptive_max_pool2d", [
AllowsTypeRefinement,
HasValueSemantics,
Expand All @@ -7206,6 +7231,31 @@ def Torch_AtenAdaptiveMaxPool2dOp : Torch_Op<"aten.adaptive_max_pool2d", [
}];
}

def Torch_AtenAdaptiveMaxPool3dOp : Torch_Op<"aten.adaptive_max_pool3d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::adaptive_max_pool3d : (Tensor, int[]) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$output_size
);
let results = (outs
AnyTorchTensorType:$result0,
AnyTorchTensorType:$result1
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAdaptiveMaxPool3dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 2);
}
void AtenAdaptiveMaxPool3dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 2);
}
}];
}

def Torch_AtenTopkOp : Torch_Op<"aten.topk", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
5 changes: 3 additions & 2 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return success();
});
patterns.onOp(
"AveragePool", 19,
"AveragePool", 11,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
std::string autoPad;
SmallVector<int64_t> dilation;
Expand Down Expand Up @@ -357,7 +357,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.op,
"padding list size does not match twice the number of axes");
}
if (binder.s64IntegerArrayAttr(strides, "strides", {1})) {
if (binder.s64IntegerArrayAttr(
strides, "strides", llvm::SmallVector<int64_t>(rank - 2, 1))) {
return failure();
}
if (strides.size() != 1 && strides.size() != rank - 2) {
Expand Down
13 changes: 2 additions & 11 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,17 +970,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
}

if (!constantValue) {
auto dataTensorType = data.getType().cast<Torch::ValueTensorType>();
if (dataTensorType.getDtype().isa<IntegerType>())
constantValue = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
if (dataTensorType.getDtype().isa<FloatType>())
constantValue = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(0.0f));

if (!constantValue)
return rewriter.notifyMatchFailure(
binder.op, "expected integer or float data tensor");
constantValue = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(0.0f));
}

// Extract all the values of 1-D pad tensor and create a list of all
Expand Down
55 changes: 7 additions & 48 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1457,56 +1457,15 @@ class ConvertAtenPermuteOp : public OpConversionPattern<AtenPermuteOp> {
return rewriter.notifyMatchFailure(op, "all dimensions must be constant");

Value inVector = adaptor.getSelf();
auto inType = inVector.getType().cast<RankedTensorType>();
int64_t inputRank = inType.getRank();
auto outType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
Type elementType = inType.getElementType();

// Check if the dimensions are a valid constants.
int64_t numDimensions = dimensions.size();
if (inputRank != numDimensions)
Value result;
if (failed(torch_to_linalg::permuteTensor(op, rewriter, op->getLoc(),
dimensions, inVector, result)))
return rewriter.notifyMatchFailure(
op, "size of `dims` must be equal to the rank of the input");
for (unsigned i = 0; i < numDimensions; i++) {
if (dimensions[i] < 0)
dimensions[i] = toPositiveDim(dimensions[i], inputRank);
if (!isValidDim(dimensions[i], inputRank))
return rewriter.notifyMatchFailure(op, "dimension out of range");
}

Location loc = op.getLoc();

SmallVector<Value> outputDims;
for (unsigned i = 0; i < inputRank; i++)
outputDims.push_back(getDimOp(rewriter, loc, inVector, dimensions[i]));
op, "failed to perform permutation of tensor");

Value outVector = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outputDims), elementType);
SmallVector<AffineExpr> idExprs;
SmallVector<AffineExpr> swapExprs;
for (unsigned i = 0; i < inputRank; i++)
idExprs.push_back(getAffineDimExpr(i, rewriter.getContext()));
for (unsigned i = 0; i < inputRank; i++)
swapExprs.push_back(idExprs[dimensions[i]]);

AffineMap inputMap =
AffineMap::get(inputRank, /*symbolCount=*/0, idExprs, op->getContext());
AffineMap outputMap = AffineMap::get(inputRank, /*symbolCount=*/0,
swapExprs, op->getContext());
SmallVector<AffineMap> indexingMaps{inputMap, outputMap};
SmallVector<utils::IteratorType> iteratorTypes(
inputRank, utils::IteratorType::parallel);
auto transpose = rewriter
.create<linalg::GenericOp>(
loc, outVector.getType(), inVector, outVector,
indexingMaps, iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, transpose);
auto outType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outType, result);
return success();
}
};
Expand Down
Loading

0 comments on commit f00e686

Please sign in to comment.