Skip to content

Commit

Permalink
PHI cache optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 17, 2021
1 parent 674ea58 commit e1f76d6
Show file tree
Hide file tree
Showing 13 changed files with 495 additions and 24 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/enzyme-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,7 @@ jobs:
run: cd enzyme/build && make -j`nproc`
- name: make check-typeanalysis
run: cd enzyme/build && make check-typeanalysis -j`nproc`
- name: make check-activityanalysis
run: cd enzyme/build && make check-activityanalysis -j`nproc`
- name: make check-enzyme
run: cd enzyme/build && make check-enzyme -j`nproc`
3 changes: 2 additions & 1 deletion enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ std::set<std::string> KnownInactiveFunctions = {"__assert_fail",
"MPI_Finalize",
"_msize",
"ftnio_fmt_write64",
"f90_strcmp_klen"};
"f90_strcmp_klen",
"vprintf"};

/// Is the use of value val as an argument of call CI known to be inactive
/// This tool can only be used when in DOWN mode
Expand Down
170 changes: 170 additions & 0 deletions enzyme/Enzyme/ActivityAnalysisPrinter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
//===- TypeAnalysisPrinter.cpp - Printer utility pass for Type Analysis
//----===//
//
// Enzyme Project
//
// Part of the Enzyme Project, under the Apache License v2.0 with LLVM
// Exceptions. See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// If using this code in an academic setting, please cite the following:
// @incollection{enzymeNeurips,
// title = {Instead of Rewriting Foreign Code for Machine Learning,
// Automatically Synthesize Fast Gradients},
// author = {Moses, William S. and Churavy, Valentin},
// booktitle = {Advances in Neural Information Processing Systems 33},
// year = {2020},
// note = {To appear in},
// }
//
//===----------------------------------------------------------------------===//
//
// This file contains a utility LLVM pass for printing derived Type Analysis
// results of a given function.
//
//===----------------------------------------------------------------------===//
#include <llvm/Config/llvm-config.h>

#include "SCEV/ScalarEvolution.h"
#include "SCEV/ScalarEvolutionExpander.h"

#include "llvm/ADT/SmallVector.h"

#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"

#include "llvm/Support/Debug.h"
#include "llvm/Transforms/Scalar.h"

#include "llvm/Analysis/BasicAliasAnalysis.h"
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/ScalarEvolution.h"

#include "llvm/Support/CommandLine.h"

#include "ActivityAnalysis.h"
#include "FunctionUtils.h"
#include "TypeAnalysis/TypeAnalysis.h"
#include "Utils.h"

using namespace llvm;
#ifdef DEBUG_TYPE
#undef DEBUG_TYPE
#endif
#define DEBUG_TYPE "activity-analysis-results"

/// Function TypeAnalysis will be starting its run from
static llvm::cl::opt<std::string>
FunctionToAnalyze("activity-analysis-func", cl::init(""), cl::Hidden,
cl::desc("Which function to analyze/print"));

namespace {

class ActivityAnalysisPrinter : public FunctionPass {
public:
static char ID;
ActivityAnalysisPrinter() : FunctionPass(ID) {}

void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<TargetLibraryInfoWrapperPass>();
}

bool runOnFunction(Function &F) override {
if (F.getName() != FunctionToAnalyze)
return /*changed*/ false;

#if LLVM_VERSION_MAJOR >= 10
auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
#else
auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
#endif

FnTypeInfo type_args(&F);
for (auto &a : type_args.Function->args()) {
TypeTree dt;
if (a.getType()->isFPOrFPVectorTy()) {
dt = ConcreteType(a.getType()->getScalarType());
} else if (a.getType()->isPointerTy()) {
auto et = cast<PointerType>(a.getType())->getElementType();
if (et->isFPOrFPVectorTy()) {
dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1);
} else if (et->isPointerTy()) {
dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1);
}
} else if (a.getType()->isIntOrIntVectorTy()) {
dt = ConcreteType(BaseType::Integer);
}
type_args.Arguments.insert(
std::pair<Argument *, TypeTree>(&a, dt.Only(-1)));
// TODO note that here we do NOT propagate constants in type info (and
// should consider whether we should)
type_args.KnownValues.insert(
std::pair<Argument *, std::set<int64_t>>(&a, {}));
}

