From 1886f74a63b22b8ad62e38c9c6d8e14401cae0e9 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 26 Aug 2024 13:15:26 +0200 Subject: [PATCH 01/46] Fix merge --- .../ArithToEmitC/arith-to-emitc-unsupported.mlir | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir index 31c39a91a8b686..3be3941b3ef043 100644 --- a/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir +++ b/mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir @@ -81,6 +81,14 @@ func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 { // ----- +func.func @arith_extsi_i1_to_i32(%arg0: i1) { + // expected-error @+1 {{failed to legalize operation 'arith.extsi'}} + %idx = arith.extsi %arg0 : i1 to i32 + return +} + +// ----- + func.func @arith_negf_tensor(%arg0: tensor<5xf32>) -> tensor<5xf32> { // expected-error @+1 {{failed to legalize operation 'arith.negf'}} %n = arith.negf %arg0 : tensor<5xf32> From dec1017e3b23053f102b6ff57d3e88c98dc04ac5 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Mon, 9 Sep 2024 09:01:16 +0200 Subject: [PATCH 02/46] [FXML-4791] Add printout of references with emitc.reference attr (#316) --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.h | 3 + mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 10 +- .../Dialect/EmitC/IR/FunctionOpAssembly.h | 43 +++ mlir/lib/Dialect/EmitC/IR/CMakeLists.txt | 1 + mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 31 +- .../Dialect/EmitC/IR/FunctionOpAssembly.cpp | 310 ++++++++++++++++++ mlir/lib/Target/Cpp/TranslateToCpp.cpp | 116 ++++++- mlir/test/Dialect/EmitC/func.mlir | 27 ++ mlir/test/Dialect/EmitC/invalid_ops.mlir | 12 +- mlir/test/Dialect/EmitC/ops.mlir | 1 + mlir/test/Target/Cpp/common-cpp.mlir | 10 + mlir/test/Target/Cpp/declare_func.mlir | 14 + mlir/test/Target/Cpp/func.mlir | 9 + mlir/test/Target/Cpp/global.mlir | 3 + mlir/test/Target/Cpp/invalid.mlir | 8 + 15 files changed, 574 insertions(+), 24 deletions(-) create mode 100644 mlir/include/mlir/Dialect/EmitC/IR/FunctionOpAssembly.h create mode 100644 mlir/lib/Dialect/EmitC/IR/FunctionOpAssembly.cpp create mode 100644 mlir/test/Dialect/EmitC/func.mlir diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h index 87a4078f280f65..0c595a6b109caa 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -47,6 +47,9 @@ bool isSupportedFloatType(mlir::Type type); /// Determines whether \p type is a emitc.size_t/ssize_t type. bool isPointerWideType(mlir::Type type); +/// Give the name of the EmitC reference attribute. +StringRef getReferenceAttributeName(); + } // namespace emitc } // namespace mlir diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 452302c565139c..0c945ab2c40304 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -285,9 +285,11 @@ def EmitC_CastOp : EmitC_Op<"cast", ``` }]; - let arguments = (ins EmitCType:$source); + let arguments = (ins EmitCType:$source, + UnitAttr:$reference); let results = (outs EmitCType:$dest); - let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; + let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest) (`ref` $reference^)?"; + let hasVerifier = 1; } def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> { @@ -1050,7 +1052,8 @@ def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> { OptionalAttr:$initial_value, UnitAttr:$extern_specifier, UnitAttr:$static_specifier, - UnitAttr:$const_specifier); + UnitAttr:$const_specifier, + UnitAttr:$reference); let assemblyFormat = [{ (`extern` $extern_specifier^)? @@ -1058,6 +1061,7 @@ def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> { (`const` $const_specifier^)? $sym_name `:` custom($type, $initial_value) + (`ref` $reference^)? attr-dict }]; diff --git a/mlir/include/mlir/Dialect/EmitC/IR/FunctionOpAssembly.h b/mlir/include/mlir/Dialect/EmitC/IR/FunctionOpAssembly.h new file mode 100644 index 00000000000000..22567c97a21ad7 --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/IR/FunctionOpAssembly.h @@ -0,0 +1,43 @@ +//===---------- FunctionOpAssembly.h - Parser for `emitc.func` op ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INCLUDE_MLIR_DIALECT_EMITC_IR_FUNCTIONOPASSEMBLY_H +#define MLIR_INCLUDE_MLIR_DIALECT_EMITC_IR_FUNCTIONOPASSEMBLY_H + +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LogicalResult.h" + +#include "mlir/IR/Builders.h" + +namespace mlir::emitc { + +class FuncOp; + +ParseResult +parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &arguments, + bool &isVariadic, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs); + +ParseResult +parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, + StringAttr typeAttrName, + function_interface_impl::FuncTypeBuilder funcTypeBuilder, + StringAttr argAttrsName, StringAttr resAttrsName); + +void printFunctionSignature(OpAsmPrinter &p, FuncOp op, ArrayRef argTypes, + bool isVariadic, ArrayRef resultTypes); + +void printFunctionOp(OpAsmPrinter &p, FuncOp op, bool isVariadic, + StringRef typeAttrName, StringAttr argAttrsName, + StringAttr resAttrsName); + +} // namespace mlir::emitc + +#endif // MLIR_INCLUDE_MLIR_DIALECT_EMITC_IR_FUNCTIONOPASSEMBLY_H diff --git a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt index 4cc54201d2745d..e1bef7f6851cb2 100644 --- a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIREmitCDialect EmitC.cpp + FunctionOpAssembly.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 03f96704ab4f6d..aa2495bc42ba03 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/EmitC/IR/EmitCTraits.h" +#include "mlir/Dialect/EmitC/IR/FunctionOpAssembly.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -122,6 +123,8 @@ bool mlir::emitc::isPointerWideType(Type type) { type); } +StringRef mlir::emitc::getReferenceAttributeName() { return "emitc.reference"; } + /// Check that the type of the initial value is compatible with the operations /// result type. static LogicalResult verifyInitializationAttribute(Operation *op, @@ -232,6 +235,13 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { emitc::isSupportedFloatType(output) || isa(output))); } +LogicalResult CastOp::verify() { + if (getReference()) + return emitOpError("cast of value type must not bear a reference"); + + return success(); +} + //===----------------------------------------------------------------------===// // CallOpaqueOp //===----------------------------------------------------------------------===// @@ -518,16 +528,15 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { function_interface_impl::VariadicFlag, std::string &) { return builder.getFunctionType(argTypes, results); }; - return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType, - getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); + return parseFunctionOp(parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp( - p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), - getArgAttrsAttrName(), getResAttrsAttrName()); + printFunctionOp(p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } LogicalResult FuncOp::verify() { @@ -1029,6 +1038,12 @@ LogicalResult GlobalOp::verify() { } if (getInitialValue().has_value()) { Attribute initValue = getInitialValue().value(); + if (getReference() && !isa(initValue)) { + return emitOpError("global reference initial value must be an opaque " + "attribute, got ") + << initValue; + } + // Check that the type of the initial value is compatible with the type of // the global variable. if (auto elementsAttr = llvm::dyn_cast(initValue)) { @@ -1057,6 +1072,8 @@ LogicalResult GlobalOp::verify() { "or opaque attribute, but got ") << initValue; } + } else if (getReference()) { + return emitOpError("global reference must be initialized"); } if (getStaticSpecifier() && getExternSpecifier()) { return emitOpError("cannot have both static and extern specifiers"); diff --git a/mlir/lib/Dialect/EmitC/IR/FunctionOpAssembly.cpp b/mlir/lib/Dialect/EmitC/IR/FunctionOpAssembly.cpp new file mode 100644 index 00000000000000..0db97a5890868c --- /dev/null +++ b/mlir/lib/Dialect/EmitC/IR/FunctionOpAssembly.cpp @@ -0,0 +1,310 @@ +//===--------- FunctionOpAssembly.cpp - Parser for `emitc.func` op --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This function printer/parser are copies of those in +// Interfaces/FunctionImplementation.cpp, except that they print out arguments +// followed by "ref" if they bear the emitc.reference attribute. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "llvm/ADT/SmallVector.h" + +#include "mlir/Dialect/EmitC/IR/FunctionOpAssembly.h" + +using namespace mlir; + +namespace mlir::emitc { + +static ParseResult +parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &arguments, + bool &isVariadic) { + + // Parse the function arguments. The argument list either has to consistently + // have ssa-id's followed by types, or just be a type list. It isn't ok to + // sometimes have SSA ID's and sometimes not. + isVariadic = false; + + return parser.parseCommaSeparatedList( + OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { + // Ellipsis must be at end of the list. + if (isVariadic) + return parser.emitError( + parser.getCurrentLocation(), + "variadic arguments must be in the end of the argument list"); + + // Handle ellipsis as a special case. + if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) { + // This is a variadic designator. + isVariadic = true; + return success(); // Stop parsing arguments. + } + // Parse argument name if present. + OpAsmParser::Argument argument; + auto argPresent = parser.parseOptionalArgument( + argument, /*allowType=*/true, /*allowAttrs=*/true); + if (argPresent.has_value()) { + if (failed(argPresent.value())) + return failure(); // Present but malformed. + + // Reject this if the preceding argument was missing a name. + if (!arguments.empty() && arguments.back().ssaName.name.empty()) + return parser.emitError(argument.ssaName.location, + "expected type instead of SSA identifier"); + if (succeeded(parser.parseOptionalKeyword("ref"))) { + llvm::ArrayRef origAttrs; + if (!argument.attrs.empty()) + origAttrs = argument.attrs.getValue(); + + SmallVector attrs(origAttrs); + attrs.push_back(NamedAttribute( + StringAttr::get(parser.getContext(), + emitc::getReferenceAttributeName()), + UnitAttr::get(parser.getContext()))); + argument.attrs = DictionaryAttr::get(parser.getContext(), attrs); + } + } else { + argument.ssaName.location = parser.getCurrentLocation(); + // Otherwise we just have a type list without SSA names. Reject + // this if the preceding argument had a name. + if (!arguments.empty() && !arguments.back().ssaName.name.empty()) + return parser.emitError(argument.ssaName.location, + "expected SSA identifier"); + + NamedAttrList attrs; + if (parser.parseType(argument.type) || + parser.parseOptionalAttrDict(attrs) || + parser.parseOptionalLocationSpecifier(argument.sourceLoc)) + return failure(); + if (succeeded(parser.parseOptionalKeyword("ref"))) { + // Add attribute to argument + attrs.push_back(NamedAttribute( + StringAttr::get(parser.getContext(), + emitc::getReferenceAttributeName()), + UnitAttr::get(parser.getContext()))); + } + argument.attrs = attrs.getDictionary(parser.getContext()); + } + arguments.push_back(argument); + return success(); + }); +} + +/// Parse a function result. +/// +/// function-result ::= type | `(` type attribute-dict? `)` +/// +static ParseResult +parseFunctionResult(OpAsmParser &parser, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs) { + + bool hasLParen = succeeded(parser.parseOptionalLParen()); + + if (hasLParen) { + // Special case for an empty set of parens. + if (succeeded(parser.parseOptionalRParen())) + return success(); + } + + // Parse a single type. + Type ty; + if (parser.parseType(ty)) + return failure(); + resultTypes.push_back(ty); + resultAttrs.emplace_back(); + + // There can be no attribute without parentheses (they would be confused with + // the function body) + if (!hasLParen) + return success(); + + // Parse result attributes if any. + NamedAttrList attrs; + if (succeeded(parser.parseOptionalAttrDict(attrs))) + resultAttrs.back() = attrs.getDictionary(parser.getContext()); + + return parser.parseRParen(); +} + +ParseResult +parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &arguments, + bool &isVariadic, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs) { + if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic)) + return failure(); + if (succeeded(parser.parseOptionalArrow())) + return parseFunctionResult(parser, resultTypes, resultAttrs); + return success(); +} + +ParseResult +parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, + StringAttr typeAttrName, + function_interface_impl::FuncTypeBuilder funcTypeBuilder, + StringAttr argAttrsName, StringAttr resAttrsName) { + SmallVector entryArgs; + SmallVector resultAttrs; + SmallVector resultTypes; + auto &builder = parser.getBuilder(); + + // Parse visibility. + (void)impl::parseOptionalVisibilityKeyword(parser, result.attributes); + + // Parse the name as a symbol. + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + result.attributes)) + return failure(); + + // Parse the function signature. + SMLoc signatureLocation = parser.getCurrentLocation(); + bool isVariadic = false; + if (parseFunctionSignature(parser, allowVariadic, entryArgs, isVariadic, + resultTypes, resultAttrs)) + return failure(); + + std::string errorMessage; + SmallVector argTypes; + argTypes.reserve(entryArgs.size()); + for (auto &arg : entryArgs) + argTypes.push_back(arg.type); + Type type = funcTypeBuilder(builder, argTypes, resultTypes, + function_interface_impl::VariadicFlag(isVariadic), + errorMessage); + if (!type) { + return parser.emitError(signatureLocation) + << "failed to construct function type" + << (errorMessage.empty() ? "" : ": ") << errorMessage; + } + result.addAttribute(typeAttrName, TypeAttr::get(type)); + + // If function attributes are present, parse them. + NamedAttrList parsedAttributes; + SMLoc attributeDictLocation = parser.getCurrentLocation(); + if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) + return failure(); + + // Disallow attributes that are inferred from elsewhere in the attribute + // dictionary. + for (StringRef disallowed : + {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(), + typeAttrName.getValue()}) { + if (parsedAttributes.get(disallowed)) + return parser.emitError(attributeDictLocation, "'") + << disallowed + << "' is an inferred attribute and should not be specified in the " + "explicit attribute dictionary"; + } + result.attributes.append(parsedAttributes); + + // Add the attributes to the function arguments. + assert(resultAttrs.size() == resultTypes.size()); + function_interface_impl::addArgAndResultAttrs( + builder, result, entryArgs, resultAttrs, argAttrsName, resAttrsName); + + // Parse the optional function body. The printer will not print the body if + // its empty, so disallow parsing of empty body in the parser. + auto *body = result.addRegion(); + SMLoc loc = parser.getCurrentLocation(); + OptionalParseResult parseResult = + parser.parseOptionalRegion(*body, entryArgs, + /*enableNameShadowing=*/false); + if (parseResult.has_value()) { + if (failed(*parseResult)) + return failure(); + // Function body was parsed, make sure its not empty. + if (body->empty()) + return parser.emitError(loc, "expected non-empty function body"); + } + return success(); +} + +void printFunctionSignature(OpAsmPrinter &p, FuncOp op, ArrayRef argTypes, + bool isVariadic, ArrayRef resultTypes) { + Region &body = op->getRegion(0); + bool isExternal = body.empty(); + + p << '('; + ArrayAttr argAttrs = op.getArgAttrsAttr(); + for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { + if (i > 0) + p << ", "; + + // Exclude reference attribute if there is to replace it by ref + SmallVector attrs; + if (argAttrs) { + for (auto attr : llvm::cast(argAttrs[i]).getValue()) { + if (attr.getName() != emitc::getReferenceAttributeName()) + attrs.push_back(attr); + } + } + + if (!isExternal) { + p.printRegionArgument(body.getArgument(i), attrs); + } else { + p.printType(argTypes[i]); + if (argAttrs) + p.printOptionalAttrDict(attrs); + } + } + + if (isVariadic) { + if (!argTypes.empty()) + p << ", "; + p << "..."; + } + + p << ')'; + + if (!resultTypes.empty()) { + assert(resultTypes.size() == 1); + p.getStream() << " -> "; + auto resultAttrs = op.getResAttrsAttr(); + p.printType(resultTypes[0]); + if (resultAttrs) + p.printOptionalAttrDict( + llvm::cast(resultAttrs[0]).getValue()); + } +} + +void printFunctionOp(OpAsmPrinter &p, FuncOp op, bool isVariadic, + StringRef typeAttrName, StringAttr argAttrsName, + StringAttr resAttrsName) { + // Print the operation and the function name. + auto funcName = + op->getAttrOfType(SymbolTable::getSymbolAttrName()) + .getValue(); + p << ' '; + + StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName(); + if (auto visibility = op->getAttrOfType(visibilityAttrName)) + p << visibility.getValue() << ' '; + p.printSymbolName(funcName); + + ArrayRef argTypes = op.getArgumentTypes(); + ArrayRef resultTypes = op.getResultTypes(); + printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); + function_interface_impl::printFunctionAttributes( + p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName}); + // Print the body if this is not an external function. + Region &body = op->getRegion(0); + if (!body.empty()) { + p << ' '; + p.printRegion(body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + } +} + +} // namespace mlir::emitc diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index f61de4a420a649..021bad1f6a7cc1 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -122,6 +123,9 @@ struct CppEmitter { /// Emits operation 'op' with/without training semicolon or returns failure. LogicalResult emitOperation(Operation &op, bool trailingSemicolon); + /// Emits a reference to type 'type' or returns failure. + LogicalResult emitReferenceToType(Location loc, Type type); + /// Emits type 'type' or returns failure. LogicalResult emitType(Location loc, Type type); @@ -143,8 +147,8 @@ struct CppEmitter { bool trailingSemicolon); /// Emits a declaration of a variable with the given type and name. - LogicalResult emitVariableDeclaration(Location loc, Type type, - StringRef name); + LogicalResult emitVariableDeclaration(Location loc, Type type, StringRef name, + bool isReference); /// Emits the variable declaration and assignment prefix for 'op'. /// - emits separate variable followed by std::tie for multi-valued operation; @@ -726,8 +730,14 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) { if (failed(emitter.emitAssignPrefix(op))) return failure(); os << "("; - if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType()))) - return failure(); + if (castOp.getReference()) { + if (failed(emitter.emitReferenceToType(op.getLoc(), + op.getResult(0).getType()))) + return failure(); + } else { + if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType()))) + return failure(); + } os << ") "; return emitter.emitOperand(castOp.getOperand()); } @@ -914,26 +924,73 @@ static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { return success(); } +template static LogicalResult printFunctionArgs(CppEmitter &emitter, - Operation *functionOp, + FuncOpClass functionOp, ArrayRef arguments) { raw_indented_ostream &os = emitter.ostream(); + return (interleaveCommaWithError( + llvm::enumerate(arguments), os, [&](auto arg) -> LogicalResult { + bool hasReference = + functionOp.template getArgAttrOfType( + arg.index(), emitc::getReferenceAttributeName()) != nullptr; + if (hasReference) + return emitter.emitReferenceToType(functionOp->getLoc(), arg.value()); + return emitter.emitType(functionOp->getLoc(), arg.value()); + })); +} + +static LogicalResult printFunctionArgs(CppEmitter &emitter, + Operation *functionOp, + ArrayRef arguments) { + if (auto emitCDialectFunc = dyn_cast(functionOp)) { + return printFunctionArgs(emitter, emitCDialectFunc, arguments); + } + if (auto funcDialectFunc = dyn_cast(functionOp)) { + return printFunctionArgs(emitter, funcDialectFunc, arguments); + } + + raw_indented_ostream &os = emitter.ostream(); return ( interleaveCommaWithError(arguments, os, [&](Type arg) -> LogicalResult { return emitter.emitType(functionOp->getLoc(), arg); })); } +template static LogicalResult printFunctionArgs(CppEmitter &emitter, - Operation *functionOp, + FuncOpClass functionOp, Region::BlockArgListType arguments) { raw_indented_ostream &os = emitter.ostream(); return (interleaveCommaWithError( arguments, os, [&](BlockArgument arg) -> LogicalResult { + bool hasReference = functionOp.template getArgAttrOfType( + arg.getArgNumber(), + emitc::getReferenceAttributeName()) != nullptr; return emitter.emitVariableDeclaration( - functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg)); + functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg), + hasReference); + })); +} + +static LogicalResult printFunctionArgs(CppEmitter &emitter, + Operation *functionOp, + Region::BlockArgListType arguments) { + if (auto emitCDialectFunc = dyn_cast(functionOp)) { + return printFunctionArgs(emitter, emitCDialectFunc, arguments); + } + if (auto funcDialectFunc = dyn_cast(functionOp)) { + return printFunctionArgs(emitter, funcDialectFunc, arguments); + } + + raw_indented_ostream &os = emitter.ostream(); + return (interleaveCommaWithError( + arguments, os, [&](BlockArgument arg) -> LogicalResult { + return emitter.emitVariableDeclaration( + functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg), + /*isReference=*/false); })); } @@ -1401,9 +1458,18 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, return result.getDefiningOp()->emitError( "result variable for the operation already declared"); } + Operation *definingOp = result.getDefiningOp(); + bool isReference = false; + // List all ops that can produce references here + if (auto castOp = llvm::dyn_cast(definingOp)) { + isReference = castOp.getReference(); + } + if (auto globalOp = llvm::dyn_cast(definingOp)) { + isReference = globalOp.getReference(); + } if (failed(emitVariableDeclaration(result.getOwner()->getLoc(), - result.getType(), - getOrCreateName(result)))) + result.getType(), getOrCreateName(result), + isReference))) return failure(); if (trailingSemicolon) os << ";\n"; @@ -1419,7 +1485,7 @@ LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) { os << "const "; if (failed(emitVariableDeclaration(op->getLoc(), op.getType(), - op.getSymName()))) { + op.getSymName(), op.getReference()))) { return failure(); } @@ -1525,11 +1591,17 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { } LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type, - StringRef name) { + StringRef name, + bool isReference) { if (auto arrType = dyn_cast(type)) { if (failed(emitType(loc, arrType.getElementType()))) return failure(); - os << " " << name; + os << " "; + if (isReference) + os << "(&"; + os << name; + if (isReference) + os << ")"; for (auto dim : arrType.getShape()) { os << "[" << dim << "]"; } @@ -1537,7 +1609,25 @@ LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type, } if (failed(emitType(loc, type))) return failure(); - os << " " << name; + os << " "; + if (isReference) + os << "&"; + os << name; + return success(); +} + +LogicalResult CppEmitter::emitReferenceToType(Location loc, Type type) { + if (auto aType = dyn_cast(type)) { + if (failed(emitType(loc, aType.getElementType()))) + return failure(); + os << " (&)"; + for (auto dim : aType.getShape()) + os << "[" << dim << "]"; + return success(); + } + if (failed(emitType(loc, type))) + return failure(); + os << " &"; return success(); } diff --git a/mlir/test/Dialect/EmitC/func.mlir b/mlir/test/Dialect/EmitC/func.mlir new file mode 100644 index 00000000000000..c047958f44b457 --- /dev/null +++ b/mlir/test/Dialect/EmitC/func.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s -split-input-file + +// CHECK: emitc.func @f +// CHECK-SAME: %{{[^:]*}}: i32 ref +emitc.func @f(%x: i32 {emitc.reference}) { + emitc.return +} + +// ----- + +// CHECK: emitc.func @f +// CHECK-SAME: %{{[^:]*}}: i32 ref +emitc.func @f(%x: i32 ref) { + emitc.return +} + +// ----- + +// CHECK: emitc.func @f +// CHECK-SAME: i32 ref +emitc.func @f(i32 ref) + +// ----- + +// CHECK: emitc.func @f +// CHECK-SAME: i32 ref +emitc.func @f(i32 {emitc.reference}) diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index 8cd8bdca4df336..aa2c969b05cc23 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -317,7 +317,7 @@ func.func @test_expression_multiple_results(%arg0: i32) -> i32 { // ----- -// expected-error @+1 {{'emitc.func' op requires zero or exactly one result, but has 2}} +// expected-error @+1 {{expected ')'}} emitc.func @multiple_results(%0: i32) -> (i32, i32) { emitc.return %0 : i32 } @@ -450,3 +450,13 @@ func.func @use_global() { %0 = emitc.get_global @myglobal : f32 return } + +// ----- + +// expected-error @+1 {{'emitc.global' op global reference initial value must be an opaque attribute, got dense<128>}} +emitc.global const @myref : !emitc.array<2xi16> = dense<128> ref + +// ----- + +// expected-error @+1 {{'emitc.global' op global reference must be initialized}} +emitc.global const @myref : !emitc.array<2xi16> ref diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 6cfacca6446cbb..7b11c230e9a9dd 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -242,6 +242,7 @@ emitc.global extern @external_linkage : i32 emitc.global static @internal_linkage : i32 emitc.global @myglobal : !emitc.array<2xf32> = dense<4.000000e+00> emitc.global const @myconstant : !emitc.array<2xi16> = dense<2> +emitc.global const @myref : !emitc.array<2xi16> = #emitc.opaque<"myconstant"> ref func.func @use_global(%i: index) -> f32 { %0 = emitc.get_global @myglobal : !emitc.array<2xf32> diff --git a/mlir/test/Target/Cpp/common-cpp.mlir b/mlir/test/Target/Cpp/common-cpp.mlir index 0e24bdd19993f0..a638263cf1350c 100644 --- a/mlir/test/Target/Cpp/common-cpp.mlir +++ b/mlir/test/Target/Cpp/common-cpp.mlir @@ -94,3 +94,13 @@ func.func @apply(%arg0: i32) -> !emitc.ptr { func.func @array_type(%arg0: !emitc.array<3xi32>, %arg1: !emitc.array<10x20xf32>) { return } + +// CHECK: void arg_references(int32_t (&v1)[3], float (&v2)[10][20], int32_t &v3) +func.func @arg_references(%arg0: !emitc.array<3xi32> {emitc.reference}, %arg1: !emitc.array<10x20xf32> {emitc.reference}, %arg2: i32 {emitc.reference}) { + return +} + +// CHECK: void emitc_arg_references(int32_t (&v1)[3], float (&v2)[10][20], int32_t &v3) +emitc.func @emitc_arg_references(%arg0: !emitc.array<3xi32> ref, %arg1: !emitc.array<10x20xf32> ref, %arg2: i32 ref) { + emitc.return +} diff --git a/mlir/test/Target/Cpp/declare_func.mlir b/mlir/test/Target/Cpp/declare_func.mlir index 00680d71824ae0..40f73d659b368c 100644 --- a/mlir/test/Target/Cpp/declare_func.mlir +++ b/mlir/test/Target/Cpp/declare_func.mlir @@ -22,3 +22,17 @@ emitc.declare_func @array_arg emitc.func @array_arg(%arg0: !emitc.array<3xi32>) { emitc.return } + +// CHECK: void reference_scalar_arg(int32_t &[[V2:[^ ]*]]); +emitc.declare_func @reference_scalar_arg +// CHECK: void reference_scalar_arg(int32_t &[[V2:[^ ]*]]) { +emitc.func @reference_scalar_arg(%arg0: i32 ref) { + emitc.return +} + +// CHECK: void reference_array_arg(int32_t (&[[V2:[^ ]*]])[3]); +emitc.declare_func @reference_array_arg +// CHECK: void reference_array_arg(int32_t (&[[V2:[^ ]*]])[3]) { +emitc.func @reference_array_arg(%arg0: !emitc.array<3xi32> ref) { + emitc.return +} diff --git a/mlir/test/Target/Cpp/func.mlir b/mlir/test/Target/Cpp/func.mlir index 9c9ea55bfc4e1a..8adb6a8adbf2fb 100644 --- a/mlir/test/Target/Cpp/func.mlir +++ b/mlir/test/Target/Cpp/func.mlir @@ -43,3 +43,12 @@ emitc.func private @extern_func(i32) attributes {specifiers = ["extern"]} emitc.func private @array_arg(!emitc.array<3xi32>) attributes {specifiers = ["extern"]} // CPP-DEFAULT: extern void array_arg(int32_t[3]); + +emitc.func private @reference_scalar_arg(i32 ref) attributes {specifiers = ["extern"]} +// CPP-DEFAULT: extern void reference_scalar_arg(int32_t &); + +emitc.func private @reference_array_arg(!emitc.array<3xi32> ref) attributes {specifiers = ["extern"]} +// CPP-DEFAULT: extern void reference_array_arg(int32_t (&)[3]); + +emitc.func private @reference_multi_arg(!emitc.array<3xi32> ref, !emitc.array<3xi32>, i32 ref) attributes {specifiers = ["extern"]} +// CPP-DEFAULT: extern void reference_multi_arg(int32_t (&)[3], int32_t[3], int32_t &); diff --git a/mlir/test/Target/Cpp/global.mlir b/mlir/test/Target/Cpp/global.mlir index f0d92e862ae322..059c991fd7839f 100644 --- a/mlir/test/Target/Cpp/global.mlir +++ b/mlir/test/Target/Cpp/global.mlir @@ -36,3 +36,6 @@ func.func @use_global(%i: index) -> f32 { // CHECK-SAME: (size_t [[V1:.*]]) // CHECK: return myglobal[[[V1]]]; } + +emitc.global @ref : i32 = #emitc.opaque<"myglobal_int"> ref +// CHECK: int32_t &ref = myglobal_int; diff --git a/mlir/test/Target/Cpp/invalid.mlir b/mlir/test/Target/Cpp/invalid.mlir index 513371a09cde1d..b7373a03c638d9 100644 --- a/mlir/test/Target/Cpp/invalid.mlir +++ b/mlir/test/Target/Cpp/invalid.mlir @@ -85,3 +85,11 @@ func.func @ptr_to_array() { %v = "emitc.variable"(){value = #emitc.opaque<"NULL">} : () -> !emitc.ptr> return } + +// ----- + +func.func @cast_ref(%arg0 : i32) { + // expected-error@+1 {{'emitc.cast' op cast of value type must not bear a reference}} + %1 = emitc.cast %arg0 : i32 to i32 ref + return +} From dfb5921d93bc48247d34a947585eb77d541ef795 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Mon, 9 Sep 2024 09:15:10 +0200 Subject: [PATCH 03/46] [FXML-4791] Lower memref expand/collapse to EmitC (#313) --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 3 +- .../MemRefToEmitC/MemRefToEmitC.cpp | 67 ++++++++++++++++++- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 21 +++++- .../MemRefToEmitC/memref-to-emitc-failed.mlir | 18 +++++ .../MemRefToEmitC/memref-to-emitc.mlir | 19 ++++++ mlir/test/Dialect/EmitC/invalid_ops.mlir | 2 +- mlir/test/Dialect/EmitC/ops.mlir | 5 ++ mlir/test/Target/Cpp/cast.mlir | 9 +++ 8 files changed, 139 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 0c945ab2c40304..bbb539c2b3f2a7 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -265,8 +265,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> { def EmitC_CastOp : EmitC_Op<"cast", [CExpression, - DeclareOpInterfaceMethods, - SameOperandsAndResultShape]> { + DeclareOpInterfaceMethods]> { let summary = "Cast operation"; let description = [{ The `cast` operation performs an explicit type conversion and is emitted diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index e0c421741b3055..f6ce553dd899a0 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -15,6 +15,8 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" @@ -166,6 +168,68 @@ struct ConvertStore final : public OpConversionPattern { return success(); } }; + +struct ConvertCollapseShape final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CollapseShapeOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + auto arrayValue = dyn_cast>(operands.getSrc()); + if (!arrayValue) { + return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); + } + + auto resultTy = getTypeConverter()->convertType(op.getType()); + if (!resultTy) { + return rewriter.notifyMatchFailure(op.getLoc(), + "cannot convert result type"); + } + + // Do not generate casts between arrays with dynamic shapes + if (!arrayValue.getType().hasStaticShape()) + return rewriter.notifyMatchFailure(op.getLoc(), + "dynamic shapes not supported"); + auto newCastOp = rewriter.create(op->getLoc(), resultTy, + operands.getSrc()); + newCastOp.setReference(true); + rewriter.replaceOp(op, newCastOp); + return success(); + } +}; + +struct ConvertExpandShape final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::ExpandShapeOp op, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + auto arrayValue = dyn_cast>(operands.getSrc()); + if (!arrayValue) { + return rewriter.notifyMatchFailure(op.getLoc(), "expected array type"); + } + + auto resultTy = getTypeConverter()->convertType(op.getType()); + if (!resultTy) { + return rewriter.notifyMatchFailure(op.getLoc(), + "cannot convert result type"); + } + + // Do not generate casts between arrays with dynamic shapes + if (!arrayValue.getType().hasStaticShape()) + return rewriter.notifyMatchFailure(op.getLoc(), + "dynamic shapes not supported"); + + auto newCastOp = rewriter.create(op->getLoc(), resultTy, + operands.getSrc()); + newCastOp.setReference(true); + rewriter.replaceOp(op, newCastOp); + return success(); + } +}; + } // namespace void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { @@ -187,5 +251,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &converter) { patterns.add(converter, patterns.getContext()); + ConvertStore, ConvertCollapseShape, ConvertExpandShape>( + converter, patterns.getContext()); } diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index aa2495bc42ba03..3f994344ffeee6 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -228,6 +228,15 @@ LogicalResult emitc::AssignOp::verify() { bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { Type input = inputs.front(), output = outputs.front(); + // Cast to array is only possible from an array + if (isa(input) != isa(output)) + return false; + + // Arrays can be casted to arrays by reference. + if (isa(input) && isa(output)) + return true; + + // Scalars return ( (emitc::isIntegerIndexOrOpaqueType(input) || emitc::isSupportedFloatType(input) || isa(input)) && @@ -236,7 +245,15 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { } LogicalResult CastOp::verify() { - if (getReference()) + bool isReference = getReference(); + + if (isa(getDest().getType())) { + if (!isReference) + return emitOpError("cast of array must bear a reference"); + return success(); + } + + if (isReference) return emitOpError("cast of value type must not bear a reference"); return success(); @@ -954,6 +971,8 @@ LogicalResult emitc::ArrayType::verify( for (int64_t dim : shape) { if (dim < 0) return emitError() << "dimensions must have non-negative size"; + if (dim == ShapedType::kDynamic) + return emitError() << "dimensions must have static size"; } if (!elementType) diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir index 89dafa7529ed53..4df7bac0b55806 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc-failed.mlir @@ -43,3 +43,21 @@ func.func @zero_rank() { // expected-error@+1 {{failed to legalize operation 'memref.global'}} memref.global "nested" constant @nested_global : memref<3x7xf32> + +// ----- + +// CHECK-LABEL: memref_expand_dyn_shape +func.func @memref_expand_dyn_shape(%arg: memref, %size: index) -> memref { + // expected-error@+1 {{failed to legalize operation 'memref.expand_shape'}} + %0 = memref.expand_shape %arg [[0, 1]] output_shape [%size, 5] : memref into memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: memref_collapse_dyn_shape +func.func @memref_collapse_dyn_shape(%arg: memref) -> memref { + // expected-error@+1 {{failed to legalize operation 'memref.collapse_shape'}} + %0 = memref.collapse_shape %arg [[0, 1]] : memref into memref + return %0 : memref +} diff --git a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir index ffb0e10d80893a..96e4486f5a8191 100644 --- a/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir +++ b/mlir/test/Conversion/MemRefToEmitC/memref-to-emitc.mlir @@ -73,3 +73,22 @@ func.func @memref_index_values(%i: index, %j: index) -> index { // CHECK: return %[[CAST_RET]] : index return %1 : index } + +// ----- + +// CHECK-LABEL: memref_expand_shape +func.func @memref_expand_shape(%arg: memref<10xi32>) -> memref<2x5xi32> { + // CHECK: emitc.cast %{{[^ ]*}} : !emitc.array<10xi32> to !emitc.array<2x5xi32> ref + %0 = memref.expand_shape %arg [[0, 1]] output_shape [2, 5] : memref<10xi32> into memref<2x5xi32> + return %0 : memref<2x5xi32> +} + + +// ----- + +// CHECK-LABEL: memref_collapse_shape +func.func @memref_collapse_shape(%arg: memref<2x5xi32>) -> memref<10xi32> { + // CHECK: emitc.cast %{{[^ ]*}} : !emitc.array<2x5xi32> to !emitc.array<10xi32> ref + %0 = memref.collapse_shape %arg [[0, 1]] : memref<2x5xi32> into memref<10xi32> + return %0 : memref<10xi32> +} diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index aa2c969b05cc23..31e065155f0922 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -138,7 +138,7 @@ func.func @cast_tensor(%arg : tensor) { // ----- func.func @cast_array(%arg : !emitc.array<4xf32>) { - // expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<4xf32>' and result type '!emitc.array<4xf32>' are cast incompatible}} + // expected-error @+1 {{'emitc.cast' op cast of array must bear a reference}} %1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32> return } diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 7b11c230e9a9dd..482b08a0b68687 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -39,6 +39,11 @@ func.func @cast(%arg0: i32) { return } +func.func @cast_array(%arg : !emitc.array<4xf32>) { + %1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32> ref + return +} + func.func @c() { %1 = "emitc.constant"(){value = 42 : i32} : () -> i32 %2 = "emitc.constant"(){value = 42 : index} : () -> !emitc.size_t diff --git a/mlir/test/Target/Cpp/cast.mlir b/mlir/test/Target/Cpp/cast.mlir index 7254f84e237f40..c4d26ebdcdec9a 100644 --- a/mlir/test/Target/Cpp/cast.mlir +++ b/mlir/test/Target/Cpp/cast.mlir @@ -28,3 +28,12 @@ func.func @cast_ptr(%arg0 : !emitc.ptr>) { %1 = emitc.cast %arg0 : !emitc.ptr> to !emitc.ptr return } + +// CHECK-LABEL: void cast_array +func.func @cast_array(%arg0: !emitc.array<10xi32>) { + // CHECK-NEXT: int32_t (&[[V1:[^ ]*]])[2][5] = (int32_t (&)[2][5]) [[V0:[^ ]*]] + %1 = emitc.cast %arg0 : !emitc.array<10xi32> to !emitc.array<2x5xi32> ref + // CHECK-NEXT: int32_t (&[[V2:[^ ]*]])[10] = (int32_t (&)[10]) [[V1]] + %2 = emitc.cast %1 : !emitc.array<2x5xi32> to !emitc.array<10xi32> ref + return +} From fa9c1f81cb469dd398d88f19decbaf5f896bcf43 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Mon, 9 Sep 2024 10:55:30 +0200 Subject: [PATCH 04/46] Hotfix to function assembly printer (#331) --- mlir/lib/Dialect/EmitC/IR/FunctionOpAssembly.cpp | 9 ++++++++- mlir/test/Dialect/EmitC/func.mlir | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/EmitC/IR/FunctionOpAssembly.cpp b/mlir/lib/Dialect/EmitC/IR/FunctionOpAssembly.cpp index 0db97a5890868c..5d3375f26d588f 100644 --- a/mlir/lib/Dialect/EmitC/IR/FunctionOpAssembly.cpp +++ b/mlir/lib/Dialect/EmitC/IR/FunctionOpAssembly.cpp @@ -244,10 +244,14 @@ void printFunctionSignature(OpAsmPrinter &p, FuncOp op, ArrayRef argTypes, // Exclude reference attribute if there is to replace it by ref SmallVector attrs; + bool isReference = false; if (argAttrs) { for (auto attr : llvm::cast(argAttrs[i]).getValue()) { - if (attr.getName() != emitc::getReferenceAttributeName()) + if (attr.getName() != emitc::getReferenceAttributeName()) { attrs.push_back(attr); + } else { + isReference = true; + } } } @@ -258,6 +262,9 @@ void printFunctionSignature(OpAsmPrinter &p, FuncOp op, ArrayRef argTypes, if (argAttrs) p.printOptionalAttrDict(attrs); } + + if (isReference) + p << " ref"; } if (isVariadic) { diff --git a/mlir/test/Dialect/EmitC/func.mlir b/mlir/test/Dialect/EmitC/func.mlir index c047958f44b457..c7486bc493c315 100644 --- a/mlir/test/Dialect/EmitC/func.mlir +++ b/mlir/test/Dialect/EmitC/func.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -split-input-file +// RUN: mlir-opt %s -split-input-file | FileCheck %s // CHECK: emitc.func @f // CHECK-SAME: %{{[^:]*}}: i32 ref From 276cbead8c909b5fcd72fa9ea5bd398dad30a74e Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Wed, 18 Sep 2024 10:50:06 +0100 Subject: [PATCH 05/46] Relaxe affine --- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 8 +++++--- mlir/test/Dialect/Affine/invalid.mlir | 2 -- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 6baed92d208fb3..e5ceb122507e16 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -499,11 +499,13 @@ ParseResult mlir::affine::parseDimAndSymbolList( template static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, - unsigned numDims) { + unsigned numDims, + bool allowNonAffineDimOperands = false) { unsigned opIt = 0; for (auto operand : operands) { if (opIt++ < numDims) { - if (!isValidDim(operand, getAffineScope(op))) + if (!isValidDim(operand, getAffineScope(op)) && + !(allowNonAffineDimOperands && operand.getType().isIndex())) return op.emitOpError("operand cannot be used as a dimension id"); } else if (!isValidSymbol(operand, getAffineScope(op))) { return op.emitOpError("operand cannot be used as a symbol"); @@ -2804,7 +2806,7 @@ LogicalResult AffineIfOp::verify() { // Verify that the operands are valid dimension/symbols. if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(), - condition.getNumDims()))) + condition.getNumDims(), true))) return failure(); return success(); diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir index 60f13102f55156..709d197bd4cf89 100644 --- a/mlir/test/Dialect/Affine/invalid.mlir +++ b/mlir/test/Dialect/Affine/invalid.mlir @@ -97,7 +97,6 @@ func.func @affine_if_invalid_dim(%arg : index) { affine.for %n0 = 0 to 7 { %dim = arith.addi %arg, %arg : index - // expected-error@+1 {{operand cannot be used as a dimension id}} affine.if #set0(%dim)[%n0] {} } return @@ -109,7 +108,6 @@ func.func @affine_if_invalid_dim(%arg : index) { func.func @affine_if_invalid_sym() { affine.for %i0 = 0 to 7 { - // expected-error@+1 {{operand cannot be used as a symbol}} affine.if #set0(%i0)[%i0] {} } return From c391a4958a29a198204836e4f1aa4a3628cede16 Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Wed, 18 Sep 2024 13:40:56 +0100 Subject: [PATCH 06/46] Address comments --- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 3 ++- mlir/test/Dialect/Affine/invalid.mlir | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index e5ceb122507e16..f37b9ce7ab3443 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2806,7 +2806,8 @@ LogicalResult AffineIfOp::verify() { // Verify that the operands are valid dimension/symbols. if (failed(verifyDimAndSymbolIdentifiers(*this, getOperands(), - condition.getNumDims(), true))) + condition.getNumDims(), + /*allowNonAffineDimOperands=*/true))) return failure(); return success(); diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir index 709d197bd4cf89..c1e4f6328bc270 100644 --- a/mlir/test/Dialect/Affine/invalid.mlir +++ b/mlir/test/Dialect/Affine/invalid.mlir @@ -96,7 +96,8 @@ func.func @affine_for_upper_bound_invalid_sym() { func.func @affine_if_invalid_dim(%arg : index) { affine.for %n0 = 0 to 7 { %dim = arith.addi %arg, %arg : index - + // Non-affine operand %dim has been made legal as input to affine.if. + // expected-error@+1 {{operand cannot be used as a symbol}} affine.if #set0(%dim)[%n0] {} } return @@ -108,6 +109,7 @@ func.func @affine_if_invalid_dim(%arg : index) { func.func @affine_if_invalid_sym() { affine.for %i0 = 0 to 7 { + // expected-error@+1 {{operand cannot be used as a symbol}} affine.if #set0(%i0)[%i0] {} } return From fe35b4b27d34e575ecb9f85555ecf636ad26dc48 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Thu, 26 Sep 2024 10:20:42 +0100 Subject: [PATCH 07/46] For lowering: use adaptor operands --- mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index 171a9b7c92dc55..71c566eb80a2d0 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -113,7 +113,7 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor, return rewriter.notifyMatchFailure(forOp, "create variables for results failed"); - assignValues(forOp.getInits(), resultVariables, rewriter, loc); + assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc); emitc::ForOp loweredFor = rewriter.create( loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep()); From 48759d351affa824619aeecfc1b3883c1ce75f6f Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Fri, 27 Sep 2024 12:47:26 +0100 Subject: [PATCH 08/46] Dyanmic shape check before negative size --- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 3f994344ffeee6..038003246f629a 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -969,10 +969,10 @@ LogicalResult emitc::ArrayType::verify( return emitError() << "shape must not be empty"; for (int64_t dim : shape) { - if (dim < 0) - return emitError() << "dimensions must have non-negative size"; if (dim == ShapedType::kDynamic) return emitError() << "dimensions must have static size"; + if (dim < 0) + return emitError() << "dimensions must have non-negative size"; } if (!elementType) From 93496a98eef23db84f52bf83ba339733e8d17eee Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Fri, 27 Sep 2024 16:14:16 +0100 Subject: [PATCH 09/46] Pass to lower UB to EmitC --- mlir/include/mlir/Conversion/Passes.h | 2 +- mlir/include/mlir/Conversion/Passes.td | 12 +++ .../mlir/Conversion/UBToEmitC/UBToEmitC.h | 25 ++++++ mlir/lib/Conversion/CMakeLists.txt | 1 + mlir/lib/Conversion/UBToEmitC/CMakeLists.txt | 17 ++++ mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp | 85 +++++++++++++++++++ .../convert-ub-to-emitc-unsupported.mlir | 15 ++++ .../UBToEmitC/convert-ub-to-emitc.mlir | 12 +++ 8 files changed, 168 insertions(+), 1 deletion(-) create mode 100644 mlir/include/mlir/Conversion/UBToEmitC/UBToEmitC.h create mode 100644 mlir/lib/Conversion/UBToEmitC/CMakeLists.txt create mode 100644 mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp create mode 100644 mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc-unsupported.mlir create mode 100644 mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index ac985f7c04c7f5..b14607fc3c3ce7 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -12,7 +12,6 @@ #include "mlir/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.h" #include "mlir/Conversion/AffineToStandard/AffineToStandard.h" #include "mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h" -#include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h" #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" #include "mlir/Conversion/ArithToEmitC/ArithToEmitCPass.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" @@ -71,6 +70,7 @@ #include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h" #include "mlir/Conversion/TosaToSCF/TosaToSCF.h" #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" +#include "mlir/Conversion/UBToEmitC/UBToEmitC.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h" #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index b40e9179fc926f..022e2b470ce700 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1210,6 +1210,18 @@ def TosaToTensor : Pass<"tosa-to-tensor"> { let constructor = "tosa::createTosaToTensor()"; } +//===----------------------------------------------------------------------===// +// UBToEmitC +//===----------------------------------------------------------------------===// + +def ConvertUBToEmitC : Pass<"convert-ub-to-emitc"> { + let summary = "Convert UB dialect to EmitC dialect"; + let description = [{ + This pass converts supported UB ops to EmitC dialect. + }]; + let dependentDialects = ["emitc::EmitCDialect"]; +} + //===----------------------------------------------------------------------===// // UBToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/UBToEmitC/UBToEmitC.h b/mlir/include/mlir/Conversion/UBToEmitC/UBToEmitC.h new file mode 100644 index 00000000000000..32d0e689e6f6d6 --- /dev/null +++ b/mlir/include/mlir/Conversion/UBToEmitC/UBToEmitC.h @@ -0,0 +1,25 @@ +//===- UBToEmitC.h - UB to EmitC dialect conversion -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_UBTOEMITC_UBTOEMITC_H +#define MLIR_CONVERSION_UBTOEMITC_UBTOEMITC_H + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +#define GEN_PASS_DECL_CONVERTUBTOEMITC +#include "mlir/Conversion/Passes.h.inc" + +namespace ub { +void populateUBToEmitCConversionPatterns(TypeConverter &converter, + RewritePatternSet &patterns); +} // namespace ub +} // namespace mlir + +#endif // MLIR_CONVERSION_UBTOEMITC_UBTOEMITC_H \ No newline at end of file diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index e107738a4c50c0..a7698b663ba2ee 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -60,6 +60,7 @@ add_subdirectory(TosaToLinalg) add_subdirectory(TosaToMLProgram) add_subdirectory(TosaToSCF) add_subdirectory(TosaToTensor) +add_subdirectory(UBToEmitC) add_subdirectory(UBToLLVM) add_subdirectory(UBToSPIRV) add_subdirectory(VectorToArmSME) diff --git a/mlir/lib/Conversion/UBToEmitC/CMakeLists.txt b/mlir/lib/Conversion/UBToEmitC/CMakeLists.txt new file mode 100644 index 00000000000000..5fe43e100d9855 --- /dev/null +++ b/mlir/lib/Conversion/UBToEmitC/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_conversion_library(MLIRUBToEmitC + UBToEmitC.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/UBToEmitC + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRLLVMCommonConversion + MLIREmitCDialect + MLIRUBDialect + ) diff --git a/mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp b/mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp new file mode 100644 index 00000000000000..fea605736bffc3 --- /dev/null +++ b/mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp @@ -0,0 +1,85 @@ +//===- UBToEmitC.cpp - UB to EmitC dialect conversion ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/UBToEmitC/UBToEmitC.h" + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTUBTOEMITC +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct PoisonOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const TypeConverter *converter = getTypeConverter(); + Type convertedType = converter->convertType(op.getType()); + + if (!convertedType) + return rewriter.notifyMatchFailure(op.getLoc(), "type conversion failed"); + + Attribute value; + if (emitc::isIntegerIndexOrOpaqueType(convertedType)) { + value = IntegerAttr::get((emitc::isPointerWideType(convertedType) + ? IndexType::get(getContext()) + : convertedType), + 0); + } else if (emitc::isSupportedFloatType(convertedType)) { + value = FloatAttr::get(convertedType, 0); + } else { + return rewriter.notifyMatchFailure( + op.getLoc(), "only scalar poison values can be lowered"); + } + + // Any constant will be fine to lower a poison op + rewriter.replaceOpWithNewOp(op, convertedType, value); + return success(); + } +}; +} // namespace + +void ub::populateUBToEmitCConversionPatterns(TypeConverter &converter, + RewritePatternSet &patterns) { + MLIRContext *ctx = patterns.getContext(); + patterns.add(converter, ctx); +} + +struct ConvertUBToEmitC : public impl::ConvertUBToEmitCBase { + using Base::Base; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + TypeConverter converter; + converter.addConversion([](Type t) { return t; }); + populateEmitCSizeTTypeConversions(converter); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalDialect(); + + mlir::ub::populateUBToEmitCConversionPatterns(converter, patterns); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; diff --git a/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc-unsupported.mlir b/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc-unsupported.mlir new file mode 100644 index 00000000000000..e2556e36af20f8 --- /dev/null +++ b/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc-unsupported.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt -convert-ub-to-emitc -split-input-file -verify-diagnostics %s + +func.func @poison_memref() { + // expected-error @+1 {{failed to legalize operation 'ub.poison'}} + %0 = ub.poison : memref + return +} + +// ----- + +func.func @poison_tensor() { + // expected-error @+1 {{failed to legalize operation 'ub.poison'}} + %1 = ub.poison : tensor + return +} \ No newline at end of file diff --git a/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir b/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir new file mode 100644 index 00000000000000..1d30ee9933be8f --- /dev/null +++ b/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-opt -convert-ub-to-emitc %s | FileCheck %s + +// CHECK-LABEL: func.func @poison +func.func @poison() { + // CHECK: "emitc.variable"(){{.*}}() -> i32 + %0 = ub.poison : i32 + // CHECK: "emitc.variable"(){{.*}}() -> f32 + %1 = ub.poison : f32 + // CHECK: "emitc.variable"(){{.*}}() -> !emitc.size_t + %2 = ub.poison : index + return +} \ No newline at end of file From 4d12bc8653ee65eeb170550a7443ea3a61a7dd14 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Mon, 30 Sep 2024 10:22:13 +0100 Subject: [PATCH 10/46] Lower to unitilalized variable --- mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp b/mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp index fea605736bffc3..29f073322cdf2c 100644 --- a/mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp +++ b/mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp @@ -37,21 +37,15 @@ struct PoisonOpLowering : public OpConversionPattern { if (!convertedType) return rewriter.notifyMatchFailure(op.getLoc(), "type conversion failed"); - Attribute value; - if (emitc::isIntegerIndexOrOpaqueType(convertedType)) { - value = IntegerAttr::get((emitc::isPointerWideType(convertedType) - ? IndexType::get(getContext()) - : convertedType), - 0); - } else if (emitc::isSupportedFloatType(convertedType)) { - value = FloatAttr::get(convertedType, 0); - } else { + if (!(emitc::isIntegerIndexOrOpaqueType(convertedType) || + emitc::isSupportedFloatType(convertedType))) { return rewriter.notifyMatchFailure( op.getLoc(), "only scalar poison values can be lowered"); } // Any constant will be fine to lower a poison op - rewriter.replaceOpWithNewOp(op, convertedType, value); + rewriter.replaceOpWithNewOp( + op, convertedType, emitc::OpaqueAttr::get(op->getContext(), "")); return success(); } }; From 8c1994b99a4963cd335248157224d49be5b0fc2f Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Mon, 30 Sep 2024 12:47:45 +0200 Subject: [PATCH 11/46] Add missing newlines --- mlir/include/mlir/Conversion/UBToEmitC/UBToEmitC.h | 2 +- .../Conversion/UBToEmitC/convert-ub-to-emitc-unsupported.mlir | 2 +- mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Conversion/UBToEmitC/UBToEmitC.h b/mlir/include/mlir/Conversion/UBToEmitC/UBToEmitC.h index 32d0e689e6f6d6..9d208a0275e904 100644 --- a/mlir/include/mlir/Conversion/UBToEmitC/UBToEmitC.h +++ b/mlir/include/mlir/Conversion/UBToEmitC/UBToEmitC.h @@ -22,4 +22,4 @@ void populateUBToEmitCConversionPatterns(TypeConverter &converter, } // namespace ub } // namespace mlir -#endif // MLIR_CONVERSION_UBTOEMITC_UBTOEMITC_H \ No newline at end of file +#endif // MLIR_CONVERSION_UBTOEMITC_UBTOEMITC_H diff --git a/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc-unsupported.mlir b/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc-unsupported.mlir index e2556e36af20f8..9254ace00fb97d 100644 --- a/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc-unsupported.mlir +++ b/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc-unsupported.mlir @@ -12,4 +12,4 @@ func.func @poison_tensor() { // expected-error @+1 {{failed to legalize operation 'ub.poison'}} %1 = ub.poison : tensor return -} \ No newline at end of file +} diff --git a/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir b/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir index 1d30ee9933be8f..a6945e7e803718 100644 --- a/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir +++ b/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir @@ -9,4 +9,4 @@ func.func @poison() { // CHECK: "emitc.variable"(){{.*}}() -> !emitc.size_t %2 = ub.poison : index return -} \ No newline at end of file +} From ca88be2ef13d4cb2d4c7da0d45674816fb54097e Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Mon, 30 Sep 2024 14:53:18 +0100 Subject: [PATCH 12/46] Add test for vector --- .../UBToEmitC/convert-ub-to-emitc-unsupported.mlir | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc-unsupported.mlir b/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc-unsupported.mlir index 9254ace00fb97d..684f6cec022a4f 100644 --- a/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc-unsupported.mlir +++ b/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc-unsupported.mlir @@ -13,3 +13,11 @@ func.func @poison_tensor() { %1 = ub.poison : tensor return } + +// ----- + +func.func @poison_vector() { + // expected-error @+1 {{failed to legalize operation 'ub.poison'}} + %1 = "ub.poison"() {value = #ub.poison} : () -> vector<4xi64> + return +} From 578a79ccedcd4b88d4774fef29b20b174e3a9d95 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Mon, 30 Sep 2024 16:03:30 +0100 Subject: [PATCH 13/46] Nit --- mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir b/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir index a6945e7e803718..b57f984354e53f 100644 --- a/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir +++ b/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir @@ -2,11 +2,11 @@ // CHECK-LABEL: func.func @poison func.func @poison() { - // CHECK: "emitc.variable"(){{.*}}() -> i32 + // CHECK: "emitc.variable"{{.*}} -> i32 %0 = ub.poison : i32 - // CHECK: "emitc.variable"(){{.*}}() -> f32 + // CHECK: "emitc.variable"{{.*}} -> f32 %1 = ub.poison : f32 - // CHECK: "emitc.variable"(){{.*}}() -> !emitc.size_t + // CHECK: "emitc.variable"{{.*}} -> !emitc.size_t %2 = ub.poison : index return } From e48815fdb767a6129664b244a84bbfd2f0ef7b10 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Tue, 1 Oct 2024 08:27:13 +0100 Subject: [PATCH 14/46] Add option to initialize UB.poison variables --- mlir/include/mlir/Conversion/Passes.td | 5 +++ .../mlir/Conversion/UBToEmitC/UBToEmitC.h | 3 +- mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp | 34 +++++++++++++++---- .../convert-ub-to-emitc-no-init.mlir | 12 +++++++ .../UBToEmitC/convert-ub-to-emitc.mlir | 6 ++-- 5 files changed, 50 insertions(+), 10 deletions(-) create mode 100644 mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc-no-init.mlir diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 022e2b470ce700..e60060502173e9 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1220,6 +1220,11 @@ def ConvertUBToEmitC : Pass<"convert-ub-to-emitc"> { This pass converts supported UB ops to EmitC dialect. }]; let dependentDialects = ["emitc::EmitCDialect"]; + let options = [ + Option<"noInitialization", "no-initialization", "bool", + /*default=*/"false", + "Do not initialize the generated variables">, + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/UBToEmitC/UBToEmitC.h b/mlir/include/mlir/Conversion/UBToEmitC/UBToEmitC.h index 9d208a0275e904..64a37c9304afc1 100644 --- a/mlir/include/mlir/Conversion/UBToEmitC/UBToEmitC.h +++ b/mlir/include/mlir/Conversion/UBToEmitC/UBToEmitC.h @@ -18,7 +18,8 @@ namespace mlir { namespace ub { void populateUBToEmitCConversionPatterns(TypeConverter &converter, - RewritePatternSet &patterns); + RewritePatternSet &patterns, + bool noInitialization); } // namespace ub } // namespace mlir diff --git a/mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp b/mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp index 29f073322cdf2c..4d0f3359a2500f 100644 --- a/mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp +++ b/mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp @@ -26,7 +26,13 @@ using namespace mlir; namespace { struct PoisonOpLowering : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; + bool noInitialization; + +public: + PoisonOpLowering(const TypeConverter &converter, MLIRContext *context, + bool noInitialization) + : OpConversionPattern(converter, context), + noInitialization(noInitialization) {} LogicalResult matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor, @@ -43,18 +49,33 @@ struct PoisonOpLowering : public OpConversionPattern { op.getLoc(), "only scalar poison values can be lowered"); } + Attribute value; + + if (noInitialization) { + value = emitc::OpaqueAttr::get(op->getContext(), ""); + } + if (!noInitialization && emitc::isIntegerIndexOrOpaqueType(convertedType)) { + value = IntegerAttr::get((emitc::isPointerWideType(convertedType)) + ? IndexType::get(op.getContext()) + : convertedType, + 42); + } + if (!noInitialization && emitc::isSupportedFloatType(convertedType)) { + value = FloatAttr::get(convertedType, 42.0f); + } + // Any constant will be fine to lower a poison op - rewriter.replaceOpWithNewOp( - op, convertedType, emitc::OpaqueAttr::get(op->getContext(), "")); + rewriter.replaceOpWithNewOp(op, convertedType, value); return success(); } }; } // namespace void ub::populateUBToEmitCConversionPatterns(TypeConverter &converter, - RewritePatternSet &patterns) { + RewritePatternSet &patterns, + bool noInitialization) { MLIRContext *ctx = patterns.getContext(); - patterns.add(converter, ctx); + patterns.add(converter, ctx, noInitialization); } struct ConvertUBToEmitC : public impl::ConvertUBToEmitCBase { @@ -70,7 +91,8 @@ struct ConvertUBToEmitC : public impl::ConvertUBToEmitCBase { target.addLegalDialect(); target.addIllegalDialect(); - mlir::ub::populateUBToEmitCConversionPatterns(converter, patterns); + mlir::ub::populateUBToEmitCConversionPatterns(converter, patterns, + noInitialization); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc-no-init.mlir b/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc-no-init.mlir new file mode 100644 index 00000000000000..24582254ee332f --- /dev/null +++ b/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc-no-init.mlir @@ -0,0 +1,12 @@ +// RUN: mlir-opt -p 'builtin.module(convert-ub-to-emitc{no-initialization})' %s | FileCheck %s + +// CHECK-LABEL: func.func @poison +func.func @poison() { + // CHECK: "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> i32 + %0 = ub.poison : i32 + // CHECK: "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> f32 + %1 = ub.poison : f32 + // CHECK: "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t + %2 = ub.poison : index + return +} diff --git a/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir b/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir index b57f984354e53f..105fc9ddc18994 100644 --- a/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir +++ b/mlir/test/Conversion/UBToEmitC/convert-ub-to-emitc.mlir @@ -2,11 +2,11 @@ // CHECK-LABEL: func.func @poison func.func @poison() { - // CHECK: "emitc.variable"{{.*}} -> i32 + // CHECK: "emitc.variable"() <{value = 42 : i32}> : () -> i32 %0 = ub.poison : i32 - // CHECK: "emitc.variable"{{.*}} -> f32 + // CHECK: "emitc.variable"() <{value = 4.200000e+01 : f32}> : () -> f32 %1 = ub.poison : f32 - // CHECK: "emitc.variable"{{.*}} -> !emitc.size_t + // CHECK: "emitc.variable"() <{value = 42 : index}> : () -> !emitc.size_t %2 = ub.poison : index return } From 3bc24d38cef557e4c18cffc228a644bff517d4ae Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Tue, 1 Oct 2024 09:30:03 +0100 Subject: [PATCH 15/46] Use else-if --- mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp b/mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp index 4d0f3359a2500f..5d7439eec2f14c 100644 --- a/mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp +++ b/mlir/lib/Conversion/UBToEmitC/UBToEmitC.cpp @@ -50,17 +50,14 @@ struct PoisonOpLowering : public OpConversionPattern { } Attribute value; - if (noInitialization) { value = emitc::OpaqueAttr::get(op->getContext(), ""); - } - if (!noInitialization && emitc::isIntegerIndexOrOpaqueType(convertedType)) { + } else if (emitc::isIntegerIndexOrOpaqueType(convertedType)) { value = IntegerAttr::get((emitc::isPointerWideType(convertedType)) ? IndexType::get(op.getContext()) : convertedType, 42); - } - if (!noInitialization && emitc::isSupportedFloatType(convertedType)) { + } else if (emitc::isSupportedFloatType(convertedType)) { value = FloatAttr::get(convertedType, 42.0f); } From 3db9abb906f344e400fe106943ce7e24df3911b6 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Tue, 1 Oct 2024 21:37:04 +0200 Subject: [PATCH 16/46] normalize-memrefs: Normalize memref.alloca The pass was only handling memref.alloc, and this extends it to also handle memref.alloca. --- mlir/include/mlir/Dialect/Affine/Utils.h | 8 ++++++- mlir/lib/Dialect/Affine/Utils/Utils.cpp | 21 ++++++++++++------- .../MemRef/Transforms/NormalizeMemRefs.cpp | 16 ++++++++++++++ .../Dialect/MemRef/normalize-memrefs.mlir | 14 +++++++++++++ 4 files changed, 51 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h index 7f25db029781c8..cc5001fc59bd99 100644 --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -31,6 +31,7 @@ class FuncOp; namespace memref { class AllocOp; +class AllocaOp; } // namespace memref struct LogicalResult; @@ -247,7 +248,12 @@ LogicalResult replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, /// Rewrites the memref defined by this alloc op to have an identity layout map /// and updates all its indexing uses. Returns failure if any of its uses /// escape (while leaving the IR in a valid state). -LogicalResult normalizeMemRef(memref::AllocOp *op); +template +LogicalResult normalizeMemRef(AllocLikeOp *op); +extern template LogicalResult +normalizeMemRef(memref::AllocaOp *op); +extern template LogicalResult +normalizeMemRef(memref::AllocOp *op); /// Normalizes `memrefType` so that the affine layout map of the memref is /// transformed to an identity map with a new shape being computed for the diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index f46381403bc522..ef3081f75608a8 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1639,9 +1639,10 @@ static AffineExpr createDimSizeExprForTiledLayout(AffineExpr oldMapOutput, /// %c4 = arith.constant 4 : index /// %1 = affine.apply #map1(%c4, %0) /// %2 = affine.apply #map2(%c4, %0) +template static void createNewDynamicSizes(MemRefType oldMemRefType, MemRefType newMemRefType, AffineMap map, - memref::AllocOp *allocOp, OpBuilder b, + AllocLikeOp *allocOp, OpBuilder b, SmallVectorImpl &newDynamicSizes) { // Create new input for AffineApplyOp. SmallVector inAffineApply; @@ -1688,7 +1689,8 @@ static void createNewDynamicSizes(MemRefType oldMemRefType, } // TODO: Currently works for static memrefs with a single layout map. -LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) { +template +LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp *allocOp) { MemRefType memrefType = allocOp->getType(); OpBuilder b(*allocOp); @@ -1704,7 +1706,7 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) { SmallVector symbolOperands(allocOp->getSymbolOperands()); AffineMap layoutMap = memrefType.getLayout().getAffineMap(); - memref::AllocOp newAlloc; + AllocLikeOp newAlloc; // Check if `layoutMap` is a tiled layout. Only single layout map is // supported for normalizing dynamic memrefs. SmallVector> tileSizePos; @@ -1716,11 +1718,11 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) { newDynamicSizes); // Add the new dynamic sizes in new AllocOp. newAlloc = - b.create(allocOp->getLoc(), newMemRefType, - newDynamicSizes, allocOp->getAlignmentAttr()); + b.create(allocOp->getLoc(), newMemRefType, newDynamicSizes, + allocOp->getAlignmentAttr()); } else { - newAlloc = b.create(allocOp->getLoc(), newMemRefType, - allocOp->getAlignmentAttr()); + newAlloc = b.create(allocOp->getLoc(), newMemRefType, + allocOp->getAlignmentAttr()); } // Replace all uses of the old memref. if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, @@ -1745,6 +1747,11 @@ LogicalResult mlir::affine::normalizeMemRef(memref::AllocOp *allocOp) { return success(); } +template LogicalResult +mlir::affine::normalizeMemRef(memref::AllocaOp *op); +template LogicalResult +mlir::affine::normalizeMemRef(memref::AllocOp *op); + MemRefType mlir::affine::normalizeMemRefType(MemRefType memrefType) { unsigned rank = memrefType.getRank(); if (rank == 0) diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp index 33772ccb7dd9d3..e8968a07ab884e 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -174,6 +174,17 @@ bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { .wasInterrupted()) return false; + if (funcOp + .walk([&](memref::AllocaOp allocaOp) -> WalkResult { + Value oldMemRef = allocaOp.getResult(); + if (!allocaOp.getType().getLayout().isIdentity() && + !isMemRefNormalizable(oldMemRef.getUsers())) + return WalkResult::interrupt(); + return WalkResult::advance(); + }) + .wasInterrupted()) + return false; + if (funcOp .walk([&](func::CallOp callOp) -> WalkResult { for (unsigned resIndex : @@ -347,6 +358,11 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp, for (memref::AllocOp allocOp : allocOps) (void)normalizeMemRef(&allocOp); + SmallVector allocaOps; + funcOp.walk([&](memref::AllocaOp op) { allocaOps.push_back(op); }); + for (memref::AllocaOp allocaOp : allocaOps) + (void)normalizeMemRef(&allocaOp); + // We use this OpBuilder to create new memref layout later. OpBuilder b(funcOp); diff --git a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir index c7af033a22a2c6..a89d1c2144b5d3 100644 --- a/mlir/test/Dialect/MemRef/normalize-memrefs.mlir +++ b/mlir/test/Dialect/MemRef/normalize-memrefs.mlir @@ -27,6 +27,20 @@ func.func @permute() { // CHECK-NEXT: memref.dealloc [[MEM]] // CHECK-NEXT: return +// CHECK-LABEL: func @alloca +func.func @alloca(%idx : index) { + // CHECK-NEXT: memref.alloca() : memref<65xf32> + %A = memref.alloca() : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> + // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32> + affine.load %A[%idx] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> + affine.for %i = 0 to 64 { + %1 = affine.load %A[%i] : memref<64xf32, affine_map<(d0) -> (d0 + 1)>> + "prevent.dce"(%1) : (f32) -> () + // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32> + } + return +} + // CHECK-LABEL: func @shift func.func @shift(%idx : index) { // CHECK-NEXT: memref.alloc() : memref<65xf32> From 09ddec3edec3a97a6ade0c46746bfa2addcf2cf6 Mon Sep 17 00:00:00 2001 From: josel-amd <166385423+josel-amd@users.noreply.github.com> Date: Wed, 2 Oct 2024 09:22:02 +0200 Subject: [PATCH 17/46] Add bounds to the affine.if regions (#380) Add bounds to the affine.if regions --- .../mlir/Dialect/Affine/IR/AffineOps.td | 3 +- mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 7 ++ .../Dialect/Affine/control-flow-sink.mlir | 88 +++++++++++++++++++ 3 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Dialect/Affine/control-flow-sink.mlir diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index dbec741cf1b1f3..d23b4707d03318 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -336,7 +336,8 @@ def AffineForOp : Affine_Op<"for", def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator, RecursivelySpeculatable, RecursiveMemoryEffects, NoRegionArguments, - DeclareOpInterfaceMethods + DeclareOpInterfaceMethods ]> { let summary = "if-then-else operation"; let description = [{ diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 83a36b9efecc2c..bee15d85aa5a7e 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2765,6 +2765,13 @@ struct AlwaysTrueOrFalseIf : public OpRewritePattern { }; } // namespace +void AffineIfOp::getRegionInvocationBounds( + ArrayRef operands, + SmallVectorImpl &invocationBounds) { + // Non-constant condition. Each region may be executed 0 or 1 times. + invocationBounds.assign(getNumRegions(), {0, 1}); +} + /// AffineIfOp has two regions -- `then` and `else`. The flow of data should be /// as follows: AffineIfOp -> `then`/`else` -> AffineIfOp void AffineIfOp::getSuccessorRegions( diff --git a/mlir/test/Dialect/Affine/control-flow-sink.mlir b/mlir/test/Dialect/Affine/control-flow-sink.mlir new file mode 100644 index 00000000000000..2b1a2b3e74d067 --- /dev/null +++ b/mlir/test/Dialect/Affine/control-flow-sink.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt -split-input-file -control-flow-sink %s | FileCheck %s + +#set = affine_set<(d0) : (-d0 + 3 >= 0)> +#map = affine_map<(d0) -> (d0)> + +func.func @test_affine_if_sink(%arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = tensor.empty() : tensor<4xf32> + %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} + ins(%arg1: tensor<4xf32>) outs(%0: tensor<4xf32>) { + ^bb0(%in: f32, %out: f32): + %index = linalg.index 0 : index + %const0 = arith.constant 0.0 : f32 + %add = arith.addf %in, %in: f32 + %4 = affine.if #set(%index) -> f32 { + affine.yield %add : f32 + } else { + affine.yield %const0 : f32 + } + linalg.yield %4 : f32 + } -> (tensor<4xf32>) + return %1: tensor<4xf32> +} + +// CHECK-LABEL: affine.if +// CHECK-NEXT: %[[ADD:.*]] = arith.addf +// CHECK-NEXT: affine.yield %[[ADD]] : f32 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-NEXT: affine.yield %[[ZERO]] : f32 +// CHECK-NEXT: } + +// ----- + +#set = affine_set<(d0) : (-d0 + 3 >= 0)> +#map = affine_map<(d0) -> (d0)> + +func.func @test_affine_if_sink_with_loop_independenct_code(%arg0: f32, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %const0 = arith.constant 0.0 : f32 + %const1 = arith.constant 1.0 : f32 + %0 = tensor.empty() : tensor<4xf32> + %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} + ins(%arg1: tensor<4xf32>) outs(%0: tensor<4xf32>) { + ^bb0(%in: f32, %out: f32): + %index = linalg.index 0 : index + %4 = affine.if #set(%index) -> f32 { + affine.yield %const1 : f32 + } else { + affine.yield %const0 : f32 + } + linalg.yield %4 : f32 + } -> (tensor<4xf32>) + return %1: tensor<4xf32> +} + +// CHECK-LABEL: affine.if +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1.0 +// CHECK-NEXT: affine.yield %[[C1]] : f32 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[C0:.*]] = arith.constant 0.0 +// CHECK-NEXT: affine.yield %[[C0]] : f32 +// CHECK-NEXT: } + + +// ----- + +func.func private @external(f32) -> () + +#map = affine_map<(d0) -> (d0)> + +func.func @affine_if_no_else(%arg0: f32, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %const1 = arith.constant 1.0 : f32 + %0 = tensor.empty() : tensor<4xf32> + %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} + ins(%arg1: tensor<4xf32>) outs(%0: tensor<4xf32>) { + ^bb0(%in: f32, %out: f32): + %index = linalg.index 0 : index + affine.if affine_set<(d0) : (-d0 + 3 >= 0)>(%index) { + func.call @external(%const1) : (f32) -> () + } + linalg.yield %in : f32 + } -> (tensor<4xf32>) + return %1: tensor<4xf32> +} + +// CHECK-LABEL: affine.if +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1.0 +// CHECK-NEXT: func.call @external(%[[C1]]) : (f32) -> () +// CHECK-NEXT: } From f139673565e310cceebeac8081a7193cf874d2ee Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 2 Oct 2024 10:45:17 +0200 Subject: [PATCH 18/46] Remove explicit SmallVector size --- mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp index e8968a07ab884e..fc04069afc7c43 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -358,7 +358,7 @@ void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp, for (memref::AllocOp allocOp : allocOps) (void)normalizeMemRef(&allocOp); - SmallVector allocaOps; + SmallVector allocaOps; funcOp.walk([&](memref::AllocaOp op) { allocaOps.push_back(op); }); for (memref::AllocaOp allocaOp : allocaOps) (void)normalizeMemRef(&allocaOp); From 719df3f87602d08c75eb9903c2d47912be31af3b Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 4 Oct 2024 15:02:48 +0200 Subject: [PATCH 19/46] Update documentation --- .../MemRef/Transforms/NormalizeMemRefs.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp index fc04069afc7c43..08b853fe65b857 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -151,11 +151,11 @@ void NormalizeMemRefs::setCalleesAndCallersNonNormalizable( }); } -/// Check whether all the uses of AllocOps, CallOps and function arguments of a -/// function are either of dereferencing type or are uses in: DeallocOp, CallOp -/// or ReturnOp. Only if these constraints are satisfied will the function -/// become a candidate for normalization. When the uses of a memref are -/// non-normalizable and the memref map layout is trivial (identity), we can +/// Check whether all the uses of AllocOps, AllocaOps, CallOps and function +/// arguments of a function are either of dereferencing type or are uses in: +/// DeallocOp, CallOp or ReturnOp. Only if these constraints are satisfied will +/// the function become a candidate for normalization. When the uses of a memref +/// are non-normalizable and the memref map layout is trivial (identity), we can /// still label the entire function as normalizable. We assume external /// functions to be normalizable. bool NormalizeMemRefs::areMemRefsNormalizable(func::FuncOp funcOp) { @@ -346,13 +346,13 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp, } /// Normalizes the memrefs within a function which includes those arising as a -/// result of AllocOps, CallOps and function's argument. The ModuleOp argument -/// is used to help update function's signature after normalization. +/// result of AllocOps, AllocaOps, CallOps and function's argument. The ModuleOp +/// argument is used to help update function's signature after normalization. void NormalizeMemRefs::normalizeFuncOpMemRefs(func::FuncOp funcOp, ModuleOp moduleOp) { // Turn memrefs' non-identity layouts maps into ones with identity. Collect - // alloc ops first and then process since normalizeMemRef replaces/erases ops - // during memref rewriting. + // alloc/alloca ops first and then process since normalizeMemRef + // replaces/erases ops during memref rewriting. SmallVector allocOps; funcOp.walk([&](memref::AllocOp op) { allocOps.push_back(op); }); for (memref::AllocOp allocOp : allocOps) From 9fd8d27036fd550d98f271a88713ee144d301e89 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 7 Oct 2024 21:20:50 +0200 Subject: [PATCH 20/46] emitc: Add fmtArgs to verbatim --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.h | 2 + mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 25 ++++++- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 72 +++++++++++++++++++++ mlir/lib/Target/Cpp/TranslateToCpp.cpp | 16 ++++- mlir/test/Dialect/EmitC/invalid_ops.mlir | 48 ++++++++++++++ mlir/test/Dialect/EmitC/ops.mlir | 10 +++ mlir/test/Target/Cpp/verbatim.mlir | 28 +++++++- 7 files changed, 196 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h index 0c595a6b109caa..bc82f58a7ee95c 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -27,6 +27,8 @@ #include "mlir/Dialect/EmitC/IR/EmitCDialect.h.inc" #include "mlir/Dialect/EmitC/IR/EmitCEnums.h.inc" +#include + namespace mlir { namespace emitc { void buildTerminatedBody(OpBuilder &builder, Location loc); diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 1ce1fe8bbf87e5..5f3bec5637b458 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -1157,10 +1157,31 @@ def EmitC_VerbatimOp : EmitC_Op<"verbatim"> { } #endif ``` + + If the `emitc.verbatim` op has operands, then the `value` is interpreted as + format string, where `{}` is a placeholder for an operand in their order. + For example, `emitc.verbatim "#pragma my src={} dst={}" %src, %dest : i32, i32` + would be emitted as `#pragma my src=a dst=b` if `%src` became `a` and + `%dest` as `b` in the C code. + `{{` in the format string is interpreted as a single `{` and doesn't introduce + a placeholder. }]; - let arguments = (ins StrAttr:$value); - let assemblyFormat = "$value attr-dict"; + let extraClassDeclaration = [{ + // Either a literal string, or an placeholder for the fmtArgs. + struct Placeholder {}; + using ReplacementItem = std::variant; + + FailureOr> parseFormatString(); + }]; + + let arguments = (ins StrAttr:$value, + Variadic:$fmtArgs); + + let builders = [OpBuilder<(ins "::mlir::StringAttr":$value), [{ build($_builder, $_state, value, {}); }] >]; + let builders = [OpBuilder<(ins "::llvm::StringRef":$value), [{ build($_builder, $_state, value, {}); }] >]; + let hasVerifier = 1; + let assemblyFormat = "$value ($fmtArgs^ `:` type($fmtArgs))? attr-dict"; } def EmitC_AssignOp : EmitC_Op<"assign", []> { diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 67b9695057a9e2..7bc40b4f555cc6 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -909,6 +909,78 @@ LogicalResult emitc::SubscriptOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// VerbatimOp +//===----------------------------------------------------------------------===// + +LogicalResult emitc::VerbatimOp::verify() { + FailureOr> fmt = parseFormatString(); + if (failed(fmt)) + return failure(); + + size_t numPlaceholders = llvm::count_if(*fmt, [](ReplacementItem &item) { + return std::holds_alternative(item); + }); + + if (numPlaceholders != getFmtArgs().size()) { + return emitOpError() + << "requires operands for each placeholder in the format string"; + } + return success(); +} + +/// Parse a format string and return a list of its parts. +/// A part is either a StringRef that has to be printed as-is, or +/// a Placeholder which requires printing the next operand of the VerbatimOp. +/// In the format string, all `{}` are replaced by Placeholders, except if the +/// `{` is escaped by `{{` - then it doesn't start a placeholder. +FailureOr> +emitc::VerbatimOp::parseFormatString() { + SmallVector items; + + // If there are not operands, the format string is not interpreted. + if (getFmtArgs().empty()) { + items.push_back(getValue()); + return items; + } + + StringRef toParse = getValue(); + while (!toParse.empty()) { + size_t idx = toParse.find('{'); + if (idx == StringRef::npos) { + // No '{' + items.push_back(toParse); + break; + } + if (idx > 0) { + // Take all chars excluding the '{'. + items.push_back(toParse.take_front(idx)); + toParse = toParse.drop_front(idx); + continue; + } + if (toParse.size() < 2) { + // '{' is last character + items.push_back(toParse); + break; + } + // toParse contains at least two characters and starts with `{`. + char nextChar = toParse[1]; + if (nextChar == '{') { + // Double '{{' -> '{' (escaping). + items.push_back(toParse.take_front(1)); + toParse = toParse.drop_front(2); + continue; + } + if (nextChar == '}') { + items.push_back(Placeholder{}); + toParse = toParse.drop_front(2); + continue; + } + return emitOpError() << "expected '}' after unescaped '{'"; + } + return items; +} + //===----------------------------------------------------------------------===// // EmitC Enums //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index d59de9f7580c24..c8a5cb5fbd3f61 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -512,7 +512,21 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::VerbatimOp verbatimOp) { raw_ostream &os = emitter.ostream(); - os << verbatimOp.getValue(); + FailureOr> items = + verbatimOp.parseFormatString(); + if (failed(items)) + return failure(); + + auto fmtArg = verbatimOp.getFmtArgs().begin(); + + for (emitc::VerbatimOp::ReplacementItem &item : *items) { + if (auto *str = std::get_if(&item)) { + os << *str; + } else { + if (failed(emitter.emitOperand(*fmtArg++))) + return failure(); + } + } return success(); } diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index a801ad4fdc2ebc..4d52dd39b02489 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -476,3 +476,51 @@ emitc.global const @myref : !emitc.array<2xi16> = dense<128> ref // expected-error @+1 {{'emitc.global' op global reference must be initialized}} emitc.global const @myref : !emitc.array<2xi16> ref + +// ----- + +func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { + // expected-error @+1 {{'emitc.verbatim' op requires operands for each placeholder in the format string}} + emitc.verbatim "" %arg0, %arg1 : !emitc.ptr, i32 + return +} + +// ----- + +func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { + // expected-error @+1 {{'emitc.verbatim' op requires operands for each placeholder in the format string}} + emitc.verbatim "abc" %arg0, %arg1 : !emitc.ptr, i32 + return +} + +// ----- + +func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { + // expected-error @+1 {{'emitc.verbatim' op requires operands for each placeholder in the format string}} + emitc.verbatim "{}" %arg0, %arg1 : !emitc.ptr, i32 + return +} + +// ----- + +func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { + // expected-error @+1 {{'emitc.verbatim' op requires operands for each placeholder in the format string}} + emitc.verbatim "{} {} {}" %arg0, %arg1 : !emitc.ptr, i32 + return +} + +// ----- + +func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { + // expected-error @+1 {{'emitc.verbatim' op expected '}' after unescaped '{'}} + emitc.verbatim "{ " %arg0, %arg1 : !emitc.ptr, i32 + return +} + +// ----- + +func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { + // expected-error @+1 {{'emitc.verbatim' op expected '}' after unescaped '{'}} + emitc.verbatim "{a} " %arg0, %arg1 : !emitc.ptr, i32 + return +} diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 912335fefd00be..23b5421d860ff3 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -240,6 +240,16 @@ emitc.verbatim "#endif // __cplusplus" emitc.verbatim "typedef int32_t i32;" emitc.verbatim "typedef float f32;" +// The value is not interpreted as format string if there are no operands. +emitc.verbatim "{} { }" + +func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { + emitc.verbatim "{} + {};" %arg0, %arg1 : !emitc.ptr, i32 + + // Trailing '{' are ok and don't start a placeholder. + emitc.verbatim "{} + {} {" %arg0, %arg1 : !emitc.ptr, i32 + return +} emitc.global @uninit : i32 emitc.global @myglobal_int : i32 = 4 diff --git a/mlir/test/Target/Cpp/verbatim.mlir b/mlir/test/Target/Cpp/verbatim.mlir index 10465dd781a81d..41e39dcb6ca900 100644 --- a/mlir/test/Target/Cpp/verbatim.mlir +++ b/mlir/test/Target/Cpp/verbatim.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s --match-full-lines +// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s --match-full-lines emitc.verbatim "#ifdef __cplusplus" @@ -19,3 +19,27 @@ emitc.verbatim "typedef int32_t i32;" // CHECK-NEXT: typedef int32_t i32; emitc.verbatim "typedef float f32;" // CHECK-NEXT: typedef float f32; + +emitc.func @func(%arg: f32) { + // CHECK: void func(float [[V0:[^ ]*]]) { + %a = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.array<3x7xi32> + // CHECK: int32_t [[A:[^ ]*]][3][7]; + + emitc.verbatim "{}" %arg : f32 + // CHECK: [[V0]] + + emitc.verbatim "{} {{a" %arg : f32 + // CHECK-NEXT: [[V0]] {a + + emitc.verbatim "#pragma my var={} property" %arg : f32 + // CHECK-NEXT: #pragma my var=[[V0]] property + + // Trailing '{' are printed as-is. + emitc.verbatim "#pragma my var={} {" %arg : f32 + // CHECK-NEXT: #pragma my var=[[V0]] { + + emitc.verbatim "#pragma my2 var={} property" %a : !emitc.array<3x7xi32> + // CHECK-NEXT: #pragma my2 var=[[A]] property + + emitc.return +} From 30b5f72411738accffa6ae58b179392ea0f5e512 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Tue, 8 Oct 2024 13:33:22 +0200 Subject: [PATCH 21/46] Don't print semicolon after emitc.verbatim within emitc.for There was code to suppress printing semicolons after emitc.verbatim in the function that emits a function body, but then it would still print `#pragma;` when the `emitc.verbatim "#pragma"` was within a loop body. I moved that code into the general printOperation() function, so it applies to emitc.verbatim independent of what the parent op is. --- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 16 ++++++---------- mlir/test/Target/Cpp/verbatim.mlir | 12 ++++++++++++ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index c8a5cb5fbd3f61..73256451ef1487 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -1063,16 +1063,7 @@ static LogicalResult printFunctionBody(CppEmitter &emitter, return failure(); } for (Operation &op : block.getOperations()) { - // When generating code for an emitc.if or cf.cond_br op no semicolon - // needs to be printed after the closing brace. - // When generating code for an emitc.for and emitc.verbatim op, printing a - // trailing semicolon is handled within the printOperation function. - bool trailingSemicolon = - !isa(op); - - if (failed(emitter.emitOperation( - op, /*trailingSemicolon=*/trailingSemicolon))) + if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) return failure(); } } @@ -1630,6 +1621,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { shouldBeInlined(cast(op)))) return success(); + if (isa(op)) { + trailingSemicolon = false; + } + os << (trailingSemicolon ? ";\n" : "\n"); return success(); diff --git a/mlir/test/Target/Cpp/verbatim.mlir b/mlir/test/Target/Cpp/verbatim.mlir index 41e39dcb6ca900..1522cc32e79a52 100644 --- a/mlir/test/Target/Cpp/verbatim.mlir +++ b/mlir/test/Target/Cpp/verbatim.mlir @@ -40,6 +40,18 @@ emitc.func @func(%arg: f32) { emitc.verbatim "#pragma my2 var={} property" %a : !emitc.array<3x7xi32> // CHECK-NEXT: #pragma my2 var=[[A]] property + emitc.return +} +// Check that no semicolon is printed after verbatim within emitc.for +emitc.func @in_loop(%arg: f32) { + %start = emitc.literal "0" : !emitc.size_t + %stop = emitc.literal "10" : !emitc.size_t + %step = emitc.literal "1" : !emitc.size_t + emitc.for %iter = %start to %stop step %step { + emitc.verbatim "#pragma" + // CHECK: #pragma + emitc.yield + } emitc.return } From edbe32102b17a6507423bc8c548621711cb81e5b Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 9 Oct 2024 09:25:01 +0200 Subject: [PATCH 22/46] Review comment --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 5f3bec5637b458..8d434d88e910a9 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -1162,7 +1162,7 @@ def EmitC_VerbatimOp : EmitC_Op<"verbatim"> { format string, where `{}` is a placeholder for an operand in their order. For example, `emitc.verbatim "#pragma my src={} dst={}" %src, %dest : i32, i32` would be emitted as `#pragma my src=a dst=b` if `%src` became `a` and - `%dest` as `b` in the C code. + `%dest` became `b` in the C code. `{{` in the format string is interpreted as a single `{` and doesn't introduce a placeholder. }]; From 95521c7eca582ace4626f38b1e58f047e6686b98 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Thu, 10 Oct 2024 12:49:19 +0200 Subject: [PATCH 23/46] Fix verbatim parsing to be unambiguous With the previous parsing, it would interpret ``` emitc.verbatim "#endif // PL_USE_XRT" %4 = "emitc.constant"() <{value = 1 : i32}> : () -> i32 ``` as if ``` emitc.verbatim "#endif // PL_USE_XRT" %4 = "emitc.constant"() <{value = 1 : i32}> : () -> i32 ``` and then complain that it expected a `:` after the `%4`. Fix this by introducing a `args` keyword to distinguish the case where the veratim has args from the case where the next operation starts. --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 2 +- mlir/test/Dialect/EmitC/invalid_ops.mlir | 12 ++++++------ mlir/test/Dialect/EmitC/ops.mlir | 9 +++++++-- mlir/test/Target/Cpp/verbatim.mlir | 10 +++++----- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 8d434d88e910a9..0de8787ba1dc8f 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -1181,7 +1181,7 @@ def EmitC_VerbatimOp : EmitC_Op<"verbatim"> { let builders = [OpBuilder<(ins "::mlir::StringAttr":$value), [{ build($_builder, $_state, value, {}); }] >]; let builders = [OpBuilder<(ins "::llvm::StringRef":$value), [{ build($_builder, $_state, value, {}); }] >]; let hasVerifier = 1; - let assemblyFormat = "$value ($fmtArgs^ `:` type($fmtArgs))? attr-dict"; + let assemblyFormat = "$value (`args` $fmtArgs^ `:` type($fmtArgs))? attr-dict"; } def EmitC_AssignOp : EmitC_Op<"assign", []> { diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index 4d52dd39b02489..3b4c6046a08c5a 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -481,7 +481,7 @@ emitc.global const @myref : !emitc.array<2xi16> ref func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { // expected-error @+1 {{'emitc.verbatim' op requires operands for each placeholder in the format string}} - emitc.verbatim "" %arg0, %arg1 : !emitc.ptr, i32 + emitc.verbatim "" args %arg0, %arg1 : !emitc.ptr, i32 return } @@ -489,7 +489,7 @@ func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { // expected-error @+1 {{'emitc.verbatim' op requires operands for each placeholder in the format string}} - emitc.verbatim "abc" %arg0, %arg1 : !emitc.ptr, i32 + emitc.verbatim "abc" args %arg0, %arg1 : !emitc.ptr, i32 return } @@ -497,7 +497,7 @@ func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { // expected-error @+1 {{'emitc.verbatim' op requires operands for each placeholder in the format string}} - emitc.verbatim "{}" %arg0, %arg1 : !emitc.ptr, i32 + emitc.verbatim "{}" args %arg0, %arg1 : !emitc.ptr, i32 return } @@ -505,7 +505,7 @@ func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { // expected-error @+1 {{'emitc.verbatim' op requires operands for each placeholder in the format string}} - emitc.verbatim "{} {} {}" %arg0, %arg1 : !emitc.ptr, i32 + emitc.verbatim "{} {} {}" args %arg0, %arg1 : !emitc.ptr, i32 return } @@ -513,7 +513,7 @@ func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { // expected-error @+1 {{'emitc.verbatim' op expected '}' after unescaped '{'}} - emitc.verbatim "{ " %arg0, %arg1 : !emitc.ptr, i32 + emitc.verbatim "{ " args %arg0, %arg1 : !emitc.ptr, i32 return } @@ -521,6 +521,6 @@ func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { // expected-error @+1 {{'emitc.verbatim' op expected '}' after unescaped '{'}} - emitc.verbatim "{a} " %arg0, %arg1 : !emitc.ptr, i32 + emitc.verbatim "{a} " args %arg0, %arg1 : !emitc.ptr, i32 return } diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 23b5421d860ff3..4e86642c2a3a95 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -244,10 +244,15 @@ emitc.verbatim "typedef float f32;" emitc.verbatim "{} { }" func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { - emitc.verbatim "{} + {};" %arg0, %arg1 : !emitc.ptr, i32 + emitc.verbatim "{} + {};" args %arg0, %arg1 : !emitc.ptr, i32 // Trailing '{' are ok and don't start a placeholder. - emitc.verbatim "{} + {} {" %arg0, %arg1 : !emitc.ptr, i32 + emitc.verbatim "{} + {} {" args %arg0, %arg1 : !emitc.ptr, i32 + + // Check there is no ambiguity whether %a is the argument to the emitc.verbatim op. + emitc.verbatim "a" + %a = "emitc.constant"(){value = 42 : i32} : () -> i32 + return } diff --git a/mlir/test/Target/Cpp/verbatim.mlir b/mlir/test/Target/Cpp/verbatim.mlir index 1522cc32e79a52..bc687fbb0e31bf 100644 --- a/mlir/test/Target/Cpp/verbatim.mlir +++ b/mlir/test/Target/Cpp/verbatim.mlir @@ -25,20 +25,20 @@ emitc.func @func(%arg: f32) { %a = "emitc.variable"(){value = #emitc.opaque<"">} : () -> !emitc.array<3x7xi32> // CHECK: int32_t [[A:[^ ]*]][3][7]; - emitc.verbatim "{}" %arg : f32 + emitc.verbatim "{}" args %arg : f32 // CHECK: [[V0]] - emitc.verbatim "{} {{a" %arg : f32 + emitc.verbatim "{} {{a" args %arg : f32 // CHECK-NEXT: [[V0]] {a - emitc.verbatim "#pragma my var={} property" %arg : f32 + emitc.verbatim "#pragma my var={} property" args %arg : f32 // CHECK-NEXT: #pragma my var=[[V0]] property // Trailing '{' are printed as-is. - emitc.verbatim "#pragma my var={} {" %arg : f32 + emitc.verbatim "#pragma my var={} {" args %arg : f32 // CHECK-NEXT: #pragma my var=[[V0]] { - emitc.verbatim "#pragma my2 var={} property" %a : !emitc.array<3x7xi32> + emitc.verbatim "#pragma my2 var={} property" args %a : !emitc.array<3x7xi32> // CHECK-NEXT: #pragma my2 var=[[A]] property emitc.return } From 9489ae85dc9095c19dd6aae1d0b4f68e7588f10f Mon Sep 17 00:00:00 2001 From: Tiago Trevisan Jost Date: Tue, 15 Oct 2024 00:49:41 +0200 Subject: [PATCH 24/46] feat: implement constant folding for tosa.slice --- .../Dialect/Tosa/Transforms/TosaFolders.cpp | 128 ++++++++++++++++++ mlir/test/Dialect/Tosa/constant-slice.mlir | 40 ++++++ 2 files changed, 168 insertions(+) create mode 100644 mlir/test/Dialect/Tosa/constant-slice.mlir diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp index 8564efa52960c6..c76818ad6d3f48 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -1688,6 +1688,133 @@ struct TosaFoldConstantPad : public TosaFoldConstantBase { } }; +template +void sliceArray(ShapedType inputType, RangeT inputValues, + llvm::ArrayRef startValues, ShapedType outputType, + SmallVector &outputValues) { + + auto outputShape = outputType.getShape(); + auto inputShape = inputType.getShape(); + + int64_t rank = inputType.getRank(); + + // Implements the logic from + // https://www.mlplatform.org/tosa/tosa_spec.html#_slice + for (size_t outIndex = 0, e = outputValues.size(); outIndex < e; ++outIndex) { + auto indexInTarget = offsetToIndex(outputShape, outIndex); + + for (int64_t i = 0; i < rank; ++i) { + indexInTarget[i] = indexInTarget[i] + startValues[i]; + } + + auto inputIndexOffset = indexToOffset(inputShape, indexInTarget); + outputValues[outIndex] = inputValues[inputIndexOffset]; + } +} + +template +DenseElementsAttr sliceType(ElementsAttr attr, ShapedType inputType, + llvm::ArrayRef start, + ShapedType outputType) { + + auto inputValues = attr.getValues(); + SmallVector outputValues(outputType.getNumElements(), + *std::begin(inputValues)); + sliceArray(inputType, inputValues, start, outputType, outputValues); + return DenseElementsAttr::get(outputType, + llvm::ArrayRef(outputValues)); +} + +template +DenseElementsAttr sliceTypeRaw(ElementsAttr attr, ShapedType inputType, + llvm::ArrayRef start, + ShapedType outputType) { + + ArrayRef inputValues = + cast(attr).getNonSplatRawData(); + + SmallVector outputValues; + outputValues.resize_for_overwrite(outputType.getNumElements()); + sliceArray(inputType, inputValues, start, outputType, outputValues); + + ArrayRef rawOutputValues(reinterpret_cast(outputValues.data()), + outputValues.size() * sizeof(BaseType)); + return DenseElementsAttr::getFromRawBuffer(outputType, rawOutputValues); +} + +DenseElementsAttr slice(ShapedType inputType, ElementsAttr inputValues, + llvm::ArrayRef start, ShapedType outputType) { + + auto baseType = inputType.getElementType(); + + if (inputValues.isSplat()) { + if (isa(baseType)) + return DenseElementsAttr::get(outputType, + inputValues.getSplatValue()); + return DenseElementsAttr::get(outputType, + inputValues.getSplatValue()); + } + + // Handle possible integer types + if (auto intType = dyn_cast(baseType)) { + switch (intType.getWidth()) { + case 1: + // i1 has special alignment which is not handled by sliceTypeRaw. + return sliceType(inputValues, inputType, start, outputType); + case 8: + return sliceTypeRaw(inputValues, inputType, start, outputType); + case 16: + return sliceTypeRaw(inputValues, inputType, start, outputType); + case 32: + return sliceTypeRaw(inputValues, inputType, start, outputType); + case 64: + return sliceTypeRaw(inputValues, inputType, start, outputType); + default: + return sliceType(inputValues, inputType, start, outputType); + } + } + + // Handle possible float types + if (baseType.isF32()) { + return sliceTypeRaw(inputValues, inputType, start, outputType); + } + if (baseType.isF64()) { + return sliceTypeRaw(inputValues, inputType, start, outputType); + } + if (baseType.isBF16()) { + return sliceTypeRaw(inputValues, inputType, start, outputType); + } + return sliceType(inputValues, inputType, start, outputType); +} + +struct TosaFoldConstantSlice : public TosaFoldConstantBase { + using TosaFoldConstantBase::TosaFoldConstantBase; + + LogicalResult matchAndRewrite(tosa::SliceOp op, + PatternRewriter &rewriter) const override { + auto outputType = cast(op.getType()); + // TOSA doesn't support quantized types. + if (!outputType.getElementType().isIntOrIndexOrFloat()) + return failure(); + + auto start = op.getStart(); + auto input = op.getInput(); + ElementsAttr inputValues; + if (!matchPattern(input, m_Constant(&inputValues))) + return failure(); + + // Only fold op with multiple users if foldSplatOrSingleUseOnly is false. + if (!llvm::hasSingleElement(input.getDefiningOp()->getUsers()) && + foldSplatOrSingleUseOnly) + return failure(); + + auto resultAttr = slice(input.getType(), inputValues, start, outputType); + rewriter.replaceOpWithNewOp(op, outputType, resultAttr); + + return success(); + } +}; + template void tileArray(ShapedType inputType, RangeT inputValues, ShapedType outputType, SmallVector &outputValues) { @@ -1991,6 +2118,7 @@ void mlir::tosa::populateTosaFoldConstantPatterns( patterns.add(ctx, options.foldSplatOrSingleUseOnly); patterns.add(ctx, options.foldSplatOrSingleUseOnly); patterns.add(ctx, options.foldSplatOrSingleUseOnly); + patterns.add(ctx, options.foldSplatOrSingleUseOnly); patterns.add(ctx, options.foldSplatOrSingleUseOnly); if (options.enableTileFolding) patterns.add(ctx, options.foldSplatOrSingleUseOnly); diff --git a/mlir/test/Dialect/Tosa/constant-slice.mlir b/mlir/test/Dialect/Tosa/constant-slice.mlir new file mode 100644 index 00000000000000..7ffa6c11d70505 --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-slice.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold %s | FileCheck %s + +// CHECK-LABEL: @slice_int8 +func.func @slice_int8() -> (tensor<1x1xi8>) { + // CHECK: "tosa.const"() <{value = dense<3> + %0 = "tosa.const"() {value = dense<[[3, 4], [5, 6]]> : tensor<2x2xi8>} : () -> tensor<2x2xi8> + %1 = "tosa.slice"(%0){size = array, start = array} : (tensor<2x2xi8>) -> tensor<1x1xi8> + return %1 : tensor<1x1xi8> +} + +func.func @slice_int16() -> (tensor<2x1xi16>) { + // CHECK: "tosa.const"() <{value = dense<{{\[\[}}3], [5]]> + %0 = "tosa.const"() {value = dense<[[3, 4], [5, 6]]> : tensor<2x2xi16>} : () -> tensor<2x2xi16> + %1 = "tosa.slice"(%0){size = array, start = array} : (tensor<2x2xi16>) -> tensor<2x1xi16> + return %1 : tensor<2x1xi16> +} + +// CHECK-LABEL: @slice_int32 +func.func @slice_int32() -> (tensor<2x1xi32>) { + // CHECK: "tosa.const"() <{value = dense<{{\[\[}}4], [6]]> + %0 = "tosa.const"() {value = dense<[[3, 4], [5, 6]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %1 = "tosa.slice"(%0){size = array, start = array} : (tensor<2x2xi32>) -> tensor<2x1xi32> + return %1 : tensor<2x1xi32> +} + +// CHECK-LABEL: @slice_int32_default_value +func.func @slice_int32_default_value() -> (tensor<3x1xi32>) { + // CHECK: "tosa.const"() <{value = dense<{{\[\[}}3], [6], [9]]> + %0 = "tosa.const"() {value = dense<[[3, 4, 5], [6, 7, 8], [9, 10, 11]]> : tensor<3x3xi32>} : () -> tensor<3x3xi32> + %1 = "tosa.slice"(%0){size = array, start = array} : (tensor<3x3xi32>) -> tensor<3x1xi32> + return %1 : tensor<3x1xi32> +} + +// CHECK-LABEL: @slice_bf16_default_value +func.func @slice_bf16_default_value() -> (tensor<3x2xbf16>) { + // CHECK: "tosa.const"() <{value = dense<{{\[\[}}4.000000e+00, 5.000000e+00], [7.000000e+00, 8.000000e+00], [1.000000e+01, 1.100000e+01]]> + %0 = "tosa.const"() {value = dense<[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]> : tensor<3x3xbf16>} : () -> tensor<3x3xbf16> + %1 = "tosa.slice"(%0){size = array, start = array} : (tensor<3x3xbf16>) -> tensor<3x2xbf16> + return %1 : tensor<3x2xbf16> +} From 83bdfafd635fd17038f2c308acb55d40a97381d2 Mon Sep 17 00:00:00 2001 From: Tiago Trevisan Jost Date: Tue, 15 Oct 2024 15:13:11 +0200 Subject: [PATCH 25/46] test: add more LIT tests for tosa.slice folding. --- .../Tosa/constant-slice-multi-user.mlir | 13 ++ mlir/test/Dialect/Tosa/constant-slice.mlir | 132 ++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 mlir/test/Dialect/Tosa/constant-slice-multi-user.mlir diff --git a/mlir/test/Dialect/Tosa/constant-slice-multi-user.mlir b/mlir/test/Dialect/Tosa/constant-slice-multi-user.mlir new file mode 100644 index 00000000000000..575ccdf6bed386 --- /dev/null +++ b/mlir/test/Dialect/Tosa/constant-slice-multi-user.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="fold-splat-or-single-use-only=0" %s | FileCheck %s +// RUN: mlir-opt --split-input-file --tosa-layerwise-constant-fold="fold-splat-or-single-use-only=1" %s | FileCheck %s --check-prefix=ONLY-SINGLE-USE-CHECK + +// CHECK-LABEL: @slice_bf16 +func.func @slice_bf16() -> (tensor<3x3xbf16>, tensor<3x2xbf16>) { + // CHECK-DAG: "tosa.const"() <{value = dense<{{\[\[}}3.000000e+00, 4.000000e+00, 5.000000e+00], [6.000000e+00, 7.000000e+00, 8.000000e+00], [9.000000e+00, 1.000000e+01, 1.100000e+01]]> + // CHECK-DAG: "tosa.const"() <{value = dense<{{\[\[}}4.000000e+00, 5.000000e+00], [7.000000e+00, 8.000000e+00], [1.000000e+01, 1.100000e+01]]> + // ONLY-SINGLE-USE-CHECK: tosa.slice + %0 = "tosa.const"() {value = dense<[[3.0, 4.0, 5.0], [6.0, 7.0, 8.0], [9.0, 10.0, 11.0]]> : tensor<3x3xbf16>} : () -> tensor<3x3xbf16> + %1 = "tosa.slice"(%0){size = array, start = array} : (tensor<3x3xbf16>) -> tensor<3x2xbf16> + return %0, %1 : tensor<3x3xbf16>, tensor<3x2xbf16> +} + diff --git a/mlir/test/Dialect/Tosa/constant-slice.mlir b/mlir/test/Dialect/Tosa/constant-slice.mlir index 7ffa6c11d70505..067cc29b8c8372 100644 --- a/mlir/test/Dialect/Tosa/constant-slice.mlir +++ b/mlir/test/Dialect/Tosa/constant-slice.mlir @@ -38,3 +38,135 @@ func.func @slice_bf16_default_value() -> (tensor<3x2xbf16>) { %1 = "tosa.slice"(%0){size = array, start = array} : (tensor<3x3xbf16>) -> tensor<3x2xbf16> return %1 : tensor<3x2xbf16> } + +// ----- + +// Following tests are all done with the following tensor, and different configurations: +// [[[1.0 , 2.25 , 3.50 , 4.75], +// [ 5.0 , 6.25 , 7.50 , 8.75]], +// [[ 13.0 , 14.25 , 15.50 , 16.75 ], +// [ 17.0 , 18.25 , 19.50 , 20.75]], +// [[-1.0 , -2.25 , -3.50 , -4.75], +// [ -5.0 , -6.25 , -7.50 , -8.75]], +// [[ -13.0 , -14.25 , -15.50 , -16.75 ], +// [ -17.0 , -18.25 , -19.50 , -20.75]]] + +// Should produce +// 1.0, 2.25, 3.50, 4.75, +// 13.0, 14.25, 15.50, 16.75, +// -1.0, -2.25, -3.50, -4.75, +// -13.0, -14.25, -15.50, -16.75 +func.func @slice_bf16_dim_1_start_zero() -> (tensor<4x1x4xbf16>) { +// CHECK-LABEL: @slice_bf16_dim_1_start_zero +// CHECK: 1.000000e+00, 2.250000e+00, 3.500000e+00, 4.750000e+00 +// CHECK-SAME: 1.300000e+01, 1.425000e+01, 1.550000e+01, 1.675000e+01 +// CHECK-SAME: -1.000000e+00, -2.250000e+00, -3.500000e+00, -4.750000e+00 +// CHECK-SAME: -1.300000e+01, -1.425000e+01, -1.550000e+01, -1.675000e+01 + %0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xbf16>} : () -> tensor<4x2x4xbf16> + %1 = "tosa.slice"(%0){size = array, start = array} : (tensor<4x2x4xbf16>) -> tensor<4x1x4xbf16> + return %1 : tensor<4x1x4xbf16> +} + +// Should produce +// 1.0, 2.25, 3.50, 4.75, +// 13.0, 14.25, 15.50, 16.75, +// -1.0, -2.25, -3.50, -4.75, +// -13.0, -14.25, -15.50, -16.75 +func.func @slice_f16_dim_1_start_zero() -> (tensor<4x1x4xf16>) { +// CHECK-LABEL: @slice_f16_dim_1_start_zero +// CHECK: 1.000000e+00, 2.250000e+00, 3.500000e+00, 4.750000e+00 +// CHECK-SAME: 1.300000e+01, 1.425000e+01, 1.550000e+01, 1.675000e+01 +// CHECK-SAME: -1.000000e+00, -2.250000e+00, -3.500000e+00, -4.750000e+00 +// CHECK-SAME: -1.300000e+01, -1.425000e+01, -1.550000e+01, -1.675000e+01 + %0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xf16>} : () -> tensor<4x2x4xf16> + %1 = "tosa.slice"(%0){size = array, start = array} : (tensor<4x2x4xf16>) -> tensor<4x1x4xf16> + return %1 : tensor<4x1x4xf16> +} + +// Should produce +// 5.0, 6.25, 7.50, 8.75 +// 17.0, 18.25, 19.50, 20.75 +// -5.0, -6.25, -7.50, -8.75 +// -17.0, -18.25, -19.50, -20.75 +func.func @slice_bf16_start_dim_1_start_one() -> (tensor<4x1x4xbf16>) { +// CHECK-LABEL: @slice_bf16_start_dim_1_start_one +// CHECK: 5.000000e+00, 6.250000e+00, 7.500000e+00, 8.750000e+00 +// CHECK-SAME: 1.700000e+01, 1.825000e+01, 1.950000e+01, 2.075000e+01 +// CHECK-SAME: -5.000000e+00, -6.250000e+00, -7.500000e+00, -8.750000e+00 +// CHECK-SAME: -1.700000e+01, -1.825000e+01, -1.950000e+01, -2.075000e+01 + %0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xbf16>} : () -> tensor<4x2x4xbf16> + %1 = "tosa.slice"(%0){size = array, start = array} : (tensor<4x2x4xbf16>) -> tensor<4x1x4xbf16> + return %1 : tensor<4x1x4xbf16> +} + +// Should produce +// 5.0, 6.25, 7.50, 8.75 +// 17.0, 18.25, 19.50, 20.75 +// -5.0, -6.25, -7.50, -8.75 +// -17.0, -18.25, -19.50, -20.75 +func.func @slice_f16_start_dim_1_start_one() -> (tensor<4x1x4xf16>) { +// CHECK-LABEL: @slice_f16_start_dim_1_start_one +// CHECK: 5.000000e+00, 6.250000e+00, 7.500000e+00, 8.750000e+00 +// CHECK-SAME: 1.700000e+01, 1.825000e+01, 1.950000e+01, 2.075000e+01 +// CHECK-SAME: -5.000000e+00, -6.250000e+00, -7.500000e+00, -8.750000e+00 +// CHECK-SAME: -1.700000e+01, -1.825000e+01, -1.950000e+01, -2.075000e+01 + %0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xf16>} : () -> tensor<4x2x4xf16> + %1 = "tosa.slice"(%0){size = array, start = array} : (tensor<4x2x4xf16>) -> tensor<4x1x4xf16> + return %1 : tensor<4x1x4xf16> +} + +// Should produce +// 1.0, 2.25, 3.50 +// 13.0, 14.25, 15.50 +// -1.0, -2.25, -3.50 +func.func @slice_bf16_start_zero_multiple_dims() -> (tensor<3x1x3xbf16>) { +// CHECK-LABEL: @slice_bf16_start_zero_multiple_dims +// CHECK: 1.000000e+00, 2.250000e+00, 3.500000e+00 +// CHECK-SAME: 1.300000e+01, 1.425000e+01, 1.550000e+01 +// CHECK-SAME: -1.000000e+00, -2.250000e+00, -3.500000e+00 + %0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xbf16>} : () -> tensor<4x2x4xbf16> + %1 = "tosa.slice"(%0){size = array, start = array} : (tensor<4x2x4xbf16>) -> tensor<3x1x3xbf16> + return %1 : tensor<3x1x3xbf16> +} + +// Should produce +// 1.0, 2.25, 3.50 +// 13.0, 14.25, 15.50 +// -1.0, -2.25, -3.50 +func.func @slice_f16_start_zero_multiple_dims() -> (tensor<3x1x3xf16>) { +// CHECK-LABEL: @slice_f16_start_zero_multiple_dims +// CHECK: 1.000000e+00, 2.250000e+00, 3.500000e+00 +// CHECK-SAME: 1.300000e+01, 1.425000e+01, 1.550000e+01 +// CHECK-SAME: -1.000000e+00, -2.250000e+00, -3.500000e+00 + %0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xf16>} : () -> tensor<4x2x4xf16> + %1 = "tosa.slice"(%0){size = array, start = array} : (tensor<4x2x4xf16>) -> tensor<3x1x3xf16> + return %1 : tensor<3x1x3xf16> +} + +// Produces +// 18.25, 19.50, 20.75 +// -6.25, -7.50, -8.75 +// -18.25, -19.50, -20.75 +func.func @slice_bf16_start_non_zero_multiple_dims() -> (tensor<3x1x3xbf16>) { +// CHECK-LABEL: @slice_bf16_start_non_zero_multiple_dims +// CHECK: 1.825000e+01, 1.950000e+01, 2.075000e+01 +// CHECK-SAME: -6.250000e+00, -7.500000e+00, -8.750000e+00 +// CHECK-SAME: -1.825000e+01, -1.950000e+01, -2.075000e+01 + %0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xbf16>} : () -> tensor<4x2x4xbf16> + %1 = "tosa.slice"(%0){size = array, start = array} : (tensor<4x2x4xbf16>) -> tensor<3x1x3xbf16> + return %1 : tensor<3x1x3xbf16> +} + +// Produces +// 18.25, 19.50, 20.75 +// -6.25, -7.50, -8.75 +// -18.25, -19.50, -20.75 +func.func @slice_f16_start_non_zero_multiple_dims() -> (tensor<3x1x3xf16>) { +// CHECK-LABEL: @slice_f16_start_non_zero_multiple_dims +// CHECK: 1.825000e+01, 1.950000e+01, 2.075000e+01 +// CHECK-SAME: -6.250000e+00, -7.500000e+00, -8.750000e+00 +// CHECK-SAME: -1.825000e+01, -1.950000e+01, -2.075000e+01 + %0 = "tosa.const"() {value = dense<[[[1.0, 2.25, 3.50, 4.75], [ 5.0, 6.25, 7.50, 8.75]], [[ 13.0, 14.25, 15.50, 16.75 ], [ 17.0, 18.25, 19.50, 20.75]], [[-1.0, -2.25, -3.50, -4.75], [ -5.0, -6.25, -7.50, -8.75]], [[ -13.0, -14.25, -15.50, -16.75 ], [ -17.0, -18.25, -19.50, -20.75]]]> : tensor<4x2x4xf16>} : () -> tensor<4x2x4xf16> + %1 = "tosa.slice"(%0){size = array, start = array} : (tensor<4x2x4xf16>) -> tensor<3x1x3xf16> + return %1 : tensor<3x1x3xf16> +} \ No newline at end of file From 3cd352bfa6b16a296a09fa009608c6bd271bc444 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 18 Oct 2024 08:05:58 -0700 Subject: [PATCH 26/46] TosaToLinalg: Prefer to emit identity maps (#386) When deciding whether to emit a map like `#map = affine_map<(d0, d1, d2, d3) -> (0, d1, d2, d3)>` or `#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` for and operand of a linalg.generic when lowering element wise TOSA ops, prefer the latter unless broadcasting of the operand is really needed. This helps later transformations which often require the affine map to be a projected permuatation, which only the latter is. --- .../Conversion/TosaToLinalg/TosaToLinalg.cpp | 10 ++++++++-- .../TosaToLinalg/tosa-to-linalg.mlir | 20 +++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 18a92c70e24153..34b6a8d6b10a7c 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -925,8 +925,14 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, auto shape = cast(operand.getType()).getShape(); SmallVector affineExprs; for (auto it : llvm::enumerate(shape)) { - auto affineExpr = it.value() == 1 ? rewriter.getAffineConstantExpr(0) - : rewriter.getAffineDimExpr(it.index()); + // Prefer producting identity maps whenever possible (i.e. no broadcasting + // needed) because some transforms (like reshape folding) + // do not support affine constant exprs. + bool requiresBroadcast = + (it.value() == 1 && resultType.getDimSize(it.index()) != 1); + auto affineExpr = requiresBroadcast + ? rewriter.getAffineConstantExpr(0) + : rewriter.getAffineDimExpr(it.index()); affineExprs.push_back(affineExpr); } return AffineMap::get(rank, 0, affineExprs, rewriter.getContext()); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index d4e5d5ee92408b..1bf038cc69aef1 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -250,6 +250,26 @@ func.func @test_add_1d_broadcast_static_to_static(%arg0: tensor<1xf32>, %arg1: t // ----- +// CHECK: #[[$MAP:.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: @test_add_1d_matching_no_broadcast +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: +func.func @test_add_1d_matching_no_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + + // CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1xf32> + // CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP]], #[[$MAP]], #[[$MAP]]], iterator_types = ["parallel"]} ins(%[[ARG0]], %[[ARG1]] : tensor<1xf32>, tensor<1xf32>) outs(%[[VAL_0]] : tensor<1xf32>) { + // CHECK: ^bb0(%[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: f32, %[[VAL_3:.*]]: f32): + // CHECK: %[[VAL_4:.*]] = arith.addf %[[VAL_1]], %[[VAL_2]] : f32 + // CHECK: linalg.yield %[[VAL_4]] : f32 + // CHECK: } -> tensor<1xf32> + %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + + // CHECK: return %[[RESULT]] : tensor<1xf32> + return %0 : tensor<1xf32> +} + +// ----- + // CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: @test_add_1d_matching_static // CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: From 2015abf98f34f27d436d9ad943b4031982a6e07a Mon Sep 17 00:00:00 2001 From: josel-amd <166385423+josel-amd@users.noreply.github.com> Date: Fri, 18 Oct 2024 17:06:26 +0200 Subject: [PATCH 27/46] Fix attr aliasing on region args (#389) * Fix for aliasing the region args * Add test case * Add empty line --- mlir/lib/IR/AsmPrinter.cpp | 1 + mlir/test/IR/test-region-attr-aliasing.mlir | 11 +++++++++++ 2 files changed, 12 insertions(+) create mode 100644 mlir/test/IR/test-region-attr-aliasing.mlir diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 5d0eb7670b3bd4..b180f082fb9f7e 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -779,6 +779,7 @@ class DummyAliasOperationPrinter : private OpAsmPrinter { void printRegionArgument(BlockArgument arg, ArrayRef argAttrs, bool omitType) override { printType(arg.getType()); + printOptionalAttrDict(argAttrs); // Visit the argument location. if (printerFlags.shouldPrintDebugInfo()) // TODO: Allow deferring argument locations. diff --git a/mlir/test/IR/test-region-attr-aliasing.mlir b/mlir/test/IR/test-region-attr-aliasing.mlir new file mode 100644 index 00000000000000..ed1d2bd823471d --- /dev/null +++ b/mlir/test/IR/test-region-attr-aliasing.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt %s | FileCheck %s + +#map = affine_map<(d0) -> (d0)> +// CHECK: {builtin.test = #map} +func.func @test_attr_alias_on_region_attr(%arg0: memref<2xf32> {builtin.test = #map}) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 0 : index + %2 = memref.load %arg0[%c0] : memref<2xf32> + memref.store %2, %arg0[%c1] : memref<2xf32> + return +} From 4b36487cc776194587f55644481dd734fcfed505 Mon Sep 17 00:00:00 2001 From: josel-amd <166385423+josel-amd@users.noreply.github.com> Date: Tue, 22 Oct 2024 10:02:35 +0200 Subject: [PATCH 28/46] Copy attributes from the original operation (SCF::ForOp) into the lowered version (Emitc::ForOp) (#390) --- mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp | 3 +++ mlir/test/Conversion/SCFToEmitC/for.mlir | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index 71c566eb80a2d0..51490c79ce4904 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -118,6 +118,9 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor, emitc::ForOp loweredFor = rewriter.create( loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep()); + // Propagate any attributes from the ODS forOp to the lowered emitc::for op. + loweredFor->setAttrs(forOp->getAttrs()); + Block *loweredBody = loweredFor.getBody(); // Erase the auto-generated terminator for the lowered for op. diff --git a/mlir/test/Conversion/SCFToEmitC/for.mlir b/mlir/test/Conversion/SCFToEmitC/for.mlir index 071e968e7ce171..b422aaa4545d9b 100644 --- a/mlir/test/Conversion/SCFToEmitC/for.mlir +++ b/mlir/test/Conversion/SCFToEmitC/for.mlir @@ -98,3 +98,12 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 // CHECK-NEXT: } // CHECK-NEXT: return %[[VAL_4]] : f32 // CHECK-NEXT: } + +func.func @loop_with_attr(%arg0 : index, %arg1 : index, %arg2 : index) { + scf.for %i0 = %arg0 to %arg1 step %arg2 { + %c1 = arith.constant 1 : index + } {test.value = 5 : index} + return +} +// CHECK-LABEL: func.func @loop_with_attr +// CHECK: {test.value = 5 : index} From ad4697caa85268496056753ad3a145f051af78dc Mon Sep 17 00:00:00 2001 From: josel-amd <166385423+josel-amd@users.noreply.github.com> Date: Mon, 28 Oct 2024 16:47:11 +0100 Subject: [PATCH 29/46] OpaqueType with format strings (#391) OpaqueType: Use format string --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.h | 4 + mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 6 +- .../mlir/Dialect/EmitC/IR/EmitCTypes.td | 11 +- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 154 ++++++++++++------ mlir/lib/Target/Cpp/TranslateToCpp.cpp | 21 ++- mlir/test/Dialect/EmitC/invalid_types.mlir | 28 ++++ mlir/test/Dialect/EmitC/types.mlir | 6 + mlir/test/Target/Cpp/types.mlir | 10 ++ 8 files changed, 182 insertions(+), 58 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h index bc82f58a7ee95c..d9cca43081c98c 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -52,6 +52,10 @@ bool isPointerWideType(mlir::Type type); /// Give the name of the EmitC reference attribute. StringRef getReferenceAttributeName(); +// Either a literal string, or an placeholder for the fmtArgs. +struct Placeholder {}; +using ReplacementItem = std::variant; + } // namespace emitc } // namespace mlir diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 0de8787ba1dc8f..1a1b58e3cbf386 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -1168,11 +1168,7 @@ def EmitC_VerbatimOp : EmitC_Op<"verbatim"> { }]; let extraClassDeclaration = [{ - // Either a literal string, or an placeholder for the fmtArgs. - struct Placeholder {}; - using ReplacementItem = std::variant; - - FailureOr> parseFormatString(); + FailureOr> parseFormatString(); }]; let arguments = (ins StrAttr:$value, diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td index 79f6d34fc91b13..0fbacd440a91de 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitCTypes.td @@ -99,9 +99,16 @@ def EmitC_OpaqueType : EmitC_Type<"Opaque", "opaque"> { ``` }]; - let parameters = (ins StringRefParameter<"the opaque value">:$value); - let assemblyFormat = "`<` $value `>`"; + let parameters = (ins StringRefParameter<"the opaque value">:$value, + OptionalArrayRefParameter<"Type">:$fmtArgs); + let assemblyFormat = "`<` $value (`,` custom($fmtArgs)^)? `>`"; let genVerifyDecl = 1; + + let builders = [TypeBuilder<(ins "::llvm::StringRef":$value), [{ return $_get($_ctxt, value, SmallVector{}); }] >]; + + let extraClassDeclaration = [{ + FailureOr> parseFormatString(); + }]; } def EmitC_PointerType : EmitC_Type<"Pointer", "ptr"> { diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 7bc40b4f555cc6..ee44f524c91428 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -20,6 +20,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/FormatVariadic.h" using namespace mlir; using namespace mlir::emitc; @@ -154,6 +155,64 @@ static LogicalResult verifyInitializationAttribute(Operation *op, return success(); } +/// Parse a format string and return a list of its parts. +/// A part is either a StringRef that has to be printed as-is, or +/// a Placeholder which requires printing the next operand of the VerbatimOp. +/// In the format string, all `{}` are replaced by Placeholders, except if the +/// `{` is escaped by `{{` - then it doesn't start a placeholder. +template +FailureOr> +parseFormatString(StringRef toParse, ArgType fmtArgs, + std::optional> + emitError = {}) { + SmallVector items; + + // If there are not operands, the format string is not interpreted. + if (fmtArgs.empty()) { + items.push_back(toParse); + return items; + } + + while (!toParse.empty()) { + size_t idx = toParse.find('{'); + if (idx == StringRef::npos) { + // No '{' + items.push_back(toParse); + break; + } + if (idx > 0) { + // Take all chars excluding the '{'. + items.push_back(toParse.take_front(idx)); + toParse = toParse.drop_front(idx); + continue; + } + if (toParse.size() < 2) { + // '{' is last character + items.push_back(toParse); + break; + } + // toParse contains at least two characters and starts with `{`. + char nextChar = toParse[1]; + if (nextChar == '{') { + // Double '{{' -> '{' (escaping). + items.push_back(toParse.take_front(1)); + toParse = toParse.drop_front(2); + continue; + } + if (nextChar == '}') { + items.push_back(Placeholder{}); + toParse = toParse.drop_front(2); + continue; + } + + if (emitError.has_value()) { + return (*emitError)() << "expected '}' after unescaped '{'"; + } + return failure(); + } + return items; +} + //===----------------------------------------------------------------------===// // AddOp //===----------------------------------------------------------------------===// @@ -914,7 +973,11 @@ LogicalResult emitc::SubscriptOp::verify() { //===----------------------------------------------------------------------===// LogicalResult emitc::VerbatimOp::verify() { - FailureOr> fmt = parseFormatString(); + auto errorCallback = [&]() -> InFlightDiagnostic { + return this->emitOpError(); + }; + FailureOr> fmt = + ::parseFormatString(getValue(), getFmtArgs(), errorCallback); if (failed(fmt)) return failure(); @@ -929,56 +992,29 @@ LogicalResult emitc::VerbatimOp::verify() { return success(); } -/// Parse a format string and return a list of its parts. -/// A part is either a StringRef that has to be printed as-is, or -/// a Placeholder which requires printing the next operand of the VerbatimOp. -/// In the format string, all `{}` are replaced by Placeholders, except if the -/// `{` is escaped by `{{` - then it doesn't start a placeholder. -FailureOr> -emitc::VerbatimOp::parseFormatString() { - SmallVector items; +static ParseResult parseVariadicTypeFmtArgs(AsmParser &p, + SmallVector ¶ms) { + Type type; + if (p.parseType(type)) + return failure(); - // If there are not operands, the format string is not interpreted. - if (getFmtArgs().empty()) { - items.push_back(getValue()); - return items; + params.push_back(type); + while (succeeded(p.parseOptionalComma())) { + if (p.parseType(type)) + return failure(); + params.push_back(type); } - StringRef toParse = getValue(); - while (!toParse.empty()) { - size_t idx = toParse.find('{'); - if (idx == StringRef::npos) { - // No '{' - items.push_back(toParse); - break; - } - if (idx > 0) { - // Take all chars excluding the '{'. - items.push_back(toParse.take_front(idx)); - toParse = toParse.drop_front(idx); - continue; - } - if (toParse.size() < 2) { - // '{' is last character - items.push_back(toParse); - break; - } - // toParse contains at least two characters and starts with `{`. - char nextChar = toParse[1]; - if (nextChar == '{') { - // Double '{{' -> '{' (escaping). - items.push_back(toParse.take_front(1)); - toParse = toParse.drop_front(2); - continue; - } - if (nextChar == '}') { - items.push_back(Placeholder{}); - toParse = toParse.drop_front(2); - continue; - } - return emitOpError() << "expected '}' after unescaped '{'"; - } - return items; + return success(); +} + +static void printVariadicTypeFmtArgs(AsmPrinter &p, ArrayRef params) { + llvm::interleaveComma(params, p, [&](Type type) { p.printType(type); }); +} + +FailureOr> emitc::VerbatimOp::parseFormatString() { + // Error checking is done in verify. + return ::parseFormatString(getValue(), getFmtArgs()); } //===----------------------------------------------------------------------===// @@ -1072,7 +1108,7 @@ emitc::ArrayType::cloneWith(std::optional> shape, LogicalResult mlir::emitc::OpaqueType::verify( llvm::function_ref emitError, - llvm::StringRef value) { + llvm::StringRef value, ArrayRef fmtArgs) { if (value.empty()) { return emitError() << "expected non empty string in !emitc.opaque type"; } @@ -1080,9 +1116,29 @@ LogicalResult mlir::emitc::OpaqueType::verify( return emitError() << "pointer not allowed as outer type with " "!emitc.opaque, use !emitc.ptr instead"; } + + FailureOr> fmt = + ::parseFormatString(value, fmtArgs, emitError); + if (failed(fmt)) + return failure(); + + size_t numPlaceholders = llvm::count_if(*fmt, [](ReplacementItem &item) { + return std::holds_alternative(item); + }); + + if (numPlaceholders != fmtArgs.size()) { + return emitError() + << "requires operands for each placeholder in the format string"; + } + return success(); } +FailureOr> emitc::OpaqueType::parseFormatString() { + // Error checking is done in verify. + return ::parseFormatString(getValue(), getFmtArgs()); +} + //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 73256451ef1487..bc73d415e6e8c8 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -512,14 +512,14 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::VerbatimOp verbatimOp) { raw_ostream &os = emitter.ostream(); - FailureOr> items = + FailureOr> items = verbatimOp.parseFormatString(); if (failed(items)) return failure(); auto fmtArg = verbatimOp.getFmtArgs().begin(); - for (emitc::VerbatimOp::ReplacementItem &item : *items) { + for (ReplacementItem &item : *items) { if (auto *str = std::get_if(&item)) { os << *str; } else { @@ -1728,6 +1728,23 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) { if (auto tType = dyn_cast(type)) return emitTupleType(loc, tType.getTypes()); if (auto oType = dyn_cast(type)) { + FailureOr> items = oType.parseFormatString(); + if (failed(items)) + return failure(); + + auto fmtArg = oType.getFmtArgs().begin(); + for (ReplacementItem &item : *items) { + if (auto *str = std::get_if(&item)) { + os << *str; + } else { + if (failed(emitType(loc, *fmtArg++))) { + return failure(); + } + } + } + + return success(); + os << oType.getValue(); return success(); } diff --git a/mlir/test/Dialect/EmitC/invalid_types.mlir b/mlir/test/Dialect/EmitC/invalid_types.mlir index ee59d90bf7f617..616f0480a19d91 100644 --- a/mlir/test/Dialect/EmitC/invalid_types.mlir +++ b/mlir/test/Dialect/EmitC/invalid_types.mlir @@ -14,6 +14,34 @@ func.func @illegal_opaque_type_2() { // ----- +// expected-error @+1 {{expected non-function type}} +func.func @illegal_opaque_type(%arg0: !emitc.opaque<"{}, {}", "string">) { + return +} + +// ----- + +// expected-error @+1 {{requires operands for each placeholder in the format string}} +func.func @illegal_opaque_type(%arg0: !emitc.opaque<"a", f32>) { + return +} + +// ----- + + // expected-error @+1 {{requires operands for each placeholder in the format string}} +func.func @illegal_opaque_type(%arg0: !emitc.opaque<"{}, {}", f32>) { + return +} + +// ----- + +// expected-error @+1 {{expected '}' after unescaped '{'}} +func.func @illegal_opaque_type(%arg0: !emitc.opaque<"{ ", i32>) { + return +} + +// ----- + func.func @illegal_array_missing_spec( // expected-error @+1 {{expected non-function type}} %arg0: !emitc.array<>) { diff --git a/mlir/test/Dialect/EmitC/types.mlir b/mlir/test/Dialect/EmitC/types.mlir index b53976eff84cad..eca23c75263ee1 100644 --- a/mlir/test/Dialect/EmitC/types.mlir +++ b/mlir/test/Dialect/EmitC/types.mlir @@ -38,6 +38,12 @@ func.func @opaque_types() { emitc.call_opaque "f"() {template_args = [!emitc.opaque<"std::vector">]} : () -> () // CHECK-NEXT: !emitc.opaque<"SmallVector"> emitc.call_opaque "f"() {template_args = [!emitc.opaque<"SmallVector">]} : () -> () + // CHECK-NEXT: !emitc.opaque<"{}", i32> + emitc.call_opaque "f"() {template_args = [!emitc>]} : () -> () + // CHECK-NEXT: !emitc.opaque<"{}, {}", i32, f32>] + emitc.call_opaque "f"() {template_args = [!emitc>]} : () -> () + // CHECK-NEXT: !emitc.opaque<"{}" + emitc.call_opaque "f"() {template_args = [!emitc>]} : () -> () return } diff --git a/mlir/test/Target/Cpp/types.mlir b/mlir/test/Target/Cpp/types.mlir index deda383b3b0a72..336dfacaa183a4 100644 --- a/mlir/test/Target/Cpp/types.mlir +++ b/mlir/test/Target/Cpp/types.mlir @@ -12,6 +12,16 @@ func.func @opaque_types() { emitc.call_opaque "f"() {template_args = [!emitc>]} : () -> () // CHECK-NEXT: f>(); emitc.call_opaque "f"() {template_args = [!emitc.opaque<"std::vector">]} : () -> () + // CHECK: f() + emitc.call_opaque "f"() {template_args = [!emitc>]} : () -> () + // CHECK: f(); + emitc.call_opaque "f"() {template_args = [!emitc>]} : () -> () + // CHECK: f(); + emitc.call_opaque "f"() {template_args = [!emitc>]} : () -> () + // CHECK: f(); + emitc.call_opaque "f"() {template_args = [!emitc> >>]} : () -> () + // CHECK: f,int32_t>>(); + emitc.call_opaque "f"() {template_args = [!emitc", !emitc", f32>>, i32>>]} : () -> () return } From a64ebcc26cb43c6043f219dba4b643585b0cf15a Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Tue, 5 Nov 2024 09:27:49 +0100 Subject: [PATCH 30/46] Add emitc.tu --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 43 +++++++++++++++++++ mlir/include/mlir/Target/Cpp/CppEmitter.h | 4 +- mlir/lib/Target/Cpp/TranslateRegistration.cpp | 8 +++- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 37 +++++++++++++--- mlir/test/Target/Cpp/tu.mlir | 29 +++++++++++++ 5 files changed, 112 insertions(+), 9 deletions(-) create mode 100644 mlir/test/Target/Cpp/tu.mlir diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 1a1b58e3cbf386..78c420997dac65 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -21,7 +21,9 @@ include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpAsmInterface.td" include "mlir/IR/RegionKindInterface.td" +include "mlir/IR/BuiltinAttributes.td" //===----------------------------------------------------------------------===// // EmitC op definitions @@ -55,6 +57,47 @@ def IntegerIndexOrOpaqueType : Type; def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[EmitCFloatType, IntegerIndexOrOpaqueType]>; +def EmitC_TranslationUnitOp : EmitC_Op<"tu", + [IsolatedFromAbove, NoRegionArguments, SymbolTable, + OpAsmOpInterface + ] # GraphRegionNoTerminator.traits> { + let summary = "A translation unit container operation"; + let description = [{ + A `tu` represents a translation unit that can be emitted + into a single C++ file. + + `mlir-translate` emits only the translation unit selected via + the `-translation-unit-id=id` flag. By default, no translation units are + emitted. + + Example: + + ```mlir + emitc.tu "main" { + emitc.func @func_one() { + emitc.return + } + } + ``` + }]; + + let arguments = (ins Builtin_StringAttr:$id); + let regions = (region SizedRegion<1>:$bodyRegion); + + let assemblyFormat = "$id attr-dict-with-keyword $bodyRegion"; + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // OpAsmOpInterface Methods + //===------------------------------------------------------------------===// + + /// EmitC ops in the body of the translation_unit can omit their 'emitc.' + /// prefix in the assembly. + static ::llvm::StringRef getDefaultDialect() { + return "emitc"; + } + }]; +} + def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> { let summary = "Addition operation"; let description = [{ diff --git a/mlir/include/mlir/Target/Cpp/CppEmitter.h b/mlir/include/mlir/Target/Cpp/CppEmitter.h index 99d8696cc8e077..d76cfc9107332e 100644 --- a/mlir/include/mlir/Target/Cpp/CppEmitter.h +++ b/mlir/include/mlir/Target/Cpp/CppEmitter.h @@ -14,6 +14,7 @@ #define MLIR_TARGET_CPP_CPPEMITTER_H #include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" namespace mlir { class Operation; @@ -24,7 +25,8 @@ namespace emitc { /// 'declareVariablesAtTop' enforces that all variables for op results and block /// arguments are declared at the beginning of the function. LogicalResult translateToCpp(Operation *op, raw_ostream &os, - bool declareVariablesAtTop = false); + bool declareVariablesAtTop = false, + StringRef onlyTu = ""); } // namespace emitc } // namespace mlir diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp index 1aa98834a73f49..7e2bc9ad012b38 100644 --- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp +++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp @@ -29,12 +29,18 @@ void registerToCppTranslation() { llvm::cl::desc("Declare variables at top when emitting C/C++"), llvm::cl::init(false)); + static llvm::cl::opt onlyTu( + "translation-unit-id", + llvm::cl::desc("Only emit the translation unit with the matching id"), + llvm::cl::init("")); + TranslateFromMLIRRegistration reg( "mlir-to-cpp", "translate from mlir to cpp", [](Operation *op, raw_ostream &output) { return emitc::translateToCpp( op, output, - /*declareVariablesAtTop=*/declareVariablesAtTop); + /*declareVariablesAtTop=*/declareVariablesAtTop, + /*onlyTu=*/onlyTu); }, [](DialectRegistry ®istry) { // clang-format off diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index bc73d415e6e8c8..60ad20bf7a0926 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -115,7 +115,8 @@ static FailureOr getOperatorPrecedence(Operation *operation) { namespace { /// Emitter that uses dialect specific emitters to emit C++ code. struct CppEmitter { - explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop); + explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop, + StringRef onlyTu); /// Emits attribute or returns failure. LogicalResult emitAttribute(Location loc, Attribute attr); @@ -231,6 +232,9 @@ struct CppEmitter { /// be declared at the beginning of a function. bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; }; + /// Returns whether this translation unit should be emitted + bool shouldEmitTu(TranslationUnitOp tu) { return tu.getId() == onlyTu; } + /// Get expression currently being emitted. ExpressionOp getEmittedExpression() { return emittedExpression; } @@ -258,6 +262,9 @@ struct CppEmitter { /// includes results from ops located in nested regions. bool declareVariablesAtTop; + /// Only emit translation units whos id matches this value. + std::string onlyTu; + /// Map from value to name of C++ variable that contain the name. ValueMapper valueMapper; @@ -936,6 +943,19 @@ static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { return success(); } +static LogicalResult printOperation(CppEmitter &emitter, TranslationUnitOp tu) { + if (!emitter.shouldEmitTu(tu)) + return success(); + + CppEmitter::Scope scope(emitter); + + for (Operation &op : tu) { + if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false))) + return failure(); + } + return success(); +} + template static LogicalResult printFunctionArgs(CppEmitter &emitter, FuncOpClass functionOp, @@ -1177,8 +1197,10 @@ static LogicalResult printOperation(CppEmitter &emitter, return success(); } -CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop) - : os(os), declareVariablesAtTop(declareVariablesAtTop) { +CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop, + StringRef onlyTu) + : os(os), declareVariablesAtTop(declareVariablesAtTop), + onlyTu(onlyTu.str()) { valueInScopeCount.push(0); labelInScopeCount.push(0); } @@ -1580,8 +1602,8 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp, - emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp, - emitc::VerbatimOp>( + emitc::TranslationUnitOp, emitc::UnaryMinusOp, + emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. .Case( @@ -1791,7 +1813,8 @@ LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef types) { } LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os, - bool declareVariablesAtTop) { - CppEmitter emitter(os, declareVariablesAtTop); + bool declareVariablesAtTop, + StringRef onlyTu) { + CppEmitter emitter(os, declareVariablesAtTop, onlyTu); return emitter.emitOperation(*op, /*trailingSemicolon=*/false); } diff --git a/mlir/test/Target/Cpp/tu.mlir b/mlir/test/Target/Cpp/tu.mlir new file mode 100644 index 00000000000000..ca10e0263a64fc --- /dev/null +++ b/mlir/test/Target/Cpp/tu.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s --check-prefix NO-FILTER +// RUN: mlir-translate -mlir-to-cpp -translation-unit-id=non-existing %s | FileCheck %s --check-prefix NON-EXISTING +// RUN: mlir-translate -mlir-to-cpp -translation-unit-id=tu_one %s | FileCheck %s --check-prefix TU-ONE +// RUN: mlir-translate -mlir-to-cpp -translation-unit-id=tu_two %s | FileCheck %s --check-prefix TU-TWO + + +// NO-FILTER-NOT: func_one +// NO-FILTER-NOT: func_two + +// NON-EXISTING-NOT: func_one +// NON-EXISTING-NOT: func_two + +// TU-ONE: func_one +// TU-ONE-NOT: func_two + +// TU-TWO-NOT: func_one +// TU-TWO: func_two + +emitc.tu "tu_one" { + emitc.func @func_one(%arg: f32) { + emitc.return + } +} + +emitc.tu "tu_two" { + emitc.func @func_two(%arg: f32) { + emitc.return + } +} From 22fcd3e1d2c4c623fb46d0254b1d67f606891e30 Mon Sep 17 00:00:00 2001 From: Ian Wood <75152913+IanWood1@users.noreply.github.com> Date: Wed, 14 Aug 2024 13:58:35 -0700 Subject: [PATCH 31/46] [mlir] Add bubbling patterns for non intersecting reshapes (#103401) Refactored @Max191's PR https://github.com/llvm/llvm-project/pull/94637 to move it to `Tensor` From the original PR >This PR adds fusion by expansion patterns to push a tensor.expand_shape up through a tensor.collapse_shape with non-intersecting reassociations. Sometimes parallel collapse_shape ops like this can block propagation of expand_shape ops, so this allows them to pass through each other. I'm not sure if I put the code/tests in the right places, so let me know where those go if they aren't. cc @MaheshRavishankar @hanhanW --------- Co-authored-by: Max Dawkins --- .../Dialect/Tensor/Transforms/Transforms.h | 4 + .../Linalg/Transforms/ElementwiseOpFusion.cpp | 2 + .../Tensor/Transforms/ReshapePatterns.cpp | 75 +++++++++++++++++++ mlir/test/Dialect/Tensor/bubble-reshapes.mlir | 47 ++++++++++++ .../Dialect/Tensor/TestTensorTransforms.cpp | 13 ++++ 5 files changed, 141 insertions(+) create mode 100644 mlir/test/Dialect/Tensor/bubble-reshapes.mlir diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h index 7f983b8b3cfd06..ae695e0326ca1a 100644 --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -67,6 +67,10 @@ void populateDropRedundantInsertSliceRankExpansionPatterns( /// `tensor.collapse_shape` into other ops. void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns); +/// Populates `patterns` with patterns that bubble up `tensor.expand_shape` +/// through `tensor.collapse_shape` ops. +void populateBubbleUpExpandShapePatterns(RewritePatternSet &patterns); + /// Populates `patterns` with patterns that fold tensor.empty with its /// consumers. /// diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index e73df61c964341..9f1b6fdc55df3b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Matchers.h" @@ -2144,6 +2145,7 @@ struct LinalgElementwiseOpFusionPass // Add elementwise op fusion patterns. populateElementwiseOpsFusionPatterns(patterns, defaultControlFn); populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn); + tensor::populateBubbleUpExpandShapePatterns(patterns); // General canonicalization patterns. affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context); diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index be0d71866a095e..5edd7a02bc42b1 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -140,6 +140,76 @@ struct FoldPaddingExpandIntoInsert : public OpRewritePattern { return success(); } }; + +/// Pattern to bubble up a tensor.expand_shape op through a producer +/// tensor.collapse_shape op that has non intersecting reassociations. +struct BubbleUpExpandThroughParallelCollapse + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp, + PatternRewriter &rewriter) const override { + auto collapseOp = + expandOp.getSrc().getDefiningOp(); + if (!collapseOp) + return failure(); + auto expandReInds = expandOp.getReassociationIndices(); + auto collapseReInds = collapseOp.getReassociationIndices(); + + // Reshapes are parallel to each other if none of the reassociation indices + // have greater than 1 index for both reshapes. + for (auto [expandReassociation, collapseReassociation] : + llvm::zip_equal(expandReInds, collapseReInds)) { + if (collapseReassociation.size() != 1 && expandReassociation.size() != 1) + return failure(); + } + + // Compute new reassociation indices and expanded/collaped shapes. + SmallVector newExpandReInds, newCollapseReInds; + Location loc = expandOp->getLoc(); + SmallVector collapseSizes = + tensor::getMixedSizes(rewriter, loc, collapseOp.getSrc()); + SmallVector expandSizes(getMixedValues( + expandOp.getStaticOutputShape(), expandOp.getOutputShape(), rewriter)); + SmallVector newExpandSizes; + int64_t index = 0, expandIndex = 0, collapseIndex = 0; + for (auto [idx, collapseReassociation] : llvm::enumerate(collapseReInds)) { + if (collapseReassociation.size() != 1) { + ReassociationIndices newCollapseReassociation; + for (size_t i = 0; i < collapseReassociation.size(); ++i) { + newCollapseReassociation.push_back(index); + newExpandReInds.push_back({index++}); + newExpandSizes.push_back(collapseSizes[collapseIndex++]); + } + newCollapseReInds.push_back(newCollapseReassociation); + expandIndex++; + continue; + } + ReassociationIndices newExpandReassociation; + auto expandReassociation = expandReInds[idx]; + for (size_t i = 0; i < expandReassociation.size(); ++i) { + newExpandReassociation.push_back(index); + newCollapseReInds.push_back({index++}); + newExpandSizes.push_back(expandSizes[expandIndex++]); + } + newExpandReInds.push_back(newExpandReassociation); + collapseIndex++; + } + + // Swap reshape order. + SmallVector dynamicSizes; + SmallVector staticSizes; + dispatchIndexOpFoldResults(newExpandSizes, dynamicSizes, staticSizes); + auto expandResultType = expandOp.getResultType().clone(staticSizes); + auto newExpand = rewriter.create( + loc, expandResultType, collapseOp.getSrc(), newExpandReInds, + newExpandSizes); + rewriter.replaceOpWithNewOp( + expandOp, newExpand.getResult(), newCollapseReInds); + return success(); + } +}; + } // namespace void mlir::tensor::populateReassociativeReshapeFoldingPatterns( @@ -152,3 +222,8 @@ void mlir::tensor::populateReassociativeReshapeFoldingPatterns( FoldPaddingExpandIntoInsert>( patterns.getContext()); } + +void mlir::tensor::populateBubbleUpExpandShapePatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Tensor/bubble-reshapes.mlir b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir new file mode 100644 index 00000000000000..cf6b12852bcd39 --- /dev/null +++ b/mlir/test/Dialect/Tensor/bubble-reshapes.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-expand-shape-bubbling %s | FileCheck %s + +func.func @bubble_parallel_reshapes(%arg0: tensor, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor { + %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor into tensor + %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]] + output_shape [%s0, %s1, %s2, %s3] : tensor into tensor + return %expand : tensor +} +// CHECK: func @bubble_parallel_reshapes +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]] +// CHECK-SAME: output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor into tensor +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor into tensor +// CHECK: return %[[COLLAPSE]] + +// ----- + +func.func @no_bubble_full_intersecting_reshapes(%arg0: tensor, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor { + %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor into tensor + %expand = tensor.expand_shape %collapse [[0], [1, 2], [3]] + output_shape [%s0, %s1, %s2, %s3] : tensor into tensor + return %expand : tensor +} +// CHECK: func @no_bubble_full_intersecting_reshapes +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]] +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]] +// CHECK: return %[[EXPAND]] + +// ----- + +func.func @no_bubble_partial_intersecting_reshapes(%arg0: tensor, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor { + %collapse = tensor.collapse_shape %arg0 [[0, 1, 2], [3]] : tensor into tensor + %expand = tensor.expand_shape %collapse [[0, 1], [2, 3]] + output_shape [%s0, %s1, %s2, %s3] : tensor into tensor + return %expand : tensor +} +// CHECK: func @no_bubble_partial_intersecting_reshapes +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1], [2, 3]] +// CHECK: return %[[EXPAND]] diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp index ae4f77f5873e2b..34de600132f5de 100644 --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -72,6 +72,11 @@ struct TestTensorTransforms llvm::cl::desc("Test folding of expand_shape/collapse_shape"), llvm::cl::init(false)}; + Option testBubbleUpExpandShapePatterns{ + *this, "test-expand-shape-bubbling", + llvm::cl::desc("Test folding of expand_shape/collapse_shape"), + llvm::cl::init(false)}; + Option testFoldIntoPackAndUnpack{ *this, "test-fold-into-pack-and-unpack", llvm::cl::desc("Test folding ops into tensor.pack and tensor.unpack"), @@ -102,6 +107,12 @@ static void applyReassociativeReshapeFoldingPatterns(Operation *rootOp) { (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +static void applyBubbleUpExpandShapePatterns(Operation *rootOp) { + RewritePatternSet patterns(rootOp->getContext()); + tensor::populateBubbleUpExpandShapePatterns(patterns); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + static void applyFoldIntoPackAndUnpackPatterns(Operation *rootOp) { RewritePatternSet patterns(rootOp->getContext()); tensor::populateFoldIntoPackAndUnpackPatterns(patterns); @@ -386,6 +397,8 @@ void TestTensorTransforms::runOnOperation() { applyDropRedundantInsertSliceRankExpansionPatterns(rootOp); if (testReassociativeReshapeFolding) applyReassociativeReshapeFoldingPatterns(rootOp); + if (testBubbleUpExpandShapePatterns) + applyBubbleUpExpandShapePatterns(rootOp); if (testFoldIntoPackAndUnpack) applyFoldIntoPackAndUnpackPatterns(rootOp); if (testRewriteExtractSliceWithTiledCollapseShape) { From 7a11b633b3c5de6e56d56c03f794738323a8cd82 Mon Sep 17 00:00:00 2001 From: Yun-Fly Date: Fri, 23 Aug 2024 10:07:17 +0800 Subject: [PATCH 32/46] [mlir][tensor] Add consumer fusion for `tensor.pack` op. (#103715) Add missing `getIterationDomainTileFromOperandTile` and `getTiledImplementationFromOperandTile` to `tensor.pack` and enable fusing it as a consumer. NOTE that, it only expects perfect tiling scenario without padding semantic currently. --- .../Tensor/IR/TensorTilingInterfaceImpl.cpp | 114 ++++++++++++++++++ .../tile-and-fuse-consumer.mlir | 59 +++++++++ 2 files changed, 173 insertions(+) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index 361340a4e62f2d..dec678de6d1c27 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -246,6 +246,120 @@ struct PackOpTiling return failure(); return tilingResult.value(); } + + /// Method to return the position of iteration domain tile computed by the + /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and + /// `resultSizes` only cover outer dimensions. + LogicalResult getIterationDomainTileFromOperandTile( + Operation *op, OpBuilder &b, unsigned operandNumber, + ArrayRef offsets, ArrayRef sizes, + SmallVectorImpl &resultOffsets, + SmallVectorImpl &resultSizes) const { + if (operandNumber != 0) + return failure(); + + auto packOp = cast(op); + // It is not trivial to infer dest tile from source tile if `packOp` has + // padding semantic. + if (packOp.getPaddingValue()) + return failure(); + + Location loc = packOp.getLoc(); + + SmallVector outerDimOffsets, outerDimSizes; + DenseMap dimAndTileMapping = + packOp.getDimAndTileMapping(); + for (auto dim : packOp.getOuterDimsPerm()) { + if (dimAndTileMapping.count(dim)) { + FailureOr cstSize = + ValueBoundsConstraintSet::computeConstantBound( + presburger::BoundType::UB, sizes[dim], + /*stopCondition=*/nullptr, /*closedUB=*/true); + std::optional cstInnerSize = + getConstantIntValue(dimAndTileMapping[dim]); + // Currently fusing `packOp` as consumer only expects perfect tiling + // scenario because even if without padding semantic, the `packOp` may + // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>, + // where the `tileSize` from operand of `packOp` is 5, which is not + // exactly divided by `innerTile`(=6) of `packOp`. As the result: + // 1. the first slice is extracted from (0) to (4) and inserted into + // (0,0)~(0,4) at first row. + // 2. the second slice is extracted from (5) to (9) and SHOULD BE + // respectively inserted into two rows with different length, including + // first row: (0,5) and second row (1,0)~(1,3). It is hard to coordinate + // them, thus adding below constraint to bypass them temporarily. In + // another word, we can only support tiling with consumer if the tile + // size for the producer is a multiple of the inner tile size for the + // packed dimensions at this moment. + if (failed(cstSize) || !cstInnerSize || *cstSize % *cstInnerSize != 0) { + return failure(); + } + + using AV = affine::AffineValueExpr; + affine::AffineBuilder ab(b, loc); + AffineExpr dim0, sym; + bindDims(b.getContext(), dim0); + bindSymbols(b.getContext(), sym); + auto avOffset = AV(dim0).bind(offsets[dim]); + auto avSize = AV(dim0).bind(sizes[dim]); + auto avTileSize = AV(sym).bind(dimAndTileMapping[dim]); + outerDimOffsets.push_back(ab.floor(avOffset, avTileSize)); + outerDimSizes.push_back(ab.ceil(avSize, avTileSize)); + } else { + outerDimOffsets.push_back(offsets[dim]); + outerDimSizes.push_back(sizes[dim]); + } + } + + resultOffsets = outerDimOffsets; + resultSizes = outerDimSizes; + return success(); + } + + /// Method to return the tiled implementation of tensor.pack as a consumer. + FailureOr getTiledImplementationFromOperandTile( + Operation *op, OpBuilder &b, unsigned operandNumber, + ArrayRef offsets, ArrayRef sizes) const { + if (operandNumber != 0) + return failure(); + + auto packOp = cast(op); + Location loc = packOp.getLoc(); + + int64_t inputRank = packOp.getSourceRank(); + auto oneAttr = b.getI64IntegerAttr(1); + SmallVector strides(inputRank, oneAttr); + + SmallVector tiledOperands; + tiledOperands.push_back(b.create(loc, packOp.getSource(), + offsets, sizes, strides)); + + SmallVector outerDimOffsets, outerDimSizes; + if (failed(getIterationDomainTileFromOperandTile( + op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets, + outerDimSizes))) + return failure(); + + SmallVector outputOffsets, outputSizes; + if (failed(getResultTilePosition(op, b, 0, outerDimOffsets, outerDimSizes, + outputOffsets, outputSizes))) + return failure(); + + strides.append(packOp.getDestRank() - inputRank, oneAttr); + auto extractSlice = b.create( + loc, packOp.getDest(), outputOffsets, outputSizes, strides); + tiledOperands.push_back(extractSlice); + + assert(!packOp.getPaddingValue() && "Expect no padding semantic"); + for (auto tile : packOp.getInnerTiles()) + tiledOperands.push_back(tile); + + Operation *tiledPackOp = b.create( + loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs()); + + return TilingResult{{tiledPackOp}, + SmallVector(tiledPackOp->getResults())}; + } }; struct UnpackTileDimInfo { diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir index 400b558e37fcda..741dfbfb1cd5c2 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir @@ -315,3 +315,62 @@ module attributes {transform.with_named_sequence} { // CHECK: } // CHECK: } // CHECK: return %[[FINAL_RESULT]]#1 : + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +module { + func.func @fuse_pack_consumer_into_scf_forall(%arg0: tensor<32x32xf32>, %arg1: tensor<32x32xf32>, %arg2: tensor<64x32xf32>) -> tensor<4x32x16xf32> { + %c4 = arith.constant 4 : index + %c64 = arith.constant 64 : index + %c0 = arith.constant 0 : index + %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2) -> (tensor<64x32xf32>) { + %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x32xf32> to tensor<32x32xf32> + %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { + ^bb0(%in: f32, %in_16: f32, %out: f32): + %13 = arith.mulf %in, %in_16 : f32 + %14 = arith.addf %out, %13 : f32 + linalg.yield %14 : f32 + } -> tensor<32x32xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x32xf32> + } + } + %output = tensor.empty() : tensor<4x32x16xf32> + %pack = tensor.pack %1 outer_dims_perm = [0, 1] inner_dims_pos = [0] inner_tiles = [16] into %output : tensor<64x32xf32> -> tensor<4x32x16xf32> + return %pack : tensor<4x32x16xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.fuse_consumer %slice_op + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK: #[[PACK_RESULT_MAP:.*]] = affine_map<(d0) -> (d0 floordiv 16)> +// CHECK: func.func @fuse_pack_consumer_into_scf_forall( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<32x32xf32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<64x32xf32>) +// CHECK: %[[OUT_INIT:.*]] = tensor.empty() : tensor<4x32x16xf32> +// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2) +// CHECK-SAME: shared_outs(%[[FIRST_OUT_ARG:.*]] = %[[ARG2]], %[[PACK_OUT_ARG:.*]] = %[[OUT_INIT]]) +// CHECK-SAME: { +// CHECK: %[[GENERIC_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: %[[GENERIC_OUT:.*]] = linalg.generic +// CHECK-SAME: outs(%[[GENERIC_OUT_SLICE]] : +// CHECK: %[[PACK_RESULT_OFFSET:.*]] = affine.apply #[[PACK_RESULT_MAP]](%[[IV1]]) +// CHECK: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1] +// CHECK: %[[TILED_PACK_OUT:.*]] = tensor.pack %[[GENERIC_OUT]] +// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0] inner_tiles = [16] +// CHECK-SAME: into %[[TILED_PACK_DEST]] +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[TILED_PACK_OUT]] into %[[PACK_OUT_ARG]][%[[PACK_RESULT_OFFSET]], %[[IV2]], 0] [2, 32, 16] [1, 1, 1] +// CHECK: tensor.parallel_insert_slice %[[GENERIC_OUT]] into %[[FIRST_OUT_ARG]][%[[IV1]], %[[IV2]]] [32, 32] [1, 1] +// CHECK: } +// CHECK: } +// CHECK: return %[[FINAL_RESULT]]#1 : From 7489677fb64aee05a8e0984a28d2901f0b2297ba Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Fri, 16 Aug 2024 16:22:02 +0100 Subject: [PATCH 33/46] [mlir][linalg] Implement TilingInterface for winograd operators (#96184) In order to support arbitrary size input data of conv2d, implement TilingInterface for winograd operations. Before converting winograd operations into nested loops with matrix multiply, tile the input of conv2d into the supported size first. Add a transform operation structured.decompose_winograd_op to decompose winograd operations. Before applying the transform op, use tile_using_for to tile the input data into supported size. The test case shows how to tile and decompose winograd operations. --- .../mlir/Dialect/Linalg/IR/LinalgOps.td | 141 ++++++- .../Linalg/TransformOps/LinalgTransformOps.td | 37 ++ .../Dialect/Linalg/Transforms/Transforms.h | 57 +++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 397 ++++++++++++++++-- .../TransformOps/LinalgTransformOps.cpp | 41 ++ .../Linalg/Transforms/WinogradConv2D.cpp | 25 +- .../transform-tile-and-winograd-rewrite.mlir | 292 +++++++++++++ .../Linalg/transform-tile-winograd.mlir | 380 +++++++++++++++++ 8 files changed, 1330 insertions(+), 40 deletions(-) create mode 100644 mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir create mode 100644 mlir/test/Dialect/Linalg/transform-tile-winograd.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index a9007c8db3078e..5b6a90f806bedd 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -154,8 +154,13 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax", let hasVerifier = 1; } -def Linalg_WinogradFilterTransformOp : - Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> { +def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform", + [AllElementTypesMatch<["filter", "output"]>, + DeclareOpInterfaceMethods]> { let summary = "Winograd filter transform operator"; let description = [{ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched @@ -190,11 +195,42 @@ def Linalg_WinogradFilterTransformOp : `outs` `(` $output `:` type($output) `)` `->` type($result) }]; + let extraClassDeclaration = [{ + ShapedType getFilterOperandType() { + return cast(getFilter().getType()); + } + ShapedType getOutputOperandType() { + return cast(getOutput().getType()); + } + int64_t getFilterOperandRank() { + return getFilterOperandType().getRank(); + } + int64_t getOutputOperandRank() { + return getOutputOperandType().getRank(); + } + int64_t getFilterFDim() { + return 0; + } + int64_t getFilterHDim() { + return 1; + } + int64_t getFilterWDim() { + return 2; + } + int64_t getFilterCDim() { + return 3; + } + }]; let hasVerifier = 1; } -def Linalg_WinogradInputTransformOp : - Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> { +def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform", + [AllElementTypesMatch<["input", "output"]>, + DeclareOpInterfaceMethods]> { let summary = "Winograd input transform operator"; let description = [{ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched @@ -229,11 +265,60 @@ def Linalg_WinogradInputTransformOp : `outs` `(` $output `:` type($output) `)` `->` type($result) }]; + let extraClassDeclaration = [{ + ShapedType getInputOperandType() { + return cast(getInput().getType()); + } + ShapedType getOutputOperandType() { + return cast(getOutput().getType()); + } + int64_t getInputOperandRank() { + return getInputOperandType().getRank(); + } + int64_t getOutputOperandRank() { + return getOutputOperandType().getRank(); + } + int64_t getInputNDim() { + return 0; + } + int64_t getInputHDim() { + return 1; + } + int64_t getInputWDim() { + return 2; + } + int64_t getInputCDim() { + return 3; + } + int64_t getOutputAlphaHDim() { + return 0; + } + int64_t getOutputAlphaWDim() { + return 1; + } + int64_t getOutputTileHDim() { + return 2; + } + int64_t getOutputTileWDim() { + return 3; + } + int64_t getOutputNDim() { + return 4; + } + int64_t getOutputCDim() { + return 5; + } + }]; let hasVerifier = 1; } -def Linalg_WinogradOutputTransformOp : - Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> { +def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform", + [AllElementTypesMatch<["value", "output"]>, + DeclareOpInterfaceMethods]> { let summary = "Winograd output transform operator"; let description = [{ Winograd Conv2D algorithm will convert linalg Conv2D operator into batched @@ -268,6 +353,50 @@ def Linalg_WinogradOutputTransformOp : `outs` `(` $output `:` type($output) `)` `->` type($result) }]; + let extraClassDeclaration = [{ + ShapedType getValueOperandType() { + return cast(getValue().getType()); + } + ShapedType getOutputOperandType() { + return cast(getOutput().getType()); + } + int64_t getValueOperandRank() { + return getValueOperandType().getRank(); + } + int64_t getOutputOperandRank() { + return getOutputOperandType().getRank(); + } + int64_t getValueAlphaHDim() { + return 0; + } + int64_t getValueAlphaWDim() { + return 1; + } + int64_t getValueTileHDim() { + return 2; + } + int64_t getValueTileWDim() { + return 3; + } + int64_t getValueNDim() { + return 4; + } + int64_t getValueFDim() { + return 5; + } + int64_t getOutputNDim() { + return 0; + } + int64_t getOutputHDim() { + return 1; + } + int64_t getOutputWDim() { + return 2; + } + int64_t getOutputFDim() { + return 3; + } + }]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index ecc86999006db6..106f0d79d9792d 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2697,4 +2697,41 @@ def WinogradConv2DOp : Op { + let description = [{ + Decompose winograd operations. It will convert filter, input and output + transform operations into a combination of scf, tensor, and linalg + equivalent operations. Before applying this transform operations, users + need to tile winograd transform operations into supported sizes. + + #### Return modes: + + This operation fails if `target` is unsupported. Otherwise, the operation + succeeds and returns a handle of the sequence that replaces the original + operations. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = + "$target attr-dict `:` functional-type($target, results)"; + + let builders = [ + OpBuilder<(ins "Value":$target)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 477ef7bfafb181..861e14d22d9625 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1316,6 +1316,63 @@ FailureOr winogradConv2D(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp op, int64_t m, int64_t r); +/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is +/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W +/// from FHWC first. We generate 2 levels of loops to iterate on F and C. After +/// the rewriting, we get +/// +/// scf.for %f = lo_f to hi_f step 1 +/// scf.for %c = lo_c to hi_c step 1 +/// %extracted = extract filter from filter +/// %ret = linalg.matmul G, %extracted +/// %ret = linalg.matmul %ret, GT +/// %inserted = insert %ret into filter +FailureOr +decomposeWinogradFilterTransformOp(RewriterBase &rewriter, + linalg::WinogradFilterTransformOp op); + +/// Rewrite linalg.winograd_input_transform. The data layout of the input is +/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W +/// from NHWC first. We generate 4 levels of loops to iterate on N, C, tileH, +/// and tileW. After the rewriting, we get +/// +/// scf.for %h = 0 to tileH step 1 +/// scf.for %w = 0 to tileW step 1 +/// scf.for %n = 0 to N step 1 +/// scf.for %c = 0 to C step 1 +/// %extracted = extract %extracted from +/// %input +/// at [%n, (%h x m), (%w x m), %c] +/// %ret = linalg.matmul BT, %extracted +/// %ret = linalg.matmul %ret, B +/// %inserted = insert %ret into +/// %output +/// at [0, 0, %h, %w, %n, %c] +FailureOr +decomposeWinogradInputTransformOp(RewriterBase &rewriter, + linalg::WinogradInputTransformOp op); + +/// Rewrite linalg.winograd_output_transform. The data layout of the output is +/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W +/// from HWNF first. We generate 4 levels of loops to iterate on N, F, tileH, +/// and tileW. After the transformation, we get +/// +/// scf.for %h = 0 to tileH step 1 +/// scf.for %w = 0 to tileW step 1 +/// scf.for %n = 0 to N step 1 +/// scf.for %f = 0 to F step 1 +/// %extracted = extract %extracted from +/// %input +/// at [0, 0, %h, %w, %n, %f] +/// %ret = linalg.matmul AT, %extracted +/// %ret = linalg.matmul %ret, A +/// %inserted = insert %ret into +/// output +/// at [%n, (%h x m), (%w x m), %f] +FailureOr +decomposeWinogradOutputTransformOp(RewriterBase &rewriter, + linalg::WinogradOutputTransformOp op); + //===----------------------------------------------------------------------===// // Rewrite patterns wrapping transformations. // TODO: every single such pattern should be a close to noop wrapper around a diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index a101552e419bc8..775ed8f37344ed 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2855,8 +2855,8 @@ FailureOr> SoftmaxOp::decomposeOperation(OpBuilder &b) { LogicalResult WinogradFilterTransformOp::verify() { auto filterType = cast(getFilter().getType()); ArrayRef filterShape = filterType.getShape(); - int64_t filterH = filterShape[1]; - int64_t filterW = filterShape[2]; + int64_t filterH = filterShape[getFilterHDim()]; + int64_t filterW = filterShape[getFilterWDim()]; int64_t r = getR(); int64_t m = getM(); @@ -2870,8 +2870,8 @@ LogicalResult WinogradFilterTransformOp::verify() { SmallVector expectedOutputShape; expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1); expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1); - expectedOutputShape.push_back(filterShape[3]); - expectedOutputShape.push_back(filterShape[0]); + expectedOutputShape.push_back(filterShape[getFilterCDim()]); + expectedOutputShape.push_back(filterShape[getFilterFDim()]); auto outputType = cast(getOutput().getType()); ArrayRef outputShape = outputType.getShape(); @@ -2881,6 +2881,103 @@ LogicalResult WinogradFilterTransformOp::verify() { return success(); } +SmallVector +WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) { + Location loc = getLoc(); + IntegerAttr zeroAttr = builder.getIndexAttr(0); + IntegerAttr oneAttr = builder.getIndexAttr(1); + Value filter = getFilter(); + int64_t filterRank = getFilterOperandRank(); + SmallVector loopBounds(filterRank); + for (unsigned dim = 0; dim < filterRank; ++dim) { + loopBounds[dim].offset = zeroAttr; + loopBounds[dim].size = getDimValue(builder, loc, filter, dim); + loopBounds[dim].stride = oneAttr; + } + return loopBounds; +} + +SmallVector +WinogradFilterTransformOp::getLoopIteratorTypes() { + int64_t filterRank = getFilterOperandRank(); + SmallVector iteratorTypes(filterRank, + utils::IteratorType::parallel); + return iteratorTypes; +} + +LogicalResult WinogradFilterTransformOp::getResultTilePosition( + OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, + ArrayRef sizes, SmallVector &resultOffsets, + SmallVector &resultSizes) { + IntegerAttr zeroAttr = builder.getI64IntegerAttr(0); + ShapedType filterType = getFilterOperandType(); + ArrayRef filterShape = filterType.getShape(); + int64_t filterH = filterShape[getFilterHDim()]; + int64_t filterW = filterShape[getFilterWDim()]; + int64_t m = getM(); + int64_t r = getR(); + int64_t alpha = m + r - 1; + int64_t alphaH = filterH != 1 ? alpha : 1; + int64_t alphaW = filterW != 1 ? alpha : 1; + IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH); + IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW); + + resultOffsets.append( + {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]}); + resultSizes.append( + {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]}); + + return success(); +} + +/// Implement tiling for winograd_filter_transform +/// The input of winograd_filter_transform is (F, KH, KW, C). +/// The output of winograd_filter_transform is (alphaH, alphaW, C, F) +/// Users can specify the tile sizes of F and C. +/// `offsets` are the values for the offsets of F, KH, KW, C for one tile. +/// `sizes` are the values for the sizes of F, KH, KW, C for one tile. +FailureOr WinogradFilterTransformOp::getTiledImplementation( + OpBuilder &builder, ArrayRef offsets, + ArrayRef sizes) { + IntegerAttr oneAttr = builder.getI64IntegerAttr(1); + IntegerAttr zeroAttr = builder.getI64IntegerAttr(0); + ShapedType filterType = getFilterOperandType(); + ArrayRef filterShape = filterType.getShape(); + int64_t filterH = filterShape[getFilterHDim()]; + int64_t filterW = filterShape[getFilterWDim()]; + IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH); + IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW); + SmallVector tiledOperands; + SmallVector sliceOffsets, sliceSizes; + + sliceOffsets.append( + {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]}); + sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr, + sizes[getFilterCDim()]}); + int64_t filterRank = getFilterOperandRank(); + SmallVector filterStrides(filterRank, oneAttr); + Location loc = getLoc(); + tiledOperands.emplace_back(builder.create( + loc, getFilter(), sliceOffsets, sliceSizes, filterStrides)); + + SmallVector resultOffsets, resultSizes; + if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets, + resultSizes))) + return failure(); + + int64_t outputRank = getOutputOperandRank(); + SmallVector outputStrides(outputRank, oneAttr); + tiledOperands.emplace_back(builder.create( + loc, getOutput(), resultOffsets, resultSizes, outputStrides)); + + SmallVector resultTypes; + resultTypes.push_back(tiledOperands[1].getType()); + Operation *tiledOp = + mlir::clone(builder, getOperation(), resultTypes, tiledOperands); + + return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; +} + //===----------------------------------------------------------------------===// // WinogradInputTransformOp //===----------------------------------------------------------------------===// @@ -2888,8 +2985,8 @@ LogicalResult WinogradFilterTransformOp::verify() { LogicalResult WinogradInputTransformOp::verify() { auto inputType = cast(getInput().getType()); ArrayRef inputShape = inputType.getShape(); - int64_t inputH = inputShape[1]; - int64_t inputW = inputShape[2]; + int64_t inputH = inputShape[getInputHDim()]; + int64_t inputW = inputShape[getInputWDim()]; int m = getM(); int r = getR(); int64_t tileSize = m + r - 1; @@ -2898,21 +2995,23 @@ LogicalResult WinogradInputTransformOp::verify() { SmallVector expectedOutputShape(6, inputH); if (ShapedType::isDynamic(inputH)) { - expectedOutputShape[0] = tileSize; - expectedOutputShape[2] = ShapedType::kDynamic; + expectedOutputShape[getOutputAlphaHDim()] = tileSize; + expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic; } else { - expectedOutputShape[0] = leftTransform ? tileSize : 1; - expectedOutputShape[2] = leftTransform ? (inputH - (r - 1)) / m : 1; + expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1; + expectedOutputShape[getOutputTileHDim()] = + leftTransform ? (inputH - (r - 1)) / m : 1; } if (ShapedType::isDynamic(inputW)) { - expectedOutputShape[1] = tileSize; - expectedOutputShape[3] = ShapedType::kDynamic; + expectedOutputShape[getOutputAlphaWDim()] = tileSize; + expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic; } else { - expectedOutputShape[1] = rightTransform ? tileSize : 1; - expectedOutputShape[3] = rightTransform ? (inputW - (r - 1)) / m : 1; + expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1; + expectedOutputShape[getOutputTileWDim()] = + rightTransform ? (inputW - (r - 1)) / m : 1; } - expectedOutputShape[4] = inputShape[0]; - expectedOutputShape[5] = inputShape[3]; + expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()]; + expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()]; auto outputType = cast(getOutput().getType()); ArrayRef outputShape = outputType.getShape(); @@ -2922,6 +3021,130 @@ LogicalResult WinogradInputTransformOp::verify() { return success(); } +SmallVector +WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) { + Location loc = getLoc(); + IntegerAttr zeroAttr = builder.getIndexAttr(0); + IntegerAttr oneAttr = builder.getIndexAttr(1); + Value output = getOutput(); + int64_t outputRank = getOutputOperandRank(); + SmallVector loopBounds(outputRank); + for (unsigned dim = 0; dim < outputRank; ++dim) { + loopBounds[dim].offset = zeroAttr; + // alphaH, alphaW, tileH, tileW, N, C + loopBounds[dim].size = getDimValue(builder, loc, output, dim); + loopBounds[dim].stride = oneAttr; + } + return loopBounds; +} + +SmallVector +WinogradInputTransformOp::getLoopIteratorTypes() { + int64_t outputRank = getOutputOperandRank(); + SmallVector iteratorTypes(outputRank, + utils::IteratorType::parallel); + return iteratorTypes; +} + +LogicalResult WinogradInputTransformOp::getResultTilePosition( + OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, + ArrayRef sizes, SmallVector &resultOffsets, + SmallVector &resultSizes) { + IntegerAttr zeroAttr = builder.getI64IntegerAttr(0); + ShapedType inputType = getInputOperandType(); + ArrayRef inputShape = inputType.getShape(); + int64_t inputH = inputShape[getInputHDim()]; + int64_t inputW = inputShape[getInputWDim()]; + int64_t m = getM(); + int64_t r = getR(); + int64_t alpha = m + r - 1; + int64_t alphaH = inputH != 1 ? alpha : 1; + int64_t alphaW = inputW != 1 ? alpha : 1; + IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH); + IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW); + + resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()], + offsets[getOutputTileWDim()], offsets[getOutputNDim()], + offsets[getOutputCDim()]}); + resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()], + sizes[getOutputTileWDim()], sizes[getOutputNDim()], + sizes[getOutputCDim()]}); + + return success(); +} + +/// Implement tiling for winograd_input_transform +/// The input of winograd_input_transform is (N, H, W, C). +/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N, +/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are +/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are +/// the values for the sizes of tileH, tileW, N, C for one tile. +FailureOr +WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder, + ArrayRef offsets, + ArrayRef sizes) { + IntegerAttr oneAttr = builder.getI64IntegerAttr(1); + IntegerAttr zeroAttr = builder.getI64IntegerAttr(0); + ShapedType inputType = getInputOperandType(); + ArrayRef inputShape = inputType.getShape(); + int64_t inputH = inputShape[getInputHDim()]; + int64_t inputW = inputShape[getInputWDim()]; + int64_t m = getM(); + int64_t r = getR(); + + Location loc = getLoc(); + MLIRContext *context = builder.getContext(); + auto offsetAffineMap = + AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); + Value mappedOffsetH = affine::makeComposedAffineApply( + builder, loc, offsetAffineMap, offsets[getOutputTileHDim()]); + Value mappedOffsetW = affine::makeComposedAffineApply( + builder, loc, offsetAffineMap, offsets[getOutputTileWDim()]); + auto sizeAffineMap = AffineMap::get( + 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context); + Value mappedSizeH = affine::makeComposedAffineApply( + builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]); + Value mappedSizeW = affine::makeComposedAffineApply( + builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]); + + SmallVector tiledOperands; + SmallVector sliceOffsets, sliceSizes; + + OpFoldResult offsetH = + inputH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr); + OpFoldResult offsetW = + inputW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr); + sliceOffsets.append( + {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]}); + OpFoldResult sizeH = + inputH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr); + OpFoldResult sizeW = + inputW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr); + sliceSizes.append( + {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]}); + int64_t inputRank = getInputOperandRank(); + SmallVector inputStrides(inputRank, oneAttr); + tiledOperands.emplace_back(builder.create( + loc, getInput(), sliceOffsets, sliceSizes, inputStrides)); + + SmallVector resultOffsets, resultSizes; + if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets, + resultSizes))) + return failure(); + + int64_t outputRank = getOutputOperandRank(); + SmallVector outputStrides(outputRank, oneAttr); + tiledOperands.emplace_back(builder.create( + loc, getOutput(), resultOffsets, resultSizes, outputStrides)); + + SmallVector resultTypes; + resultTypes.push_back(tiledOperands[1].getType()); + Operation *tiledOp = + mlir::clone(builder, getOperation(), resultTypes, tiledOperands); + + return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; +} + //===----------------------------------------------------------------------===// // WinogradOutputTransformOp //===----------------------------------------------------------------------===// @@ -2929,32 +3152,34 @@ LogicalResult WinogradInputTransformOp::verify() { LogicalResult WinogradOutputTransformOp::verify() { auto valueType = cast(getValue().getType()); ArrayRef valueShape = valueType.getShape(); - int64_t valueH = valueShape[0]; - int64_t valueW = valueShape[1]; - int64_t valueTileH = valueShape[2]; - int64_t valueTileW = valueShape[3]; + int64_t valueH = valueShape[getValueAlphaHDim()]; + int64_t valueW = valueShape[getValueAlphaWDim()]; + int64_t valueTileH = valueShape[getValueTileHDim()]; + int64_t valueTileW = valueShape[getValueTileWDim()]; int m = getM(); int r = getR(); bool leftTransform = valueH != 1; bool rightTransform = valueW != 1; - SmallVector expectedOutputShape(4, valueH); + int64_t outputRank = getOutputOperandRank(); + SmallVector expectedOutputShape(outputRank, valueH); if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) { - expectedOutputShape[1] = ShapedType::kDynamic; + expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic; } else { if (valueH != (leftTransform ? m + r - 1 : 1)) return emitOpError("expect input height equals to input tile size"); - expectedOutputShape[1] = (leftTransform ? m : 1) * valueTileH; + expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH; } if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) { - expectedOutputShape[2] = ShapedType::kDynamic; + expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic; } else { if (valueW != (rightTransform ? m + r - 1 : 1)) return emitOpError("expect input width equals to input tile size"); - expectedOutputShape[2] = (rightTransform ? m : 1) * valueTileW; + expectedOutputShape[getOutputWDim()] = + (rightTransform ? m : 1) * valueTileW; } - expectedOutputShape[0] = valueShape[4]; - expectedOutputShape[3] = valueShape[5]; + expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()]; + expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()]; auto outputType = cast(getOutput().getType()); ArrayRef outputShape = outputType.getShape(); @@ -2964,6 +3189,124 @@ LogicalResult WinogradOutputTransformOp::verify() { return success(); } +SmallVector +WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) { + Location loc = getLoc(); + IntegerAttr zeroAttr = builder.getIndexAttr(0); + IntegerAttr oneAttr = builder.getIndexAttr(1); + Value value = getValue(); + int64_t valueRank = getValueOperandRank(); + SmallVector loopBounds(valueRank); + for (unsigned dim = 0; dim < valueRank; ++dim) { + loopBounds[dim].offset = zeroAttr; + // alphaH, alphaW, tileH, tileW, N, F + loopBounds[dim].size = getDimValue(builder, loc, value, dim); + loopBounds[dim].stride = oneAttr; + } + return loopBounds; +} + +SmallVector +WinogradOutputTransformOp::getLoopIteratorTypes() { + int64_t valueRank = getValueOperandRank(); + SmallVector iteratorTypes(valueRank, + utils::IteratorType::parallel); + return iteratorTypes; +} + +LogicalResult WinogradOutputTransformOp::getResultTilePosition( + OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, + ArrayRef sizes, SmallVector &resultOffsets, + SmallVector &resultSizes) { + int64_t m = getM(); + + Location loc = getLoc(); + MLIRContext *context = builder.getContext(); + auto affineMap = + AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); + + Value mappedOffsetH = affine::makeComposedAffineApply( + builder, loc, affineMap, offsets[getValueTileHDim()]); + Value mappedOffsetW = affine::makeComposedAffineApply( + builder, loc, affineMap, offsets[getValueTileWDim()]); + Value mappedSizeH = affine::makeComposedAffineApply( + builder, loc, affineMap, sizes[getValueTileHDim()]); + Value mappedSizeW = affine::makeComposedAffineApply( + builder, loc, affineMap, sizes[getValueTileWDim()]); + + ShapedType valueType = getValueOperandType(); + ArrayRef valueShape = valueType.getShape(); + int64_t valueH = valueShape[0]; + int64_t valueW = valueShape[1]; + IntegerAttr oneAttr = builder.getI64IntegerAttr(1); + IntegerAttr zeroAttr = builder.getI64IntegerAttr(0); + OpFoldResult offsetH = + valueH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr); + OpFoldResult offsetW = + valueW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr); + OpFoldResult sizeH = + valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr); + OpFoldResult sizeW = + valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr); + + resultOffsets.append( + {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]}); + resultSizes.append( + {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]}); + return success(); +} + +/// Implement tiling for winograd_output_transform +/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N, +/// F). The output of winograd_output_transform is (N, H, W, F) Users can +/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values +/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values +/// for the sizes of tileH, tileW, N, F for one tile. +FailureOr WinogradOutputTransformOp::getTiledImplementation( + OpBuilder &builder, ArrayRef offsets, + ArrayRef sizes) { + IntegerAttr oneAttr = builder.getI64IntegerAttr(1); + IntegerAttr zeroAttr = builder.getI64IntegerAttr(0); + Location loc = getLoc(); + SmallVector tiledOperands; + SmallVector sliceOffsets, sliceSizes; + + ShapedType valueType = getValueOperandType(); + ArrayRef valueShape = valueType.getShape(); + int64_t alphaH = valueShape[getValueAlphaHDim()]; + int64_t alphaW = valueShape[getValueAlphaWDim()]; + IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH); + IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW); + + sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()], + offsets[getValueTileWDim()], offsets[getValueNDim()], + offsets[getValueFDim()]}); + sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()], + sizes[getValueTileWDim()], sizes[getValueNDim()], + sizes[getValueFDim()]}); + int64_t valueRank = getValueOperandRank(); + SmallVector sliceStrides(valueRank, oneAttr); + tiledOperands.emplace_back(builder.create( + loc, getValue(), sliceOffsets, sliceSizes, sliceStrides)); + + SmallVector resultOffsets, resultSizes; + if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets, + resultSizes))) + return failure(); + + int64_t outputRank = getOutputOperandRank(); + SmallVector strides(outputRank, oneAttr); + tiledOperands.emplace_back(builder.create( + loc, getOutput(), resultOffsets, resultSizes, strides)); + + SmallVector resultTypes; + resultTypes.push_back(tiledOperands[1].getType()); + Operation *tiledOp = + mlir::clone(builder, getOperation(), resultTypes, tiledOperands); + + return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; +} + //===----------------------------------------------------------------------===// // LinalgDialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 48b3abbeee7010..fbf4e29024f7c2 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3851,6 +3851,47 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne( return DiagnosedSilenceableFailure::success(); } +DiagnosedSilenceableFailure transform::DecomposeWinogradOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + rewriter.setInsertionPoint(target); + FailureOr maybeTransformed = failure(); + bool supported = + TypeSwitch(target) + .Case([&](linalg::WinogradFilterTransformOp op) { + maybeTransformed = decomposeWinogradFilterTransformOp(rewriter, op); + return true; + }) + .Case([&](linalg::WinogradInputTransformOp op) { + maybeTransformed = decomposeWinogradInputTransformOp(rewriter, op); + return true; + }) + .Case([&](linalg::WinogradOutputTransformOp op) { + maybeTransformed = decomposeWinogradOutputTransformOp(rewriter, op); + return true; + }) + .Default([&](Operation *op) { return false; }); + + if (!supported) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "this operation is not supported to decompose into other operations"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + if (supported && failed(maybeTransformed)) { + DiagnosedSilenceableFailure diag = + emitSilenceableError() << "decompose Winograd operations failed"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } + + results.push_back(*maybeTransformed); + return DiagnosedSilenceableFailure::success(); +} + #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp index c6c770e2781ff0..b65b18699a15aa 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -490,8 +490,6 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, Type elementType = inputType.getElementType(); auto inputShape = inputType.getShape(); // N, H, W, C int64_t inputN = inputShape[0]; - int64_t inputH = inputShape[1]; - int64_t inputW = inputShape[2]; int64_t inputC = inputShape[3]; auto valueType = cast(retValue.getType()); auto valueShape = valueType.getShape(); // alphaH, alphaW, HTile, WTile, N, C @@ -500,11 +498,6 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, int64_t alphaH = leftTransform ? m + r - 1 : 1; int64_t alphaW = rightTransform ? m + r - 1 : 1; - if ((inputH != (tileH * m) + (r - 1)) && inputH != 1) - return Value(); - if ((inputW != (tileW * m) + (r - 1)) && inputW != 1) - return Value(); - auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, ValueRange args) -> scf::ValueVector { Value tileHIter = ivs[0]; @@ -1169,6 +1162,24 @@ FailureOr winogradConv2D(RewriterBase &rewriter, return winogradConv2DHelper(rewriter, op, m, r); } +FailureOr +decomposeWinogradFilterTransformOp(RewriterBase &rewriter, + linalg::WinogradFilterTransformOp op) { + return decomposeWinogradFilterTransformHelper(rewriter, op); +} + +FailureOr +decomposeWinogradInputTransformOp(RewriterBase &rewriter, + linalg::WinogradInputTransformOp op) { + return decomposeWinogradInputTransformHelper(rewriter, op); +} + +FailureOr +decomposeWinogradOutputTransformOp(RewriterBase &rewriter, + linalg::WinogradOutputTransformOp op) { + return decomposeWinogradOutputTransformHelper(rewriter, op); +} + void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, int64_t r) { MLIRContext *context = patterns.getContext(); diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir new file mode 100644 index 00000000000000..6bb3fb1423edc6 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir @@ -0,0 +1,292 @@ +// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s + +func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> { + %0 = tensor.empty() : tensor<6x6x5x2xf32> + %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> + %2 = tensor.empty() : tensor<6x6x2x2x2x5xf32> + %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%2 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> + %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32> + %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x2x2x2x5xf32> into tensor<36x8x5xf32> + %4 = tensor.empty() : tensor<36x8x2xf32> + %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x8x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x8x2xf32>) -> tensor<36x8x2xf32> + %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] : tensor<36x8x2xf32> into tensor<6x6x2x2x2x2xf32> + %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x2x2x2x2xf32>) outs(%arg2 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %6 : tensor<2x8x8x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op) + %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op + %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op) + %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op + %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @conv2d +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> { +// CHECK: %[[CST:.*]] = arith.constant 1.024000e+03 : f32 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C5:.*]] = arith.constant 5 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[S0:.*]] = tensor.empty() +// CHECK: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]]) +// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) +// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 3, 1] [1, 1, 1, 1] +// CHECK: %[[S11:.*]] = linalg.matmul +// CHECK: %[[S13:.*]] = linalg.matmul +// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 6, 1, 1] [1, 1, 1, 1] +// CHECK: scf.yield %[[INSERTED_SLICE]] +// CHECK: scf.yield %[[S9]] +// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32> +// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32> +// CHECK: %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]]) +// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) +// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]]) +// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]]) +// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1] +// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] +// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) +// CHECK: %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) +// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1] +// CHECK: %[[S15:.*]] = linalg.matmul +// CHECK: %[[S17:.*]] = linalg.matmul +// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S17]] into %[[ARG10]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] +// CHECK: scf.yield %[[INSERTED_SLICE_9]] +// CHECK: scf.yield %[[S13]] +// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] +// CHECK: scf.yield %[[INSERTED_SLICE]] +// CHECK: scf.yield %[[S9]] +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] +// CHECK: %[[COLLAPSED_6:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]] +// CHECK: %[[S6:.*]] = linalg.batch_matmul +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 2, 2, 2, 2] +// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x8x8x2xf32> +// CHECK: %[[S8:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S7]]) +// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) +// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] +// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]]) +// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]]) +// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG2]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1] +// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]]) +// CHECK: %[[S15:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]]) +// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0, %[[ARG7]], %[[ARG9]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] +// CHECK: %[[S17:.*]] = linalg.matmul +// CHECK: %[[S19:.*]] = linalg.matmul +// CHECK: %[[S20:.*]] = tensor.empty() +// CHECK: %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK: linalg.yield %[[IN]] : f32 +// CHECK: } -> tensor<4x4xf32> +// CHECK: %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S22]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] +// CHECK: scf.yield %[[INSERTED_SLICE_9]] +// CHECK: scf.yield %[[S15]] +// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]]) +// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG5]]) +// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] +// CHECK: scf.yield %[[INSERTED_SLICE]] +// CHECK: scf.yield %[[S9]] + +// ----- + +func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<6x6x5x2xf32> + %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x3x5xf32>) outs(%0 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> + %padded = tensor.pad %arg0 low[0, 0, 0, 0] high[0, 3, 3, 0] { + ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index): + tensor.yield %cst : f32 + } : tensor<2x11x11x5xf32> to tensor<2x14x14x5xf32> + %2 = tensor.empty() : tensor<6x6x3x3x2x5xf32> + %3 = linalg.winograd_input_transform m(4) r(3) ins(%padded : tensor<2x14x14x5xf32>) outs(%2 : tensor<6x6x3x3x2x5xf32>) -> tensor<6x6x3x3x2x5xf32> + %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x6x5x2xf32> into tensor<36x5x2xf32> + %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x6x3x3x2x5xf32> into tensor<36x18x5xf32> + %4 = tensor.empty() : tensor<36x18x2xf32> + %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<36x18x5xf32>, tensor<36x5x2xf32>) outs(%4 : tensor<36x18x2xf32>) -> tensor<36x18x2xf32> + %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] : tensor<36x18x2xf32> into tensor<6x6x3x3x2x2xf32> + %padded_1 = tensor.pad %arg2 low[0, 0, 0, 0] high[0, 3, 3, 0] { + ^bb0(%arg4: index, %arg5: index, %arg6: index, %arg7: index): + tensor.yield %cst : f32 + } : tensor<2x9x9x2xf32> to tensor<2x12x12x2xf32> + %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x6x3x3x2x2xf32>) outs(%padded_1 : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32> + %extracted_slice = tensor.extract_slice %6[0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32> + return %extracted_slice : tensor<2x9x9x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op) + %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op + %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op) + %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op + %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @conv2d_unaligned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> { +// CHECK: %[[CST:.*]] = arith.constant 1.024000e+03 : f32 +// CHECK: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C5:.*]] = arith.constant 5 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[S0:.*]] = tensor.empty() +// CHECK: %[[S1:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S0]]) +// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) +// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG4]], 0, 0, %[[ARG6]]] [1, 3, 3, 1] [1, 1, 1, 1] +// CHECK: %[[S11:.*]] = linalg.matmul +// CHECK: %[[S13:.*]] = linalg.matmul +// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S13]] into %[[ARG7]][0, 0, %[[ARG6]], %[[ARG4]]] [6, 6, 1, 1] [1, 1, 1, 1] +// CHECK: scf.yield %[[INSERTED_SLICE]] : tensor<6x6x5x2xf32> +// CHECK: scf.yield %[[S9]] : tensor<6x6x5x2xf32> +// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0] +// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32> +// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32> +// CHECK: %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]]) +// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) +// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]]) +// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]]) +// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1] +// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] +// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) +// CHECK: %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) +// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1] +// CHECK: %[[S15:.*]] = linalg.matmul +// CHECK: %[[S17:.*]] = linalg.matmul +// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S17]] into %[[ARG11]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] +// CHECK: scf.yield %[[INSERTED_SLICE_12]] : tensor<6x6x1x1x2x5xf32> +// CHECK: scf.yield %[[S13]] : tensor<6x6x1x1x2x5xf32> +// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] +// CHECK: scf.yield %[[INSERTED_SLICE]] +// CHECK: scf.yield %[[S9]] +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] +// CHECK: %[[COLLAPSED_7:.*]] = tensor.collapse_shape %[[S4]] {{\[}}[0, 1], [2, 3, 4], [5]] +// CHECK: %[[S6:.*]] = linalg.batch_matmul +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 6, 3, 3, 2, 2] +// CHECK: %[[PADDED_8:.*]] = tensor.pad %[[ARG2]] low[0, 0, 0, 0] high[0, 3, 3, 0] +// CHECK: %[[S7:.*]] = tensor.empty() : tensor<2x12x12x2xf32> +// CHECK: %[[S8:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S7]]) +// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]]) +// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] +// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]]) +// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]]) +// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[PADDED_8]][0, %[[S10]], %[[S11]], 0] [2, 4, 4, 2] [1, 1, 1, 1] +// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]]) +// CHECK: %[[S15:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]]) +// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][0, 0, 0, 0, %[[ARG8]], %[[ARG10]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] +// CHECK: %[[S17:.*]] = linalg.matmul +// CHECK: %[[S19:.*]] = linalg.matmul +// CHECK: %[[S20:.*]] = tensor.empty() : tensor<4x4xf32> +// CHECK: %[[S21:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S20]] : tensor<4x4xf32>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK: linalg.yield %[[IN]] : f32 +// CHECK: } -> tensor<4x4xf32> +// CHECK: %[[S22:.*]] = linalg.mul ins(%[[S21]], %[[S19]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S20]] : tensor<4x4xf32>) -> tensor<4x4xf32> +// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S22]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1] +// CHECK: scf.yield %[[INSERTED_SLICE_12]] +// CHECK: scf.yield %[[S15]] : tensor<2x4x4x2xf32> +// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]]) +// CHECK: %[[S14:.*]] = affine.apply #[[$MAP0]](%[[ARG6]]) +// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG7]][0, %[[S13]], %[[S14]], 0] [2, 4, 4, 2] [1, 1, 1, 1] +// CHECK: scf.yield %[[INSERTED_SLICE]] +// CHECK: scf.yield %[[S9]] +// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] +// CHECK: return %[[EXTRACTED_SLICE]] + +// ----- + +func.func @conv2d_mx1_rx1(%arg0: tensor<2x6x1x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> { + %0 = tensor.empty() : tensor<6x1x5x2xf32> + %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> + %2 = tensor.empty() : tensor<6x1x1x1x2x5xf32> + %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x1x5xf32>) outs(%2 : tensor<6x1x1x1x2x5xf32>) -> tensor<6x1x1x1x2x5xf32> + %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32> + %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x1x2x5xf32> into tensor<6x2x5xf32> + %4 = tensor.empty() : tensor<6x2x2xf32> + %5 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x2x5xf32>, tensor<6x5x2xf32>) outs(%4 : tensor<6x2x2xf32>) -> tensor<6x2x2xf32> + %expanded = tensor.expand_shape %5 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] : tensor<6x2x2xf32> into tensor<6x1x1x1x2x2xf32> + %6 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x1x2x2xf32>) outs(%arg2 : tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> + return %6 : tensor<2x4x1x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op) + %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op + %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op) + %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op + %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> ()> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @conv2d_mx1_rx1 +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x1x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<2x4x1x2xf32>) -> tensor<2x4x1x2xf32> { +// CHECK: %[[CST:.*]] = arith.constant 3.200000e+01 : f32 +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C5:.*]] = arith.constant 5 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32> +// CHECK: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]]) +// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) +// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1] +// CHECK: %[[S9:.*]] = linalg.matmul +// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1] +// CHECK: scf.yield %[[INSERTED_SLICE]] +// CHECK: scf.yield %[[S7]] +// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x1x1x1x2x5xf32> +// CHECK: %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]]) +// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) +// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 6, 1, 1] [1, 1, 1, 1] +// CHECK: %[[S9:.*]] = linalg.matmul +// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG6]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] +// CHECK: scf.yield %[[INSERTED_SLICE]] +// CHECK: scf.yield %[[S7]] +// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]] +// CHECK: %[[COLLAPSED_3:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]] +// CHECK: %[[S5:.*]] = linalg.batch_matmul +// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S5]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 1, 2, 2] +// CHECK: %[[S6:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]]) +// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]]) +// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, 0, %[[ARG3]], %[[ARG5]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] +// CHECK: %[[S9:.*]] = linalg.matmul +// CHECK: %[[S10:.*]] = tensor.empty() : tensor<4x1xf32> +// CHECK: %[[S11:.*]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%[[CST]] : f32) outs(%[[S10]] : tensor<4x1xf32>) { +// CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32): +// CHECK: linalg.yield %[[IN]] : f32 +// CHECK: } -> tensor<4x1xf32> +// CHECK: %[[S12:.*]] = linalg.mul ins(%[[S11]], %[[S9]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S10]] : tensor<4x1xf32>) -> tensor<4x1xf32> +// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S12]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1] +// CHECK: scf.yield %[[INSERTED_SLICE]] +// CHECK: scf.yield %[[S7]] +// CHECK: return %[[S6]] diff --git a/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir new file mode 100644 index 00000000000000..21522a2083b463 --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-tile-winograd.mlir @@ -0,0 +1,380 @@ +// RUN: mlir-opt %s -transform-interpreter --split-input-file | FileCheck %s + +func.func @tile_winograd_filter(%arg0: tensor<2x3x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> { + %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> + return %0 : tensor<6x6x5x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK-LABEL: func.func @tile_winograd_filter( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x3x5xf32>, %[[ARG1:.*]]: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index +// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] +// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]] +// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, 1] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x1xf32> +// CHECK: %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, 1, 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x1x1xf32> +// CHECK: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x6x1x1xf32>) -> tensor<6x6x1x1xf32> + +// ----- + +func.func @tile_winograd_filter(%arg0: tensor<2x3x3x5xf32>, %arg1: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> { + %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x3x5xf32>) outs(%arg1 : tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> + return %0 : tensor<6x6x5x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 0, 0, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (-d0 + 5, 2)> +// CHECK-LABEL: func.func @tile_winograd_filter( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x3x5xf32>, %[[ARG1:.*]]: tensor<6x6x5x2xf32>) -> tensor<6x6x5x2xf32> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2_1:.*]] = arith.constant 2 : index +// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] +// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C2_1]] +// CHECK: %[[C5_2:.*]] = arith.constant 5 : index +// CHECK: %[[S3:.*]] = affine.min #[[$MAP0]](%[[ARG4]]) +// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 3, %[[S3]]] [1, 1, 1, 1] : tensor<2x3x3x5xf32> to tensor<1x3x3x?xf32> +// CHECK: %[[EXTRACTED_SLICE_3:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 6, %[[S3]], 1] [1, 1, 1, 1] : tensor<6x6x5x2xf32> to tensor<6x6x?x1xf32> +// CHECK: %[[S4:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x3x?xf32>) outs(%[[EXTRACTED_SLICE_3]] : tensor<6x6x?x1xf32>) -> tensor<6x6x?x1xf32> + +// ----- + +func.func @tile_winograd_filter(%arg0: tensor<2x3x1x5xf32>, %arg1: tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> { + %0 = linalg.winograd_filter_transform m(4) r(3) ins(%arg0 : tensor<2x3x1x5xf32>) outs(%arg1 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> + return %0 : tensor<6x1x5x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [1, 0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK-LABEL: func.func @tile_winograd_filter( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x1x5xf32>, %[[ARG1:.*]]: tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index +// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] +// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C5]] step %[[C1_1]] +// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], 0, 0, %[[ARG4]]] [1, 3, 1, 1] [1, 1, 1, 1] : tensor<2x3x1x5xf32> to tensor<1x3x1x1xf32> +// CHECK: %[[EXTRACTED_SLICE_2:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG4]], %[[ARG2]]] [6, 1, 1, 1] [1, 1, 1, 1] : tensor<6x1x5x2xf32> to tensor<6x1x1x1xf32> +// CHECK: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x3x1x1xf32>) outs(%[[EXTRACTED_SLICE_2]] : tensor<6x1x1x1xf32>) -> tensor<6x1x1x1xf32> + +// ----- + +func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> { + %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> + return %0 : tensor<6x6x2x2x2x5xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop3:2 = transform.structured.tile_using_for %0 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)> +// CHECK: #[[$MAP1:.+]] = affine_map<() -> (6)> +// CHECK-LABEL: func.func @tile_winograd_input( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C2_1:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index +// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] +// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]] +// CHECK: %[[S3:.*]] = affine.apply #[[$MAP0]](%[[ARG2]]) +// CHECK: %[[S4:.*]] = affine.apply #[[$MAP0]](%[[ARG4]]) +// CHECK: %[[S5:.*]] = affine.apply #[[$MAP1]]() +// CHECK: %[[S6:.*]] = affine.apply #[[$MAP1]]() +// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 5] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x5xf32> +// CHECK: %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x2x5xf32> +// CHECK: %[[S7:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x5xf32>) outs(%[[EXTRACTED_SLICE_5]] : tensor<6x6x1x1x2x5xf32>) -> tensor<6x6x1x1x2x5xf32> + +// ----- + +func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> { + %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> + return %0 : tensor<6x6x2x2x2x5xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop3:4 = transform.structured.tile_using_for %0 tile_sizes [0, 0, 1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)> +// CHECK: #[[$MAP1:.+]] = affine_map<() -> (6)> +// CHECK-LABEL: func.func @tile_winograd_input( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_6:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C2_1:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C2_4:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C1_5:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C1_7:.*]] = arith.constant 1 : index +// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] +// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]] +// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C2_4]] step %[[C1_5]] +// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C1_7]] +// CHECK: %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]]) +// CHECK: %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]]) +// CHECK: %[[S7:.*]] = affine.apply #[[$MAP1]]() +// CHECK: %[[S8:.*]] = affine.apply #[[$MAP1]]() +// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S5]], %[[S6]], %[[ARG8]]] [1, %[[S7]], %[[S8]], 1] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<1x?x?x1xf32> +// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x1x1x1x1xf32> +// CHECK: %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x?x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32> + +// ----- + +func.func @tile_winograd_input(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> { + %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x10x10x5xf32>) outs(%arg1 : tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> + return %0 : tensor<6x6x2x2x2x5xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop3:4 = transform.structured.tile_using_for %0 tile_sizes [0, 0, 2, 2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (-d0 + 5, 2)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 4)> +// CHECK: #[[$MAP2:.+]] = affine_map<() -> (10)> +// CHECK-LABEL: func.func @tile_winograd_input( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<6x6x2x2x2x5xf32>) -> tensor<6x6x2x2x2x5xf32> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_4:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_7:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C2_5:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C2_0:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C2_3:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C2_6:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C2_8:.*]] = arith.constant 2 : index +// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]] +// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]] +// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C2_5]] step %[[C2_6]] +// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_7]] to %[[C5]] step %[[C2_8]] +// CHECK: %[[S5:.*]] = affine.min #[[$MAP0]](%[[ARG8]]) +// CHECK: %[[S6:.*]] = affine.apply #[[$MAP1]](%[[ARG2]]) +// CHECK: %[[S7:.*]] = affine.apply #[[$MAP1]](%[[ARG4]]) +// CHECK: %[[S8:.*]] = affine.apply #[[$MAP2]]() +// CHECK: %[[S9:.*]] = affine.apply #[[$MAP2]]() +// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], %[[S6]], %[[S7]], %[[ARG8]]] [2, %[[S8]], %[[S9]], %[[S5]]] [1, 1, 1, 1] : tensor<2x10x10x5xf32> to tensor<2x?x?x?xf32> +// CHECK: %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, 2, %[[S5]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x5xf32> to tensor<6x6x2x2x2x?xf32> +// CHECK: %[[S10:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<2x?x?x?xf32>) outs(%[[EXTRACTED_SLICE_12]] : tensor<6x6x2x2x2x?xf32>) -> tensor<6x6x2x2x2x?xf32> + +// ----- + +func.func @tile_winograd_input(%arg0: tensor<2x1x10x5xf32>, %arg1: tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32> { + %0 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x1x10x5xf32>) outs(%arg1 : tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32> + return %0 : tensor<1x6x1x2x2x5xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop3:4 = transform.structured.tile_using_for %0 tile_sizes [0, 0, 1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)> +// CHECK: #[[$MAP1:.+]] = affine_map<() -> (6)> +// CHECK-LABEL: func.func @tile_winograd_input( +// CHECK-SAME: %[[ARG0:.*]]: tensor<2x1x10x5xf32>, %[[ARG1:.*]]: tensor<1x6x1x2x2x5xf32>) -> tensor<1x6x1x2x2x5xf32> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_6:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C2_4:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C1_5:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C1_7:.*]] = arith.constant 1 : index +// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C1]] step %[[C1_0]] +// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2]] step %[[C1_2]] +// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C2_4]] step %[[C1_5]] +// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C1_7]] +// CHECK: %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]]) +// CHECK: %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]]) +// CHECK: %[[S7:.*]] = affine.apply #[[$MAP1]]() +// CHECK: %[[S8:.*]] = affine.apply #[[$MAP1]]() +// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[ARG6]], 0, %[[S6]], %[[ARG8]]] [1, 1, %[[S8]], 1] [1, 1, 1, 1] : tensor<2x1x10x5xf32> to tensor<1x1x?x1xf32> +// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [1, 6, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<1x6x1x2x2x5xf32> to tensor<1x6x1x1x1x1xf32> +// CHECK: %[[S9:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<1x1x?x1xf32>) outs(%[[EXTRACTED_SLICE_10]] : tensor<1x6x1x1x1x1xf32>) -> tensor<1x6x1x1x1x1xf32> + +// ----- + +func.func @tile_winograd_output(%arg0 : tensor<6x6x2x2x2x2xf32>, %arg1: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> { + %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x2x2x2x2xf32>) outs(%arg1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %0 : tensor<2x8x8x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop1:2 = transform.structured.tile_using_for %0 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)> +// CHECK: #[[$MAP1:.+]] = affine_map<() -> (4)> +// CHECK-LABEL: func.func @tile_winograd_output( +// CHECK-SAME: %[[ARG0:.*]]: tensor<6x6x2x2x2x2xf32>, %[[ARG1:.*]]: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C2_1:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index +// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] +// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C2_1]] step %[[C1_2]] +// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], 0, 0] [6, 6, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x2x2xf32> to tensor<6x6x1x1x2x2xf32> +// CHECK: %[[S3:.*]] = affine.apply #[[$MAP0]](%[[ARG2]]) +// CHECK: %[[S4:.*]] = affine.apply #[[$MAP0]](%[[ARG4]]) +// CHECK: %[[S5:.*]] = affine.apply #[[$MAP1]]() +// CHECK: %[[S6:.*]] = affine.apply #[[$MAP1]]() +// CHECK: %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG1]][0, %[[S3]], %[[S4]], 0] [2, %[[S5]], %[[S6]], 2] [1, 1, 1, 1] : tensor<2x8x8x2xf32> to tensor<2x?x?x2xf32> + +// ----- + +func.func @tile_winograd_output(%arg0 : tensor<6x6x2x2x3x5xf32>, %arg1: tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32> { + %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x6x2x2x3x5xf32>) outs(%arg1 : tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32> + return %0 : tensor<3x8x8x5xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop1:4 = transform.structured.tile_using_for %0 tile_sizes [0, 0, 2, 2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (-d0 + 3, 2)> +// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (-d0 + 5, 2)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> (d0 * 4)> +// CHECK: #[[$MAP3:.+]] = affine_map<() -> (8)> +// CHECK-LABEL: func.func @tile_winograd_output( +// CHECK-SAME: %[[ARG0:.*]]: tensor<6x6x2x2x3x5xf32>, %[[ARG1:.*]]: tensor<3x8x8x5xf32>) -> tensor<3x8x8x5xf32> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_4:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_6:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C2_0:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C2_3:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C2_5:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C2_7:.*]] = arith.constant 2 : index +// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C2_0]] +// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_1]] to %[[C2_2]] step %[[C2_3]] +// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_4]] to %[[C3]] step %[[C2_5]] +// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_6]] to %[[C5]] step %[[C2_7]] +// CHECK: %[[C3_8:.*]] = arith.constant 3 : index +// CHECK: %[[S5:.*]] = affine.min #[[$MAP0]](%[[ARG6]]) +// CHECK: %[[C5_9:.*]] = arith.constant 5 : index +// CHECK: %[[S6:.*]] = affine.min #[[$MAP1]](%[[ARG8]]) +// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 6, 2, 2, %[[S5]], %[[S6]]] [1, 1, 1, 1, 1, 1] : tensor<6x6x2x2x3x5xf32> to tensor<6x6x2x2x?x?xf32> +// CHECK: %[[S7:.*]] = affine.apply #[[$MAP2]](%[[ARG2]]) +// CHECK: %[[S8:.*]] = affine.apply #[[$MAP2]](%[[ARG4]]) +// CHECK: %[[S9:.*]] = affine.apply #[[$MAP3]]() +// CHECK: %[[S10:.*]] = affine.apply #[[$MAP3]]() +// CHECK: %[[EXTRACTED_SLICE_12:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG6]], %[[S7]], %[[S8]], %[[ARG8]]] [%[[S5]], %[[S9]], %[[S10]], %[[S6]]] [1, 1, 1, 1] : tensor<3x8x8x5xf32> to tensor + +// ----- + +func.func @tile_winograd_output(%arg0 : tensor<6x1x2x1x3x5xf32>, %arg1: tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32> { + %0 = linalg.winograd_output_transform m(4) r(3) ins(%arg0 : tensor<6x1x2x1x3x5xf32>) outs(%arg1 : tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32> + return %0 : tensor<3x8x1x5xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loop1:4 = transform.structured.tile_using_for %0 tile_sizes [0, 0, 1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)> +// CHECK: #[[$MAP1:.+]] = affine_map<() -> (4)> +// CHECK-LABEL: func.func @tile_winograd_output( +// CHECK-SAME: %[[ARG0:.*]]: tensor<6x1x2x1x3x5xf32>, %[[ARG1:.*]]: tensor<3x8x1x5xf32>) -> tensor<3x8x1x5xf32> { +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_3:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C0_5:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C1_2:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C1_4:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C1_6:.*]] = arith.constant 1 : index +// CHECK: %[[S1:.*]] = scf.for %[[ARG2:.*]] = %[[C0]] to %[[C2]] step %[[C1]] +// CHECK: %[[S2:.*]] = scf.for %[[ARG4:.*]] = %[[C0_0]] to %[[C1_1]] step %[[C1_2]] +// CHECK: %[[S3:.*]] = scf.for %[[ARG6:.*]] = %[[C0_3]] to %[[C3]] step %[[C1_4]] +// CHECK: %[[S4:.*]] = scf.for %[[ARG8:.*]] = %[[C0_5]] to %[[C5]] step %[[C1_6]] +// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %[[ARG2]], %[[ARG4]], %[[ARG6]], %[[ARG8]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x1x2x1x3x5xf32> to tensor<6x1x1x1x1x1xf32> +// CHECK: %[[S5:.*]] = affine.apply #[[$MAP0]](%[[ARG2]]) +// CHECK: %[[S6:.*]] = affine.apply #[[$MAP0]](%[[ARG4]]) +// CHECK: %[[S7:.*]] = affine.apply #[[$MAP1]]() +// CHECK: %[[S8:.*]] = affine.apply #[[$MAP1]]() +// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG6]], %[[S5]], 0, %[[ARG8]]] [1, %[[S7]], 1, 1] [1, 1, 1, 1] : tensor<3x8x1x5xf32> to tensor<1x?x1x1xf32> +// CHECK: %[[S9:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXTRACTED_SLICE]] : tensor<6x1x1x1x1x1xf32>) outs(%[[EXTRACTED_SLICE_9]] : tensor<1x?x1x1xf32>) -> tensor<1x?x1x1xf32> From d70a2f839e1649ef298201f2429e28491680a20b Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Wed, 11 Sep 2024 22:15:43 -0700 Subject: [PATCH 34/46] [mlir][TilingInterface] Avoid looking at operands for getting slices to continue tile + fuse. (#107882) Current implementation of `scf::tileConsumerAndFuseProducerUsingSCF` looks at operands of tiled/tiled+fused operations to see if they are produced by `extract_slice` operations to populate the worklist used to continue fusion. This implicit assumption does not always work. Instead make the implementations of `getTiledImplementation` return the slices to use to continue fusion. This is a breaking change - To continue to get the same behavior of `scf::tileConsumerAndFuseProducerUsingSCF`, change all out-of-tree implementation of `TilingInterface::getTiledImplementation` to return the slices to continue fusion on. All in-tree implementations have been adapted to this. - This change touches parts that required a simplification to the `ControlFn` in `scf::SCFTileAndFuseOptions`. It now returns a `std::optional` object that should be `std::nullopt` if fusion is not to be performed. Signed-off-by: MaheshRavishankar --- .../include/mlir/Dialect/Linalg/Utils/Utils.h | 11 ++- .../SCF/Transforms/TileUsingInterface.h | 33 ++++--- .../include/mlir/Interfaces/TilingInterface.h | 7 +- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 82 ++++++++++------ .../Linalg/Transforms/TilingInterfaceImpl.cpp | 26 +++++- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 20 ++-- .../SCF/Transforms/TileUsingInterface.cpp | 93 +++++++++++-------- .../Tensor/IR/TensorTilingInterfaceImpl.cpp | 71 ++++++++------ .../tile-and-fuse-using-interface.mlir | 45 +++++++++ .../TestTilingInterfaceTransformOps.cpp | 12 ++- 10 files changed, 271 insertions(+), 129 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index 65a1a8b42e1495..f1df49ce3eaa36 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -178,11 +178,12 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, /// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck` /// controls whether to omit the partial/boundary tile condition check in /// cases where we statically know that it is unnecessary. -Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, - ArrayRef tileSizes, AffineMap map, - ArrayRef lbs, ArrayRef ubs, - ArrayRef subShapeSizes, - bool omitPartialTileCheck); +Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, + ArrayRef tileSizes, AffineMap map, + ArrayRef lbs, + ArrayRef ubs, + ArrayRef subShapeSizes, + bool omitPartialTileCheck); /// Creates extract_slice/subview ops for all `valuesToTile` of the given /// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 1f21af6d6a29ac..77c812cde71533 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -106,6 +106,9 @@ struct SCFTilingResult { /// Values to use as replacements for the untiled op. Is the same size as the /// number of results of the untiled op. SmallVector replacements; + /// Slices generated after tiling that can be used for fusing with the tiled + /// producer. + SmallVector generatedSlices; }; /// Method to tile an op that implements the `TilingInterface` using @@ -129,18 +132,22 @@ struct SCFTileAndFuseOptions { /// 2) the producer value that is to be fused /// 3) a boolean value set to `true` if the fusion is from /// a destination operand. - /// It retuns two booleans - /// - returns `true` if the fusion should be done through the candidate slice - /// - returns `true` if a replacement for the fused producer needs to be - /// yielded from within the tiled loop. Note that it is valid to return - /// `true` only if the slice fused is disjoint across all iterations of the - /// tiled loop. It is up to the caller to ensure that this is true for the - /// fused producers. - using ControlFnTy = std::function( + /// The control function returns an `std::optiona`. + /// If the return value is `std::nullopt`, that implies no fusion + /// is to be performed along that slice. + struct ControlFnResult { + /// Set to true if the loop nest has to return a replacement value + /// for the fused producer. + bool yieldProducerReplacement = false; + }; + using ControlFnTy = std::function( tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, bool isDestinationOperand)>; - ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, bool) { - return std::make_tuple(true, false); + /// The default control function implements greedy fusion without yielding + /// a replacement for any of the fused results. + ControlFnTy fusionControlFn = [](tensor::ExtractSliceOp, OpResult, + bool) -> std::optional { + return ControlFnResult{}; }; SCFTileAndFuseOptions &setFusionControlFn(ControlFnTy controlFn) { fusionControlFn = controlFn; @@ -156,6 +163,7 @@ struct SCFFuseProducerOfSliceResult { OpResult origProducer; // Original untiled producer. Value tiledAndFusedProducer; // Tile and fused producer value. SmallVector tiledOps; + SmallVector generatedSlices; }; std::optional tileAndFuseProducerOfSlice(RewriterBase &rewriter, @@ -215,7 +223,10 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter, /// /// The @param `yieldResultNumber` decides which result would be yield. If not /// given, yield all `opResult` of fused producer. -LogicalResult yieldReplacementForFusedProducer( +/// +/// The method returns the list of new slices added during the process (which +/// can be used to fuse along). +FailureOr> yieldReplacementForFusedProducer( RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef loops, diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h index 2f51496d1b110a..b33aa1489c3116 100644 --- a/mlir/include/mlir/Interfaces/TilingInterface.h +++ b/mlir/include/mlir/Interfaces/TilingInterface.h @@ -25,12 +25,15 @@ namespace mlir { /// Container for result values of tiling. /// - `tiledOps` contains operations created by the tiling implementation that -/// are returned to the caller for further transformations. +/// are returned to the caller for further transformations. /// - `tiledValues` contains the tiled value corresponding to the result of the -/// untiled operation. +/// untiled operation. +/// - `generatedSlices` contains the list of slices that are generated during +/// tiling. These slices can be used for fusing producers. struct TilingResult { SmallVector tiledOps; SmallVector tiledValues; + SmallVector generatedSlices; }; /// Container for the result of merge operation of tiling. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 775ed8f37344ed..d452ff72b68aa3 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -66,20 +66,20 @@ static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, /// Returns a memref.subview or a tensor.extract_slice based on the type of the /// `source`. -static Value getSlice(OpBuilder &b, Location loc, Value source, - ArrayRef offsets, - ArrayRef sizes, - ArrayRef strides) { - return TypeSwitch(source.getType()) - .Case([&](RankedTensorType t) -> Value { +static Operation *getSlice(OpBuilder &b, Location loc, Value source, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + return TypeSwitch(source.getType()) + .Case([&](RankedTensorType t) -> Operation * { return b.create(loc, source, offsets, sizes, strides); }) - .Case([&](MemRefType type) -> Value { + .Case([&](MemRefType type) -> Operation * { return b.create(loc, source, offsets, sizes, strides); }) - .Default([&](Type t) { return nullptr; }); + .Default([&](Type t) -> Operation * { return nullptr; }); } //===----------------------------------------------------------------------===// @@ -2599,10 +2599,18 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder, auto oneAttr = builder.getI64IntegerAttr(1); SmallVector strides(rank, oneAttr); SmallVector tiledOperands; - tiledOperands.emplace_back( - getSlice(builder, getLoc(), getInput(), offsets, sizes, strides)); - tiledOperands.emplace_back( - getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides)); + Operation *inputSlice = + getSlice(builder, getLoc(), getInput(), offsets, sizes, strides); + if (!inputSlice) { + return emitOpError("failed to compute input slice"); + } + tiledOperands.emplace_back(inputSlice->getResult(0)); + Operation *outputSlice = + getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides); + if (!outputSlice) { + return emitOpError("failed to compute output slice"); + } + tiledOperands.emplace_back(outputSlice->getResult(0)); SmallVector resultTypes; if (hasPureTensorSemantics()) @@ -2610,7 +2618,10 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder, Operation *tiledOp = mlir::clone(builder, getOperation(), resultTypes, tiledOperands); - return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; + return TilingResult{ + {tiledOp}, + SmallVector(tiledOp->getResults()), + llvm::to_vector(ArrayRef{inputSlice, outputSlice})}; } LogicalResult SoftmaxOp::getResultTilePosition( @@ -2957,8 +2968,9 @@ FailureOr WinogradFilterTransformOp::getTiledImplementation( int64_t filterRank = getFilterOperandRank(); SmallVector filterStrides(filterRank, oneAttr); Location loc = getLoc(); - tiledOperands.emplace_back(builder.create( - loc, getFilter(), sliceOffsets, sliceSizes, filterStrides)); + auto filterSlice = builder.create( + loc, getFilter(), sliceOffsets, sliceSizes, filterStrides); + tiledOperands.emplace_back(filterSlice); SmallVector resultOffsets, resultSizes; if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets, @@ -2967,15 +2979,19 @@ FailureOr WinogradFilterTransformOp::getTiledImplementation( int64_t outputRank = getOutputOperandRank(); SmallVector outputStrides(outputRank, oneAttr); - tiledOperands.emplace_back(builder.create( - loc, getOutput(), resultOffsets, resultSizes, outputStrides)); + auto outputSlice = builder.create( + loc, getOutput(), resultOffsets, resultSizes, outputStrides); + tiledOperands.emplace_back(outputSlice); SmallVector resultTypes; resultTypes.push_back(tiledOperands[1].getType()); Operation *tiledOp = mlir::clone(builder, getOperation(), resultTypes, tiledOperands); - return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; + return TilingResult{ + {tiledOp}, + SmallVector(tiledOp->getResults()), + llvm::to_vector(ArrayRef{filterSlice, outputSlice})}; } //===----------------------------------------------------------------------===// @@ -3124,8 +3140,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder, {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]}); int64_t inputRank = getInputOperandRank(); SmallVector inputStrides(inputRank, oneAttr); - tiledOperands.emplace_back(builder.create( - loc, getInput(), sliceOffsets, sliceSizes, inputStrides)); + auto inputSlice = builder.create( + loc, getInput(), sliceOffsets, sliceSizes, inputStrides); + tiledOperands.emplace_back(inputSlice); SmallVector resultOffsets, resultSizes; if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets, @@ -3134,15 +3151,19 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder, int64_t outputRank = getOutputOperandRank(); SmallVector outputStrides(outputRank, oneAttr); - tiledOperands.emplace_back(builder.create( - loc, getOutput(), resultOffsets, resultSizes, outputStrides)); + auto outputSlice = builder.create( + loc, getOutput(), resultOffsets, resultSizes, outputStrides); + tiledOperands.emplace_back(outputSlice); SmallVector resultTypes; resultTypes.push_back(tiledOperands[1].getType()); Operation *tiledOp = mlir::clone(builder, getOperation(), resultTypes, tiledOperands); - return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; + return TilingResult{ + {tiledOp}, + SmallVector(tiledOp->getResults()), + llvm::to_vector(ArrayRef{inputSlice, outputSlice})}; } //===----------------------------------------------------------------------===// @@ -3286,8 +3307,9 @@ FailureOr WinogradOutputTransformOp::getTiledImplementation( sizes[getValueFDim()]}); int64_t valueRank = getValueOperandRank(); SmallVector sliceStrides(valueRank, oneAttr); - tiledOperands.emplace_back(builder.create( - loc, getValue(), sliceOffsets, sliceSizes, sliceStrides)); + auto valueSlice = builder.create( + loc, getValue(), sliceOffsets, sliceSizes, sliceStrides); + tiledOperands.emplace_back(valueSlice); SmallVector resultOffsets, resultSizes; if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets, @@ -3296,15 +3318,19 @@ FailureOr WinogradOutputTransformOp::getTiledImplementation( int64_t outputRank = getOutputOperandRank(); SmallVector strides(outputRank, oneAttr); - tiledOperands.emplace_back(builder.create( - loc, getOutput(), resultOffsets, resultSizes, strides)); + auto outputSlice = builder.create( + loc, getOutput(), resultOffsets, resultSizes, strides); + tiledOperands.emplace_back(outputSlice); SmallVector resultTypes; resultTypes.push_back(tiledOperands[1].getType()); Operation *tiledOp = mlir::clone(builder, getOperation(), resultTypes, tiledOperands); - return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; + return TilingResult{ + {tiledOp}, + SmallVector(tiledOp->getResults()), + llvm::to_vector(ArrayRef{valueSlice, outputSlice})}; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index fbff91a94219cc..f86715a94b268a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -120,8 +120,16 @@ struct LinalgOpTilingInterface Location loc = op->getLoc(); LinalgOp linalgOp = cast(op); SmallVector valuesToTile = linalgOp->getOperands(); - SmallVector tiledOperands = makeTiledShapes( + SmallVector tiledOperands = makeTiledShapes( b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true); + SmallVector generatedSlices = llvm::map_to_vector( + llvm::make_filter_range( + tiledOperands, + [](Value v) -> bool { + return isa_and_nonnull( + v.getDefiningOp()); + }), + [](Value v) -> Operation * { return v.getDefiningOp(); }); SmallVector resultTensorTypes = getTensorOutputTypes(linalgOp, tiledOperands); @@ -129,7 +137,8 @@ struct LinalgOpTilingInterface Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands); offsetIndices(b, cast(tiledOp), offsets); - return TilingResult{{tiledOp}, SmallVector(tiledOp->getResults())}; + return TilingResult{ + {tiledOp}, SmallVector(tiledOp->getResults()), generatedSlices}; } /// Utility to fetch the offsets and sizes when applied as per the indexing @@ -260,7 +269,8 @@ struct LinalgOpTilingInterface return TilingResult{ tilingResult->tiledOps, - SmallVector{tilingResult->tiledValues[resultNumber]}}; + SmallVector{tilingResult->tiledValues[resultNumber]}, + tilingResult->generatedSlices}; } /// Method to generate the tiled implementation of an operation from the tile @@ -406,8 +416,12 @@ struct LinalgOpPartialReductionInterface } // Step 2a: Extract a slice of the input operands. - SmallVector tiledInputs = makeTiledShapes( + SmallVector tiledInputs = makeTiledShapes( b, loc, linalgOp, linalgOp.getDpsInputs(), offsets, sizes, {}, true); + SmallVector generatedSlices = llvm::map_to_vector( + llvm::make_filter_range( + tiledInputs, [](Value v) -> bool { return v.getDefiningOp(); }), + [](Value v) -> Operation * { return v.getDefiningOp(); }); // Step 2b: Extract a slice of the init operands. SmallVector tiledInits; @@ -424,6 +438,7 @@ struct LinalgOpPartialReductionInterface auto extractSlice = b.create( loc, valueToTile, initOffset, initSizes, initStride); tiledInits.push_back(extractSlice); + generatedSlices.push_back(extractSlice); } // Update the indexing maps. @@ -453,7 +468,8 @@ struct LinalgOpPartialReductionInterface return TilingResult{ {genericOp.getOperation()}, llvm::map_to_vector(genericOp->getResults(), - [](OpResult r) -> Value { return r; })}; + [](OpResult r) -> Value { return r; }), + generatedSlices}; } FailureOr mergeReductions(Operation *op, OpBuilder &b, diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index fa0598dd96885c..6a3f2fc5fbc496 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -565,9 +565,9 @@ void GenerateLoopNest::doit( assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops"); } -static Value materializeTiledShape(OpBuilder &builder, Location loc, - Value valueToTile, - const SliceParameters &sliceParams) { +static Operation *materializeTiledShape(OpBuilder &builder, Location loc, + Value valueToTile, + const SliceParameters &sliceParams) { auto shapedType = dyn_cast(valueToTile.getType()); auto *sliceOp = TypeSwitch(shapedType) .Case([&](MemRefType) { @@ -583,14 +583,15 @@ static Value materializeTiledShape(OpBuilder &builder, Location loc, .Default([](ShapedType) -> Operation * { llvm_unreachable("Unexpected shaped type"); }); - return sliceOp->getResult(0); + return sliceOp; } -Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, - ArrayRef tileSizes, AffineMap map, - ArrayRef lbs, ArrayRef ubs, - ArrayRef subShapeSizes, - bool omitPartialTileCheck) { +Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile, + ArrayRef tileSizes, AffineMap map, + ArrayRef lbs, + ArrayRef ubs, + ArrayRef subShapeSizes, + bool omitPartialTileCheck) { SliceParameters sliceParams = computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs, ubs, subShapeSizes, omitPartialTileCheck); @@ -841,6 +842,7 @@ SmallVector makeTiledShapes(OpBuilder &builder, Location loc, tiledShapes.push_back( sliceParams.has_value() ? materializeTiledShape(builder, loc, valueToTile, *sliceParams) + ->getResult(0) : valueToTile); } return tiledShapes; diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index e404c01010a325..3729300588422e 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -854,7 +854,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, if (llvm::all_of(tileSizes, isZeroIndex)) { tiledResults.append(clonedOp->result_begin(), clonedOp->result_end()); tilingResult = - TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()}; + TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(), + /*generatedSlices=*/{}}; return success(); } @@ -910,12 +911,14 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, // op. if (loops.empty()) { return scf::SCFTilingResult{tilingResult->tiledOps, loops, - tilingResult->tiledValues}; + tilingResult->tiledValues, + tilingResult->generatedSlices}; } SmallVector replacements = llvm::map_to_vector( loops.front()->getResults(), [](OpResult r) -> Value { return r; }); - return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements}; + return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements, + tilingResult->generatedSlices}; } FailureOr @@ -1180,13 +1183,13 @@ mlir::scf::tileAndFuseProducerOfSlice( ->getOpOperands()[destinationInitArg.value()->getOperandNumber()] .set(origDestinationTensors[resultNumber]); } - return scf::SCFFuseProducerOfSliceResult{fusableProducer, - tileAndFuseResult->tiledValues[0], - tileAndFuseResult->tiledOps}; + return scf::SCFFuseProducerOfSliceResult{ + fusableProducer, tileAndFuseResult->tiledValues[0], + tileAndFuseResult->tiledOps, tileAndFuseResult->generatedSlices}; } /// Reconstruct the fused producer from within the tiled-and-fused code. -LogicalResult mlir::scf::yieldReplacementForFusedProducer( +FailureOr> mlir::scf::yieldReplacementForFusedProducer( RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, scf::SCFFuseProducerOfSliceResult fusedProducerInfo, MutableArrayRef loops, @@ -1214,6 +1217,7 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer( } } + SmallVector generatedSlices; YieldTiledValuesFn newYieldValuesFn = [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/, ValueRange newRegionIterArgs, SmallVector &tiledResult, @@ -1284,6 +1288,7 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer( loc, newRegionArg, offsetList[index], sizesList[index], SmallVector(offsetList[index].size(), rewriter.getIndexAttr(1))); + generatedSlices.push_back(destSlice); unsigned resultNumber = initNumberList[index]; rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() { tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice); @@ -1303,8 +1308,11 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer( return success(); }; - return addInitOperandsToLoopNest(rewriter, loops, initValueList, - newYieldValuesFn); + if (failed(addInitOperandsToLoopNest(rewriter, loops, initValueList, + newYieldValuesFn))) { + return failure(); + } + return generatedSlices; } /// Implementation of tile consumer and fuse producer greedily. @@ -1358,52 +1366,62 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( // operations. If the producers of the source of the `tensor.extract_slice` // can be tiled such that the tiled value is generated in-place, that // effectively tiles + fuses the operations. - auto addCandidateSlices = [](Operation *fusedOp, - std::deque &candidates) { - for (Value operand : fusedOp->getOperands()) - if (auto sliceOp = operand.getDefiningOp()) - candidates.push_back(sliceOp); + struct WorklistItem { + tensor::ExtractSliceOp candidateSlice; + SCFTileAndFuseOptions::ControlFnResult controlFnResult; + }; + std::deque worklist; + auto addCandidateSlices = [&worklist, &options, + &loops](ArrayRef candidates) { + for (auto candidate : candidates) { + auto sliceOp = dyn_cast(candidate); + if (!sliceOp || sliceOp.use_empty()) + continue; + + auto [fusableProducer, destinationInitArg] = + getUntiledProducerFromSliceSource(&sliceOp.getSourceMutable(), loops); + if (!fusableProducer) + continue; + std::optional controlFnResult = + options.fusionControlFn(sliceOp, fusableProducer, + destinationInitArg.has_value()); + if (!controlFnResult) + continue; + worklist.emplace_back(WorklistItem{sliceOp, controlFnResult.value()}); + } }; - std::deque candidates; - addCandidateSlices(tiledAndFusedOps.back(), candidates); + addCandidateSlices(tilingResult->generatedSlices); OpBuilder::InsertionGuard g(rewriter); - while (!candidates.empty()) { + while (!worklist.empty()) { // Traverse the slices in BFS fashion. - tensor::ExtractSliceOp candidateSliceOp = candidates.front(); - candidates.pop_front(); - - // Find the original producer of the slice. - auto [fusableProducer, destinationInitArg] = - getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), - loops); - if (!fusableProducer) - continue; - - auto [fuseSlice, yieldReplacement] = options.fusionControlFn( - candidateSliceOp, fusableProducer, destinationInitArg.has_value()); - if (!fuseSlice) - continue; + WorklistItem worklistItem = worklist.front(); + worklist.pop_front(); // The operands of the fused producer might themselved be slices of // values produced by operations that implement the `TilingInterface`. // Add these operations to the worklist. std::optional fusedResult = - tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, loops); + tileAndFuseProducerOfSlice(rewriter, worklistItem.candidateSlice, + loops); if (!fusedResult) continue; - if (yieldReplacement) { + if (worklistItem.controlFnResult.yieldProducerReplacement) { // Reconstruct and yield all opResult of fusableProducerOp by default. The // caller can specific which one to yield by designating optional argument // named `yieldResultNumber` of `yieldReplacementForFusedProducer`. - Operation *fusableProducerOp = fusableProducer.getOwner(); - if (failed(yieldReplacementForFusedProducer( - rewriter, candidateSliceOp, fusedResult.value(), loops))) { + Operation *fusableProducerOp = fusedResult->origProducer.getOwner(); + FailureOr> newSlices = + yieldReplacementForFusedProducer(rewriter, + worklistItem.candidateSlice, + fusedResult.value(), loops); + if (failed(newSlices)) { return rewriter.notifyMatchFailure( fusableProducerOp, "failed to replacement value for this " "operation from within the tiled loop"); } + addCandidateSlices(newSlices.value()); for (auto [index, result] : llvm::enumerate(fusableProducerOp->getResults())) { origValToResultNumber[result] = loops.front()->getNumResults() - @@ -1411,12 +1429,11 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( index; } } - + addCandidateSlices(fusedResult->generatedSlices); if (Operation *tiledAndFusedOp = fusedResult->tiledAndFusedProducer.getDefiningOp()) { fusedProducers.insert(fusedResult->origProducer.getDefiningOp()); tiledAndFusedOps.insert(tiledAndFusedOp); - addCandidateSlices(tiledAndFusedOp, candidates); } } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index dec678de6d1c27..34eec4cee052e3 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -170,8 +170,9 @@ struct PackOpTiling SmallVector strides(inputRank, oneAttr); SmallVector tiledOperands; - tiledOperands.push_back(b.create( - loc, packOp.getSource(), inputIndices, inputSizes, strides)); + auto sourceSlice = b.create( + loc, packOp.getSource(), inputIndices, inputSizes, strides); + tiledOperands.push_back(sourceSlice); SmallVector outputOffsets, outputSizes; if (failed(getResultTilePosition(op, b, 0, offsets, sizes, outputOffsets, @@ -179,9 +180,9 @@ struct PackOpTiling return {}; strides.append(packOp.getDestRank() - inputRank, oneAttr); - auto extractSlice = b.create( + auto outSlice = b.create( loc, packOp.getDest(), outputOffsets, outputSizes, strides); - tiledOperands.push_back(extractSlice); + tiledOperands.push_back(outSlice); if (auto val = packOp.getPaddingValue()) tiledOperands.push_back(val); @@ -189,10 +190,12 @@ struct PackOpTiling tiledOperands.push_back(tile); Operation *tiledPackOp = b.create( - loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs()); + loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); - return TilingResult{{tiledPackOp}, - SmallVector(tiledPackOp->getResults())}; + return TilingResult{ + {tiledPackOp}, + SmallVector(tiledPackOp->getResults()), + llvm::to_vector(ArrayRef{sourceSlice, outSlice})}; } LogicalResult @@ -331,8 +334,9 @@ struct PackOpTiling SmallVector strides(inputRank, oneAttr); SmallVector tiledOperands; - tiledOperands.push_back(b.create(loc, packOp.getSource(), - offsets, sizes, strides)); + auto sourceSlice = b.create(loc, packOp.getSource(), + offsets, sizes, strides); + tiledOperands.push_back(sourceSlice); SmallVector outerDimOffsets, outerDimSizes; if (failed(getIterationDomainTileFromOperandTile( @@ -346,19 +350,21 @@ struct PackOpTiling return failure(); strides.append(packOp.getDestRank() - inputRank, oneAttr); - auto extractSlice = b.create( + auto outSlice = b.create( loc, packOp.getDest(), outputOffsets, outputSizes, strides); - tiledOperands.push_back(extractSlice); + tiledOperands.push_back(outSlice); assert(!packOp.getPaddingValue() && "Expect no padding semantic"); for (auto tile : packOp.getInnerTiles()) tiledOperands.push_back(tile); Operation *tiledPackOp = b.create( - loc, TypeRange{extractSlice.getType()}, tiledOperands, op->getAttrs()); + loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); - return TilingResult{{tiledPackOp}, - SmallVector(tiledPackOp->getResults())}; + return TilingResult{ + {tiledPackOp}, + SmallVector(tiledPackOp->getResults()), + llvm::to_vector(ArrayRef{sourceSlice, outSlice})}; } }; @@ -537,9 +543,12 @@ struct UnPackOpTiling SmallVector destStrides(destRank, oneAttr); Value sliceDest; + SmallVector generatedSlices; if (isPerfectTilingCase) { - sliceDest = b.create(loc, unpackOp.getDest(), offsets, - sizes, destStrides); + auto destSliceOp = b.create(loc, unpackOp.getDest(), + offsets, sizes, destStrides); + sliceDest = destSliceOp; + generatedSlices.push_back(destSliceOp); } else { sliceDest = b.create(loc, destExpandedSizes, unpackOp.getDestType().getElementType()); @@ -554,12 +563,15 @@ struct UnPackOpTiling if (isPerfectTilingCase) return TilingResult{{tiledUnpackOp}, - SmallVector(tiledUnpackOp->getResults())}; + SmallVector(tiledUnpackOp->getResults()), + generatedSlices}; auto extractSlice = b.create(loc, tiledUnpackOp->getResult(0), resultOffsetsFromDest, sizes, destStrides); - return TilingResult{{tiledUnpackOp}, {extractSlice.getResult()}}; + generatedSlices.push_back(extractSlice); + return TilingResult{ + {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices}; } LogicalResult @@ -680,7 +692,9 @@ struct UnPackOpTiling tiledOperands, op->getAttrs()); return TilingResult{{tiledUnPackOp}, - SmallVector(tiledUnPackOp->getResults())}; + SmallVector(tiledUnPackOp->getResults()), + llvm::to_vector(ArrayRef{ + extractSourceSlice, extractDestSlice})}; } }; @@ -850,7 +864,7 @@ FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, // the result shape of the new SliceOp has a zero dimension. auto createPadOfExtractSlice = [&]() { // Create pad(extract_slice(x)). - Value newSliceOp = b.create( + auto newSliceOp = b.create( loc, padOp.getSource(), newOffsets, newLengths, newStrides); auto newPadOp = b.create( loc, Type(), newSliceOp, newLows, newHighs, @@ -862,14 +876,16 @@ FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, padOp.getRegion().cloneInto(&newPadOp.getRegion(), bvm); // Cast result and return. - return newPadOp; + return std::make_tuple(newPadOp, newSliceOp); }; // Rewrite extract_slice(pad(x)) into a GenerateOp it is statically known that // the original data source x is not used. if (hasZeroLen) { Operation *generateOp = createGenerateOp(); - return TilingResult{{generateOp}, {castResult(generateOp->getResult(0))}}; + return TilingResult{{generateOp}, + {castResult(generateOp->getResult(0))}, + /*generatedSlices=*/{}}; } // If there are dynamic dimensions: Generate an scf.if check to avoid @@ -877,6 +893,7 @@ FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, if (generateZeroSliceGuard && dynHasZeroLenCond) { Operation *thenOp; Operation *elseOp; + Operation *sliceOp; auto result = b.create( loc, dynHasZeroLenCond, /*thenBuilder=*/ @@ -886,14 +903,16 @@ FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, }, /*elseBuilder=*/ [&](OpBuilder &b, Location loc) { - elseOp = createPadOfExtractSlice(); + std::tie(elseOp, sliceOp) = createPadOfExtractSlice(); b.create(loc, castResult(elseOp->getResult(0))); }); - return TilingResult{{elseOp}, SmallVector(result->getResults())}; + return TilingResult{ + {elseOp}, SmallVector(result->getResults()), {sliceOp}}; } - Operation *newPadOp = createPadOfExtractSlice(); - return TilingResult{{newPadOp}, {castResult(newPadOp->getResult(0))}}; + auto [newPadOp, sliceOp] = createPadOfExtractSlice(); + return TilingResult{ + {newPadOp}, {castResult(newPadOp->getResult(0))}, {sliceOp}}; } void mlir::tensor::registerTilingInterfaceExternalModels( diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir index d1aed593f45451..3ea1929e4ed785 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -542,3 +542,48 @@ module attributes {transform.with_named_sequence} { // CHECK-DAG: %[[INSERTSLICE:.+]] = tensor.insert_slice %[[GENERIC2]] into %[[ITERARG0]][%[[IV]], 0] // CHECK: scf.yield %[[INSERTSLICE]] // CHECK: return %[[RESULT]] + +// ----- + +func.func @pad_producer_fusion(%arg0 : tensor<10xf32>) -> tensor<16xf32> { + %0 = tensor.empty() : tensor<10xf32> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%arg0 : tensor<10xf32>) outs(%0 : tensor<10xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + %2 = arith.addf %b0, %b0: f32 + linalg.yield %2 : f32 + } -> tensor<10xf32> + %cst = arith.constant 0.0 : f32 + %2 = tensor.pad %1 low[4] high[2] { + ^bb0(%arg1 : index): + tensor.yield %cst : f32 + } : tensor<10xf32> to tensor<16xf32> + return %2 : tensor<16xf32> +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %pad = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.structured.fuse %pad [8] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-LABEL: func @pad_producer_fusion +// CHECK-SAME: %[[ARG0:.+]]: tensor<10xf32> +// CHECK: %[[FOR_RESULT:.+]] = scf.for +// CHECK: %[[IF_RESULT:.+]] = scf.if +// CHECK: else +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[SLICE]] : +// CHECK: %[[PAD:.+]] = tensor.pad %[[GENERIC]] +// CHECK: %[[CAST:.+]] = tensor.cast %[[PAD]] +// CHECK: scf.yield %[[CAST]] +// CHECK: %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[IF_RESULT]] +// CHECK: scf.yield %[[INSERT_SLICE]] +// CHECK: return %[[FOR_RESULT]] diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 7aa7b58433f36c..b6da47977cb4cf 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -91,11 +91,13 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp, scf::SCFTileAndFuseOptions::ControlFnTy controlFn = [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, - bool isDestinationOperand) { - Operation *owner = originalProducer.getOwner(); - bool yieldProducerReplacement = yieldReplacementsFor.contains(owner); - return std::make_tuple(true, yieldProducerReplacement); - }; + bool isDestinationOperand) + -> std::optional { + Operation *owner = originalProducer.getOwner(); + bool yieldProducerReplacement = yieldReplacementsFor.contains(owner); + return scf::SCFTileAndFuseOptions::ControlFnResult{ + yieldProducerReplacement}; + }; tileAndFuseOptions.setFusionControlFn(controlFn); rewriter.setInsertionPoint(target); From a8317e1f17bad0640a2d07795521b8eee60c9829 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Fri, 4 Oct 2024 14:42:55 -0400 Subject: [PATCH 35/46] [mlir] Add option for a cleanup pattern set to SCF tiling helper (#109554) The SCF helper for tiling an operation implementing the TilingInterface and greedily fusing consumers requires an uninterrupted chain of operations implementing the tiling interface to succeed. There can be cases with intermediate ops that don't implement the interface but have producers that could be fused if various canonicalization/simplification patterns could run in between fusion steps. This adds an option to SCFTileAndFuseOptions for a pattern set to run between fusion steps to the ops that result from fusion/tiling. Removed and newly inserted slices are tracked for continued fusion applications. See this RFC for more discussion: https://discourse.llvm.org/t/rfc-split-fusion-portions-of-the-tilinginterface-into-a-new-interface/81155 --- .../Linalg/TransformOps/LinalgTransformOps.td | 9 +- .../SCF/Transforms/TileUsingInterface.h | 6 + .../TransformOps/LinalgTransformOps.cpp | 9 + .../SCF/Transforms/TileUsingInterface.cpp | 156 +++++++++++++++--- .../Dialect/Linalg/transform-op-fuse.mlir | 100 +++++++++++ 5 files changed, 252 insertions(+), 28 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 106f0d79d9792d..2338f7d2da7298 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -284,18 +284,23 @@ def FuseOp : Op:$tile_sizes, - DefaultValuedAttr:$tile_interchange); + DefaultValuedAttr:$tile_interchange, + DefaultValuedAttr:$apply_cleanup); let results = (outs TransformHandleTypeInterface:$transformed, Variadic:$loops); let assemblyFormat = [{ $target ($tile_sizes^)? (`interchange` $tile_interchange^)? - attr-dict `:` functional-type(operands, results) + (`apply_cleanup` `=` $apply_cleanup^)? attr-dict + `:` functional-type(operands, results) }]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 77c812cde71533..9f5f9f3fca97ad 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -15,6 +15,7 @@ #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" #include @@ -153,6 +154,11 @@ struct SCFTileAndFuseOptions { fusionControlFn = controlFn; return *this; } + + /// An optional set of rewrite patterns to apply to the results of tiling + /// before fusion. This will track deleted and newly inserted + /// `tensor.extract_slice` ops and update the worklist. + std::optional cleanupPatterns = std::nullopt; }; /// Fuse the producer of the source of `candidateSliceOp` by computing the diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index fbf4e29024f7c2..8cf60da2e89b11 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -557,6 +557,15 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions = tilingOptions; + + if (getApplyCleanup()) { + MLIRContext *context = rewriter.getContext(); + RewritePatternSet patterns(context); + tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context); + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); + tileAndFuseOptions.cleanupPatterns = std::move(patterns); + } + LogicalResult result = applyTilingToAll( rewriter, getOperation(), state.getPayloadOps(getTarget()), tileSizes.size() - llvm::count(tileSizes, 0), transformResults, diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 3729300588422e..bb0d90dbba4a01 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -24,6 +24,8 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include @@ -1315,6 +1317,104 @@ FailureOr> mlir::scf::yieldReplacementForFusedProducer( return generatedSlices; } +namespace { + +//===----------------------------------------------------------------------===// +// SliceTrackingListener +//===----------------------------------------------------------------------===// + +/// This class is a listener for tracking the insertion and removal of +/// `tensor.extract_slice` ops in a worklist. This can be used in a greedy +/// fusion algorithm to apply cleanup patterns in between fusion steps. +class SliceTrackingListener : public RewriterBase::Listener { +public: + explicit SliceTrackingListener( + std::optional patterns); + SliceTrackingListener() = default; + + /// Adds the given list of operations to the worklist, and if present, applies + /// the list of `patterns` to the newly added operations. This only processes + /// the given operations and any newly inserted ones by the pattern set. + LogicalResult insertAndApplyPatterns(ArrayRef newOps); + + /// Add to the new operation worklist if it is an extract_slice. + void notifyOperationInserted(Operation *op, + OpBuilder::InsertPoint previous) override; + + /// Shared helper for operation removal from the worklist. + void removeOp(Operation *op); + + /// Remove the operation from the worklist. + void notifyOperationErased(Operation *op) override; + + /// Remove the operation from the worklist. + void notifyOperationReplaced(Operation *op, ValueRange replacement) override; + + /// The worklist for this transformation keeps track of the slices to visit + /// next for fusion. + std::deque worklist; + +private: + /// Optional pattern set to apply when adding new operations to the worklist. + std::optional patterns = std::nullopt; +}; + +SliceTrackingListener::SliceTrackingListener( + std::optional p) { + patterns = std::move(p); +} + +LogicalResult +SliceTrackingListener::insertAndApplyPatterns(ArrayRef ops) { + for (Operation *op : ops) { + if (auto slice = dyn_cast(op)) + worklist.push_back(slice); + } + + if (!patterns) + return success(); + + GreedyRewriteConfig config; + config.listener = this; + config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; + return applyOpPatternsAndFold(ops, patterns.value(), config); +} + +void SliceTrackingListener::notifyOperationInserted( + Operation *op, OpBuilder::InsertPoint previous) { + auto slice = dyn_cast(op); + if (!slice) + return; + worklist.push_back(slice); +} + +// Scan the worklist for the given op and remove it if present. The expectation +// is for the worklist to be small and for removal to be relatively rare. +void SliceTrackingListener::removeOp(Operation *op) { + if (!isa(op)) + return; + auto iter = worklist.begin(); + while (iter != worklist.end()) { + if (*iter == op) + break; + iter++; + } + if (iter == worklist.end()) + return; + + worklist.erase(iter); +} + +void SliceTrackingListener::notifyOperationErased(Operation *op) { + removeOp(op); +} + +void SliceTrackingListener::notifyOperationReplaced(Operation *op, + ValueRange replacement) { + removeOp(op); +} +} // namespace + /// Implementation of tile consumer and fuse producer greedily. FailureOr mlir::scf::tileConsumerAndFuseProducersUsingSCF( @@ -1370,33 +1470,32 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( tensor::ExtractSliceOp candidateSlice; SCFTileAndFuseOptions::ControlFnResult controlFnResult; }; - std::deque worklist; - auto addCandidateSlices = [&worklist, &options, - &loops](ArrayRef candidates) { - for (auto candidate : candidates) { - auto sliceOp = dyn_cast(candidate); - if (!sliceOp || sliceOp.use_empty()) - continue; - auto [fusableProducer, destinationInitArg] = - getUntiledProducerFromSliceSource(&sliceOp.getSourceMutable(), loops); - if (!fusableProducer) - continue; - std::optional controlFnResult = - options.fusionControlFn(sliceOp, fusableProducer, - destinationInitArg.has_value()); - if (!controlFnResult) - continue; - worklist.emplace_back(WorklistItem{sliceOp, controlFnResult.value()}); - } - }; + SliceTrackingListener sliceTracker = + SliceTrackingListener(options.cleanupPatterns); - addCandidateSlices(tilingResult->generatedSlices); + if (failed( + sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) { + return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); + } OpBuilder::InsertionGuard g(rewriter); - while (!worklist.empty()) { - // Traverse the slices in BFS fashion. - WorklistItem worklistItem = worklist.front(); - worklist.pop_front(); + while (!sliceTracker.worklist.empty()) { + auto candidateSlice = sliceTracker.worklist.front(); + sliceTracker.worklist.pop_front(); + + auto [fusableProducer, destinationInitArg] = + getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(), + loops); + if (!fusableProducer) + continue; + + std::optional controlFnResult = + options.fusionControlFn(candidateSlice, fusableProducer, + destinationInitArg.has_value()); + if (!controlFnResult) + continue; + + WorklistItem worklistItem = {candidateSlice, controlFnResult.value()}; // The operands of the fused producer might themselved be slices of // values produced by operations that implement the `TilingInterface`. @@ -1407,6 +1506,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( if (!fusedResult) continue; + SmallVector worklistCandidates = fusedResult->generatedSlices; + if (worklistItem.controlFnResult.yieldProducerReplacement) { // Reconstruct and yield all opResult of fusableProducerOp by default. The // caller can specific which one to yield by designating optional argument @@ -1421,7 +1522,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( fusableProducerOp, "failed to replacement value for this " "operation from within the tiled loop"); } - addCandidateSlices(newSlices.value()); + worklistCandidates.append(newSlices.value()); for (auto [index, result] : llvm::enumerate(fusableProducerOp->getResults())) { origValToResultNumber[result] = loops.front()->getNumResults() - @@ -1429,12 +1530,15 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF( index; } } - addCandidateSlices(fusedResult->generatedSlices); if (Operation *tiledAndFusedOp = fusedResult->tiledAndFusedProducer.getDefiningOp()) { fusedProducers.insert(fusedResult->origProducer.getDefiningOp()); tiledAndFusedOps.insert(tiledAndFusedOp); } + + if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) { + return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed"); + } } DenseMap replacements; diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir index 3a023deb1132f3..ac1ca9319d3354 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir @@ -178,3 +178,103 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: func.func @fuse_through_slice +func.func @fuse_through_slice(%arg0: tensor, %arg1: tensor) -> tensor { + + // CHECK: %[[RES:.*]] = scf.for + // CHECK: scf.for + // CHECK: linalg.elemwise_unary + // CHECK: linalg.elemwise_binary + // CHECK: return %[[RES]] + %0 = linalg.elemwise_unary ins(%arg0 : tensor) + outs(%arg0: tensor) -> tensor + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg1, %c0 : tensor + %dim1 = tensor.dim %arg1, %c1 : tensor + %1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor to tensor + %2 = linalg.elemwise_binary ins(%1, %arg1 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %2 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} + : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @fuse_through_slice_and_cast_chain +func.func @fuse_through_slice_and_cast_chain(%arg0: tensor<100x100xf32>, %arg1: tensor) -> tensor { + + // CHECK: %[[RES:.*]] = scf.for + // CHECK: scf.for + // CHECK: linalg.elemwise_unary + // CHECK: linalg.elemwise_binary + // CHECK: return %[[RES]] + %0 = linalg.elemwise_unary ins(%arg0 : tensor<100x100xf32>) + outs(%arg0: tensor<100x100xf32>) -> tensor<100x100xf32> + %1 = tensor.cast %0 : tensor<100x100xf32> to tensor<100x?xf32> + %2 = tensor.extract_slice %1 [1, 1] [98, 98] [1, 1] : tensor<100x?xf32> to tensor<98x98xf32> + %3 = tensor.cast %2 : tensor<98x98xf32> to tensor + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg1, %c0 : tensor + %dim1 = tensor.dim %arg1, %c1 : tensor + %4 = tensor.extract_slice %3 [1, 1] [%dim0, %dim1] [1, 1] : tensor to tensor + %5 = linalg.elemwise_binary ins(%4, %arg1 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %5 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} + : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func.func @fuse_unrelated_slice +func.func @fuse_unrelated_slices(%arg0: tensor, %arg1: tensor) -> (tensor, tensor<10x10xf32>) { + + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[SLICE1]] + // CHECK: %[[RES:.*]] = scf.for + // CHECK: scf.for + // CHECK: linalg.elemwise_unary + // CHECK: linalg.elemwise_binary + // CHECK: return %[[RES]], %[[SLICE2]] + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %dim0 = tensor.dim %arg1, %c0 : tensor + %dim1 = tensor.dim %arg1, %c1 : tensor + %slice1 = tensor.extract_slice %arg0 [1, 1] [%dim0, %dim1] [1, 1] : tensor to tensor + %slice2 = tensor.extract_slice %slice1 [1, 1] [10, 10] [1, 1] : tensor to tensor<10x10xf32> + %0 = linalg.elemwise_unary ins(%arg0 : tensor) + outs(%arg0: tensor) -> tensor + %1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor to tensor + %2 = linalg.elemwise_binary ins(%1, %arg1 : tensor, tensor) + outs(%arg1: tensor) -> tensor + return %2, %slice2 : tensor, tensor<10x10xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true} + : (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op) + transform.yield + } +} From ac5f771ebd8b92da0c88d973a0ee908cdabd5ea6 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 6 Nov 2024 17:12:12 +0100 Subject: [PATCH 36/46] emitc.tu: Automatically create block for body The auto-generated builder created an emitc.tu that had an empty region. This is a bit cumbersome to work with, as you would always manually needed to create a block in it. Do what ModuleOp::build does and always create that block. Also accept StringRef as argument for id instead of requiring a StringAttr. --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 8 ++++++++ mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 10 ++++++++++ 2 files changed, 18 insertions(+) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 78c420997dac65..99005af984b54d 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -85,7 +85,11 @@ def EmitC_TranslationUnitOp : EmitC_Op<"tu", let regions = (region SizedRegion<1>:$bodyRegion); let assemblyFormat = "$id attr-dict-with-keyword $bodyRegion"; + let builders = [OpBuilder<(ins CArg<"StringRef">:$id)>]; let extraClassDeclaration = [{ + /// Construct a module from the given location with an optional name. + static TranslationUnitOp create(Location loc, StringRef name); + //===------------------------------------------------------------------===// // OpAsmOpInterface Methods //===------------------------------------------------------------------===// @@ -96,6 +100,10 @@ def EmitC_TranslationUnitOp : EmitC_Op<"tu", return "emitc"; } }]; + + // We need to ensure that the body region has a block; + // the auto-generated builders do not guarantee that. + let skipDefaultBuilders = 1; } def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> { diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index ee44f524c91428..c2fb835c2ebc34 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -1251,6 +1251,16 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } +//===----------------------------------------------------------------------===// +// TranslationUnitOp +//===----------------------------------------------------------------------===// +void TranslationUnitOp::build(OpBuilder &builder, OperationState &state, + StringRef id) { + state.addRegion()->emplaceBlock(); + state.attributes.push_back( + builder.getNamedAttr("id", builder.getStringAttr(id))); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// From 831eb6640aea7d076c5dd752dd5ecdf245a78c7a Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 6 Nov 2024 17:14:33 +0100 Subject: [PATCH 37/46] emitc.include: don't require the parent to be a ModuleOp `#include` make sense everywhere, and in particular we need to allow them inside a `emitc.tu`. But sometimes we might even want to have an `#include` in a function body. --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 78c420997dac65..3f290f68b85dc8 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -774,7 +774,7 @@ def EmitC_ReturnOp : EmitC_Op<"return", [Pure, HasParent<"FuncOp">, } def EmitC_IncludeOp - : EmitC_Op<"include", [HasParent<"ModuleOp">]> { + : EmitC_Op<"include", []> { let summary = "Include operation"; let description = [{ The `emitc.include` operation allows to define a source file inclusion via the From 69cbbb5541dde4b12fc6880675d813e612b2bc92 Mon Sep 17 00:00:00 2001 From: Tiago Trevisan Jost Date: Fri, 8 Nov 2024 13:49:25 +0100 Subject: [PATCH 38/46] fix: fuse locations of double reshapes when folding. --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 4 ++ .../Tosa/canonicalize_with_debuginfo.mlir | 43 ++++++++++++++++++- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 5b123096686aa6..c3d9d2a773ae70 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1107,6 +1107,10 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { if (auto reshapeOp = llvm::dyn_cast_if_present( getInput1().getDefiningOp())) { getInput1Mutable().assign(reshapeOp.getInput1()); + + // Fuse locations so that first ReshapeOp location isn't lost. + getResult().getDefiningOp()->setLoc( + mlir::FusedLoc::get(getContext(), {reshapeOp->getLoc(), getLoc()})); return getResult(); } diff --git a/mlir/test/Dialect/Tosa/canonicalize_with_debuginfo.mlir b/mlir/test/Dialect/Tosa/canonicalize_with_debuginfo.mlir index d206136b5fe2a1..3daf80fc33c506 100644 --- a/mlir/test/Dialect/Tosa/canonicalize_with_debuginfo.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize_with_debuginfo.mlir @@ -43,4 +43,45 @@ func.func @canonicalize_optimize_sqrt_reciprocal_bf16(%arg0: tensor<1x5x1x1xbf16 return %2 : tensor<1x5x1x1xbf16> } #loc0 = loc("Pow_B") -#loc1 = loc("Reciprocal_C") \ No newline at end of file +#loc1 = loc("Reciprocal_C") + +// ----- + +// CHECK-LABEL: @reshape_canonicalize_double +func.func @reshape_canonicalize_double(%arg0: tensor) -> tensor { + // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0 {new_shape = array} {{.*}} loc([[LOC:.*]]) + // CHECK: return %[[VAL_1]] + %0 = tosa.reshape %arg0 {new_shape = array}: (tensor) -> tensor<5x?xf32> loc(#loc0) + %1 = tosa.reshape %0 {new_shape = array}: (tensor<5x?xf32>) -> tensor loc(#loc1) + return %1 : tensor +} +#loc0 = loc("reshape1") +#loc1 = loc("reshape2") + +// CHECK-DAG: #[[A:.*]] = loc("reshape1") +// CHECK-DAG: #[[B:.*]] = loc("reshape2") +// CHECK-DAG: [[LOC]] = loc(fused[#[[A]], #[[B]]]) + +// ----- + +// CHECK-LABEL: @reshape_canonicalize_double_fused_locs +func.func @reshape_canonicalize_double_fused_locs(%arg0: tensor) -> tensor { + // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0 {new_shape = array} {{.*}} loc([[LOC:.*]]) + // CHECK: return %[[VAL_1]] + %0 = tosa.reshape %arg0 {new_shape = array}: (tensor) -> tensor<5x?xf32> loc(#fused_loc0) + %1 = tosa.reshape %0 {new_shape = array}: (tensor<5x?xf32>) -> tensor loc(#fused_loc1) + return %1 : tensor +} +#loc0 = loc("reshape1_1") +#loc1 = loc("reshape1_2") +#loc2 = loc("reshape2_1") +#loc3 = loc("reshape2_2") + +// CHECK-DAG: #[[A:.*]] = loc("reshape1_1") +// CHECK-DAG: #[[B:.*]] = loc("reshape1_2") +// CHECK-DAG: #[[C:.*]] = loc("reshape2_1") +// CHECK-DAG: #[[D:.*]] = loc("reshape2_2") +// CHECK-DAG: [[LOC]] = loc(fused[#[[A]], #[[B]], #[[C]], #[[D]]]) + +#fused_loc0 = loc(fused[#loc0, #loc1]) +#fused_loc1 = loc(fused[#loc2, #loc3]) \ No newline at end of file From 39c4494897dd292aa9169fd076531c7ba72ea487 Mon Sep 17 00:00:00 2001 From: Tiago Trevisan Jost Date: Mon, 11 Nov 2024 14:52:07 +0100 Subject: [PATCH 39/46] feat: improve CSE by fusing locations when replacing one op for the other. --- mlir/lib/Transforms/CSE.cpp | 5 ++ mlir/test/Transforms/cse_with_locations.mlir | 48 ++++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 mlir/test/Transforms/cse_with_locations.mlir diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 3affd88d158de5..db556e065bc2bb 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -171,6 +171,11 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, // current op. if (isa(existing->getLoc()) && !isa(op->getLoc())) existing->setLoc(op->getLoc()); + else { + // Otherwise, fuse both locations. + existing->setLoc(mlir::FusedLoc::get(existing->getContext(), + {existing->getLoc(), op->getLoc()})); + } ++numCSE; } diff --git a/mlir/test/Transforms/cse_with_locations.mlir b/mlir/test/Transforms/cse_with_locations.mlir new file mode 100644 index 00000000000000..be2b5b04c0027b --- /dev/null +++ b/mlir/test/Transforms/cse_with_locations.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' -mlir-print-debuginfo | FileCheck %s + +// CHECK-LABEL: @many +func.func @many(f32, f32) -> (f32, f32) { +^bb0(%a : f32, %b : f32): + // All operations have locations. Must have locations of Add0, Add1, Add2, Add3. + %c = arith.addf %a, %b : f32 loc(#loc0) + %d = arith.addf %a, %b : f32 loc(#loc1) + %e = arith.addf %a, %b : f32 loc(#loc2) + %f = arith.addf %a, %b : f32 loc(#loc3) + // CHECK-NEXT: %[[VAR_0:[0-9a-zA-Z_]+]] = arith.addf %{{.*}}, %{{.*}} : f32 loc([[LOC_ABCD:.*]]) + + // First operation has unknown location. Must have locations of Add0, Add1, Add2. + %g = arith.addf %c, %d : f32 loc(#loc) + %h = arith.addf %e, %f : f32 loc(#loc0) + %i = arith.addf %c, %e : f32 loc(#fused_loc0) + // CHECK-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = arith.addf %[[VAR_0]], %[[VAR_0]] : f32 loc([[LOC_ABC:.*]]) + + // Last operation has unknown location. Must have locations of Add2, Add3. + %j = arith.addf %g, %h : f32 loc(#fused_loc1) + %k = arith.addf %h, %i : f32 loc(#loc) + // CHECK-NEXT: %[[VAR_2:[0-9a-zA-Z_]+]] = arith.addf %[[VAR_1]], %[[VAR_1]] : f32 loc([[LOC_CD:.*]]) + + // Two operations with fused locations. Must have locations of Add1, Add2, Add3. + %l = arith.addf %j, %k : f32 loc(#fused_loc0) + %m = arith.addf %j, %k : f32 loc(#fused_loc1) + // CHECK-NEXT: %[[VAR_3:[0-9a-zA-Z_]+]] = arith.addf %[[VAR_2]], %[[VAR_2]] : f32 loc([[LOC_BCD:.*]]) + + // CHECK-NEXT: return %[[VAR_3]], %[[VAR_3]] : f32, f32 + return %l, %m : f32, f32 +} +#loc = loc(unknown) +#loc0 = loc("Add0") +#loc1 = loc("Add1") +#loc2 = loc("Add2") +#loc3 = loc("Add3") + +#fused_loc0 = loc(fused[#loc1, #loc2]) +#fused_loc1 = loc(fused[#loc2, #loc3]) + +// CHECK-DAG: #[[LOC_A:.*]] = loc("Add0") +// CHECK-DAG: #[[LOC_B:.*]] = loc("Add1") +// CHECK-DAG: #[[LOC_C:.*]] = loc("Add2") +// CHECK-DAG: #[[LOC_D:.*]] = loc("Add3") +// CHECK-DAG: [[LOC_ABCD]] = loc(fused[#[[LOC_A]], #[[LOC_B]], #[[LOC_C]], #[[LOC_D]]]) +// CHECK-DAG: [[LOC_ABC]] = loc(fused[#[[LOC_A]], #[[LOC_B]], #[[LOC_C]]]) +// CHECK-DAG: [[LOC_BCD]] = loc(fused[#[[LOC_B]], #[[LOC_C]], #[[LOC_D]]]) +// CHECK-DAG: [[LOC_CD]] = loc(fused[#[[LOC_C]], #[[LOC_D]]]) \ No newline at end of file From 99f8f981abc42ed837bef8b5ceee52a9ae0907e7 Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Tue, 12 Nov 2024 09:31:14 +0000 Subject: [PATCH 40/46] Make EliminateLibm work on EmitC::FuncOp --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 6 +++++ .../EmitC/Transforms/EliminateLibm.cpp | 10 +++---- mlir/test/Dialect/EmitC/eliminate_libm.mlir | 26 +++++++++---------- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index ab71bb680e1509..0015ff0cd88107 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -752,6 +752,12 @@ def EmitC_FuncOp : EmitC_Op<"func", [ /// Returns the result types of this function. ArrayRef getResultTypes() { return getFunctionType().getResults(); } + + //===------------------------------------------------------------------===// + // SymbolOpInterface Methods + //===------------------------------------------------------------------===// + + bool isDeclaration() { return isExternal(); } }]; let hasCustomAssemblyFormat = 1; let hasVerifier = 1; diff --git a/mlir/lib/Dialect/EmitC/Transforms/EliminateLibm.cpp b/mlir/lib/Dialect/EmitC/Transforms/EliminateLibm.cpp index 1484867c6e4c5b..d2d87cdc9aa1ab 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/EliminateLibm.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/EliminateLibm.cpp @@ -37,9 +37,9 @@ namespace { /// Replace all Libm calls (where callee has `libm` attribute + no definition) /// by opaque calls -struct OpacifyLibmCall : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(func::CallOp callOp, +struct OpacifyLibmCall : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(emitc::CallOp callOp, PatternRewriter &rewriter) const override { auto *st = SymbolTable::getNearestSymbolTable(callOp); @@ -65,8 +65,8 @@ struct EliminateLibmPass : public impl::EliminateLibmBase { MLIRContext *context = module->getContext(); // Find the first math.h inclusion - SmallVector libmPrototypes; - module.walk([&libmPrototypes](func::FuncOp funcOp) { + SmallVector libmPrototypes; + module.walk([&libmPrototypes](emitc::FuncOp funcOp) { if (funcOp->hasAttr("libm") && funcOp.isDeclaration()) libmPrototypes.push_back(funcOp); }); diff --git a/mlir/test/Dialect/EmitC/eliminate_libm.mlir b/mlir/test/Dialect/EmitC/eliminate_libm.mlir index 681df7b4308acc..d307ba6ddd8350 100644 --- a/mlir/test/Dialect/EmitC/eliminate_libm.mlir +++ b/mlir/test/Dialect/EmitC/eliminate_libm.mlir @@ -1,29 +1,29 @@ // RUN: mlir-opt %s --eliminate-libm --verify-diagnostics --split-input-file | FileCheck %s // CHECK: emitc.include <"cmath"> -// CHECK-NOT: func.func private @expm1 -// CHECK-DAG: func.func @call_expm1(%[[IN:.*]]: f64) -> f64 +// CHECK-NOT: emitc.func private @expm1 +// CHECK-DAG: emitc.func @call_expm1(%[[IN:.*]]: f64) -> f64 // CHECK-DAG: %[[RESULT:.*]] = emitc.call_opaque "expm1"(%[[IN]]) : (f64) -> f64 // CHECK-DAG: return %[[RESULT]] module { - func.func private @expm1(f64) -> f64 attributes {libm, llvm.readnone} - func.func @call_expm1(%in : f64) -> f64 { - %e1 = func.call @expm1(%in) : (f64) -> f64 - return %e1 : f64 + emitc.func private @expm1(f64) -> f64 attributes {libm, llvm.readnone, specifiers = ["extern"]} + emitc.func @call_expm1(%in : f64) -> f64 { + %e1 = emitc.call @expm1(%in) : (f64) -> f64 + emitc.return %e1 : f64 } } // ----- // CHECK-NOT: emitc.include <"cmath"> -// CHECK: func.func private @expm1 -// CHECK: func.func @call_expm1(%[[IN:.*]]: f64) -> f64 -// CHECK-NEXT: %[[RESULT:.*]] = call @expm1(%[[IN]]) : (f64) -> f64 +// CHECK: emitc.func private @expm1 +// CHECK: emitc.func @call_expm1(%[[IN:.*]]: f64) -> f64 +// CHECK-NEXT: %[[RESULT:.*]] = emitc.call @expm1(%[[IN]]) : (f64) -> f64 // CHECK-NEXT: return %[[RESULT]] module { - func.func private @expm1(f64) -> f64 attributes {llvm.readnone} - func.func @call_expm1(%in : f64) -> f64 { - %e1 = func.call @expm1(%in) : (f64) -> f64 - return %e1 : f64 + emitc.func private @expm1(f64) -> f64 attributes {llvm.readnone} + emitc.func @call_expm1(%in : f64) -> f64 { + %e1 = emitc.call @expm1(%in) : (f64) -> f64 + emitc.return %e1 : f64 } } From 1fc2b9872010a70fb8d64de5cf4e326324eb9c40 Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Tue, 12 Nov 2024 10:17:21 +0000 Subject: [PATCH 41/46] Fix test --- mlir/test/Dialect/EmitC/func.mlir | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/test/Dialect/EmitC/func.mlir b/mlir/test/Dialect/EmitC/func.mlir index c7486bc493c315..06a9202e76f59b 100644 --- a/mlir/test/Dialect/EmitC/func.mlir +++ b/mlir/test/Dialect/EmitC/func.mlir @@ -16,12 +16,12 @@ emitc.func @f(%x: i32 ref) { // ----- -// CHECK: emitc.func @f +// CHECK: emitc.func private @f // CHECK-SAME: i32 ref -emitc.func @f(i32 ref) +emitc.func private @f(i32 ref) // ----- -// CHECK: emitc.func @f +// CHECK: emitc.func private @f // CHECK-SAME: i32 ref -emitc.func @f(i32 {emitc.reference}) +emitc.func private @f(i32 {emitc.reference}) From a4a93fba250a8735540c22232ef949546cbaa1ce Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Thu, 14 Nov 2024 12:02:31 +0100 Subject: [PATCH 42/46] emitc: Do not add newlines after ModuleOp, TranslationUnitOp --- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 3 +++ mlir/test/Target/Cpp/tu.mlir | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 60ad20bf7a0926..cc32828a0c544c 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -1638,6 +1638,9 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { if (hasDeferredEmission(&op)) return success(); + if (isa(op)) + return success(); // skip adding newlines + if (getEmittedExpression() || (isa(op) && shouldBeInlined(cast(op)))) diff --git a/mlir/test/Target/Cpp/tu.mlir b/mlir/test/Target/Cpp/tu.mlir index ca10e0263a64fc..7d15f5a502fe6d 100644 --- a/mlir/test/Target/Cpp/tu.mlir +++ b/mlir/test/Target/Cpp/tu.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s --check-prefix NO-FILTER -// RUN: mlir-translate -mlir-to-cpp -translation-unit-id=non-existing %s | FileCheck %s --check-prefix NON-EXISTING +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s --check-prefix NO-FILTER --allow-empty +// RUN: mlir-translate -mlir-to-cpp -translation-unit-id=non-existing %s | FileCheck %s --check-prefix NON-EXISTING --allow-empty // RUN: mlir-translate -mlir-to-cpp -translation-unit-id=tu_one %s | FileCheck %s --check-prefix TU-ONE // RUN: mlir-translate -mlir-to-cpp -translation-unit-id=tu_two %s | FileCheck %s --check-prefix TU-TWO From 7326995c58238afcd15e1d6697f340f82e3f7eda Mon Sep 17 00:00:00 2001 From: josel-amd <166385423+josel-amd@users.noreply.github.com> Date: Fri, 22 Nov 2024 09:53:53 +0100 Subject: [PATCH 43/46] Fix yield conversion of scf.if/scf.for to emitc (#401) * Fix conversion for scf.for and scf.if --- mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp | 63 +++++++++++++------ mlir/test/Conversion/SCFToEmitC/for.mlir | 58 ++++++++++++++--- mlir/test/Conversion/SCFToEmitC/if.mlir | 27 ++++++++ 3 files changed, 123 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index 51490c79ce4904..41c69eed208608 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -79,25 +79,31 @@ createVariablesForResults(T op, const TypeConverter *typeConverter, // Create a series of assign ops assigning given values to given variables at // the current insertion point of given rewriter. -static void assignValues(ValueRange values, SmallVector &variables, +static void assignValues(ValueRange values, ValueRange variables, ConversionPatternRewriter &rewriter, Location loc) { for (auto [value, var] : llvm::zip(values, variables)) rewriter.create(loc, var, value); } -static void lowerYield(SmallVector &resultVariables, - ConversionPatternRewriter &rewriter, - scf::YieldOp yield) { +static LogicalResult lowerYield(Operation *op, ValueRange resultVariables, + ConversionPatternRewriter &rewriter, + scf::YieldOp yield) { Location loc = yield.getLoc(); - ValueRange operands = yield.getOperands(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(yield); - assignValues(operands, resultVariables, rewriter, loc); + SmallVector yieldOperands; + if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) { + return rewriter.notifyMatchFailure(op, "failed to lower yield operands"); + } + + assignValues(yieldOperands, resultVariables, rewriter, loc); rewriter.create(loc); rewriter.eraseOp(yield); + + return success(); } LogicalResult @@ -118,22 +124,32 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor, emitc::ForOp loweredFor = rewriter.create( loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep()); - // Propagate any attributes from the ODS forOp to the lowered emitc::for op. - loweredFor->setAttrs(forOp->getAttrs()); - Block *loweredBody = loweredFor.getBody(); // Erase the auto-generated terminator for the lowered for op. rewriter.eraseOp(loweredBody->getTerminator()); + // Convert the original region types into the new types by adding unrealized + // casts in the beginning of the loop. This performs the conversion in place. + if (failed(rewriter.convertRegionTypes(&forOp.getRegion(), + *getTypeConverter(), nullptr))) { + return rewriter.notifyMatchFailure(forOp, "region types conversion failed"); + } + + // Register the replacements for the block arguments and inline the body of + // the scf.for loop into the body of the emitc::for loop. + Block *scfBody = &(forOp.getRegion().front()); SmallVector replacingValues; replacingValues.push_back(loweredFor.getInductionVar()); replacingValues.append(resultVariables.begin(), resultVariables.end()); + rewriter.mergeBlocks(scfBody, loweredBody, replacingValues); + + auto result = lowerYield(forOp, resultVariables, rewriter, + cast(loweredBody->getTerminator())); - Block *adaptorBody = &(adaptor.getRegion().front()); - rewriter.mergeBlocks(adaptorBody, loweredBody, replacingValues); - lowerYield(resultVariables, rewriter, - cast(loweredBody->getTerminator())); + if (failed(result)) { + return result; + } rewriter.replaceOp(forOp, resultVariables); return success(); @@ -169,11 +185,16 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, // emitc::if regions, but the scf::yield is replaced not only with an // emitc::yield, but also with a sequence of emitc::assign ops that set the // yielded values into the result variables. - auto lowerRegion = [&resultVariables, &rewriter](Region ®ion, - Region &loweredRegion) { + auto lowerRegion = [&resultVariables, &rewriter, + &ifOp](Region ®ion, Region &loweredRegion) { rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end()); Operation *terminator = loweredRegion.back().getTerminator(); - lowerYield(resultVariables, rewriter, cast(terminator)); + auto result = lowerYield(ifOp, resultVariables, rewriter, + cast(terminator)); + if (failed(result)) { + return result; + } + return success(); }; Region &thenRegion = adaptor.getThenRegion(); @@ -185,11 +206,17 @@ IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor, rewriter.create(loc, adaptor.getCondition(), false, false); Region &loweredThenRegion = loweredIf.getThenRegion(); - lowerRegion(thenRegion, loweredThenRegion); + auto result = lowerRegion(thenRegion, loweredThenRegion); + if (failed(result)) { + return result; + } if (hasElseBlock) { Region &loweredElseRegion = loweredIf.getElseRegion(); - lowerRegion(elseRegion, loweredElseRegion); + auto result = lowerRegion(elseRegion, loweredElseRegion); + if (failed(result)) { + return result; + } } rewriter.replaceOp(ifOp, resultVariables); diff --git a/mlir/test/Conversion/SCFToEmitC/for.mlir b/mlir/test/Conversion/SCFToEmitC/for.mlir index b422aaa4545d9b..79a53ec8fd4c08 100644 --- a/mlir/test/Conversion/SCFToEmitC/for.mlir +++ b/mlir/test/Conversion/SCFToEmitC/for.mlir @@ -99,11 +99,55 @@ func.func @nested_for_yield(%arg0 : index, %arg1 : index, %arg2 : index) -> f32 // CHECK-NEXT: return %[[VAL_4]] : f32 // CHECK-NEXT: } -func.func @loop_with_attr(%arg0 : index, %arg1 : index, %arg2 : index) { - scf.for %i0 = %arg0 to %arg1 step %arg2 { - %c1 = arith.constant 1 : index - } {test.value = 5 : index} - return +func.func @for_yield_index(%arg0 : index, %arg1 : index, %arg2 : index) -> index { + %zero = arith.constant 0 : index + %r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index { + scf.yield %acc : index + } + return %r : index } -// CHECK-LABEL: func.func @loop_with_attr -// CHECK: {test.value = 5 : index} + +// CHECK-LABEL: func.func @for_yield_index( +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index { +// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t +// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t +// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t +// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t +// CHECK: emitc.for %[[VAL_5:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] { +// CHECK: emitc.assign %[[VAL_4]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t +// CHECK: } +// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index +// CHECK: return %[[VAL_8]] : index +// CHECK: } + + +func.func @for_yield_update_loop_carried_var(%arg0 : index, %arg1 : index, %arg2 : index) -> index { + %zero = arith.constant 0 : index + %r = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%acc = %zero) -> index { + %sn = arith.addi %acc, %acc : index + scf.yield %sn: index + } + return %r : index + } + +// CHECK-LABEL: func.func @for_yield_update_loop_carried_var( +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index, %[[ARG_2:.*]]: index) -> index { +// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[ARG_2]] : index to !emitc.size_t +// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[ARG_1]] : index to !emitc.size_t +// CHECK: %[[VAL_2:.*]] = builtin.unrealized_conversion_cast %[[ARG_0]] : index to !emitc.size_t +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t +// CHECK: %[[VAL_4:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t +// CHECK: emitc.assign %[[VAL_3]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t +// CHECK: emitc.for %[[ARG_3:.*]] = %[[VAL_2]] to %[[VAL_1]] step %[[VAL_0]] { +// CHECK: %[[VAL_5:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index +// CHECK: %[[VAL_6:.*]] = arith.addi %[[VAL_5]], %[[VAL_5]] : index +// CHECK: %[[VAL_8:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : index to !emitc.size_t +// CHECK: emitc.assign %[[VAL_8]] : !emitc.size_t to %[[VAL_4]] : !emitc.size_t +// CHECK: } +// CHECK: %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_4]] : !emitc.size_t to index +// CHECK: return %[[VAL_9]] : index +// CHECK: } diff --git a/mlir/test/Conversion/SCFToEmitC/if.mlir b/mlir/test/Conversion/SCFToEmitC/if.mlir index afc9abc761eb4c..eba1dda213e706 100644 --- a/mlir/test/Conversion/SCFToEmitC/if.mlir +++ b/mlir/test/Conversion/SCFToEmitC/if.mlir @@ -68,3 +68,30 @@ func.func @test_if_yield(%arg0: i1, %arg1: f32) { // CHECK-NEXT: } // CHECK-NEXT: return // CHECK-NEXT: } + + +func.func @test_if_yield_index(%arg0: i1, %arg1: f32) { + %0 = arith.constant 0 : index + %1 = arith.constant 1 : index + %x = scf.if %arg0 -> (index) { + scf.yield %0 : index + } else { + scf.yield %1 : index + } + return +} + +// CHECK: func.func @test_if_yield_index( +// CHECK-SAME: %[[ARG_0:.*]]: i1, %[[ARG_1:.*]]: f32) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to !emitc.size_t +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[C1]] : index to !emitc.size_t +// CHECK: %[[VAL_2:.*]] = "emitc.variable"() <{value = #emitc.opaque<"">}> : () -> !emitc.size_t +// CHECK: emitc.if %[[ARG_0]] { +// CHECK: emitc.assign %[[VAL_0]] : !emitc.size_t to %[[VAL_2]] : !emitc.size_t +// CHECK: } else { +// CHECK: emitc.assign %[[VAL_1]] : !emitc.size_t to %[[VAL_2]] : !emitc.size_t +// CHECK: } +// CHECK: return +// CHECK: } From 20a6720fc06275f17688ea90407409d28a0357d0 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 25 Nov 2024 17:43:39 +0100 Subject: [PATCH 44/46] Add -mlir-reproducer-before-all (#402) * Add -mlir-reproducer-before-all --- mlir/include/mlir-c/Pass.h | 5 ++ mlir/include/mlir/Pass/PassManager.h | 4 + mlir/lib/Bindings/Python/Pass.cpp | 8 ++ mlir/lib/CAPI/IR/Pass.cpp | 5 ++ mlir/lib/Pass/IRPrinting.cpp | 77 +++++++++++++++++++ mlir/lib/Pass/PassManagerOptions.cpp | 7 ++ .../mlir/_mlir_libs/_mlir/passmanager.pyi | 1 + mlir/test/Pass/reproducer-before-all.mlir | 26 +++++++ 8 files changed, 133 insertions(+) create mode 100644 mlir/test/Pass/reproducer-before-all.mlir diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 35db138305d1e2..0ab2c29bf3f777 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -78,6 +78,11 @@ mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op); MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(MlirPassManager passManager); +/// Enable lir-reproducer-before-all. +MLIR_CAPI_EXPORTED void +mlirPassManagerEnableReproducerBeforeAll(MlirPassManager passManager, + MlirStringRef outputDir); + /// Enable / disable verify-each. MLIR_CAPI_EXPORTED void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable); diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h index d9bab431e2e0cc..e2d78823c835ad 100644 --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -423,6 +423,10 @@ class PassManager : public OpPassManager { llvm::StringRef printTreeDir = ".pass_manager_output", OpPrintingFlags opPrintingFlags = OpPrintingFlags()); + /// Dump a reproducer before each pass into a file in the given output + /// directory. + void enableReproducerBeforeAll(llvm::StringRef outputDir); + //===--------------------------------------------------------------------===// // Pass Timing diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index a68421b61641f6..e19eb450634ac1 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -78,6 +78,14 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { mlirPassManagerEnableIRPrinting(passManager.get()); }, "Enable mlir-print-ir-after-all.") + .def( + "enable_reproducer_before_all", + [](PyPassManager &passManager, const std::string &outputDir) { + mlirPassManagerEnableReproducerBeforeAll( + passManager.get(), + mlirStringRefCreate(outputDir.data(), outputDir.size())); + }, + "Enable mlir-reproducer-before-all.") .def( "enable_verifier", [](PyPassManager &passManager, bool enable) { diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index d242baae99c086..0ae054c7c639f9 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -48,6 +48,11 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) { return unwrap(passManager)->enableIRPrinting(); } +void mlirPassManagerEnableReproducerBeforeAll(MlirPassManager passManager, + MlirStringRef outputDir) { + return unwrap(passManager)->enableReproducerBeforeAll(unwrap(outputDir)); +} + void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { unwrap(passManager)->enableVerifier(enable); } diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp index 9ffda6402cc07a..75f3e35b890200 100644 --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "PassDetail.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/FileUtilities.h" @@ -345,6 +346,74 @@ struct FileTreeIRPrinterConfig : public PassManager::IRPrinterConfig { llvm::DenseMap counters; }; +/// Print a pass pipeline like `builtin.module(func.func(cse))` +/// from a list of scopes and the pass. +template +void printAsPassPipeline(RangeT scopes, Pass *pass, raw_ostream &os) { + // Add pass scopes like 'builtin.module(emitc.tu(' + for (OperationName scope : scopes) + os << scope << "("; + pass->printAsTextualPipeline(os); + for (OperationName _ : scopes) + os << ")"; +} + +/// A pass instrumentation to dump the IR before each pass into +/// numbered files. +/// It includes a mlir_reproducer info to rerun the pass. +class ReproducerBeforeAll : public PassInstrumentation { +public: + ReproducerBeforeAll(mlir::StringRef outputDir) : outputDir(outputDir) {} + void runBeforePass(Pass *pass, Operation *op) override; + + std::string outputDir; + + uint32_t counter = 0; +}; + +void ReproducerBeforeAll::runBeforePass(Pass *pass, Operation *op) { + // Skip adator passes (which adopt FuncOp passes to ModuleOp pass managers). + if (isa(pass)) + return; + + llvm::SmallString<128> path(outputDir); + if (failed(createDirectoryOrPrintErr(path))) + return; + + // Open output file. + std::string fileName = + llvm::formatv("{0,0+2}_{1}.mlir", counter++, pass->getArgument()); + llvm::sys::path::append(path, fileName); + + std::string error; + std::unique_ptr file = openOutputFile(path, &error); + if (!file) { + llvm::errs() << "Error opening output file " << path << ": " << error + << "\n"; + return; + } + + SmallVector scopes; + scopes.push_back(op->getName()); + while (Operation *parentOp = op->getParentOp()) { + scopes.push_back(parentOp->getName()); + op = parentOp; + } + + std::string pipelineStr; + llvm::raw_string_ostream passOS(pipelineStr); + printAsPassPipeline(llvm::reverse(scopes), pass, passOS); + + AsmState state(op); + state.attachResourcePrinter("mlir_reproducer", + [&](Operation *op, AsmResourceBuilder &builder) { + builder.buildString("pipeline", pipelineStr); + builder.buildBool("disable_threading", true); + builder.buildBool("verify_each", true); + }); + op->print(file->os(), state); + file->keep(); +} } // namespace /// Add an instrumentation to print the IR before and after pass execution, @@ -383,3 +452,11 @@ void PassManager::enableIRPrintingToFileTree( printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure, opPrintingFlags, printTreeDir)); } + +/// Add an instrumentation to print the IR before and after pass execution. +void PassManager::enableReproducerBeforeAll(StringRef outputDir) { + if (getContext()->isMultithreadingEnabled()) + llvm::report_fatal_error("IR printing can't be setup on a pass-manager " + "without disabling multi-threading first."); + addInstrumentation(std::make_unique(outputDir)); +} diff --git a/mlir/lib/Pass/PassManagerOptions.cpp b/mlir/lib/Pass/PassManagerOptions.cpp index dd119a75f40696..3bdfb16336edb3 100644 --- a/mlir/lib/Pass/PassManagerOptions.cpp +++ b/mlir/lib/Pass/PassManagerOptions.cpp @@ -64,6 +64,10 @@ struct PassManagerOptions { "tree rooted at this directory. Use in conjunction with " "mlir-print-ir-* flags")}; + llvm::cl::opt reproducerBeforeAllDir{ + "mlir-reproducer-before-all", + llvm::cl::desc("Save a reproducer before each pass to this directory")}; + /// Add an IR printing instrumentation if enabled by any 'print-ir' flags. void addPrinterInstrumentation(PassManager &pm); @@ -151,6 +155,9 @@ LogicalResult mlir::applyPassManagerCLOptions(PassManager &pm) { pm.enableCrashReproducerGeneration(options->reproducerFile, options->localReproducer); + if (!options->reproducerBeforeAllDir.empty()) + pm.enableReproducerBeforeAll(options->reproducerBeforeAllDir); + // Enable statistics dumping. if (options->passStatistics) pm.enableStatistics(options->passStatisticsDisplayMode); diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi index c072d5e0fb86f3..b91296d44a41ce 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi @@ -17,6 +17,7 @@ class PassManager: def _CAPICreate(self) -> object: ... def _testing_release(self) -> None: ... def enable_ir_printing(self) -> None: ... + def enable_reproducer_before_all(self, output_dir: str) -> None: ... def enable_verifier(self, enable: bool) -> None: ... @staticmethod def parse(pipeline: str, context: Optional[_ir.Context] = None) -> PassManager: ... diff --git a/mlir/test/Pass/reproducer-before-all.mlir b/mlir/test/Pass/reproducer-before-all.mlir new file mode 100644 index 00000000000000..bec7e257e59e67 --- /dev/null +++ b/mlir/test/Pass/reproducer-before-all.mlir @@ -0,0 +1,26 @@ +// RUN: rm -rf %t || true +// RUN: mlir-opt %s -mlir-disable-threading -mlir-reproducer-before-all=%t \ +// RUN: -pass-pipeline='builtin.module(canonicalize,cse,func.func(canonicalize))' +// RUN: FileCheck %s -input-file=%t/00_canonicalize.mlir --check-prefixes CHECK0 +// RUN: FileCheck %s -input-file=%t/01_cse.mlir --check-prefixes CHECK1 +// RUN: FileCheck %s -input-file=%t/02_canonicalize.mlir --check-prefixes CHECK2 + +builtin.module @outer { + func.func @symA() { + return + } +} + +// CHECK0: module @outer { +// CHECK0: {-# +// CHECK0-NEXT: external_resources: { +// CHECK0-NEXT: mlir_reproducer: { +// CHECK0-NEXT: pipeline: "builtin.module(canonicalize +// CHECK0-NEXT: disable_threading: true, +// CHECK0-NEXT: verify_each: true +// CHECK0-NEXT: } +// CHECK0-NEXT: } +// CHECK0-NEXT: #-} + +// CHECK1: pipeline: "builtin.module(cse +// CHECK2: pipeline: "builtin.module(func.func(canonicalize From e25d20732231ef267634a5d0dc2a652a6563f9ec Mon Sep 17 00:00:00 2001 From: lmendesp-amd Date: Wed, 27 Nov 2024 16:37:22 +0100 Subject: [PATCH 45/46] Include comments with template argument names in Cpp code from EmitC (#403) * Include comments with template arg names in Cpp code from EmitC * Apply suggestions from code review Co-authored-by: Corentin Ferry Co-authored-by: Matthias Gehre * Test for the presence of template arg names when there are no template args --------- Co-authored-by: Corentin Ferry Co-authored-by: Matthias Gehre --- mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 3 ++- mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 13 ++++++++++ mlir/lib/Target/Cpp/TranslateToCpp.cpp | 26 +++++++++++++++++--- mlir/test/Dialect/EmitC/invalid_ops.mlir | 24 ++++++++++++++++++ mlir/test/Dialect/EmitC/ops.mlir | 7 ++++++ mlir/test/Target/Cpp/template_arg_names.mlir | 14 +++++++++++ 6 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 mlir/test/Target/Cpp/template_arg_names.mlir diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 0015ff0cd88107..681c8709e574b0 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -292,6 +292,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> { Arg:$callee, Arg, "the order of operands and further attributes">:$args, Arg, "template arguments">:$template_args, + Arg, "template argument names">:$template_arg_names, Variadic:$operands ); let results = (outs Variadic); @@ -302,7 +303,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> { "::mlir::ValueRange":$operands, CArg<"::mlir::ArrayAttr", "{}">:$args, CArg<"::mlir::ArrayAttr", "{}">:$template_args), [{ - build($_builder, $_state, resultTypes, callee, args, template_args, + build($_builder, $_state, resultTypes, callee, args, template_args, {}, operands); }] > diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index c2fb835c2ebc34..7c8267a234368e 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -355,6 +355,19 @@ LogicalResult emitc::CallOpaqueOp::verify() { } } + if (std::optional templateArgNames = getTemplateArgNames()) { + if (std::optional templateArgsAttr = getTemplateArgs()) { + if ((*templateArgNames).size() && + (*templateArgNames).size() != (*templateArgsAttr).size()) { + return emitOpError("number of template argument names must be equal to " + "number of template arguments"); + } + } else { + return emitOpError("should not have names for template arguments if it " + "does not have template arguments"); + } + } + if (llvm::any_of(getResultTypes(), llvm::IsaPred)) { return emitOpError() << "cannot return array type"; } diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index cc32828a0c544c..ee94967ab56a87 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -659,11 +659,31 @@ static LogicalResult printOperation(CppEmitter &emitter, return success(); }; + auto emitNamedArgs = + [&](std::tuple tuple) + -> LogicalResult { + Attribute attr = std::get<0>(tuple); + StringAttr argName = cast(std::get<1>(tuple)); + + os << "/*" << argName.str() << "=*/"; + return emitArgs(attr); + }; + if (callOpaqueOp.getTemplateArgs()) { os << "<"; - if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os, - emitArgs))) - return failure(); + if (callOpaqueOp.getTemplateArgNames() && + !callOpaqueOp.getTemplateArgNames()->empty()) { + if (failed(interleaveCommaWithError( + llvm::zip_equal(*callOpaqueOp.getTemplateArgs(), + *callOpaqueOp.getTemplateArgNames()), + os, emitNamedArgs))) { + return failure(); + } + } else { + if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os, + emitArgs))) + return failure(); + } os << ">"; } diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index 3b4c6046a08c5a..7f0e89b57b01af 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -524,3 +524,27 @@ func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { emitc.verbatim "{a} " args %arg0, %arg1 : !emitc.ptr, i32 return } + +// ----- + +func.func @template_args_with_names(%arg0: i32) { + // expected-error @+1 {{'emitc.call_opaque' op number of template argument names must be equal to number of template arguments}} + emitc.call_opaque "kernel1"(%arg0) {template_arg_names = ["N", "P"], template_args = [42 : i32]} : (i32) -> () + return +} + +// ----- + +func.func @template_args_with_names(%arg0: i32) { + // expected-error @+1 {{'emitc.call_opaque' op number of template argument names must be equal to number of template arguments}} + emitc.call_opaque "kernel1"(%arg0) {template_arg_names = ["N"], template_args = [42 : i32, 56 : i32]} : (i32) -> () + return +} + +// ----- + +func.func @template_args_with_names(%arg0: i32) { + // expected-error @+1 {{'emitc.call_opaque' op should not have names for template arguments if it does not have template arguments}} + emitc.call_opaque "kernel1"(%arg0) {template_arg_names = ["N"]} : (i32) -> () + return +} diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 4e86642c2a3a95..8fe0a828c84dbf 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -282,3 +282,10 @@ func.func @member_access(%arg0: !emitc.opaque<"mystruct">, %arg1: !emitc.opaque< %2 = "emitc.member_of_ptr" (%arg2) {member = "a"} : (!emitc.ptr>) -> i32 return } + +func.func @template_args_with_names(%arg0: i32, %arg1: f32) { + emitc.call_opaque "kernel1"(%arg0, %arg1) {template_arg_names = ["N", "P"], template_args = [42 : i32, 56]} : (i32, f32) -> () + emitc.call_opaque "kernel2"(%arg0, %arg1) {template_arg_names = ["N"], template_args = [42 : i32]} : (i32, f32) -> () + emitc.call_opaque "kernel3"(%arg0, %arg1) {template_arg_names = [], template_args = [#emitc.opaque<"42">]} : (i32, f32) -> () + return +} diff --git a/mlir/test/Target/Cpp/template_arg_names.mlir b/mlir/test/Target/Cpp/template_arg_names.mlir new file mode 100644 index 00000000000000..f4e504b0594746 --- /dev/null +++ b/mlir/test/Target/Cpp/template_arg_names.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT + +// CPP-DEFAULT-LABEL: void basic +func.func @basic(%arg0: i32, %arg1: f32) { + emitc.call_opaque "kernel3"(%arg0, %arg1) : (i32, f32) -> () +// CPP-DEFAULT: kernel3( + emitc.call_opaque "kernel4"(%arg0, %arg1) {template_arg_names = ["N", "P"], template_args = [42 : i32, 56]} : (i32, f32) -> () +// CPP-DEFAULT: kernel4( + emitc.call_opaque "kernel4"(%arg0, %arg1) {template_arg_names = ["N"], template_args = [#emitc.opaque<"42">]} : (i32, f32) -> () +// CPP-DEFAULT: kernel4( + return +} + + From f3f49190caedec8ebd3d021ffc60d39cf5b5bddf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20Miguel=20Sousa?= Date: Thu, 28 Nov 2024 15:58:05 +0100 Subject: [PATCH 46/46] Readability: Add option to emit constants values in place * Readability: Add option to emit constants values in place, instead of producing dedicated variables. --- mlir/include/mlir/Target/Cpp/CppEmitter.h | 3 +- mlir/lib/Target/Cpp/TranslateRegistration.cpp | 8 ++- mlir/lib/Target/Cpp/TranslateToCpp.cpp | 51 ++++++++++++++++--- .../Cpp/emitc-constants-as-variables.mlir | 19 +++++++ 4 files changed, 71 insertions(+), 10 deletions(-) create mode 100644 mlir/test/Target/Cpp/emitc-constants-as-variables.mlir diff --git a/mlir/include/mlir/Target/Cpp/CppEmitter.h b/mlir/include/mlir/Target/Cpp/CppEmitter.h index d76cfc9107332e..1c7ba78eba0c93 100644 --- a/mlir/include/mlir/Target/Cpp/CppEmitter.h +++ b/mlir/include/mlir/Target/Cpp/CppEmitter.h @@ -26,7 +26,8 @@ namespace emitc { /// arguments are declared at the beginning of the function. LogicalResult translateToCpp(Operation *op, raw_ostream &os, bool declareVariablesAtTop = false, - StringRef onlyTu = ""); + StringRef onlyTu = "", + bool constantsAsVariables = true); } // namespace emitc } // namespace mlir diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp index 7e2bc9ad012b38..dfe8bcb106f124 100644 --- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp +++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp @@ -34,13 +34,19 @@ void registerToCppTranslation() { llvm::cl::desc("Only emit the translation unit with the matching id"), llvm::cl::init("")); + static llvm::cl::opt constantsAsVariables( + "constants-as-variables", + llvm::cl::desc("Use variables to hold the constant values"), + llvm::cl::init(true)); + TranslateFromMLIRRegistration reg( "mlir-to-cpp", "translate from mlir to cpp", [](Operation *op, raw_ostream &output) { return emitc::translateToCpp( op, output, /*declareVariablesAtTop=*/declareVariablesAtTop, - /*onlyTu=*/onlyTu); + /*onlyTu=*/onlyTu, + /*constantsAsVariables=*/constantsAsVariables); }, [](DialectRegistry ®istry) { // clang-format off diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index ee94967ab56a87..89bd5b434e0d05 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -116,7 +116,7 @@ namespace { /// Emitter that uses dialect specific emitters to emit C++ code. struct CppEmitter { explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop, - StringRef onlyTu); + StringRef onlyTu, bool constantsAsVariables); /// Emits attribute or returns failure. LogicalResult emitAttribute(Location loc, Attribute attr); @@ -235,6 +235,10 @@ struct CppEmitter { /// Returns whether this translation unit should be emitted bool shouldEmitTu(TranslationUnitOp tu) { return tu.getId() == onlyTu; } + /// Returns whether the value of ConstantOps should be stored in variables + /// or emmited directly in their usage locations. + bool shouldUseConstantsAsVariables() { return constantsAsVariables; } + /// Get expression currently being emitted. ExpressionOp getEmittedExpression() { return emittedExpression; } @@ -265,6 +269,9 @@ struct CppEmitter { /// Only emit translation units whos id matches this value. std::string onlyTu; + /// Use variables to hold the constant values + bool constantsAsVariables; + /// Map from value to name of C++ variable that contain the name. ValueMapper valueMapper; @@ -365,6 +372,10 @@ static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, static LogicalResult printOperation(CppEmitter &emitter, emitc::ConstantOp constantOp) { + if (!emitter.shouldUseConstantsAsVariables()) { + return success(); + } + Operation *operation = constantOp.getOperation(); Attribute value = constantOp.getValue(); @@ -1218,9 +1229,9 @@ static LogicalResult printOperation(CppEmitter &emitter, } CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop, - StringRef onlyTu) + StringRef onlyTu, bool constantsAsVariables) : os(os), declareVariablesAtTop(declareVariablesAtTop), - onlyTu(onlyTu.str()) { + onlyTu(onlyTu.str()), constantsAsVariables(constantsAsVariables) { valueInScopeCount.push(0); labelInScopeCount.push(0); } @@ -1425,8 +1436,25 @@ LogicalResult CppEmitter::emitExpression(ExpressionOp expressionOp) { } LogicalResult CppEmitter::emitOperand(Value value) { + Operation *def = value.getDefiningOp(); + if (!shouldUseConstantsAsVariables()) { + if (auto constant = dyn_cast_if_present(def)) { + os << "(("; + + if (failed(emitType(constant.getLoc(), constant.getType()))) { + return failure(); + } + os << ") "; + + if (failed(emitAttribute(constant.getLoc(), constant.getValue()))) { + return failure(); + } + os << ")"; + return success(); + } + } + if (isPartOfCurrentExpression(value)) { - Operation *def = value.getDefiningOp(); assert(def && "Expected operand to be defined by an operation"); FailureOr precedence = getOperatorPrecedence(def); if (failed(precedence)) @@ -1452,7 +1480,7 @@ LogicalResult CppEmitter::emitOperand(Value value) { return success(); } - auto expressionOp = dyn_cast_if_present(value.getDefiningOp()); + auto expressionOp = dyn_cast_if_present(def); if (expressionOp && shouldBeInlined(expressionOp)) return emitExpression(expressionOp); @@ -1671,7 +1699,13 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { trailingSemicolon = false; } - os << (trailingSemicolon ? ";\n" : "\n"); + bool trailingNewline = true; + if (!shouldUseConstantsAsVariables() && isa(op)) { + trailingSemicolon = false; + trailingNewline = false; + } + + os << (trailingSemicolon ? ";" : "") << (trailingNewline ? "\n" : ""); return success(); } @@ -1837,7 +1871,8 @@ LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef types) { LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os, bool declareVariablesAtTop, - StringRef onlyTu) { - CppEmitter emitter(os, declareVariablesAtTop, onlyTu); + StringRef onlyTu, + bool constantsAsVariables) { + CppEmitter emitter(os, declareVariablesAtTop, onlyTu, constantsAsVariables); return emitter.emitOperation(*op, /*trailingSemicolon=*/false); } diff --git a/mlir/test/Target/Cpp/emitc-constants-as-variables.mlir b/mlir/test/Target/Cpp/emitc-constants-as-variables.mlir new file mode 100644 index 00000000000000..5774bdc47308ff --- /dev/null +++ b/mlir/test/Target/Cpp/emitc-constants-as-variables.mlir @@ -0,0 +1,19 @@ +// RUN: mlir-translate -mlir-to-cpp -constants-as-variables=false %s | FileCheck %s -check-prefix=CPP-DEFAULT + +func.func @test() { + %start = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t + %stop = "emitc.constant"() <{value = 10 : index}> : () -> !emitc.size_t + %step = "emitc.constant"() <{value = 1 : index}> : () -> !emitc.size_t + + emitc.for %iter = %start to %stop step %step { + emitc.yield + } + + return +} + +// CPP-DEFAULT: void test() { +// CPP-DEFAULT-NEXT: for (size_t v1 = ((size_t) 0); v1 < ((size_t) 10); v1 += ((size_t) 1)) { +// CPP-DEFAULT-NEXT: } +// CPP-DEFAULT-NEXT: return; +// CPP-DEFAULT-NEXT: }