Skip to content

Commit

Permalink
ASM: suggest OpResult name for BGV/CKKS/Openfhe
Browse files Browse the repository at this point in the history
  • Loading branch information
ZenithalHourlyRate committed Dec 21, 2024
1 parent 276bc7c commit 09754c7
Show file tree
Hide file tree
Showing 13 changed files with 186 additions and 5 deletions.
13 changes: 12 additions & 1 deletion lib/Dialect/BGV/IR/BGVOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,24 @@ include "lib/Dialect/LWE/IR/LWETraits.td"
include "lib/Dialect/Polynomial/IR/PolynomialAttributes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/OpAsmInterface.td"

class BGV_Op<string mnemonic, list<Trait> traits = []> :
Op<BGV_Dialect, mnemonic, traits> {
Op<BGV_Dialect, mnemonic, traits # [OpAsmOpInterface]> {
let cppNamespace = "::mlir::heir::bgv";
let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
// OpAsmOpInterface Methods
//===------------------------------------------------------------------===//

void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn) {
::mlir::heir::getAsmResultNames(*this, setNameFn);
}
}];
}

class BGV_CiphertextPlaintextOp<string mnemonic, list<Trait> traits =
Expand Down
12 changes: 11 additions & 1 deletion lib/Dialect/CKKS/IR/CKKSOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,21 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinAttributes.td"

class CKKS_Op<string mnemonic, list<Trait> traits = []> :
Op<CKKS_Dialect, mnemonic, traits> {
Op<CKKS_Dialect, mnemonic, traits # [OpAsmOpInterface]> {
let cppNamespace = "::mlir::heir::ckks";
let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
// OpAsmOpInterface Methods
//===------------------------------------------------------------------===//

void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn) {
::mlir::heir::getAsmResultNames(*this, setNameFn);
}
}];
}

class CKKS_CiphertextPlaintextOp<string mnemonic, list<Trait> traits =
Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/LWE/IR/LWEDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ namespace mlir {
namespace heir {
namespace lwe {

std::string lweSuggestNameForType(Type type) {
return llvm::TypeSwitch<Type, std::string>(type)
.Case<NewLWECiphertextType>([&](Type) { return "ct"; })
.Case<NewLWEPlaintextType>([&](Type) { return "pt"; })
.Case<NewLWEPublicKeyType>([&](Type) { return "pkey"; })
.Case<NewLWESecretKeyType>([&](Type) { return "skey"; })
.Default([&](Type) { return ""; }); // use the default numbering.
}

class LWEOpAsmDialectInterface : public OpAsmDialectInterface {
public:
using OpAsmDialectInterface::OpAsmDialectInterface;
Expand Down
14 changes: 13 additions & 1 deletion lib/Dialect/LWE/IR/LWEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/OpAsmInterface.td"

class HasEncoding<
string encodingHolder,
Expand Down Expand Up @@ -85,11 +86,22 @@ class KeyAndCiphertextMatch<

// LWE Operations are always Pure by design
class LWE_Op<string mnemonic, list<Trait> traits = []> :
Op<LWE_Dialect, mnemonic, traits # [Pure]> {
Op<LWE_Dialect, mnemonic, traits # [Pure, OpAsmOpInterface]> {
let cppNamespace = "::mlir::heir::lwe";
let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, results)
}];


let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
// OpAsmOpInterface Methods
//===------------------------------------------------------------------===//

void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn) {
::mlir::heir::getAsmResultNames(*this, setNameFn);
}
}];
}

class LWE_BinOp<string mnemonic, list<Trait> traits = []> :
Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/LWE/IR/LWETypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,18 @@
#define GET_TYPEDEF_CLASSES
#include "lib/Dialect/LWE/IR/LWETypes.h.inc"

namespace mlir {
namespace heir {

// just declaration here
void getAsmResultNames(Operation *op, ::mlir::OpAsmSetValueNameFn setNameFn);

namespace lwe {

std::string lweSuggestNameForType(Type type);

} // namespace lwe
} // namespace heir
} // namespace mlir

