Skip to content

Commit

Permalink
Additional julia updates
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 17, 2021
1 parent 0b2624f commit a18e093
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 26 deletions.
18 changes: 18 additions & 0 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
//===----------------------------------------------------------------------===//
#include "CApi.h"
#include "EnzymeLogic.h"
#include "LibraryFuncs.h"
#include "SCEV/TargetLibraryInfo.h"

#include "llvm/ADT/Triple.h"
Expand Down Expand Up @@ -184,6 +185,23 @@ void FreeTypeAnalysis(EnzymeTypeAnalysisRef TAR) {
delete TA;
}

void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle,
CustomShadowFree FHandle) {
shadowHandlers[std::string(Name)] =
[=](IRBuilder<> &B, CallInst *CI,
ArrayRef<Value *> Args) -> llvm::Value * {
SmallVector<LLVMValueRef, 3> refs;
for (auto a : Args)
refs.push_back(wrap(a));
return unwrap(AHandle(wrap(&B), wrap(CI), Args.size(), refs.data()));
};
shadowErasers[std::string(Name)] = [=](IRBuilder<> &B, Value *ToFree,
Function *AllocF) -> llvm::CallInst * {
return cast_or_null<CallInst>(
unwrap(FHandle(wrap(&B), wrap(ToFree), wrap(AllocF))));
};
}

LLVMValueRef EnzymeCreatePrimalAndGradient(
LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
size_t constant_args_size, EnzymeTypeAnalysisRef TA,
Expand Down
8 changes: 8 additions & 0 deletions enzyme/Enzyme/CApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ LLVMValueRef
EnzymeExtractFunctionFromAugmentation(EnzymeAugmentedReturnPtr ret);
LLVMTypeRef EnzymeExtractTapeTypeFromAugmentation(EnzymeAugmentedReturnPtr ret);

typedef LLVMValueRef (*CustomShadowAlloc)(LLVMBuilderRef, LLVMValueRef,
size_t /*numArgs*/, LLVMValueRef *);
typedef LLVMValueRef (*CustomShadowFree)(LLVMBuilderRef, LLVMValueRef,
LLVMValueRef);

void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle,
CustomShadowFree FHandle);

#ifdef __cplusplus
}
#endif
Expand Down
7 changes: 7 additions & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@

#include <algorithm>

std::map<std::string, std::function<llvm::Value *(IRBuilder<> &, CallInst *,
ArrayRef<Value *>)>>
shadowHandlers;
std::map<std::string,
std::function<llvm::CallInst *(IRBuilder<> &, Value *, Function *)>>
shadowErasers;

llvm::cl::opt<bool>
EnzymeNewCache("enzyme-new-cache", cl::init(true), cl::Hidden,
cl::desc("Use new cache decision algorithm"));
Expand Down
38 changes: 33 additions & 5 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ using namespace llvm;

enum class DerivativeMode { Forward, Reverse, Both };

#include "llvm-c/Core.h"

extern std::map<std::string, std::function<llvm::Value *(
IRBuilder<> &, CallInst *, ArrayRef<Value *>)>>
shadowHandlers;

