Skip to content

Commit

Permalink
Handle differentiation via invoke
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 24, 2021
1 parent 81e87fc commit 33bee0d
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 19 deletions.
78 changes: 60 additions & 18 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ class Enzyme : public ModulePass {
}

/// Return whether successful
template <typename T>
bool HandleAutoDiff(T *CI, TargetLibraryInfo &TLI, bool PostOpt,
bool HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, bool PostOpt,
bool fwdMode) {

Value *fn = CI->getArgOperand(0);
Expand Down Expand Up @@ -575,9 +574,65 @@ class Enzyme : public ModulePass {

bool Changed = false;

for (BasicBlock &BB : F)
if (InvokeInst *II = dyn_cast<InvokeInst>(BB.getTerminator())) {

Function *Fn = II->getCalledFunction();

#if LLVM_VERSION_MAJOR >= 11
if (auto castinst = dyn_cast<ConstantExpr>(II->getCalledOperand()))
#else
if (auto castinst = dyn_cast<ConstantExpr>(II->getCalledValue()))
#endif
{
if (castinst->isCast())
if (auto fn = dyn_cast<Function>(castinst->getOperand(0)))
Fn = fn;
}
if (!Fn)
continue;

if (!(Fn->getName() == "__enzyme_float" ||
Fn->getName() == "__enzyme_double" ||
Fn->getName() == "__enzyme_integer" ||
Fn->getName() == "__enzyme_pointer" ||
Fn->getName().contains("__enzyme_call_inactive") ||
Fn->getName().contains("__enzyme_autodiff") ||
Fn->getName().contains("__enzyme_fwddiff")))
continue;

SmallVector<Value *, 16> CallArgs(II->arg_begin(), II->arg_end());
SmallVector<OperandBundleDef, 1> OpBundles;
II->getOperandBundlesAsDefs(OpBundles);
// Insert a normal call instruction...
#if LLVM_VERSION_MAJOR >= 8
CallInst *NewCall =
CallInst::Create(II->getFunctionType(), II->getCalledOperand(),
CallArgs, OpBundles, "", II);
#else
CallInst *NewCall =
CallInst::Create(II->getFunctionType(), II->getCalledValue(),
CallArgs, OpBundles, "", II);
#endif
NewCall->takeName(II);
NewCall->setCallingConv(II->getCallingConv());
NewCall->setAttributes(II->getAttributes());
NewCall->setDebugLoc(II->getDebugLoc());
II->replaceAllUsesWith(NewCall);

// Insert an unconditional branch to the normal destination.
BranchInst::Create(II->getNormalDest(), II);

// Remove any PHI node entries from the exception destination.
II->getUnwindDest()->removePredecessor(&BB);

// Remove the invoke instruction now.
BB.getInstList().erase(II);
Changed = true;
}

std::set<CallInst *> toLowerAuto;
std::set<CallInst *> toLowerFwd;
std::set<InvokeInst *> toLowerI;
std::set<CallInst *> InactiveCalls;
retry:;
for (BasicBlock &BB : F) {
Expand Down Expand Up @@ -752,15 +807,9 @@ class Enzyme : public ModulePass {
}
}

bool autoDiff = Fn && (Fn->getName() == "__enzyme_autodiff" ||
Fn->getName() == "enzyme_autodiff_" ||
Fn->getName().startswith("__enzyme_autodiff") ||
Fn->getName().contains("__enzyme_autodiff"));
bool autoDiff = Fn && Fn->getName().contains("__enzyme_autodiff");

bool fwdDiff = Fn && (Fn->getName() == "__enzyme_fwddiff" ||
Fn->getName() == "enzyme_fwddiff_" ||
Fn->getName().startswith("__enzyme_fwddiff") ||
Fn->getName().contains("__enzyme_fwddiff"));
bool fwdDiff = Fn && Fn->getName().contains("__enzyme_fwddiff");

if (autoDiff || fwdDiff) {
if (autoDiff) {
Expand Down Expand Up @@ -845,13 +894,6 @@ class Enzyme : public ModulePass {
break;
}

for (auto CI : toLowerI) {
successful &= HandleAutoDiff(CI, TLI, PostOpt, /*fwdMode*/ false);
Changed = true;
if (!successful)
break;
}

if (Changed) {
// TODO consider enabling when attributor does not delete
// dead internal functions, which invalidates Enzyme's cache
Expand Down
8 changes: 7 additions & 1 deletion enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2736,6 +2736,13 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
shouldRecompute(LI, incoming_available, &BuilderM);
}
}
if (!inst->mayReadOrWriteMemory()) {
reduceRegister |= tryLegalRecomputeCheck &&
legalRecompute(inst, incoming_available, &BuilderM) &&
shouldRecompute(inst, incoming_available, &BuilderM);
}
if (this->isOriginalBlock(*BuilderM.GetInsertBlock()))
reduceRegister = false;
}

if (!reduceRegister) {
Expand Down Expand Up @@ -2928,7 +2935,6 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
}
inst = cast<Instruction>(val);
assert(prelcssaInst->getType() == inst->getType());

assert(!this->isOriginalBlock(*BuilderM.GetInsertBlock()));

// Update index and caching per lcssa
Expand Down
32 changes: 32 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/invoke.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -adce -S | FileCheck %s

define double @sq(double %x) {
entry:
%0 = fmul fast double %x, %x
ret double %0
}

declare i32 @__gxx_personality_v0(...)

; Function Attrs: norecurse ssp uwtable
define double @caller(double %x) personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) {
%res = invoke double (...) @_Z17__enzyme_autodiffz(double (double)* nonnull @sq, double %x)
to label %eblock unwind label %cblock

eblock:
ret double %res

cblock:
%lp = landingpad { i8*, i32 }
cleanup
ret double 0.000000e+00
}

declare double @_Z17__enzyme_autodiffz(...)

; CHECK: define double @caller(double %x)
; CHECK-NEXT: eblock:
; CHECK-NEXT: %0 = call { double } @diffesq(double %x, double 1.000000e+00)
; CHECK-NEXT: %1 = extractvalue { double } %0, 0
; CHECK-NEXT: ret double %1
; CHECK-NEXT: }

0 comments on commit 33bee0d

Please sign in to comment.