Skip to content

Commit

Permalink
Improve sret style memory handling (#860)
Browse files Browse the repository at this point in the history
* Fix sret julia GC issue

* Simplify

* Fixup

* Fix

* Fix version inv
  • Loading branch information
wsmoses authored Sep 29, 2022
1 parent a6a92b4 commit ab357c2
Show file tree
Hide file tree
Showing 10 changed files with 704 additions and 398 deletions.
35 changes: 34 additions & 1 deletion enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -10377,7 +10377,12 @@ class AdjointGenerator
if (!forwardsShadow) {
if (Mode == DerivativeMode::ReverseModePrimal) {
// Needs a stronger replacement check/assertion.
Value *replacement = UndefValue::get(placeholder->getType());
Value *replacement;
if (EnzymeZeroCache)
replacement = ConstantPointerNull::get(
cast<PointerType>(placeholder->getType()));
else
replacement = UndefValue::get(placeholder->getType());
gutils->replaceAWithB(placeholder, replacement);
gutils->invertedPointers.erase(found);
gutils->invertedPointers.insert(std::make_pair(
Expand Down Expand Up @@ -11615,6 +11620,34 @@ class AdjointGenerator
gradByVal[args.size()] = orig->getParamByValType(i);
}
#endif

bool writeOnlyNoCapture = true;
#if LLVM_VERSION_MAJOR >= 8
if (!orig->doesNotCapture(i))
#else
if (!(orig->dataOperandHasImpliedAttr(i + 1, Attribute::NoCapture) ||
(called && called->hasParamAttribute(i, Attribute::NoCapture))))
#endif
{
writeOnlyNoCapture = false;
}
#if LLVM_VERSION_MAJOR >= 14
if (!orig->onlyWritesMemory(i))
#else
if (!(orig->dataOperandHasImpliedAttr(i + 1, Attribute::WriteOnly) ||
orig->dataOperandHasImpliedAttr(i + 1, Attribute::ReadNone) ||
(called && (called->hasParamAttribute(i, Attribute::WriteOnly) ||
called->hasParamAttribute(i, Attribute::ReadNone)))))
#endif
{
writeOnlyNoCapture = false;
}
if (writeOnlyNoCapture) {
if (EnzymeZeroCache)
argi = ConstantPointerNull::get(cast<PointerType>(argi->getType()));
else
argi = UndefValue::get(argi->getType());
}
args.push_back(lookup(argi, Builder2));
}

Expand Down
39 changes: 37 additions & 2 deletions enzyme/Enzyme/DifferentialUseAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@

#include "GradientUtils.h"

typedef std::pair<const Value *, ValueType> UsageKey;

// Determine if a value is needed directly to compute the adjoint
// of the given instruction user
static inline bool is_use_directly_needed_in_reverse(
Expand Down Expand Up @@ -297,6 +295,43 @@ static inline bool is_use_directly_needed_in_reverse(
// we still need even if instruction is inactive
if (funcName == "llvm.julia.gc_preserve_begin")
return true;

bool writeOnlyNoCapture = true;
auto F = getFunctionFromCall(const_cast<CallInst *>(CI));
#if LLVM_VERSION_MAJOR >= 14
for (size_t i = 0; i < CI->arg_size(); i++)
#else
for (size_t i = 0; i < CI->getNumArgOperands(); i++)
#endif
{
if (val == CI->getArgOperand(i)) {
#if LLVM_VERSION_MAJOR >= 8
if (!CI->doesNotCapture(i))
#else
if (!(CI->dataOperandHasImpliedAttr(i + 1, Attribute::NoCapture) ||
(F && F->hasParamAttribute(i, Attribute::NoCapture))))
#endif
{
writeOnlyNoCapture = false;
break;
}
#if LLVM_VERSION_MAJOR >= 14
if (!CI->onlyWritesMemory(i))
#else
if (!(CI->dataOperandHasImpliedAttr(i + 1, Attribute::WriteOnly) ||
CI->dataOperandHasImpliedAttr(i + 1, Attribute::ReadNone) ||
(F && (F->hasParamAttribute(i, Attribute::WriteOnly) ||
F->hasParamAttribute(i, Attribute::ReadNone)))))
#endif
{
writeOnlyNoCapture = false;
break;
}
}
}
// Don't need the primal argument if it is write only and not captured
if (writeOnlyNoCapture)
return false;
}

return !gutils->isConstantInstruction(user) ||
Expand Down
11 changes: 11 additions & 0 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ struct CacheAnalysis {
// Pointer operands originating from call instructions that are not
// malloc/free are conservatively considered uncacheable.
if (auto obj_op = dyn_cast<CallInst>(obj)) {
auto n = getFuncNameFromCall(obj_op);
// If this is a known allocation which is not captured or returned,
// a caller function cannot overwrite this (since it cannot access).
// Since we don't currently perform this check, we can instead check
Expand All @@ -193,6 +194,9 @@ struct CacheAnalysis {
if (allocationsWithGuaranteedFree.find(obj_op) !=
allocationsWithGuaranteedFree.end()) {

} else if (n == "julia.get_pgcstack" || n == "julia.ptls_states" ||
n == "jl_get_ptls_states") {

} else {
// OP is a non malloc/free call so we need to cache
mustcache = true;
Expand Down Expand Up @@ -267,6 +271,13 @@ struct CacheAnalysis {
oldFunc->getParent()->getDataLayout(), 100);
#endif

if (auto obj_op = dyn_cast<CallInst>(obj)) {
auto n = getFuncNameFromCall(obj_op);
if (n == "julia.get_pgcstack" || n == "julia.ptls_states" ||
n == "jl_get_ptls_states")
return false;
}

// Openmp bound and local thread id are unchanging
// definitionally cacheable.
if (omp)
Expand Down
18 changes: 17 additions & 1 deletion enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ void RecursivelyReplaceAddressSpace(Value *AI, Value *rep, bool legal) {
auto AS = cast<PointerType>(rep->getType())->getAddressSpace();
if (AS == ASC->getDestAddressSpace()) {
ASC->replaceAllUsesWith(rep);
toErase.push_back(ASC);
continue;
}
ASC->setOperand(0, rep);
Expand Down Expand Up @@ -360,13 +361,28 @@ void RecursivelyReplaceAddressSpace(Value *AI, Value *rep, bool legal) {
continue;
}
}
IRBuilder<> B(CI);
auto Addr = B.CreateAddrSpaceCast(rep, prev->getType());
#if LLVM_VERSION_MAJOR >= 14
for (size_t i = 0; i < CI->arg_size(); i++)
#else
for (size_t i = 0; i < CI->getNumArgOperands(); i++)
#endif
{
if (CI->getArgOperand(i) == prev) {
CI->setArgOperand(i, Addr);
}
}
continue;
}
llvm::errs() << " rep: " << *rep << " prev: " << *prev << " inst: " << *inst
<< "\n";
llvm_unreachable("Illegal address space propagation");
}
for (auto I : llvm::reverse(toErase))

for (auto I : llvm::reverse(toErase)) {
I->eraseFromParent();
}
for (auto SI : toPostCache) {
IRBuilder<> B(SI->getNextNode());
PostCacheStore(SI, B);
Expand Down
Loading

0 comments on commit ab357c2

Please sign in to comment.