diff --git a/Readme.md b/Readme.md index 7dd210b8b7e..544e12dfc5d 100644 --- a/Readme.md +++ b/Readme.md @@ -39,6 +39,10 @@ brew install enzyme ``` spack install enzyme ``` +[Nix](https://nixos.org/) +``` +nix-shell -p enzyme +``` To get involved or if you have questions, please join our [mailing list](https://groups.google.com/d/forum/enzyme-dev). diff --git a/enzyme/.bazelversion b/enzyme/.bazelversion new file mode 100644 index 00000000000..f22d756da39 --- /dev/null +++ b/enzyme/.bazelversion @@ -0,0 +1 @@ +6.5.0 diff --git a/enzyme/Enzyme/ActivityAnalysis.cpp b/enzyme/Enzyme/ActivityAnalysis.cpp index 2a22512732b..44e88abe351 100644 --- a/enzyme/Enzyme/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/ActivityAnalysis.cpp @@ -579,6 +579,11 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) { if (Name == "jl_reshape_array" || Name == "ijl_reshape_array") return val != CI->getArgOperand(1); + // Only the 0-th arg impacts activity + if (Name == "jl_genericmemory_copy_slice" || + Name == "ijl_genericmemory_copy_slice") + return val != CI->getArgOperand(0); + // Allocations, deallocations, and c++ guards don't impact the activity // of arguments if (isAllocationFunction(Name, TLI) || isDeallocationFunction(Name, TLI)) @@ -660,6 +665,13 @@ static inline void propagateArgumentInformation( return; } + // Only the 0-th arg impacts activity + if (Name == "jl_genericmemory_copy_slice" || + Name == "ijl_genericmemory_copy_slice") { + propagateFromOperand(CI.getArgOperand(0)); + return; + } + // Only the 1-th arg impacts activity if (Name == "jl_reshape_array" || Name == "ijl_reshape_array") { propagateFromOperand(CI.getArgOperand(1)); @@ -1554,6 +1566,26 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { ReEvaluateValueIfInactiveValue[II->getOperand(0)].insert(TmpOrig); } } + } else if (auto RMW = dyn_cast(TmpOrig)) { + if (directions == UP) { + if (isConstantValue(TR, RMW->getPointerOperand())) { + InsertConstantValue(TR, Val); + return true; + } + } else { + if (UpHypothesis->isConstantValue(TR, RMW->getPointerOperand())) { + InsertConstantValue(TR, Val); + insertConstantsFrom(TR, *UpHypothesis); + return true; + } + } + if (EnzymeEnableRecursiveHypotheses) { + ReEvaluateValueIfInactiveValue[RMW->getPointerOperand()].insert(Val); + if (TmpOrig != Val) { + ReEvaluateValueIfInactiveValue[RMW->getPointerOperand()].insert( + TmpOrig); + } + } } else if (auto op = dyn_cast(TmpOrig)) { if (isInactiveCall(*op) || op->hasFnAttr("enzyme_inactive_val") || op->getAttributes().hasAttribute(llvm::AttributeList::ReturnIndex, @@ -1940,7 +1972,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) { isRefSet(AARes)) { if (EnzymePrintActivity) llvm::errs() << "potential active load: " << *I << "\n"; - if (isa(I) || isNVLoad(I)) { + if (isa(I) || isNVLoad(I) || isa(I)) { // If the ref'ing value is a load check if the loaded value is // active if (!Hypothesis->isConstantValue(TR, I)) { @@ -2696,6 +2728,11 @@ bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults const &TR, if (AllocaSet.count(TmpOrig)) { continue; } + // We are literally storing our value into ourselves [or relevant + // derived pointer] + if (TmpOrig == val) { + continue; + } if (isa(TmpOrig)) { newAllocaSet.insert(TmpOrig); continue; @@ -2797,8 +2834,16 @@ bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults const &TR, if (isa(TmpOrig) || isAllocationCall(TmpOrig, TLI)) { done.insert( std::make_tuple((User *)SI, SI->getPointerOperand(), UA)); + // If we are capturing a variable v, we need to check any loads or + // stores into that variable, even if we are checking only for + // stores. + auto UA2 = UA; + if (UA == UseActivity::OnlyStores || + UA == UseActivity::OnlyNonPointerStores || + UA == UseActivity::AllStores) + UA2 = UseActivity::None; for (const auto a : TmpOrig->users()) { - todo.push_back(std::make_tuple(a, TmpOrig, UA)); + todo.push_back(std::make_tuple(a, TmpOrig, UA2)); } AllocaSet.insert(TmpOrig); if (EnzymePrintActivity) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 96b8494302e..655bdca6943 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -1380,31 +1380,33 @@ class AdjointGenerator : public llvm::InstVisitor { ss << "Cannot deduce adding type (cast) of " << I; EmitNoTypeError(str, I, gutils, Builder2); } - assert(FT); - auto rule = [&](Value *dif) { - if (I.getOpcode() == CastInst::CastOps::FPTrunc || - I.getOpcode() == CastInst::CastOps::FPExt) { - return Builder2.CreateFPCast(dif, op0->getType()); - } else if (I.getOpcode() == CastInst::CastOps::BitCast) { - return Builder2.CreateBitCast(dif, op0->getType()); - } else if (I.getOpcode() == CastInst::CastOps::Trunc) { - // TODO CHECK THIS - return Builder2.CreateZExt(dif, op0->getType()); - } else { - std::string s; - llvm::raw_string_ostream ss(s); - ss << *I.getParent()->getParent() << "\n"; - ss << "cannot handle above cast " << I << "\n"; - EmitNoDerivativeError(ss.str(), I, gutils, Builder2); - return (llvm::Value *)UndefValue::get(op0->getType()); - } - }; + if (FT) { + + auto rule = [&](Value *dif) { + if (I.getOpcode() == CastInst::CastOps::FPTrunc || + I.getOpcode() == CastInst::CastOps::FPExt) { + return Builder2.CreateFPCast(dif, op0->getType()); + } else if (I.getOpcode() == CastInst::CastOps::BitCast) { + return Builder2.CreateBitCast(dif, op0->getType()); + } else if (I.getOpcode() == CastInst::CastOps::Trunc) { + // TODO CHECK THIS + return Builder2.CreateZExt(dif, op0->getType()); + } else { + std::string s; + llvm::raw_string_ostream ss(s); + ss << *I.getParent()->getParent() << "\n"; + ss << "cannot handle above cast " << I << "\n"; + EmitNoDerivativeError(ss.str(), I, gutils, Builder2); + return (llvm::Value *)UndefValue::get(op0->getType()); + } + }; - Value *dif = diffe(&I, Builder2); - Value *diff = applyChainRule(op0->getType(), Builder2, rule, dif); + Value *dif = diffe(&I, Builder2); + Value *diff = applyChainRule(op0->getType(), Builder2, rule, dif); - addToDiffe(orig_op0, diff, Builder2, FT); + addToDiffe(orig_op0, diff, Builder2, FT); + } } Type *diffTy = gutils->getShadowType(I.getType()); diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 22df3dab9a7..bc5a095cfae 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -29,7 +29,8 @@ using namespace llvm; extern "C" { -void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *) = nullptr; +void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *, LLVMValueRef, uint64_t, + LLVMValueRef, uint8_t) = nullptr; } void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called, @@ -3014,6 +3015,9 @@ bool AdjointGenerator::handleKnownCallDerivatives( bb, anti, getIndex(&call, CacheType::Shadow, BuilderZ)); } else { bool zeroed = false; + uint64_t idx = 0; + Value *prev = nullptr; + ; auto rule = [&]() { Value *anti = bb.CreateCall(call.getFunctionType(), call.getCalledOperand(), @@ -3058,8 +3062,12 @@ bool AdjointGenerator::handleKnownCallDerivatives( if (funcName == "julia.gc_alloc_obj" || funcName == "jl_gc_alloc_typed" || funcName == "ijl_gc_alloc_typed") { - if (EnzymeShadowAllocRewrite) - EnzymeShadowAllocRewrite(wrap(anti), gutils); + if (EnzymeShadowAllocRewrite) { + bool used = unnecessaryInstructions.find(&call) == + unnecessaryInstructions.end(); + EnzymeShadowAllocRewrite(wrap(anti), gutils, wrap(&call), + idx, wrap(prev), used); + } } } if (Mode == DerivativeMode::ReverseModeCombined || @@ -3075,6 +3083,8 @@ bool AdjointGenerator::handleKnownCallDerivatives( zeroed = true; } } + idx++; + prev = anti; return anti; }; @@ -3224,6 +3234,8 @@ bool AdjointGenerator::handleKnownCallDerivatives( args.push_back(gutils->getNewFromOriginal(arg)); } + uint64_t idx = 0; + Value *prev = gutils->getNewFromOriginal(&call); auto rule = [&]() { SmallVector BundleTypes(args.size(), ValueType::Primal); @@ -3236,6 +3248,19 @@ bool AdjointGenerator::handleKnownCallDerivatives( CI->setCallingConv(call.getCallingConv()); CI->setTailCallKind(call.getTailCallKind()); CI->setDebugLoc(dbgLoc); + + if (funcName == "julia.gc_alloc_obj" || + funcName == "jl_gc_alloc_typed" || + funcName == "ijl_gc_alloc_typed") { + if (EnzymeShadowAllocRewrite) { + bool used = unnecessaryInstructions.find(&call) == + unnecessaryInstructions.end(); + EnzymeShadowAllocRewrite(wrap(CI), gutils, wrap(&call), idx, + wrap(prev), used); + } + } + idx++; + prev = CI; return CI; }; diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index eba0de11f54..e88e762e7f0 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -1179,9 +1179,13 @@ void DiffeGradientUtils::addToInvertedPtrDiffe( // the pointers and conditionally execute. if ((!isa(basePtr) && !isAllocationCall(basePtr, TLI)) && runtimeActivity && !merge) { - Value *shadow = Builder2.CreateICmpNE( - lookupM(getNewFromOriginal(origptr), Builder2), - lookupM(invertPointerM(origptr, Builder2), Builder2)); + Value *primal_val = lookupM(getNewFromOriginal(origptr), Builder2); + Value *shadow_val = + lookupM(invertPointerM(origptr, Builder2), Builder2); + if (getWidth() != 1) { + shadow_val = extractMeta(Builder2, shadow_val, 0); + } + Value *shadow = Builder2.CreateICmpNE(primal_val, shadow_val); BasicBlock *current = Builder2.GetInsertBlock(); BasicBlock *conditional = diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index c169365371a..414374dd779 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -3822,6 +3822,9 @@ bool GradientUtils::legalRecompute(const Value *val, } } + if (isa(val)) + return false; + if (auto phi = dyn_cast(val)) { if (auto uiv = hasUninverted(val)) { if (auto dli = dyn_cast_or_null(uiv)) { @@ -3835,6 +3838,13 @@ bool GradientUtils::legalRecompute(const Value *val, } } + auto found = fictiousPHIs.find(const_cast(phi)); + if (found != fictiousPHIs.end()) { + auto orig = found->second; + if (isa(orig)) + return false; + } + if (phi->getNumIncomingValues() == 0) { llvm::errs() << *oldFunc << "\n"; llvm::errs() << *newFunc << "\n"; diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index be139fb3d8b..72672a95940 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -192,4 +192,20 @@ def GenericAdjointOp : Enzyme_Op<"genericAdjoint", [AttrSizedOperandSegments]> { } +def BroadcastOp : Enzyme_Op<"broadcast"> { + let description = [{ + Broadcast the operand by adding extra dimensions with sizes provided by the `shape` attribute to the front. + For scalar operands, ranked tensor is created. + + NOTE: Only works for scalar and *ranked* tensor operands for now. + }]; + + let arguments = (ins AnyType:$input, DenseI64ArrayAttr:$shape); + let results = (outs AnyRankedTensor:$output); + + let builders = [ + OpBuilder<(ins "Value":$input, "ArrayRef":$shape)> + ]; +} + #endif // ENZYME_OPS diff --git a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp index 3e318542730..7e48db2d583 100644 --- a/enzyme/Enzyme/MLIR/Dialect/Ops.cpp +++ b/enzyme/Enzyme/MLIR/Dialect/Ops.cpp @@ -27,6 +27,7 @@ #include "mlir/IR/IRMapping.h" #include "mlir/IR/IntegerSet.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" @@ -191,3 +192,17 @@ LogicalResult BatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } + +//===----------------------------------------------------------------------===// +// BroadcastOp +//===----------------------------------------------------------------------===// + +void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input, + ArrayRef shape) { + auto shapeAttr = builder.getDenseI64ArrayAttr(shape); + auto resultTy = input.getType(); + for (auto s : llvm::reverse(shape)) { + resultTy = resultTy.cast().getShadowType(s); + } + build(builder, result, resultTy, input, shapeAttr); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp index 9b27503d79d..8d3650969d0 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/ArithAutoDiffOpInterfaceImpl.cpp @@ -17,6 +17,7 @@ #include "Interfaces/GradientUtilsReverse.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Support/LogicalResult.h" @@ -69,3 +70,10 @@ void mlir::enzyme::registerArithDialectAutoDiffInterface( arith::ConstantOp::attachInterface(*context); }); } + +void mlir::enzyme::registerTensorDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, tensor::TensorDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 355808cdbcc..f727dca2f87 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -74,7 +74,8 @@ void mlir::enzyme::detail::branchingForwardHandler(Operation *inst, newVals.push_back(gutils->invertPointerM(op, builder)); } else { Type retTy = - arg.getType().cast().getShadowType(); + arg.getType().cast().getShadowType( + gutils->width); auto toret = retTy.cast().createNullValue( builder, op.getLoc()); newVals.push_back(toret); @@ -146,7 +147,7 @@ LogicalResult mlir::enzyme::detail::memoryIdentityForwardHandler( if (auto iface = dyn_cast(operand.get().getType())) { if (!iface.isMutable()) { - Type retTy = iface.getShadowType(); + Type retTy = iface.getShadowType(gutils->width); auto toret = retTy.cast().createNullValue( builder, operand.get().getLoc()); newOperands.push_back(toret); @@ -346,7 +347,7 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( << result.getType() << "\n"; return failure(); } - newOpResultTypes.push_back(typeIface.getShadowType()); + newOpResultTypes.push_back(typeIface.getShadowType(gutils->width)); } SmallVector newOperands; @@ -432,4 +433,5 @@ void mlir::enzyme::registerCoreDialectAutodiffInterfaces( enzyme::registerCFDialectAutoDiffInterface(registry); enzyme::registerLinalgDialectAutoDiffInterface(registry); enzyme::registerFuncDialectAutoDiffInterface(registry); + enzyme::registerTensorDialectAutoDiffInterface(registry); } diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index d6f28ccfc73..650f6c6326b 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -260,6 +260,7 @@ void registerCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); void registerMathDialectAutoDiffInterface(DialectRegistry ®istry); void registerFuncDialectAutoDiffInterface(DialectRegistry ®istry); +void registerTensorDialectAutoDiffInterface(DialectRegistry ®istry); void registerCoreDialectAutodiffInterfaces(DialectRegistry ®istry); diff --git a/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp index 5308304f5b7..54845c740d3 100644 --- a/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp @@ -73,7 +73,7 @@ class AutoDiffCallFwd fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, mode, freeMemory, width, /* addedType */ nullptr, type_args, volatile_args, - /* augmented */ nullptr); + /* augmented */ nullptr, gutils->postpasses); SmallVector fwdArguments; @@ -173,7 +173,7 @@ class AutoDiffCallRev auto revFn = gutils->Logic.CreateReverseDiff( fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, returnShadow, mode, freeMemory, width, /*addedType*/ nullptr, type_args, - volatile_args, /*augmented*/ nullptr); + volatile_args, /*augmented*/ nullptr, gutils->postpasses); SmallVector revArguments; diff --git a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp index 8a9057a5853..c212a89398f 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/CloneFunction.cpp @@ -258,9 +258,11 @@ FunctionOpInterface CloneFunctionWithReturns( mlir::Value val = blk.getArgument(i); mlir::Value dval; if (i == ArgActivity.size() - 1) - dval = blk.addArgument(val.getType(), val.getLoc()); + dval = blk.addArgument(getShadowType(val.getType(), width), + val.getLoc()); else - dval = blk.insertArgument(blk.args_begin() + i + 1, val.getType(), + dval = blk.insertArgument(blk.args_begin() + i + 1, + getShadowType(val.getType(), width), val.getLoc()); ptrInputs.map(oval, dval); } diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp index fbd337813bc..7a5770ccdaa 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp @@ -13,6 +13,9 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Dominance.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" + #include "llvm/ADT/BreadthFirstIterator.h" #include "EnzymeLogic.h" @@ -78,7 +81,8 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( std::vector ArgActivity, MTypeAnalysis &TA, std::vector returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, - std::vector volatile_args, void *augmented) { + std::vector volatile_args, void *augmented, + llvm::StringRef postpasses) { if (fn.getFunctionBody().empty()) { llvm::errs() << fn << "\n"; llvm_unreachable("Differentiating empty function"); @@ -105,7 +109,7 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( auto gutils = MDiffeGradientUtils::CreateFromClone( *this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP, RetActivity, ArgActivity, addedType, - /*omp*/ false); + /*omp*/ false, postpasses); ForwardCachedFunctions[tup] = gutils->newFunc; insert_or_assign2( @@ -195,10 +199,19 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff( if (!valid) return nullptr; - // if (PostOpt) - // PPC.optimizeIntermediate(nf); - // if (EnzymePrint) { - // llvm::errs() << nf << "\n"; - //} + if (postpasses != "") { + mlir::PassManager pm(nf->getContext()); + std::string error_message; + // llvm::raw_string_ostream error_stream(error_message); + mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm); + if (mlir::failed(result)) { + return nullptr; + } + + if (!mlir::succeeded(pm.run(nf))) { + return nullptr; + } + } + return nf; } diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h index c8cad6eee27..aef498d5227 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h @@ -196,14 +196,17 @@ class MEnzymeLogic { std::vector returnPrimals, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, MFnTypeInfo type_args, std::vector volatile_args, - void *augmented); - - FunctionOpInterface CreateReverseDiff( - FunctionOpInterface fn, std::vector retType, - std::vector constants, MTypeAnalysis &TA, - std::vector returnPrimals, std::vector returnShadows, - DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, - MFnTypeInfo type_args, std::vector volatile_args, void *augmented); + void *augmented, llvm::StringRef postpasses); + + FunctionOpInterface + CreateReverseDiff(FunctionOpInterface fn, std::vector retType, + std::vector constants, MTypeAnalysis &TA, + std::vector returnPrimals, + std::vector returnShadows, DerivativeMode mode, + bool freeMemory, size_t width, mlir::Type addedType, + MFnTypeInfo type_args, std::vector volatile_args, + void *augmented, llvm::StringRef postpasses); + void initializeShadowValues(SmallVector &dominatorToposortBlocks, MGradientUtilsReverse *gutils); diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index 0812a7ccde5..7ca0e9ea72f 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -11,6 +11,8 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" #include "EnzymeLogic.h" #include "Interfaces/GradientUtils.h" @@ -182,7 +184,8 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( std::vector constants, MTypeAnalysis &TA, std::vector returnPrimals, std::vector returnShadows, DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType, - MFnTypeInfo type_args, std::vector volatile_args, void *augmented) { + MFnTypeInfo type_args, std::vector volatile_args, void *augmented, + llvm::StringRef postpasses) { if (fn.getFunctionBody().empty()) { llvm::errs() << fn << "\n"; @@ -214,7 +217,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone( *this, mode, width, fn, TA, type_args, returnPrimalsP, returnShadowsP, - retType, constants, addedType); + retType, constants, addedType, postpasses); ReverseCachedFunctions[tup] = gutils->newFunc; @@ -254,5 +257,19 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( if (!res.succeeded()) return nullptr; + if (postpasses != "") { + mlir::PassManager pm(nf->getContext()); + std::string error_message; + // llvm::raw_string_ostream error_stream(error_message); + mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm); + if (mlir::failed(result)) { + return nullptr; + } + + if (!mlir::succeeded(pm.run(nf))) { + return nullptr; + } + } + return nf; } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 1ec4212dc5a..0dab1032af9 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -37,15 +37,15 @@ mlir::enzyme::MGradientUtils::MGradientUtils( ArrayRef ReturnActivity, ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode, unsigned width, bool omp) + DerivativeMode mode, unsigned width, bool omp, llvm::StringRef postpasses) : newFunc(newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_), invertedPointers(invertedPointers_), originalToNewFn(originalToNewFn_), originalToNewFnOps(originalToNewFnOps_), blocksNotForAnalysis(), activityAnalyzer(std::make_unique( blocksNotForAnalysis, constantvalues_, activevals_, ReturnActivity)), - TA(TA_), TR(TR_), omp(omp), returnPrimals(returnPrimals), - returnShadows(returnShadows), width(width), ArgDiffeTypes(ArgDiffeTypes_), - RetDiffeTypes(ReturnActivity) {} + TA(TA_), TR(TR_), omp(omp), postpasses(postpasses), + returnPrimals(returnPrimals), returnShadows(returnShadows), width(width), + ArgDiffeTypes(ArgDiffeTypes_), RetDiffeTypes(ReturnActivity) {} mlir::Value mlir::enzyme::MGradientUtils::getNewFromOriginal( const mlir::Value originst) const { @@ -108,7 +108,8 @@ mlir::Value mlir::enzyme::MGradientUtils::invertPointerM(mlir::Value v, return invertedPointers.lookupOrNull(v); if (isConstantValue(v)) { - if (auto iface = v.getType().dyn_cast()) { + if (auto iface = + getShadowType(v.getType()).dyn_cast()) { OpBuilder::InsertionGuard guard(Builder2); if (auto op = v.getDefiningOp()) Builder2.setInsertionPoint(getNewFromOriginal(op)); diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h index 1fac52caab3..085bd678f83 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h @@ -36,6 +36,7 @@ class MGradientUtils { MTypeAnalysis &TA; MTypeResults TR; bool omp; + llvm::StringRef postpasses; const llvm::ArrayRef returnPrimals; const llvm::ArrayRef returnShadows; @@ -58,7 +59,8 @@ class MGradientUtils { ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode, unsigned width, bool omp); + DerivativeMode mode, unsigned width, bool omp, + llvm::StringRef postpasses); void erase(Operation *op) { op->erase(); } void replaceOrigOpWith(Operation *op, ValueRange vals) { for (auto &&[res, rep] : llvm::zip(op->getResults(), vals)) { @@ -113,11 +115,12 @@ class MDiffeGradientUtils : public MGradientUtils { ArrayRef RetActivity, ArrayRef ArgActivity, IRMapping &origToNew_, std::map &origToNewOps_, - DerivativeMode mode, unsigned width, bool omp) + DerivativeMode mode, unsigned width, bool omp, + llvm::StringRef postpasses) : MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_, returnPrimals, returnShadows, constantvalues_, activevals_, RetActivity, ArgActivity, origToNew_, - origToNewOps_, mode, width, omp), + origToNewOps_, mode, width, omp, postpasses), initializationBlock(&*(newFunc.getFunctionBody().begin())) {} // Technically diffe constructor @@ -127,7 +130,7 @@ class MDiffeGradientUtils : public MGradientUtils { const llvm::ArrayRef returnPrimals, const llvm::ArrayRef returnShadows, ArrayRef RetActivity, ArrayRef ArgActivity, - mlir::Type additionalArg, bool omp) { + mlir::Type additionalArg, bool omp, llvm::StringRef postpasses) { std::string prefix; switch (mode) { @@ -163,7 +166,8 @@ class MDiffeGradientUtils : public MGradientUtils { return new MDiffeGradientUtils( Logic, newFunc, todiff, TA, TR, invertedPointers, returnPrimals, returnShadows, constant_values, nonconstant_values, RetActivity, - ArgActivity, originalToNew, originalToNewOps, mode, width, omp); + ArgActivity, originalToNew, originalToNewOps, mode, width, omp, + postpasses); } }; diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp index 793b073de0f..c9fe98bc5a5 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp @@ -37,12 +37,12 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse( ArrayRef ReturnActivity, ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode_, unsigned width) + DerivativeMode mode_, unsigned width, StringRef postpasses) : MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {}, invertedPointers_, returnPrimals, returnShadows, constantvalues_, activevals_, ReturnActivity, ArgDiffeTypes_, originalToNewFn_, originalToNewFnOps_, - mode_, width, /*omp*/ false) {} + mode_, width, /*omp*/ false, postpasses) {} Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() { Type indexType = getIndexType(); @@ -138,7 +138,7 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone( FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, const ArrayRef returnPrimals, const ArrayRef returnShadows, ArrayRef retType, ArrayRef constant_args, - mlir::Type additionalArg) { + mlir::Type additionalArg, llvm::StringRef postpasses) { std::string prefix; switch (mode_) { @@ -174,5 +174,5 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone( return new MGradientUtilsReverse( Logic, newFunc, todiff, TA, invertedPointers, returnPrimals, returnShadows, constant_values, nonconstant_values, retType, - constant_args, originalToNew, originalToNewOps, mode_, width); + constant_args, originalToNew, originalToNewOps, mode_, width, postpasses); } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h index b6b63c6d13d..7f2d26cba2e 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h @@ -36,7 +36,8 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { ArrayRef ArgDiffeTypes_, IRMapping &originalToNewFn_, std::map &originalToNewFnOps_, - DerivativeMode mode_, unsigned width); + DerivativeMode mode_, unsigned width, + llvm::StringRef postpasses); IRMapping mapReverseModeBlocks; @@ -64,12 +65,14 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { void createReverseModeBlocks(Region &oldFunc, Region &newFunc); - static MGradientUtilsReverse *CreateFromClone( - MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, - FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo, - const ArrayRef returnPrimals, const ArrayRef returnShadows, - llvm::ArrayRef retType, - llvm::ArrayRef constant_args, mlir::Type additionalArg); + static MGradientUtilsReverse * + CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, + FunctionOpInterface todiff, MTypeAnalysis &TA, + MFnTypeInfo &oldTypeInfo, const ArrayRef returnPrimals, + const ArrayRef returnShadows, + llvm::ArrayRef retType, + llvm::ArrayRef constant_args, + mlir::Type additionalArg, llvm::StringRef postpasses); }; } // namespace enzyme diff --git a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt index 0445fc43064..99db4d80034 100644 --- a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt @@ -31,6 +31,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms MLIRFuncDialect MLIRFuncTransforms MLIRGPUDialect + MLIRTensorDialect MLIRIR MLIRLLVMDialect MLIRMathDialect diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index d83532db35a..c91f5400fef 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" #include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Pass/PassManager.h" #define DEBUG_TYPE "enzyme" @@ -31,6 +32,18 @@ struct DifferentiatePass : public DifferentiatePassBase { void runOnOperation() override; + void getDependentDialects(DialectRegistry ®istry) const override { + mlir::OpPassManager pm; + mlir::LogicalResult result = mlir::parsePassPipeline(postpasses, pm); + if (!mlir::failed(result)) { + pm.getDependentDialects(registry); + } + + registry + .insert(); + } + static std::vector mode_from_fn(FunctionOpInterface fn, DerivativeMode mode) { std::vector retTypes; @@ -150,7 +163,7 @@ struct DifferentiatePass : public DifferentiatePassBase { FunctionOpInterface newFunc = Logic.CreateForwardDiff( fn, retType, constants, TA, returnPrimals, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + /*augmented*/ nullptr, postpasses); if (!newFunc) return failure(); @@ -276,7 +289,7 @@ struct DifferentiatePass : public DifferentiatePassBase { Logic.CreateReverseDiff(fn, retType, arg_activities, TA, returnPrimals, returnShadows, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + /*augmented*/ nullptr, postpasses); if (!newFunc) return failure(); diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp index 629a567815e..1e01c8f87bc 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp @@ -121,13 +121,13 @@ struct DifferentiateWrapperPass returnPrimal, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + /*augmented*/ nullptr, ""); } else { newFunc = Logic.CreateReverseDiff( fn, RetActivity, ArgActivity, TA, returnPrimal, returnShadow, mode, freeMemory, width, /*addedType*/ nullptr, type_args, volatile_args, - /*augmented*/ nullptr); + /*augmented*/ nullptr, ""); } if (!newFunc) { signalPassFailure(); diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.h b/enzyme/Enzyme/MLIR/Passes/Passes.h index ec674cd33bb..fff304a7e49 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.h +++ b/enzyme/Enzyme/MLIR/Passes/Passes.h @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "Dialect/Dialect.h" @@ -80,6 +81,10 @@ namespace affine { class AffineDialect; } // end namespace affine +namespace tensor { +class TensorDialect; +} // end namespace tensor + namespace LLVM { class LLVMDialect; } // end namespace LLVM diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index a8d885f8ac9..d3494956a12 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -16,7 +16,17 @@ def DifferentiatePass : Pass<"enzyme"> { let dependentDialects = [ "arith::ArithDialect", "complex::ComplexDialect", - "cf::ControlFlowDialect" + "cf::ControlFlowDialect", + "tensor::TensorDialect", + ]; + let options = [ + Option< + /*C++ variable name=*/"postpasses", + /*CLI argument=*/"postpasses", + /*type=*/"std::string", + /*default=*/"", + /*description=*/"Optimization passes to apply to generated derivative functions" + >, ]; let constructor = "mlir::enzyme::createDifferentiatePass()"; } diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 933a22304e2..f4655bac845 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -102,6 +102,10 @@ llvm::cl::opt EnzymeMemmoveWarning( llvm::cl::opt EnzymeRuntimeError( "enzyme-runtime-error", cl::init(false), cl::Hidden, cl::desc("Emit Runtime errors instead of compile time ones")); + +llvm::cl::opt EnzymeNonPower2Cache( + "enzyme-non-power2-cache", cl::init(false), cl::Hidden, + cl::desc("Disable caching of integers which are not a power of 2")); } void ZeroMemory(llvm::IRBuilder<> &Builder, llvm::Type *T, llvm::Value *obj, diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 02ce4b8b47e..089a99c8691 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -91,6 +91,7 @@ enum class ErrorType { extern "C" { /// Print additional debug info relevant to performance extern llvm::cl::opt EnzymePrintPerf; +extern llvm::cl::opt EnzymeNonPower2Cache; extern llvm::cl::opt EnzymeStrongZero; extern llvm::cl::opt EnzymeBlasCopy; extern llvm::cl::opt EnzymeLapackCopy; @@ -1194,6 +1195,10 @@ static inline bool hasNoCache(llvm::Value *op) { } } } + if (auto IT = dyn_cast(op->getType())) + if (!isPowerOf2_64(IT->getBitWidth()) && !EnzymeNonPower2Cache) + return true; + return false; } diff --git a/enzyme/test/MLIR/ForwardMode/batched_branch.mlir b/enzyme/test/MLIR/ForwardMode/batched_branch.mlir new file mode 100644 index 00000000000..d663eea5afe --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_branch.mlir @@ -0,0 +1,26 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64, %y : f64) -> f64 { + %c = arith.cmpf ult, %x, %y : f64 + cf.cond_br %c, ^blk2(%x : f64), ^blk2(%y : f64) + + ^blk2(%r : f64): + return %r : f64 + } + func.func @dsq(%x : f64, %dx : tensor<2xf64>, %y : f64, %dy : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx, %y, %dy) { activity=[#enzyme, #enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]]) : (f64, tensor<2xf64>, f64, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: return %[[i0]] : tensor<2xf64> +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: f64, %[[arg3]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = arith.cmpf ult, %[[arg0]], %[[arg2]] : f64 +// CHECK-NEXT: cf.cond_br %[[i0]], ^bb1(%[[arg0]], %[[arg1]] : f64, tensor<2xf64>), ^bb1(%[[arg2]], %[[arg3]] : f64, tensor<2xf64>) +// CHECK-NEXT: ^bb1(%[[i1:.+]]: f64, %[[i2:.+]]: tensor<2xf64>): // 2 preds: ^bb0, ^bb0 +// CHECK-NEXT: return %[[i2]] : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/batched_for.mlir b/enzyme/test/MLIR/ForwardMode/batched_for.mlir new file mode 100644 index 00000000000..3ec17ec50f5 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_for.mlir @@ -0,0 +1,33 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64 { + %cst = arith.constant 10.000000e+00 : f64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %r = scf.for %arg1 = %c0 to %c10 step %c1 iter_args(%arg2 = %cst) -> (f64) { + %n = arith.addf %arg2, %x : f64 + scf.yield %n : f64 + } + return %r : f64 + } + func.func @dsq(%x : f64, %dx : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-DAG: %[[cst:.+]] = arith.constant dense<0.000000e+00> : tensor<2xf64> +// CHECK-DAG: %[[cst_0:.+]] = arith.constant 1.000000e+01 : f64 +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index +// CHECK-NEXT: %[[i0:.+]]:2 = scf.for %[[arg2:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[arg3:.+]] = %[[cst_0]], %[[arg4:.+]] = %[[cst]]) -> (f64, tensor<2xf64>) { +// CHECK-NEXT: %[[i1:.+]] = arith.addf %[[arg4]], %[[arg1]] : tensor<2xf64> +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[arg3]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[i2]], %[[i1]] : f64, tensor<2xf64> +// CHECK-NEXT: } +// CHECK-NEXT: return %[[i0]]#1 : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/batched_if.mlir b/enzyme/test/MLIR/ForwardMode/batched_if.mlir new file mode 100644 index 00000000000..33c9e1b9fe8 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_if.mlir @@ -0,0 +1,43 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64, %c : i1) -> f64 { + %c2 = arith.constant 2.000000e+00 : f64 + %c10 = arith.constant 10.000000e+00 : f64 + %r:2 = scf.if %c -> (f64, f64) { + %mul = arith.mulf %x, %x : f64 + scf.yield %mul, %c2 : f64, f64 + } else { + %add = arith.addf %x, %x : f64 + scf.yield %add, %c10 : f64, f64 + } + %res = arith.mulf %r#0, %r#1 : f64 + return %res : f64 + } + func.func @dsq(%x : f64, %dx : tensor<2xf64>, %c : i1) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx, %c) { activity=[#enzyme, #enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>, i1) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>, %[[arg2:.+]]: i1) -> tensor<2xf64> { +// CHECK-DAG: %[[cst2:.+]] = arith.constant 2.000000e+00 : f64 +// CHECK-DAG: %[[cst10:.+]] = arith.constant 1.000000e+01 : f64 +// CHECK-NEXT: %[[r0:.+]]:3 = scf.if %[[arg2]] -> (f64, tensor<2xf64>, f64) { +// CHECK-NEXT: %[[t4:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %[[t5:.+]] = arith.mulf %[[arg1]], %[[t4]] : tensor<2xf64> +// CHECK-NEXT: %[[t6:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %[[t7:.+]] = arith.mulf %[[arg1]], %[[t6]] : tensor<2xf64> +// CHECK-NEXT: %[[t8:.+]] = arith.addf %[[t5]], %[[t7]] : tensor<2xf64> +// CHECK-NEXT: %[[t9:.+]] = arith.mulf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[t9]], %[[t8]], %[[cst2]] : f64, tensor<2xf64>, f64 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[e4:.+]] = arith.addf %[[arg1]], %[[arg1]] : tensor<2xf64> +// CHECK-NEXT: %[[e5:.+]] = arith.addf %[[arg0]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[e5]], %[[e4]], %[[cst10]] : f64, tensor<2xf64>, f64 +// CHECK-NEXT: } +// CHECK-NEXT: %[[r1:.+]] = "enzyme.broadcast"(%[[r0]]#2) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %[[r2:.+]] = arith.mulf %[[r0]]#1, %[[r1]] : tensor<2xf64> +// CHECK-NEXT: %[[r3:.+]] = arith.mulf %[[r0]]#0, %[[r0]]#2 : f64 +// CHECK-NEXT: return %[[r2]] : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir new file mode 100644 index 00000000000..d384bdd0933 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_scalar.mlir @@ -0,0 +1,26 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : f64) -> f64{ + %y = arith.mulf %x, %x : f64 + return %y : f64 + } + func.func @dsq(%x : f64, %dx : tensor<2xf64>) -> tensor<2xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (f64, tensor<2xf64>) -> (tensor<2xf64>) + return %r : tensor<2xf64> + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (f64, tensor<2xf64>) -> tensor<2xf64> +// CHECK-NEXT: return %[[i0]] : tensor<2xf64> +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: f64, %[[arg1:.+]]: tensor<2xf64>) -> tensor<2xf64> { +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (f64) -> tensor<2xf64> +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64> +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64> +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64> +// CHECK-NEXT: return %[[i2]] : tensor<2xf64> +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir new file mode 100644 index 00000000000..2a565f9ff41 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/batched_tensor.mlir @@ -0,0 +1,26 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @square(%x : tensor<10xf64>) -> tensor<10xf64>{ + %y = arith.mulf %x, %x : tensor<10xf64> + return %y : tensor<10xf64> + } + func.func @dsq(%x : tensor<10xf64>, %dx : tensor<2x10xf64>) -> tensor<2x10xf64> { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme], width=2 } : (tensor<10xf64>, tensor<2x10xf64>) -> (tensor<2x10xf64>) + return %r : tensor<2x10xf64> + } +} + +// CHECK: func.func @dsq(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> { +// CHECK-NEXT: %[[i0:.+]] = call @fwddiffe2square(%[[arg0]], %[[arg1]]) : (tensor<10xf64>, tensor<2x10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: return %[[i0]] : tensor<2x10xf64> +// CHECK-NEXT: } +// CHECK: func.func private @fwddiffe2square(%[[arg0:.+]]: tensor<10xf64>, %[[arg1:.+]]: tensor<2x10xf64>) -> tensor<2x10xf64> { +// CHECK-NEXT: %[[s0:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[i0:.+]] = arith.mulf %[[arg1]], %[[s0]] : tensor<2x10xf64> +// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array}> : (tensor<10xf64>) -> tensor<2x10xf64> +// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2x10xf64> +// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2x10xf64> +// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<10xf64> +// CHECK-NEXT: return %[[i2]] : tensor<2x10xf64> +// CHECK-NEXT: } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 900c5c813cd..dccbc7b7923 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -275,8 +275,19 @@ SmallVector prepareArgs(const Twine &curIndent, raw_ostream &os, os << ord; } if (!vecValue && !startsWith(ord, "local")) { - if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) + if (newFromOriginal && (!lookup || intrinsic != MLIRDerivatives)) { os << ")"; + if (intrinsic == MLIRDerivatives) { + os << ";\n"; + os << "if (gutils->width != 1) {\n" + << " " << argName << "_" << (idx - 1) + << " = builder.create(\n" + << " op.getLoc(),\n" + << " " << argName << "_" << (idx - 1) << ",\n" + << " llvm::SmallVector({gutils->width}));\n" + << "}"; + } + } if (lookup && intrinsic != MLIRDerivatives) os << ", " << builder << ")";