TypeTree dt;
if (F.getReturnType()->isFPOrFPVectorTy()) {
dt = ConcreteType(F.getReturnType()->getScalarType());
} else if (F.getReturnType()->isPointerTy()) {
auto et = cast<PointerType>(F.getReturnType())->getElementType();
if (et->isFPOrFPVectorTy()) {
dt = TypeTree(ConcreteType(et->getScalarType())).Only(-1);
} else if (et->isPointerTy()) {
dt = TypeTree(ConcreteType(BaseType::Pointer)).Only(-1);
}
} else if (F.getReturnType()->isIntOrIntVectorTy()) {
dt = ConcreteType(BaseType::Integer);
}
type_args.Return = dt.Only(-1);

TypeAnalysis TA(TLI);
TypeResults TR = TA.analyzeFunction(type_args);

llvm::SmallPtrSet<llvm::Value *, 4> ConstantValues;
llvm::SmallPtrSet<llvm::Value *, 4> ActiveValues;
for (auto &a : type_args.Function->args()) {
if (a.getType()->isIntOrIntVectorTy()) {
ConstantValues.insert(&a);
} else {
ActiveValues.insert(&a);
}
}

PreProcessCache PPC;
bool ActiveReturns = F.getReturnType()->isFPOrFPVectorTy();
ActivityAnalyzer ATA(PPC.FAM.getResult<AAManager>(F), TLI, ConstantValues,
ActiveValues, ActiveReturns);

for (auto &a : F.args()) {
bool icv = ATA.isConstantValue(TR, &a);
llvm::errs().flush();
llvm::outs() << a << ": icv:" << icv << "\n";
llvm::outs().flush();
}
for (auto &BB : F) {
llvm::outs() << BB.getName() << "\n";
for (auto &I : BB) {
bool ici = ATA.isConstantInstruction(TR, &I);
bool icv = ATA.isConstantValue(TR, &I);
llvm::errs().flush();
llvm::outs() << I << ": icv:" << icv << " ici:" << ici << "\n";
llvm::outs().flush();
}
}
return /*changed*/ false;
}
};

} // namespace

char ActivityAnalysisPrinter::ID = 0;

static RegisterPass<ActivityAnalysisPrinter>
X("print-activity-analysis", "Print Activity Analysis Results");
141 changes: 138 additions & 3 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@

#include "llvm/Analysis/TypeBasedAliasAnalysis.h"

#include "llvm/Analysis/CFLSteensAliasAnalysis.h"

#if LLVM_VERSION_MAJOR > 6
#include "llvm/Analysis/PhiValues.h"
#endif
Expand Down Expand Up @@ -121,6 +123,12 @@ static cl::opt<int>
EnzymeInlineCount("enzyme-inline-count", cl::init(10000), cl::Hidden,
cl::desc("Limit of number of functions to inline"));

#if LLVM_VERSION_MAJOR >= 8
static cl::opt<bool> EnzymePHIRestructure(
"enzyme-phi-restructure", cl::init(false), cl::Hidden,
cl::desc("Whether to restructure phi's to have better unwrap behavior"));
#endif

