Skip to content

Commit

Permalink
New PM and Opaque pointer progress (#960)
Browse files Browse the repository at this point in the history
* New PM and Opaque pointer progress

* Continuing

* Version invariance

* ensure command line arg parsed

* Fix build

* work on llvm 10
  • Loading branch information
wsmoses authored Dec 28, 2022
1 parent e0bb22c commit d54daa7
Show file tree
Hide file tree
Showing 14 changed files with 247 additions and 89 deletions.
28 changes: 21 additions & 7 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -4278,7 +4278,11 @@ class AdjointGenerator
auto *PowF = CI.getCalledValue();
#endif
assert(PowF);
auto FT = cast<FunctionType>(PowF->getType()->getPointerElementType());
FunctionType *FT = nullptr;
if (auto F = dyn_cast<Function>(PowF))
FT = F->getFunctionType();
else
cast<FunctionType>(PowF->getType()->getPointerElementType());

if (vdiff && !gutils->isConstantValue(orig_ops[0])) {

Expand Down Expand Up @@ -12089,8 +12093,11 @@ class AdjointGenerator
"whose runtime value is inactive",
gutils->getNewFromOriginal(orig->getDebugLoc()), orig);

auto ft =
cast<FunctionType>(callval->getType()->getPointerElementType());
FunctionType *ft = nullptr;
if (auto F = dyn_cast<Function>(callval))
ft = F->getFunctionType();
else
ft = cast<FunctionType>(callval->getType()->getPointerElementType());

std::set<llvm::Type *> seen;
DIFFE_TYPE subretType = whatType(orig->getType(), Mode,
Expand Down Expand Up @@ -12165,8 +12172,11 @@ class AdjointGenerator
// sub_index_map = fnandtapetype.tapeIndices;

assert(newcalled);
FunctionType *FT =
cast<FunctionType>(newcalled->getType()->getPointerElementType());
FunctionType *FT = nullptr;
if (auto F = dyn_cast<Function>(newcalled))
FT = F->getFunctionType();
else
FT = cast<FunctionType>(newcalled->getType()->getPointerElementType());

// llvm::errs() << "seeing sub_index_map of " << sub_index_map->size()
// << " in ap " << cast<Function>(called)->getName() << "\n";
Expand Down Expand Up @@ -12545,8 +12555,12 @@ class AdjointGenerator

assert(newcalled);
// if (auto NC = dyn_cast<Function>(newcalled)) {
FunctionType *FT =
cast<FunctionType>(newcalled->getType()->getPointerElementType());
FunctionType *FT = nullptr;
if (auto F = dyn_cast<Function>(newcalled))
FT = F->getFunctionType();
else {
FT = cast<FunctionType>(newcalled->getType()->getPointerElementType());
}

if (false) {
badfn:;
Expand Down
26 changes: 11 additions & 15 deletions enzyme/Enzyme/CacheUtility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1010,15 +1010,13 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T,
/*available*/ ValueToValueMapTy());

#if LLVM_VERSION_MAJOR > 7
storeInto = v.CreateLoad(storeInto->getType()->getPointerElementType(),
storeInto);
storeInto = v.CreateLoad(types[i + 1], storeInto);
#if LLVM_VERSION_MAJOR >= 10
cast<LoadInst>(storeInto)->setAlignment(Align(alignSize));
#else
cast<LoadInst>(storeInto)->setAlignment(alignSize);
#endif
storeInto = v.CreateGEP(storeInto->getType()->getPointerElementType(),
storeInto, idx);
storeInto = v.CreateGEP(types[i], storeInto, idx);
#else
storeInto = v.CreateLoad(storeInto);
cast<LoadInst>(storeInto)->setAlignment(alignSize);
Expand Down Expand Up @@ -1603,13 +1601,13 @@ Value *CacheUtility::getCachePointer(bool inForwardPass, IRBuilder<> &BuilderM,

/// Perform the final load from the cache, applying requisite invariant
/// group and alignment
llvm::Value *CacheUtility::loadFromCachePointer(llvm::IRBuilder<> &BuilderM,
llvm::Value *CacheUtility::loadFromCachePointer(Type *T,
llvm::IRBuilder<> &BuilderM,
llvm::Value *cptr,
llvm::Value *cache) {
// Retrieve the actual result
#if LLVM_VERSION_MAJOR > 7
auto result =
BuilderM.CreateLoad(cptr->getType()->getPointerElementType(), cptr);
auto result = BuilderM.CreateLoad(T, cptr);
#else
auto result = BuilderM.CreateLoad(cptr);
#endif
Expand Down Expand Up @@ -1639,11 +1637,10 @@ llvm::Value *CacheUtility::loadFromCachePointer(llvm::IRBuilder<> &BuilderM,

/// Given an allocation specified by the LimitContext ctx and cache, lookup the
/// underlying cached value.
Value *
CacheUtility::lookupValueFromCache(bool inForwardPass, IRBuilder<> &BuilderM,
LimitContext ctx, Value *cache, bool isi1,
const ValueToValueMapTy &available,
Value *extraSize, Value *extraOffset) {
Value *CacheUtility::lookupValueFromCache(
Type *T, bool inForwardPass, IRBuilder<> &BuilderM, LimitContext ctx,
Value *cache, bool isi1, const ValueToValueMapTy &available,
Value *extraSize, Value *extraOffset) {
// Get the underlying cache pointer
auto cptr =
getCachePointer(inForwardPass, BuilderM, ctx, cache, isi1,
Expand All @@ -1652,15 +1649,14 @@ CacheUtility::lookupValueFromCache(bool inForwardPass, IRBuilder<> &BuilderM,
// Optionally apply the additional offset
if (extraOffset) {
#if LLVM_VERSION_MAJOR > 7
cptr = BuilderM.CreateGEP(cptr->getType()->getPointerElementType(), cptr,
extraOffset);
cptr = BuilderM.CreateGEP(T, cptr, extraOffset);
#else
cptr = BuilderM.CreateGEP(cptr, extraOffset);
#endif
cast<GetElementPtrInst>(cptr)->setIsInBounds(true);
}

Value *result = loadFromCachePointer(BuilderM, cptr, cache);
Value *result = loadFromCachePointer(T, BuilderM, cptr, cache);

// If using the efficient bool cache, do the corresponding
// mask and shift to retrieve the actual value
Expand Down
13 changes: 8 additions & 5 deletions enzyme/Enzyme/CacheUtility.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ class CacheUtility {

/// Perform the final load from the cache, applying requisite invariant
/// group and alignment
llvm::Value *loadFromCachePointer(llvm::IRBuilder<> &BuilderM,
llvm::Value *loadFromCachePointer(llvm::Type *T, llvm::IRBuilder<> &BuilderM,
llvm::Value *cptr, llvm::Value *cache);

public:
Expand Down Expand Up @@ -387,10 +387,13 @@ class CacheUtility {

/// Given an allocation specified by the LimitContext ctx and cache, lookup
/// the underlying cached value.
llvm::Value *lookupValueFromCache(
bool inForwardPass, llvm::IRBuilder<> &BuilderM, LimitContext ctx,
llvm::Value *cache, bool isi1, const llvm::ValueToValueMapTy &available,
llvm::Value *extraSize = nullptr, llvm::Value *extraOffset = nullptr);
llvm::Value *lookupValueFromCache(llvm::Type *T, bool inForwardPass,
llvm::IRBuilder<> &BuilderM,
LimitContext ctx, llvm::Value *cache,
bool isi1,
const llvm::ValueToValueMapTy &available,
llvm::Value *extraSize = nullptr,
llvm::Value *extraOffset = nullptr);

protected:
// List of values loaded from the cache
Expand Down
81 changes: 79 additions & 2 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,12 @@ static bool replaceOriginalCall(CallInst *CI, Function *fn, Value *diffret,
} else if (CI->hasStructRetAttr()) {
Value *sret = CI->getArgOperand(0);
PointerType *stype = cast<PointerType>(sret->getType());
StructType *st = dyn_cast<StructType>(stype->getPointerElementType());
#if LLVM_VERSION_MAJOR >= 15
auto sret_ty = CI->getParamStructRetType(0);
#else
auto sret_ty = stype->getPointerElementType();
#endif
StructType *st = dyn_cast<StructType>(sret_ty);

// Assign results to struct allocated at the call site.
if (st && st->isLayoutIdentical(diffretsty)) {
Expand All @@ -625,7 +630,7 @@ static bool replaceOriginalCall(CallInst *CI, Function *fn, Value *diffret,
}
} else {
auto &DL = fn->getParent()->getDataLayout();
if (DL.getTypeSizeInBits(stype->getPointerElementType()) !=
if (DL.getTypeSizeInBits(sret_ty) !=
DL.getTypeSizeInBits(diffret->getType())) {
EmitFailure("IllegalReturnCast", CI->getDebugLoc(), CI,
"Cannot cast return type of gradient ",
Expand Down Expand Up @@ -2547,10 +2552,82 @@ class EnzymeNewPM final : public EnzymeBase,

AnalysisKey EnzymeNewPM::Key;

#ifdef ENZYME_RUNPASS
#include "PreserveNVVM.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Transforms/IPO/GlobalOpt.h"
#include "llvm/Transforms/Scalar/GVN.h"
#include "llvm/Transforms/Scalar/LoopDeletion.h"
#include "llvm/Transforms/Scalar/SROA.h"
#endif

extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
llvmGetPassPluginInfo() {
return {LLVM_PLUGIN_API_VERSION, "EnzymeNewPM", "v0.1",
[](llvm::PassBuilder &PB) {
#ifdef ENZYME_RUNPASS
#if LLVM_VERSION_MAJOR < 14
using OptimizationLevel = llvm::PassBuilder::OptimizationLevel;
#endif
#if LLVM_VERSION_MAJOR >= 12
auto loadPass = [](ModulePassManager &MPM, OptimizationLevel)
#else
auto loadPass = [](ModulePassManager &MPM)
#endif
{
MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true));
FunctionPassManager OptimizerPM;
FunctionPassManager OptimizerPM2;
#if LLVM_VERSION_MAJOR >= 14
OptimizerPM.addPass(llvm::GVNPass());
OptimizerPM.addPass(llvm::SROAPass());
#else
OptimizerPM.addPass(llvm::GVN());
OptimizerPM.addPass(llvm::SROA());
#endif
MPM.addPass(
createModuleToFunctionPassAdaptor(std::move(OptimizerPM)));
MPM.addPass(EnzymeNewPM(/*PostOpt=*/true));
MPM.addPass(PreserveNVVMNewPM(/*Begin*/ false));
#if LLVM_VERSION_MAJOR >= 14
OptimizerPM2.addPass(llvm::GVNPass());
OptimizerPM2.addPass(llvm::SROAPass());
#else
OptimizerPM2.addPass(llvm::GVN());
OptimizerPM2.addPass(llvm::SROA());
#endif

LoopPassManager LPM1;
LPM1.addPass(LoopDeletionPass());
OptimizerPM2.addPass(
createFunctionToLoopPassAdaptor(std::move(LPM1)));

MPM.addPass(
createModuleToFunctionPassAdaptor(std::move(OptimizerPM2)));
MPM.addPass(GlobalOptPass());
};
// TODO need for perf reasons to move Enzyme pass to the pre vectorization.
#if LLVM_VERSION_MAJOR >= 12
PB.registerPipelineEarlySimplificationEPCallback(loadPass);
#else
PB.registerPipelineStartEPCallback(loadPass);
#endif

#if LLVM_VERSION_MAJOR >= 12
auto loadNVVM = [](ModulePassManager &MPM, OptimizationLevel)
#else
auto loadNVVM = [](ModulePassManager &MPM)
#endif
{ MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true)); };

// We should register at vectorizer start for consistency, however,
// that requires a functionpass, and we have a modulepass.
// PB.registerVectorizerStartEPCallback(loadPass);
PB.registerPipelineStartEPCallback(loadNVVM);
#if LLVM_VERSION_MAJOR >= 15
PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadNVVM);
#endif
#endif
PB.registerPipelineParsingCallback(
[](llvm::StringRef Name, llvm::ModulePassManager &MPM,
llvm::ArrayRef<llvm::PassBuilder::PipelineElement>) {
Expand Down
Loading

0 comments on commit d54daa7

Please sign in to comment.