Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 1, 2025
1 parent bdb9901 commit e906eb2
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ class AutoDiffCallFwd
fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, mode,
freeMemory, width,
/* addedType */ nullptr, type_args, volatile_args,
/* augmented */ nullptr,
gutils->postpasses);
/* augmented */ nullptr, gutils->postpasses);

SmallVector<Value> fwdArguments;

Expand Down
8 changes: 4 additions & 4 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff(
std::vector<DIFFE_TYPE> ArgActivity, MTypeAnalysis &TA,
std::vector<bool> returnPrimals, DerivativeMode mode, bool freeMemory,
size_t width, mlir::Type addedType, MFnTypeInfo type_args,
std::vector<bool> volatile_args, void *augmented, llvm::StringRef postpasses) {
std::vector<bool> volatile_args, void *augmented,
llvm::StringRef postpasses) {
if (fn.getFunctionBody().empty()) {
llvm::errs() << fn << "\n";
llvm_unreachable("Differentiating empty function");
Expand Down Expand Up @@ -201,9 +202,8 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff(
if (postpasses != "") {
mlir::PassManager pm(nf->getContext());
std::string error_message;
//llvm::raw_string_ostream error_stream(error_message);
mlir::LogicalResult result =
mlir::parsePassPipeline(postpasses, pm);
// llvm::raw_string_ostream error_stream(error_message);
mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm);
if (mlir::failed(result)) {
return nullptr;
}
Expand Down
17 changes: 9 additions & 8 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,15 @@ class MEnzymeLogic {
MFnTypeInfo type_args, std::vector<bool> volatile_args,
void *augmented, llvm::StringRef postpasses);

FunctionOpInterface CreateReverseDiff(
FunctionOpInterface fn, std::vector<DIFFE_TYPE> retType,
std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA,
std::vector<bool> returnPrimals, std::vector<bool> returnShadows,
DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType,
MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented,
llvm::StringRef postpasses);

FunctionOpInterface
CreateReverseDiff(FunctionOpInterface fn, std::vector<DIFFE_TYPE> retType,
std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA,
std::vector<bool> returnPrimals,
std::vector<bool> returnShadows, DerivativeMode mode,
bool freeMemory, size_t width, mlir::Type addedType,
MFnTypeInfo type_args, std::vector<bool> volatile_args,
void *augmented, llvm::StringRef postpasses);

void
initializeShadowValues(SmallVector<mlir::Block *> &dominatorToposortBlocks,
MGradientUtilsReverse *gutils);
Expand Down
5 changes: 2 additions & 3 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,8 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
if (postpasses != "") {
mlir::PassManager pm(nf->getContext());
std::string error_message;
//llvm::raw_string_ostream error_stream(error_message);
mlir::LogicalResult result =
mlir::parsePassPipeline(postpasses, pm);
// llvm::raw_string_ostream error_stream(error_message);
mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm);
if (mlir::failed(result)) {
return nullptr;
}
Expand Down
6 changes: 3 additions & 3 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ mlir::enzyme::MGradientUtils::MGradientUtils(
originalToNewFnOps(originalToNewFnOps_), blocksNotForAnalysis(),
activityAnalyzer(std::make_unique<enzyme::ActivityAnalyzer>(
blocksNotForAnalysis, constantvalues_, activevals_, ReturnActivity)),
TA(TA_), TR(TR_), omp(omp), postpasses(postpasses), returnPrimals(returnPrimals),
returnShadows(returnShadows), width(width), ArgDiffeTypes(ArgDiffeTypes_),
RetDiffeTypes(ReturnActivity) {}
TA(TA_), TR(TR_), omp(omp), postpasses(postpasses),
returnPrimals(returnPrimals), returnShadows(returnShadows), width(width),
ArgDiffeTypes(ArgDiffeTypes_), RetDiffeTypes(ReturnActivity) {}

mlir::Value mlir::enzyme::MGradientUtils::getNewFromOriginal(
const mlir::Value originst) const {
Expand Down
9 changes: 6 additions & 3 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ class MGradientUtils {
ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
DerivativeMode mode, unsigned width, bool omp, llvm::StringRef postpasses);
DerivativeMode mode, unsigned width, bool omp,
llvm::StringRef postpasses);
void erase(Operation *op) { op->erase(); }
void replaceOrigOpWith(Operation *op, ValueRange vals) {
for (auto &&[res, rep] : llvm::zip(op->getResults(), vals)) {
Expand Down Expand Up @@ -114,7 +115,8 @@ class MDiffeGradientUtils : public MGradientUtils {
ArrayRef<DIFFE_TYPE> RetActivity,
ArrayRef<DIFFE_TYPE> ArgActivity, IRMapping &origToNew_,
std::map<Operation *, Operation *> &origToNewOps_,
DerivativeMode mode, unsigned width, bool omp, llvm::StringRef postpasses)
DerivativeMode mode, unsigned width, bool omp,
llvm::StringRef postpasses)
: MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_,
returnPrimals, returnShadows, constantvalues_,
activevals_, RetActivity, ArgActivity, origToNew_,
Expand Down Expand Up @@ -164,7 +166,8 @@ class MDiffeGradientUtils : public MGradientUtils {
return new MDiffeGradientUtils(
Logic, newFunc, todiff, TA, TR, invertedPointers, returnPrimals,
returnShadows, constant_values, nonconstant_values, RetActivity,
ArgActivity, originalToNew, originalToNewOps, mode, width, omp, postpasses);
ArgActivity, originalToNew, originalToNewOps, mode, width, omp,
postpasses);
}
};

Expand Down
17 changes: 10 additions & 7 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ class MGradientUtilsReverse : public MDiffeGradientUtils {
ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
DerivativeMode mode_, unsigned width, llvm::StringRef postpasses);
DerivativeMode mode_, unsigned width,
llvm::StringRef postpasses);

IRMapping mapReverseModeBlocks;

Expand Down Expand Up @@ -64,12 +65,14 @@ class MGradientUtilsReverse : public MDiffeGradientUtils {

void createReverseModeBlocks(Region &oldFunc, Region &newFunc);

static MGradientUtilsReverse *CreateFromClone(
MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width,
FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
const ArrayRef<bool> returnPrimals, const ArrayRef<bool> returnShadows,
llvm::ArrayRef<DIFFE_TYPE> retType,
llvm::ArrayRef<DIFFE_TYPE> constant_args, mlir::Type additionalArg, llvm::StringRef postpasses);
static MGradientUtilsReverse *
CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width,
FunctionOpInterface todiff, MTypeAnalysis &TA,
MFnTypeInfo &oldTypeInfo, const ArrayRef<bool> returnPrimals,
const ArrayRef<bool> returnShadows,
llvm::ArrayRef<DIFFE_TYPE> retType,
llvm::ArrayRef<DIFFE_TYPE> constant_args,
mlir::Type additionalArg, llvm::StringRef postpasses);
};

} // namespace enzyme
Expand Down
9 changes: 4 additions & 5 deletions enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,16 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {

void runOnOperation() override;


void getDependentDialects(DialectRegistry &registry) const override {
mlir::OpPassManager pm;
mlir::LogicalResult result =
mlir::parsePassPipeline(postpasses, pm);
mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm);
if (!mlir::failed(result)) {
pm.getDependentDialects(registry);
}

registry.insert<mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
mlir::cf::ControlFlowDialect, mlir::tensor::TensorDialect>();
registry
.insert<mlir::arith::ArithDialect, mlir::complex::ComplexDialect,
mlir::cf::ControlFlowDialect, mlir::tensor::TensorDialect>();
}

static std::vector<DIFFE_TYPE> mode_from_fn(FunctionOpInterface fn,
Expand Down

0 comments on commit e906eb2

Please sign in to comment.