/// Is the use of value val as an argument of call CI potentially captured
bool couldFunctionArgumentCapture(llvm::CallInst *CI, llvm::Value *val) {
Function *F = CI->getCalledFunction();
Expand Down Expand Up @@ -411,8 +419,6 @@ OldAllocationSize(Value *Ptr, CallInst *Loc, Function *NewF, IntegerType *T,
continue;
}

// llvm::errs() << *NewF->getParent() << "\n";
// llvm::errs() << *NewF << "\n";
EmitFailure("DynamicReallocSize", Loc->getDebugLoc(), Loc,
"could not statically determine size of realloc ", *Loc,
" - because of - ", *next.first);
Expand Down Expand Up @@ -667,12 +673,17 @@ PreProcessCache::PreProcessCache() {
// disable for now, consider enabling in future
// FAM.registerPass([] { return SCEVAA(); });

// FAM.registerPass([] { return CFLSteensAA(); });

FAM.registerPass([] {
auto AM = AAManager();
AM.registerFunctionAnalysis<BasicAA>();
AM.registerFunctionAnalysis<TypeBasedAA>();
// AM.registerFunctionAnalysis<SCEVAA>();
AM.registerModuleAnalysis<GlobalsAA>();

// broken for different reasons
// AM.registerFunctionAnalysis<SCEVAA>();
// AM.registerFunctionAnalysis<CFLSteensAA>();
return AM;
});

Expand Down Expand Up @@ -1193,6 +1204,130 @@ Function *PreProcessCache::preprocessForClone(Function *F, bool topLevel) {
FAM.invalidate(*NewF, PA);
}

#if LLVM_VERSION_MAJOR >= 8
if (EnzymePHIRestructure) {
if (false) {
reset:;
PreservedAnalyses PA;
FAM.invalidate(*NewF, PA);
}

SmallVector<BasicBlock *, 4> MultiBlocks;
for (auto &B : *NewF) {
if (B.hasNPredecessorsOrMore(3))
MultiBlocks.push_back(&B);
}

LoopInfo &LI = FAM.getResult<LoopAnalysis>(*NewF);
for (BasicBlock *B : MultiBlocks) {

// Map of function edges to list of values possible
std::map<std::pair</*pred*/ BasicBlock *, /*successor*/ BasicBlock *>,
std::set<BasicBlock *>>
done;
{
std::deque<std::tuple<
std::pair</*pred*/ BasicBlock *, /*successor*/ BasicBlock *>,
BasicBlock *>>
Q; // newblock, target

for (auto P : predecessors(B)) {
Q.emplace_back(std::make_pair(P, B), P);
}

for (std::tuple<
std::pair</*pred*/ BasicBlock *, /*successor*/ BasicBlock *>,
BasicBlock *>
trace;
Q.size() > 0;) {
trace = Q.front();
Q.pop_front();
auto edge = std::get<0>(trace);
auto block = edge.first;
auto target = std::get<1>(trace);

if (done[edge].count(target))
continue;
done[edge].insert(target);

Loop *blockLoop = LI.getLoopFor(block);

for (BasicBlock *Pred : predecessors(block)) {
// Don't go up the backedge as we can use the last value if desired
// via lcssa
if (blockLoop && blockLoop->getHeader() == block &&
blockLoop == LI.getLoopFor(Pred))
continue;

Q.push_back(
std::tuple<std::pair<BasicBlock *, BasicBlock *>, BasicBlock *>(
std::make_pair(Pred, block), target));
}
}
}

SmallPtrSet<BasicBlock *, 4> Preds;
for (auto &pair : done) {
Preds.insert(pair.first.first);
}

for (auto BB : Preds) {
bool illegal = false;
SmallPtrSet<BasicBlock *, 2> UnionSet;
size_t numSuc = 0;
for (BasicBlock *sucI : successors(BB)) {
numSuc++;
const auto &SI = done[std::make_pair(BB, sucI)];
if (SI.size() == 0) {
// sucI->getName();
illegal = true;
break;
}
for (auto si : SI) {
UnionSet.insert(si);

for (BasicBlock *sucJ : successors(BB)) {
if (sucI == sucJ)
continue;
if (done[std::make_pair(BB, sucJ)].count(si)) {
illegal = true;
goto endIllegal;
}
}
}
}
endIllegal:;

if (!illegal && numSuc > 1 && !B->hasNPredecessors(UnionSet.size())) {
BasicBlock *Ins =
BasicBlock::Create(BB->getContext(), "tmpblk", BB->getParent());
IRBuilder<> Builder(Ins);
for (auto &phi : B->phis()) {
auto nphi = Builder.CreatePHI(phi.getType(), 2);
SmallVector<BasicBlock *, 4> Blocks;

for (auto blk : UnionSet) {
nphi->addIncoming(phi.getIncomingValueForBlock(blk), blk);
phi.removeIncomingValue(blk, /*deleteifempty*/ false);
}

phi.addIncoming(nphi, Ins);
}
Builder.CreateBr(B);
for (auto blk : UnionSet) {
auto term = blk->getTerminator();
for (unsigned Idx = 0, NumSuccessors = term->getNumSuccessors();
Idx != NumSuccessors; ++Idx)
if (term->getSuccessor(Idx) == B)
term->setSuccessor(Idx, Ins);
}
goto reset;
}
}
}
}
#endif

if (EnzymePrint)
llvm::errs() << "after simplification :\n" << *NewF << "\n";

Expand Down
Loading

0 comments on commit e1f76d6

Please sign in to comment.