#endif // LIB_DIALECT_LWE_IR_LWETYPES_H_
9 changes: 9 additions & 0 deletions lib/Dialect/Openfhe/IR/OpenfheDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ namespace mlir {
namespace heir {
namespace openfhe {

std::string openfheSuggestNameForType(Type type) {
return llvm::TypeSwitch<Type, std::string>(type)
.Case<CryptoContextType>([&](Type) { return "cc"; })
.Case<CCParamsType>([&](Type) { return "params"; })
.Case<PublicKeyType>([&](Type) { return "pkey"; })
.Case<PrivateKeyType>([&](Type) { return "skey"; })
.Default([&](Type) { return ""; }); // use the default numbering.
}

void OpenfheDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
Expand Down
12 changes: 11 additions & 1 deletion lib/Dialect/Openfhe/IR/OpenfheOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,21 @@ include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

class Openfhe_Op<string mnemonic, list<Trait> traits = []> :
Op<Openfhe_Dialect, mnemonic, traits> {
Op<Openfhe_Dialect, mnemonic, traits # [OpAsmOpInterface]> {
let assemblyFormat = [{
operands attr-dict `:` functional-type(operands, results)
}];
let cppNamespace = "::mlir::heir::openfhe";

let extraClassDeclaration = [{
//===------------------------------------------------------------------===//
// OpAsmOpInterface Methods
//===------------------------------------------------------------------===//

void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn) {
::mlir::heir::getAsmResultNames(*this, setNameFn);
}
}];
}

class Openfhe_UnaryTypeSwitchOp<string mnemonic, list<Trait> traits = []>
Expand Down
14 changes: 14 additions & 0 deletions lib/Dialect/Openfhe/IR/OpenfheTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,18 @@
#define GET_TYPEDEF_CLASSES
#include "lib/Dialect/Openfhe/IR/OpenfheTypes.h.inc"

namespace mlir {
namespace heir {

// just declaration here
void getAsmResultNames(Operation *op, ::mlir::OpAsmSetValueNameFn setNameFn);

namespace openfhe {

std::string openfheSuggestNameForType(Type type);

} // namespace openfhe
} // namespace heir
} // namespace mlir

#endif // LIB_DIALECT_OPENFHE_IR_OPENFHETYPES_H_
12 changes: 12 additions & 0 deletions lib/Utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,15 @@ cc_library(
"@llvm-project//mlir:Support",
],
)

cc_library(
name = "OpAsmInterfaceHelper",
srcs = ["OpAsmInterfaceHelper.cpp"],
hdrs = ["OpAsmInterfaceHelper.h"],
deps = [
"@heir//lib/Dialect/LWE/IR:Dialect",
"@heir//lib/Dialect/Openfhe/IR:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)
41 changes: 41 additions & 0 deletions lib/Utils/OpAsmInterfaceHelper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include "lib/Utils/OpAsmInterfaceHelper.h"

#include "lib/Dialect/LWE/IR/LWETypes.h"
#include "lib/Dialect/Openfhe/IR/OpenfheTypes.h"

namespace mlir {
namespace heir {

void suggestNameForValue(Value value, ::mlir::OpAsmSetValueNameFn setNameFn) {
auto suggestFunctions = {
lwe::lweSuggestNameForType,
openfhe::openfheSuggestNameForType,
};
// only the first suggestion is used
// if no suggestion then do nothing
for (auto suggest : suggestFunctions) {
auto suggested = suggest(value.getType());
if (!suggested.empty()) {
setNameFn(value, suggested);
return;
}
}
}

void getAsmBlockArgumentNames(Operation* op, Region& region,
::mlir::OpAsmSetValueNameFn setNameFn) {
for (auto& block : region) {
for (auto arg : block.getArguments()) {
suggestNameForValue(arg, setNameFn);
}
}
}

void getAsmResultNames(Operation* op, ::mlir::OpAsmSetValueNameFn setNameFn) {
for (auto result : op->getResults()) {
suggestNameForValue(result, setNameFn);
}
}

} // namespace heir
} // namespace mlir
17 changes: 17 additions & 0 deletions lib/Utils/OpAsmInterfaceHelper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef LIB_UTILS_OPASMINTERFACEHELPER_
#define LIB_UTILS_OPASMINTERFACEHELPER_

#include "mlir/include/mlir/IR/DialectImplementation.h" // from @llvm-project

namespace mlir {
namespace heir {

void getAsmResultNames(Operation *op, ::mlir::OpAsmSetValueNameFn setNameFn);

void getAsmBlockArgumentNames(Operation *op, Region &region,
::mlir::OpAsmSetValueNameFn setNameFn);

} // namespace heir
} // namespace mlir

#endif // LIB_UTILS_OPASMINTERFACEHELPER_
3 changes: 3 additions & 0 deletions tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ cc_binary(
"@heir//lib/Transforms/StraightLineVectorizer",
"@heir//lib/Transforms/TensorToScalars",
"@heir//lib/Transforms/UnusedMemRef",
"@heir//lib/Utils:OpAsmInterfaceHelper",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:AffineToStandard",
Expand Down Expand Up @@ -171,6 +172,7 @@ cc_binary(
"@heir//lib/Target/TfheRust:TfheRustEmitter",
"@heir//lib/Target/TfheRustBool:TfheRustBoolEmitter",
"@heir//lib/Target/Verilog:VerilogEmitter",
"@heir//lib/Utils:OpAsmInterfaceHelper",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:TranslateLib",
],
Expand Down Expand Up @@ -198,6 +200,7 @@ cc_binary(
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@heir//lib/Dialect/TfheRust/IR:Dialect",
"@heir//lib/Dialect/TfheRustBool/IR:Dialect",
"@heir//lib/Utils:OpAsmInterfaceHelper",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:BufferizationDialect",
Expand Down
21 changes: 20 additions & 1 deletion tools/heir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
#include "lib/Transforms/StraightLineVectorizer/StraightLineVectorizer.h"
#include "lib/Transforms/TensorToScalars/TensorToScalars.h"
#include "lib/Transforms/UnusedMemRef/UnusedMemRef.h"
#include "lib/Utils/OpAsmInterfaceHelper.h"
#include "mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
#include "mlir/include/mlir/Conversion/ArithToLLVM/ArithToLLVM.h" // from @llvm-project
#include "mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" // from @llvm-project
Expand Down Expand Up @@ -119,7 +120,25 @@
using namespace mlir;
using namespace tosa;
using namespace heir;
using mlir::func::FuncOp;

// hack here: another template specialization for FuncOp
// expect linker to pick this one
//
// This is really unsafe as it depends on ::mlir::detail,
// which is not a expected behavior. However, the current
// OpAsmOpInterface declaration in MLIR already has a default implementation
// so we can not provide another implementation for it (MLIR does not
// support it)
//
// for detail, check #1219
template <>
void ::mlir::detail::OpAsmOpInterfaceInterfaceTraits::
Model<mlir::func::FuncOp>::getAsmBlockArgumentNames(
mlir::detail::OpAsmOpInterfaceInterfaceTraits::Concept const *,
mlir::Operation *op, mlir::Region &region,
::mlir::OpAsmSetValueNameFn setNameFn) {
::mlir::heir::getAsmBlockArgumentNames(op, region, setNameFn);
}

int main(int argc, char **argv) {
mlir::DialectRegistry registry;
Expand Down

0 comments on commit 09754c7

Please sign in to comment.