static inline std::string to_string(DerivativeMode mode) {
switch (mode) {
case DerivativeMode::Forward:
Expand Down Expand Up @@ -525,6 +531,22 @@ class GradientUtils : public CacheUtility {
for (unsigned i = 0; i < orig->getNumArgOperands(); ++i) {
args.push_back(getNewFromOriginal(orig->getArgOperand(i)));
}

if (shadowHandlers.find(orig->getCalledFunction()->getName().str()) !=
shadowHandlers.end()) {
Value *anti = shadowHandlers[orig->getCalledFunction()->getName().str()](
bb, orig, args);
invertedPointers[orig] = anti;
// assert(placeholder != anti);
bb.SetInsertPoint(placeholder->getNextNode());
replaceAWithB(placeholder, anti);
erase(placeholder);

anti = cacheForReverse(bb, anti, idx);
invertedPointers[orig] = anti;
return anti;
}

Value *anti =
bb.CreateCall(orig->getCalledFunction(), args, orig->getName() + "'mi");
cast<CallInst>(anti)->setAttributes(orig->getAttributes());
Expand Down Expand Up @@ -575,16 +597,22 @@ class GradientUtils : public CacheUtility {
*orig);
}
}
auto dst_arg = bb.CreateBitCast(
anti, Type::getInt8PtrTy(orig->getContext(),
anti->getType()->getPointerAddressSpace()));

Value *dst_arg = anti;

dst_arg = bb.CreateBitCast(
dst_arg,
Type::getInt8PtrTy(orig->getContext(),
anti->getType()->getPointerAddressSpace()));

auto val_arg = ConstantInt::get(Type::getInt8Ty(orig->getContext()), 0);
Value *size;
// todo check if this memset is legal and if a write barrier is needed
if (orig->getCalledFunction()->getName() == "julia.gc_alloc_obj")
if (orig->getCalledFunction()->getName() == "julia.gc_alloc_obj") {
size = args[1];
else
} else {
size = args[0];
}
auto len_arg =
bb.CreateZExtOrTrunc(size, Type::getInt64Ty(orig->getContext()));
auto volatile_arg = ConstantInt::getFalse(orig->getContext());
Expand Down
18 changes: 17 additions & 1 deletion enzyme/Enzyme/LibraryFuncs.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"

extern std::map<std::string, std::function<llvm::Value *(
llvm::IRBuilder<> &, llvm::CallInst *,
llvm::ArrayRef<llvm::Value *>)>>
shadowHandlers;
extern std::map<std::string,
std::function<llvm::CallInst *(
llvm::IRBuilder<> &, llvm::Value *, llvm::Function *)>>
shadowErasers;

/// Return whether a given function is a known C/C++ memory allocation function
/// For updating below one should read MemoryBuiltins.cpp, TargetLibraryInfo.cpp
static inline bool isAllocationFunction(const llvm::Function &F,
Expand All @@ -38,6 +47,9 @@ static inline bool isAllocationFunction(const llvm::Function &F,
return true;
if (F.getName() == "julia.gc_alloc_obj")
return true;
if (shadowHandlers.find(F.getName().str()) != shadowHandlers.end())
return true;

using namespace llvm;
llvm::LibFunc libfunc;
if (!TLI.getLibFunc(F, libfunc))
Expand Down Expand Up @@ -193,8 +205,12 @@ freeKnownAllocation(llvm::IRBuilder<> &builder, llvm::Value *tofree,
allocationfn.getName() == "__rust_alloc_zeroed") {
llvm_unreachable("todo - hook in rust allocation fns");
}
if (allocationfn.getName() == "julia.gc_alloc_obj") {
if (allocationfn.getName() == "julia.gc_alloc_obj")
return nullptr;

if (shadowErasers.find(allocationfn.getName().str()) != shadowErasers.end()) {
return shadowErasers[allocationfn.getName().str()](builder, tofree,
&allocationfn);
}

llvm::LibFunc libfunc;
Expand Down
41 changes: 21 additions & 20 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2500,26 +2500,7 @@ void TypeAnalyzer::visitCallInst(CallInst &call) {
analyzeFuncTypes<__VA_ARGS__>(::fn, call, *this); \
return; \
}
// All these are always valid => no direction check
// CONSIDER(malloc)
// TODO consider handling other allocation functions integer inputs
if (isAllocationFunction(*ci, interprocedural.TLI)) {
size_t Idx = 0;
for (auto &Arg : ci->args()) {
if (Arg.getType()->isIntegerTy()) {
updateAnalysis(call.getOperand(Idx),
TypeTree(BaseType::Integer).Only(-1), &call);
}
Idx++;
}
assert(ci->getReturnType()->isPointerTy());
updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1), &call);
return;
}
if (ci->getName().startswith("_ZN3std2io5stdio6_print") ||
ci->getName().startswith("_ZN4core3fmt")) {
return;
}

auto customrule = interprocedural.CustomRules.find(ci->getName().str());
if (customrule != interprocedural.CustomRules.end()) {
auto returnAnalysis = getAnalysis(&call);
Expand All @@ -2544,6 +2525,26 @@ void TypeAnalyzer::visitCallInst(CallInst &call) {
}
return;
}
// All these are always valid => no direction check
// CONSIDER(malloc)
// TODO consider handling other allocation functions integer inputs
if (isAllocationFunction(*ci, interprocedural.TLI)) {
size_t Idx = 0;
for (auto &Arg : ci->args()) {
if (Arg.getType()->isIntegerTy()) {
updateAnalysis(call.getOperand(Idx),
TypeTree(BaseType::Integer).Only(-1), &call);
}
Idx++;
}
assert(ci->getReturnType()->isPointerTy());
updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1), &call);
return;
}
if (ci->getName().startswith("_ZN3std2io5stdio6_print") ||
ci->getName().startswith("_ZN4core3fmt")) {
return;
}
/// MPI
if (ci->getName() == "MPI_Init") {
TypeTree ptrint;
Expand Down

0 comments on commit a18e093

Please sign in to comment.