Skip to content

Commit

Permalink
MLIR: post optimization pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 1, 2025
1 parent 7cf9e90 commit 727a2d0
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 22 deletions.
1 change: 1 addition & 0 deletions enzyme/.bazelversion
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
6.5.0
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ class AutoDiffCallFwd
fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, mode,
freeMemory, width,
/* addedType */ nullptr, type_args, volatile_args,
/* augmented */ nullptr);
/* augmented */ nullptr,
gutils->postpasses);

SmallVector<Value> fwdArguments;

Expand Down Expand Up @@ -173,7 +174,7 @@ class AutoDiffCallRev
auto revFn = gutils->Logic.CreateReverseDiff(
fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, returnShadow,
mode, freeMemory, width, /*addedType*/ nullptr, type_args,
volatile_args, /*augmented*/ nullptr);
volatile_args, /*augmented*/ nullptr, gutils->postpasses);

SmallVector<Value> revArguments;

Expand Down
24 changes: 17 additions & 7 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff(
std::vector<DIFFE_TYPE> ArgActivity, MTypeAnalysis &TA,
std::vector<bool> returnPrimals, DerivativeMode mode, bool freeMemory,
size_t width, mlir::Type addedType, MFnTypeInfo type_args,
std::vector<bool> volatile_args, void *augmented) {
std::vector<bool> volatile_args, void *augmented, llvm::StringRef postpasses) {
if (fn.getFunctionBody().empty()) {
llvm::errs() << fn << "\n";
llvm_unreachable("Differentiating empty function");
Expand All @@ -105,7 +105,7 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff(
auto gutils = MDiffeGradientUtils::CreateFromClone(
*this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP,
RetActivity, ArgActivity, addedType,
/*omp*/ false);
/*omp*/ false, postpasses);
ForwardCachedFunctions[tup] = gutils->newFunc;

insert_or_assign2<MForwardCacheKey, FunctionOpInterface>(
Expand Down Expand Up @@ -195,10 +195,20 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff(
if (!valid)
return nullptr;

// if (PostOpt)
// PPC.optimizeIntermediate(nf);
// if (EnzymePrint) {
// llvm::errs() << nf << "\n";
//}
if (postpasses != "") {
mlir::PassManager pm(nf->getContext());
std::string error_message;
//llvm::raw_string_ostream error_stream(error_message);
mlir::LogicalResult result =
mlir::parsePassPipeline(postpasses, pm);
if (mlir::failed(result)) {
return nullptr;
}

if (!mlir::succeeded(pm.run(nf))) {
return nullptr;
}
}

return nf;
}
6 changes: 4 additions & 2 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,16 @@ class MEnzymeLogic {
std::vector<bool> returnPrimals, DerivativeMode mode,
bool freeMemory, size_t width, mlir::Type addedType,
MFnTypeInfo type_args, std::vector<bool> volatile_args,
void *augmented);
void *augmented, llvm::StringRef postpasses);

FunctionOpInterface CreateReverseDiff(
FunctionOpInterface fn, std::vector<DIFFE_TYPE> retType,
std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA,
std::vector<bool> returnPrimals, std::vector<bool> returnShadows,
DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType,
MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented);
MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented,
llvm::StringRef postpasses);

void
initializeShadowValues(SmallVector<mlir::Block *> &dominatorToposortBlocks,
MGradientUtilsReverse *gutils);
Expand Down
20 changes: 18 additions & 2 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA,
std::vector<bool> returnPrimals, std::vector<bool> returnShadows,
DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType,
MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented) {
MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented,
llvm::StringRef postpasses) {

if (fn.getFunctionBody().empty()) {
llvm::errs() << fn << "\n";
Expand Down Expand Up @@ -214,7 +215,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(

MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone(
*this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP,
retType, constants, addedType);
retType, constants, addedType, postpasses);

ReverseCachedFunctions[tup] = gutils->newFunc;

Expand Down Expand Up @@ -254,5 +255,20 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
if (!res.succeeded())
return nullptr;

if (postpasses != "") {
mlir::PassManager pm(nf->getContext());
std::string error_message;
//llvm::raw_string_ostream error_stream(error_message);
mlir::LogicalResult result =
mlir::parsePassPipeline(postpasses, pm);
if (mlir::failed(result)) {
return nullptr;
}

if (!mlir::succeeded(pm.run(nf))) {
return nullptr;
}
}

return nf;
}
4 changes: 2 additions & 2 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ mlir::enzyme::MGradientUtils::MGradientUtils(
ArrayRef<DIFFE_TYPE> ReturnActivity, ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
DerivativeMode mode, unsigned width, bool omp)
DerivativeMode mode, unsigned width, bool omp, llvm::StringRef postpasses)
: newFunc(newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_),
invertedPointers(invertedPointers_), originalToNewFn(originalToNewFn_),
originalToNewFnOps(originalToNewFnOps_), blocksNotForAnalysis(),
activityAnalyzer(std::make_unique<enzyme::ActivityAnalyzer>(
blocksNotForAnalysis, constantvalues_, activevals_, ReturnActivity)),
TA(TA_), TR(TR_), omp(omp), returnPrimals(returnPrimals),
TA(TA_), TR(TR_), omp(omp), postpasses(postpasses), returnPrimals(returnPrimals),
returnShadows(returnShadows), width(width), ArgDiffeTypes(ArgDiffeTypes_),
RetDiffeTypes(ReturnActivity) {}

