Skip to content

Commit

Permalink
fix ConstantExpr handling in CreateAugmentedPrimal (#743)
Browse files Browse the repository at this point in the history
* fix ConstantExpr handling in CreateAugmentedPrimal

* add testcase

* fix testcase

* Update constexpr.ll

* respect lifetime

Co-authored-by: Tim Gymnich <timgymnich@me.com>
  • Loading branch information
ZuseZ4 and tgymnich authored Jul 21, 2022
1 parent e8ed87c commit 4eb8421
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 24 deletions.
48 changes: 24 additions & 24 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2232,6 +2232,20 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
gutils->newFunc->getEntryBlock().getFirstNonPHIOrDbgOrLifetime());
}

//! Keep track of inverted pointers we may need to return
ValueToValueMapTy invertedRetPs;
if (shadowReturnUsed) {
for (BasicBlock &BB : *gutils->oldFunc) {
if (auto ri = dyn_cast<ReturnInst>(BB.getTerminator())) {
if (Value *orig_oldval = ri->getReturnValue()) {
auto newri = gutils->getNewFromOriginal(ri);
IRBuilder<> BuilderZ(newri);
invertedRetPs[newri] = gutils->invertPointerM(orig_oldval, BuilderZ);
}
}
}
}

(IRBuilder<>(gutils->inversionAllocs)).CreateUnreachable();
DeleteDeadBlock(gutils->inversionAllocs);

Expand Down Expand Up @@ -2290,20 +2304,6 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
#endif
}

//! Keep track of inverted pointers we may need to return
ValueToValueMapTy invertedRetPs;
if (shadowReturnUsed) {
for (BasicBlock &BB : *gutils->oldFunc) {
if (auto ri = dyn_cast<ReturnInst>(BB.getTerminator())) {
if (Value *orig_oldval = ri->getReturnValue()) {
auto newri = gutils->getNewFromOriginal(ri);
IRBuilder<> BuilderZ(newri);
invertedRetPs[newri] = gutils->invertPointerM(orig_oldval, BuilderZ);
}
}
}
}

gutils->eraseFictiousPHIs();

if (llvm::verifyFunction(*gutils->newFunc, &llvm::errs())) {
Expand Down Expand Up @@ -2412,22 +2412,21 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
Function *NewF = Function::Create(
FTy, nf->getLinkage(), "augmented_" + todiff->getName(), nf->getParent());

unsigned ii = 0, jj = 0;
unsigned attrIndex = 0;
auto i = nf->arg_begin(), j = NewF->arg_begin();
for (; i != nf->arg_end();) {
while (i != nf->arg_end()) {
VMap[i] = j;
if (nf->hasParamAttribute(ii, Attribute::NoCapture)) {
NewF->addParamAttr(jj, Attribute::NoCapture);
if (nf->hasParamAttribute(attrIndex, Attribute::NoCapture)) {
NewF->addParamAttr(attrIndex, Attribute::NoCapture);
}
if (nf->hasParamAttribute(ii, Attribute::NoAlias)) {
NewF->addParamAttr(jj, Attribute::NoAlias);
if (nf->hasParamAttribute(attrIndex, Attribute::NoAlias)) {
NewF->addParamAttr(attrIndex, Attribute::NoAlias);
}

j->setName(i->getName());
++j;
++jj;
++i;
++ii;
++attrIndex;
}

SmallVector<ReturnInst *, 4> Returns;
Expand Down Expand Up @@ -2617,9 +2616,10 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
if (auto ggep = dyn_cast<GetElementPtrInst>(gep)) {
ggep->setIsInBounds(true);
}
if (isa<ConstantData>(invertedRetPs[ri]))
if (isa<ConstantExpr>(invertedRetPs[ri]) ||
isa<ConstantData>(invertedRetPs[ri])) {
ib.CreateStore(invertedRetPs[ri], gep);
else {
} else {
assert(VMap[invertedRetPs[ri]]);
ib.CreateStore(VMap[invertedRetPs[ri]], gep);
}
Expand Down
26 changes: 26 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/constexpr.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -S | FileCheck %s

@_ZTId = external dso_local constant i8*

define i8* @_ZNK4implIdE4typeEv() {
ret i8* bitcast (i8** @_ZTId to i8*)
}

declare void @_Z17__enzyme_virtualreverse(i8*)

define void @_Z18wrapper_1body_intsv() {
call void @_Z17__enzyme_virtualreverse(i8* bitcast (i8* ()* @_ZNK4implIdE4typeEv to i8*))
ret void
}

; CHECK: define internal { i8*, i8*, i8* } @augmented__ZNK4implIdE4typeEv()
; CHECK-NEXT: %1 = alloca { i8*, i8*, i8* }
; CHECK-NEXT: %2 = getelementptr inbounds { i8*, i8*, i8* }, { i8*, i8*, i8* }* %1, i32 0, i32 0
; CHECK-NEXT: store i8* null, i8** %2
; CHECK-NEXT: %3 = getelementptr inbounds { i8*, i8*, i8* }, { i8*, i8*, i8* }* %1, i32 0, i32 1
; CHECK-NEXT: store i8* bitcast (i8** @_ZTId to i8*), i8** %3
; CHECK-NEXT: %4 = getelementptr inbounds { i8*, i8*, i8* }, { i8*, i8*, i8* }* %1, i32 0, i32 2
; CHECK-NEXT: store i8* bitcast (i8** @_ZTId_shadow to i8*), i8** %4
; CHECK-NEXT: %5 = load { i8*, i8*, i8* }, { i8*, i8*, i8* }* %1
; CHECK-NEXT: ret { i8*, i8*, i8* } %5
; CHECK-NEXT: }

0 comments on commit 4eb8421

Please sign in to comment.