From 09754c74a67c070f3afdd8ad7672b86654ab4483 Mon Sep 17 00:00:00 2001 From: Zenithal Date: Sat, 21 Dec 2024 07:44:20 +0000 Subject: [PATCH] ASM: suggest OpResult name for BGV/CKKS/Openfhe --- lib/Dialect/BGV/IR/BGVOps.td | 13 ++++++- lib/Dialect/CKKS/IR/CKKSOps.td | 12 ++++++- lib/Dialect/LWE/IR/LWEDialect.cpp | 9 +++++ lib/Dialect/LWE/IR/LWEOps.td | 14 +++++++- lib/Dialect/LWE/IR/LWETypes.h | 14 ++++++++ lib/Dialect/Openfhe/IR/OpenfheDialect.cpp | 9 +++++ lib/Dialect/Openfhe/IR/OpenfheOps.td | 12 ++++++- lib/Dialect/Openfhe/IR/OpenfheTypes.h | 14 ++++++++ lib/Utils/BUILD | 12 +++++++ lib/Utils/OpAsmInterfaceHelper.cpp | 41 +++++++++++++++++++++++ lib/Utils/OpAsmInterfaceHelper.h | 17 ++++++++++ tools/BUILD | 3 ++ tools/heir-opt.cpp | 21 +++++++++++- 13 files changed, 186 insertions(+), 5 deletions(-) create mode 100644 lib/Utils/OpAsmInterfaceHelper.cpp create mode 100644 lib/Utils/OpAsmInterfaceHelper.h diff --git a/lib/Dialect/BGV/IR/BGVOps.td b/lib/Dialect/BGV/IR/BGVOps.td index 9521dac83..2eb964b9a 100644 --- a/lib/Dialect/BGV/IR/BGVOps.td +++ b/lib/Dialect/BGV/IR/BGVOps.td @@ -11,13 +11,24 @@ include "lib/Dialect/LWE/IR/LWETraits.td" include "lib/Dialect/Polynomial/IR/PolynomialAttributes.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/OpAsmInterface.td" class BGV_Op traits = []> : - Op { + Op { let cppNamespace = "::mlir::heir::bgv"; let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) }]; + + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // OpAsmOpInterface Methods + //===------------------------------------------------------------------===// + + void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn) { + ::mlir::heir::getAsmResultNames(*this, setNameFn); + } + }]; } class BGV_CiphertextPlaintextOp traits = diff --git a/lib/Dialect/CKKS/IR/CKKSOps.td b/lib/Dialect/CKKS/IR/CKKSOps.td index 06ccb20c9..20bb8baf1 100644 --- a/lib/Dialect/CKKS/IR/CKKSOps.td +++ b/lib/Dialect/CKKS/IR/CKKSOps.td @@ -13,11 +13,21 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/BuiltinAttributes.td" class CKKS_Op traits = []> : - Op { + Op { let cppNamespace = "::mlir::heir::ckks"; let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) }]; + + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // OpAsmOpInterface Methods + //===------------------------------------------------------------------===// + + void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn) { + ::mlir::heir::getAsmResultNames(*this, setNameFn); + } + }]; } class CKKS_CiphertextPlaintextOp traits = diff --git a/lib/Dialect/LWE/IR/LWEDialect.cpp b/lib/Dialect/LWE/IR/LWEDialect.cpp index 46f0cd11f..70706b143 100644 --- a/lib/Dialect/LWE/IR/LWEDialect.cpp +++ b/lib/Dialect/LWE/IR/LWEDialect.cpp @@ -39,6 +39,15 @@ namespace mlir { namespace heir { namespace lwe { +std::string lweSuggestNameForType(Type type) { + return llvm::TypeSwitch(type) + .Case([&](Type) { return "ct"; }) + .Case([&](Type) { return "pt"; }) + .Case([&](Type) { return "pkey"; }) + .Case([&](Type) { return "skey"; }) + .Default([&](Type) { return ""; }); // use the default numbering. +} + class LWEOpAsmDialectInterface : public OpAsmDialectInterface { public: using OpAsmDialectInterface::OpAsmDialectInterface; diff --git a/lib/Dialect/LWE/IR/LWEOps.td b/lib/Dialect/LWE/IR/LWEOps.td index 4a036c732..470728899 100644 --- a/lib/Dialect/LWE/IR/LWEOps.td +++ b/lib/Dialect/LWE/IR/LWEOps.td @@ -10,6 +10,7 @@ include "mlir/IR/BuiltinAttributeInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/CommonAttrConstraints.td" +include "mlir/IR/OpAsmInterface.td" class HasEncoding< string encodingHolder, @@ -85,11 +86,22 @@ class KeyAndCiphertextMatch< // LWE Operations are always Pure by design class LWE_Op traits = []> : - Op { + Op { let cppNamespace = "::mlir::heir::lwe"; let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) }]; + + + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // OpAsmOpInterface Methods + //===------------------------------------------------------------------===// + + void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn) { + ::mlir::heir::getAsmResultNames(*this, setNameFn); + } + }]; } class LWE_BinOp traits = []> : diff --git a/lib/Dialect/LWE/IR/LWETypes.h b/lib/Dialect/LWE/IR/LWETypes.h index 8ef4dec1f..cb4f182bd 100644 --- a/lib/Dialect/LWE/IR/LWETypes.h +++ b/lib/Dialect/LWE/IR/LWETypes.h @@ -7,4 +7,18 @@ #define GET_TYPEDEF_CLASSES #include "lib/Dialect/LWE/IR/LWETypes.h.inc" +namespace mlir { +namespace heir { + +// just declaration here +void getAsmResultNames(Operation *op, ::mlir::OpAsmSetValueNameFn setNameFn); + +namespace lwe { + +std::string lweSuggestNameForType(Type type); + +} // namespace lwe +} // namespace heir +} // namespace mlir + #endif // LIB_DIALECT_LWE_IR_LWETYPES_H_ diff --git a/lib/Dialect/Openfhe/IR/OpenfheDialect.cpp b/lib/Dialect/Openfhe/IR/OpenfheDialect.cpp index 2a996f4d0..0a3caf9a4 100644 --- a/lib/Dialect/Openfhe/IR/OpenfheDialect.cpp +++ b/lib/Dialect/Openfhe/IR/OpenfheDialect.cpp @@ -15,6 +15,15 @@ namespace mlir { namespace heir { namespace openfhe { +std::string openfheSuggestNameForType(Type type) { + return llvm::TypeSwitch(type) + .Case([&](Type) { return "cc"; }) + .Case([&](Type) { return "params"; }) + .Case([&](Type) { return "pkey"; }) + .Case([&](Type) { return "skey"; }) + .Default([&](Type) { return ""; }); // use the default numbering. +} + void OpenfheDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST diff --git a/lib/Dialect/Openfhe/IR/OpenfheOps.td b/lib/Dialect/Openfhe/IR/OpenfheOps.td index 95ba52854..dec66a948 100644 --- a/lib/Dialect/Openfhe/IR/OpenfheOps.td +++ b/lib/Dialect/Openfhe/IR/OpenfheOps.td @@ -11,11 +11,21 @@ include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" class Openfhe_Op traits = []> : - Op { + Op { let assemblyFormat = [{ operands attr-dict `:` functional-type(operands, results) }]; let cppNamespace = "::mlir::heir::openfhe"; + + let extraClassDeclaration = [{ + //===------------------------------------------------------------------===// + // OpAsmOpInterface Methods + //===------------------------------------------------------------------===// + + void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn) { + ::mlir::heir::getAsmResultNames(*this, setNameFn); + } + }]; } class Openfhe_UnaryTypeSwitchOp traits = []> diff --git a/lib/Dialect/Openfhe/IR/OpenfheTypes.h b/lib/Dialect/Openfhe/IR/OpenfheTypes.h index beb488bce..03d49df86 100644 --- a/lib/Dialect/Openfhe/IR/OpenfheTypes.h +++ b/lib/Dialect/Openfhe/IR/OpenfheTypes.h @@ -6,4 +6,18 @@ #define GET_TYPEDEF_CLASSES #include "lib/Dialect/Openfhe/IR/OpenfheTypes.h.inc" +namespace mlir { +namespace heir { + +// just declaration here +void getAsmResultNames(Operation *op, ::mlir::OpAsmSetValueNameFn setNameFn); + +namespace openfhe { + +std::string openfheSuggestNameForType(Type type); + +} // namespace openfhe +} // namespace heir +} // namespace mlir + #endif // LIB_DIALECT_OPENFHE_IR_OPENFHETYPES_H_ diff --git a/lib/Utils/BUILD b/lib/Utils/BUILD index d74d96bb8..f798e1142 100644 --- a/lib/Utils/BUILD +++ b/lib/Utils/BUILD @@ -47,3 +47,15 @@ cc_library( "@llvm-project//mlir:Support", ], ) + +cc_library( + name = "OpAsmInterfaceHelper", + srcs = ["OpAsmInterfaceHelper.cpp"], + hdrs = ["OpAsmInterfaceHelper.h"], + deps = [ + "@heir//lib/Dialect/LWE/IR:Dialect", + "@heir//lib/Dialect/Openfhe/IR:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/lib/Utils/OpAsmInterfaceHelper.cpp b/lib/Utils/OpAsmInterfaceHelper.cpp new file mode 100644 index 000000000..2522158e4 --- /dev/null +++ b/lib/Utils/OpAsmInterfaceHelper.cpp @@ -0,0 +1,41 @@ +#include "lib/Utils/OpAsmInterfaceHelper.h" + +#include "lib/Dialect/LWE/IR/LWETypes.h" +#include "lib/Dialect/Openfhe/IR/OpenfheTypes.h" + +namespace mlir { +namespace heir { + +void suggestNameForValue(Value value, ::mlir::OpAsmSetValueNameFn setNameFn) { + auto suggestFunctions = { + lwe::lweSuggestNameForType, + openfhe::openfheSuggestNameForType, + }; + // only the first suggestion is used + // if no suggestion then do nothing + for (auto suggest : suggestFunctions) { + auto suggested = suggest(value.getType()); + if (!suggested.empty()) { + setNameFn(value, suggested); + return; + } + } +} + +void getAsmBlockArgumentNames(Operation* op, Region& region, + ::mlir::OpAsmSetValueNameFn setNameFn) { + for (auto& block : region) { + for (auto arg : block.getArguments()) { + suggestNameForValue(arg, setNameFn); + } + } +} + +void getAsmResultNames(Operation* op, ::mlir::OpAsmSetValueNameFn setNameFn) { + for (auto result : op->getResults()) { + suggestNameForValue(result, setNameFn); + } +} + +} // namespace heir +} // namespace mlir diff --git a/lib/Utils/OpAsmInterfaceHelper.h b/lib/Utils/OpAsmInterfaceHelper.h new file mode 100644 index 000000000..67c2f46be --- /dev/null +++ b/lib/Utils/OpAsmInterfaceHelper.h @@ -0,0 +1,17 @@ +#ifndef LIB_UTILS_OPASMINTERFACEHELPER_ +#define LIB_UTILS_OPASMINTERFACEHELPER_ + +#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project + +namespace mlir { +namespace heir { + +void getAsmResultNames(Operation *op, ::mlir::OpAsmSetValueNameFn setNameFn); + +void getAsmBlockArgumentNames(Operation *op, Region ®ion, + ::mlir::OpAsmSetValueNameFn setNameFn); + +} // namespace heir +} // namespace mlir + +#endif // LIB_UTILS_OPASMINTERFACEHELPER_ diff --git a/tools/BUILD b/tools/BUILD index cb835caec..400a4fd4a 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -106,6 +106,7 @@ cc_binary( "@heir//lib/Transforms/StraightLineVectorizer", "@heir//lib/Transforms/TensorToScalars", "@heir//lib/Transforms/UnusedMemRef", + "@heir//lib/Utils:OpAsmInterfaceHelper", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:AffineToStandard", @@ -171,6 +172,7 @@ cc_binary( "@heir//lib/Target/TfheRust:TfheRustEmitter", "@heir//lib/Target/TfheRustBool:TfheRustBoolEmitter", "@heir//lib/Target/Verilog:VerilogEmitter", + "@heir//lib/Utils:OpAsmInterfaceHelper", "@llvm-project//llvm:Support", "@llvm-project//mlir:TranslateLib", ], @@ -198,6 +200,7 @@ cc_binary( "@heir//lib/Dialect/TensorExt/IR:Dialect", "@heir//lib/Dialect/TfheRust/IR:Dialect", "@heir//lib/Dialect/TfheRustBool/IR:Dialect", + "@heir//lib/Utils:OpAsmInterfaceHelper", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:BufferizationDialect", diff --git a/tools/heir-opt.cpp b/tools/heir-opt.cpp index dc51c1b19..8a5d02373 100644 --- a/tools/heir-opt.cpp +++ b/tools/heir-opt.cpp @@ -67,6 +67,7 @@ #include "lib/Transforms/StraightLineVectorizer/StraightLineVectorizer.h" #include "lib/Transforms/TensorToScalars/TensorToScalars.h" #include "lib/Transforms/UnusedMemRef/UnusedMemRef.h" +#include "lib/Utils/OpAsmInterfaceHelper.h" #include "mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project #include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project #include "mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project @@ -119,7 +120,25 @@ using namespace mlir; using namespace tosa; using namespace heir; -using mlir::func::FuncOp; + +// hack here: another template specialization for FuncOp +// expect linker to pick this one +// +// This is really unsafe as it depends on ::mlir::detail, +// which is not a expected behavior. However, the current +// OpAsmOpInterface declaration in MLIR already has a default implementation +// so we can not provide another implementation for it (MLIR does not +// support it) +// +// for detail, check #1219 +template <> +void ::mlir::detail::OpAsmOpInterfaceInterfaceTraits:: + Model::getAsmBlockArgumentNames( + mlir::detail::OpAsmOpInterfaceInterfaceTraits::Concept const *, + mlir::Operation *op, mlir::Region ®ion, + ::mlir::OpAsmSetValueNameFn setNameFn) { + ::mlir::heir::getAsmBlockArgumentNames(op, region, setNameFn); +} int main(int argc, char **argv) { mlir::DialectRegistry registry;