diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 209e3d3a02c4..081a665d80cf 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -23,6 +23,7 @@ //===----------------------------------------------------------------------===// #include "CApi.h" #include "EnzymeLogic.h" +#include "LibraryFuncs.h" #include "SCEV/TargetLibraryInfo.h" #include "llvm/ADT/Triple.h" @@ -184,6 +185,23 @@ void FreeTypeAnalysis(EnzymeTypeAnalysisRef TAR) { delete TA; } +void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle, + CustomShadowFree FHandle) { + shadowHandlers[std::string(Name)] = + [=](IRBuilder<> &B, CallInst *CI, + ArrayRef Args) -> llvm::Value * { + SmallVector refs; + for (auto a : Args) + refs.push_back(wrap(a)); + return unwrap(AHandle(wrap(&B), wrap(CI), Args.size(), refs.data())); + }; + shadowErasers[std::string(Name)] = [=](IRBuilder<> &B, Value *ToFree, + Function *AllocF) -> llvm::CallInst * { + return cast_or_null( + unwrap(FHandle(wrap(&B), wrap(ToFree), wrap(AllocF)))); + }; +} + LLVMValueRef EnzymeCreatePrimalAndGradient( LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args, size_t constant_args_size, EnzymeTypeAnalysisRef TA, diff --git a/enzyme/Enzyme/CApi.h b/enzyme/Enzyme/CApi.h index c32ac6bc70ec..7ff0f54eb4a7 100644 --- a/enzyme/Enzyme/CApi.h +++ b/enzyme/Enzyme/CApi.h @@ -145,6 +145,14 @@ LLVMValueRef EnzymeExtractFunctionFromAugmentation(EnzymeAugmentedReturnPtr ret); LLVMTypeRef EnzymeExtractTapeTypeFromAugmentation(EnzymeAugmentedReturnPtr ret); +typedef LLVMValueRef (*CustomShadowAlloc)(LLVMBuilderRef, LLVMValueRef, + size_t /*numArgs*/, LLVMValueRef *); +typedef LLVMValueRef (*CustomShadowFree)(LLVMBuilderRef, LLVMValueRef, + LLVMValueRef); + +void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle, + CustomShadowFree FHandle); + #ifdef __cplusplus } #endif diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index bed206d826d7..0c6ab462c27b 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -45,6 +45,13 @@ #include +std::map &, CallInst *, + ArrayRef)>> + shadowHandlers; +std::map &, Value *, Function *)>> + shadowErasers; + llvm::cl::opt EnzymeNewCache("enzyme-new-cache", cl::init(true), cl::Hidden, cl::desc("Use new cache decision algorithm")); diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index 6c2b1ae44526..499fee6cea62 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -73,6 +73,12 @@ using namespace llvm; enum class DerivativeMode { Forward, Reverse, Both }; +#include "llvm-c/Core.h" + +extern std::map &, CallInst *, ArrayRef)>> + shadowHandlers; + static inline std::string to_string(DerivativeMode mode) { switch (mode) { case DerivativeMode::Forward: @@ -525,6 +531,22 @@ class GradientUtils : public CacheUtility { for (unsigned i = 0; i < orig->getNumArgOperands(); ++i) { args.push_back(getNewFromOriginal(orig->getArgOperand(i))); } + + if (shadowHandlers.find(orig->getCalledFunction()->getName().str()) != + shadowHandlers.end()) { + Value *anti = shadowHandlers[orig->getCalledFunction()->getName().str()]( + bb, orig, args); + invertedPointers[orig] = anti; + // assert(placeholder != anti); + bb.SetInsertPoint(placeholder->getNextNode()); + replaceAWithB(placeholder, anti); + erase(placeholder); + + anti = cacheForReverse(bb, anti, idx); + invertedPointers[orig] = anti; + return anti; + } + Value *anti = bb.CreateCall(orig->getCalledFunction(), args, orig->getName() + "'mi"); cast(anti)->setAttributes(orig->getAttributes()); @@ -575,16 +597,22 @@ class GradientUtils : public CacheUtility { *orig); } } - auto dst_arg = bb.CreateBitCast( - anti, Type::getInt8PtrTy(orig->getContext(), - anti->getType()->getPointerAddressSpace())); + + Value *dst_arg = anti; + + dst_arg = bb.CreateBitCast( + dst_arg, + Type::getInt8PtrTy(orig->getContext(), + anti->getType()->getPointerAddressSpace())); + auto val_arg = ConstantInt::get(Type::getInt8Ty(orig->getContext()), 0); Value *size; // todo check if this memset is legal and if a write barrier is needed - if (orig->getCalledFunction()->getName() == "julia.gc_alloc_obj") + if (orig->getCalledFunction()->getName() == "julia.gc_alloc_obj") { size = args[1]; - else + } else { size = args[0]; + } auto len_arg = bb.CreateZExtOrTrunc(size, Type::getInt64Ty(orig->getContext())); auto volatile_arg = ConstantInt::getFalse(orig->getContext()); diff --git a/enzyme/Enzyme/LibraryFuncs.h b/enzyme/Enzyme/LibraryFuncs.h index ce00580ac1bf..25ec662f346e 100644 --- a/enzyme/Enzyme/LibraryFuncs.h +++ b/enzyme/Enzyme/LibraryFuncs.h @@ -28,6 +28,15 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" +extern std::map &, llvm::CallInst *, + llvm::ArrayRef)>> + shadowHandlers; +extern std::map &, llvm::Value *, llvm::Function *)>> + shadowErasers; + /// Return whether a given function is a known C/C++ memory allocation function /// For updating below one should read MemoryBuiltins.cpp, TargetLibraryInfo.cpp static inline bool isAllocationFunction(const llvm::Function &F, @@ -38,6 +47,9 @@ static inline bool isAllocationFunction(const llvm::Function &F, return true; if (F.getName() == "julia.gc_alloc_obj") return true; + if (shadowHandlers.find(F.getName().str()) != shadowHandlers.end()) + return true; + using namespace llvm; llvm::LibFunc libfunc; if (!TLI.getLibFunc(F, libfunc)) @@ -193,8 +205,12 @@ freeKnownAllocation(llvm::IRBuilder<> &builder, llvm::Value *tofree, allocationfn.getName() == "__rust_alloc_zeroed") { llvm_unreachable("todo - hook in rust allocation fns"); } - if (allocationfn.getName() == "julia.gc_alloc_obj") { + if (allocationfn.getName() == "julia.gc_alloc_obj") return nullptr; + + if (shadowErasers.find(allocationfn.getName().str()) != shadowErasers.end()) { + return shadowErasers[allocationfn.getName().str()](builder, tofree, + &allocationfn); } llvm::LibFunc libfunc; diff --git a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp index 8fcbbad88051..bb85fb8e7e46 100644 --- a/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp +++ b/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp @@ -2500,26 +2500,7 @@ void TypeAnalyzer::visitCallInst(CallInst &call) { analyzeFuncTypes<__VA_ARGS__>(::fn, call, *this); \ return; \ } - // All these are always valid => no direction check - // CONSIDER(malloc) - // TODO consider handling other allocation functions integer inputs - if (isAllocationFunction(*ci, interprocedural.TLI)) { - size_t Idx = 0; - for (auto &Arg : ci->args()) { - if (Arg.getType()->isIntegerTy()) { - updateAnalysis(call.getOperand(Idx), - TypeTree(BaseType::Integer).Only(-1), &call); - } - Idx++; - } - assert(ci->getReturnType()->isPointerTy()); - updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1), &call); - return; - } - if (ci->getName().startswith("_ZN3std2io5stdio6_print") || - ci->getName().startswith("_ZN4core3fmt")) { - return; - } + auto customrule = interprocedural.CustomRules.find(ci->getName().str()); if (customrule != interprocedural.CustomRules.end()) { auto returnAnalysis = getAnalysis(&call); @@ -2544,6 +2525,26 @@ void TypeAnalyzer::visitCallInst(CallInst &call) { } return; } + // All these are always valid => no direction check + // CONSIDER(malloc) + // TODO consider handling other allocation functions integer inputs + if (isAllocationFunction(*ci, interprocedural.TLI)) { + size_t Idx = 0; + for (auto &Arg : ci->args()) { + if (Arg.getType()->isIntegerTy()) { + updateAnalysis(call.getOperand(Idx), + TypeTree(BaseType::Integer).Only(-1), &call); + } + Idx++; + } + assert(ci->getReturnType()->isPointerTy()); + updateAnalysis(&call, TypeTree(BaseType::Pointer).Only(-1), &call); + return; + } + if (ci->getName().startswith("_ZN3std2io5stdio6_print") || + ci->getName().startswith("_ZN4core3fmt")) { + return; + } /// MPI if (ci->getName() == "MPI_Init") { TypeTree ptrint;