Skip to content

Commit

Permalink
fix build
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 1, 2025
1 parent 419db81 commit bdb9901
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 8 deletions.
3 changes: 3 additions & 0 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"

#include "llvm/ADT/BreadthFirstIterator.h"

#include "EnzymeLogic.h"
Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h"

#include "EnzymeLogic.h"
#include "Interfaces/GradientUtils.h"
Expand Down
8 changes: 4 additions & 4 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse(
ArrayRef<DIFFE_TYPE> ReturnActivity, ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
DerivativeMode mode_, unsigned width)
DerivativeMode mode_, unsigned width, StringRef postpasses)
: MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {},
invertedPointers_, returnPrimals, returnShadows,
constantvalues_, activevals_, ReturnActivity,
ArgDiffeTypes_, originalToNewFn_, originalToNewFnOps_,
mode_, width, /*omp*/ false) {}
mode_, width, /*omp*/ false, postpasses) {}

Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() {
Type indexType = getIndexType();
Expand Down Expand Up @@ -138,7 +138,7 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone(
FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
const ArrayRef<bool> returnPrimals, const ArrayRef<bool> returnShadows,
ArrayRef<DIFFE_TYPE> retType, ArrayRef<DIFFE_TYPE> constant_args,
mlir::Type additionalArg) {
mlir::Type additionalArg, llvm::StringRef postpasses) {
std::string prefix;

switch (mode_) {
Expand Down Expand Up @@ -174,5 +174,5 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone(
return new MGradientUtilsReverse(
Logic, newFunc, todiff, TA, invertedPointers, returnPrimals,
returnShadows, constant_values, nonconstant_values, retType,
constant_args, originalToNew, originalToNewOps, mode_, width);
constant_args, originalToNew, originalToNewOps, mode_, width, postpasses);
}
4 changes: 2 additions & 2 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class MGradientUtilsReverse : public MDiffeGradientUtils {
ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
DerivativeMode mode_, unsigned width);
DerivativeMode mode_, unsigned width, llvm::StringRef postpasses);

IRMapping mapReverseModeBlocks;

Expand Down Expand Up @@ -69,7 +69,7 @@ class MGradientUtilsReverse : public MDiffeGradientUtils {
FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
const ArrayRef<bool> returnPrimals, const ArrayRef<bool> returnShadows,
llvm::ArrayRef<DIFFE_TYPE> retType,
llvm::ArrayRef<DIFFE_TYPE> constant_args, mlir::Type additionalArg);
llvm::ArrayRef<DIFFE_TYPE> constant_args, mlir::Type additionalArg, llvm::StringRef postpasses);
};

} // namespace enzyme
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/PassManager.h"

#define DEBUG_TYPE "enzyme"

Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/MLIR/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def DifferentiatePass : Pass<"enzyme"> {
];
let options = [
Option<
/*C++ variable name=*/"postopt",
/*CLI argument=*/"postopt",
/*C++ variable name=*/"postpasses",
/*CLI argument=*/"postpasses",
/*type=*/"std::string",
/*default=*/"",
/*description=*/"Optimization passes to apply to generated derivative functions"
Expand Down

0 comments on commit bdb9901

Please sign in to comment.