Expand Down
11 changes: 6 additions & 5 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class MGradientUtils {
MTypeAnalysis &TA;
MTypeResults TR;
bool omp;
llvm::StringRef postpasses;
const llvm::ArrayRef<bool> returnPrimals;
const llvm::ArrayRef<bool> returnShadows;

Expand All @@ -58,7 +59,7 @@ class MGradientUtils {
ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
DerivativeMode mode, unsigned width, bool omp);
DerivativeMode mode, unsigned width, bool omp, llvm::StringRef postpasses);
void erase(Operation *op) { op->erase(); }
void replaceOrigOpWith(Operation *op, ValueRange vals) {
for (auto &&[res, rep] : llvm::zip(op->getResults(), vals)) {
Expand Down Expand Up @@ -113,11 +114,11 @@ class MDiffeGradientUtils : public MGradientUtils {
ArrayRef<DIFFE_TYPE> RetActivity,
ArrayRef<DIFFE_TYPE> ArgActivity, IRMapping &origToNew_,
std::map<Operation *, Operation *> &origToNewOps_,
DerivativeMode mode, unsigned width, bool omp)
DerivativeMode mode, unsigned width, bool omp, llvm::StringRef postpasses)
: MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_,
returnPrimals, returnShadows, constantvalues_,
activevals_, RetActivity, ArgActivity, origToNew_,
origToNewOps_, mode, width, omp),
origToNewOps_, mode, width, omp, postpasses),
initializationBlock(&*(newFunc.getFunctionBody().begin())) {}

// Technically diffe constructor
Expand All @@ -127,7 +128,7 @@ class MDiffeGradientUtils : public MGradientUtils {
const llvm::ArrayRef<bool> returnPrimals,
const llvm::ArrayRef<bool> returnShadows,
ArrayRef<DIFFE_TYPE> RetActivity, ArrayRef<DIFFE_TYPE> ArgActivity,
mlir::Type additionalArg, bool omp) {
mlir::Type additionalArg, bool omp, llvm::StringRef postpasses) {
std::string prefix;

switch (mode) {
Expand Down Expand Up @@ -163,7 +164,7 @@ class MDiffeGradientUtils : public MGradientUtils {
return new MDiffeGradientUtils(
Logic, newFunc, todiff, TA, TR, invertedPointers, returnPrimals,
returnShadows, constant_values, nonconstant_values, RetActivity,
ArgActivity, originalToNew, originalToNewOps, mode, width, omp);
ArgActivity, originalToNew, originalToNewOps, mode, width, omp, postpasses);
}
};

Expand Down
19 changes: 17 additions & 2 deletions enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {

void runOnOperation() override;


void getDependentDialects(DialectRegistry &registry) const override {
mlir::PassManager pm(nf->getContext());
std::string error_message;
//llvm::raw_string_ostream error_stream(error_message);
mlir::LogicalResult result =
mlir::parsePassPipeline(postpasses, pm);
if (!mlir::failed(result)) {
pm.getDependentDialects(registry);
}

registry.insert<mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
mlir::cf::ControlFlowDialect, mlir::tensor::TensorDialect>();
}

static std::vector<DIFFE_TYPE> mode_from_fn(FunctionOpInterface fn,
DerivativeMode mode) {
std::vector<DIFFE_TYPE> retTypes;
Expand Down Expand Up @@ -150,7 +165,7 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
FunctionOpInterface newFunc = Logic.CreateForwardDiff(
fn, retType, constants, TA, returnPrimals, mode, freeMemory, width,
/*addedType*/ nullptr, type_args, volatile_args,
/*augmented*/ nullptr);
/*augmented*/ nullptr, postpasses);
if (!newFunc)
return failure();

Expand Down Expand Up @@ -276,7 +291,7 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
Logic.CreateReverseDiff(fn, retType, arg_activities, TA, returnPrimals,
returnShadows, mode, freeMemory, width,
/*addedType*/ nullptr, type_args, volatile_args,
/*augmented*/ nullptr);
/*augmented*/ nullptr, postpasses);
if (!newFunc)
return failure();

Expand Down
9 changes: 9 additions & 0 deletions enzyme/Enzyme/MLIR/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ def DifferentiatePass : Pass<"enzyme"> {
"cf::ControlFlowDialect",
"tensor::TensorDialect",
];
let options = [
Option<
/*C++ variable name=*/"postopt",
/*CLI argument=*/"postopt",
/*type=*/"std::string",
/*default=*/"",
/*description=*/"Optimization passes to apply to generated derivative functions"
>,
],
let constructor = "mlir::enzyme::createDifferentiatePass()";
}

Expand Down

0 comments on commit 727a2d0

Please sign in to comment.