Skip to content

Commit

Permalink
Merge branch 'main' into remove-ops
Browse files Browse the repository at this point in the history
  • Loading branch information
Pangoraw authored Jan 3, 2025
2 parents b869aab + 7bc73fa commit cf998f8
Show file tree
Hide file tree
Showing 34 changed files with 460 additions and 81 deletions.
4 changes: 4 additions & 0 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ brew install enzyme
```
spack install enzyme
```
[Nix](https://nixos.org/)
```
nix-shell -p enzyme
```
To get involved or if you have questions, please join our [mailing list](https://groups.google.com/d/forum/enzyme-dev).
Expand Down
1 change: 1 addition & 0 deletions enzyme/.bazelversion
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
6.5.0
49 changes: 47 additions & 2 deletions enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,11 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) {
if (Name == "jl_reshape_array" || Name == "ijl_reshape_array")
return val != CI->getArgOperand(1);

// Only the 0-th arg impacts activity
if (Name == "jl_genericmemory_copy_slice" ||
Name == "ijl_genericmemory_copy_slice")
return val != CI->getArgOperand(0);

// Allocations, deallocations, and c++ guards don't impact the activity
// of arguments
if (isAllocationFunction(Name, TLI) || isDeallocationFunction(Name, TLI))
Expand Down Expand Up @@ -660,6 +665,13 @@ static inline void propagateArgumentInformation(
return;
}

// Only the 0-th arg impacts activity
if (Name == "jl_genericmemory_copy_slice" ||
Name == "ijl_genericmemory_copy_slice") {
propagateFromOperand(CI.getArgOperand(0));
return;
}

// Only the 1-th arg impacts activity
if (Name == "jl_reshape_array" || Name == "ijl_reshape_array") {
propagateFromOperand(CI.getArgOperand(1));
Expand Down Expand Up @@ -1554,6 +1566,26 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
ReEvaluateValueIfInactiveValue[II->getOperand(0)].insert(TmpOrig);
}
}
} else if (auto RMW = dyn_cast<AtomicRMWInst>(TmpOrig)) {
if (directions == UP) {
if (isConstantValue(TR, RMW->getPointerOperand())) {
InsertConstantValue(TR, Val);
return true;
}
} else {
if (UpHypothesis->isConstantValue(TR, RMW->getPointerOperand())) {
InsertConstantValue(TR, Val);
insertConstantsFrom(TR, *UpHypothesis);
return true;
}
}
if (EnzymeEnableRecursiveHypotheses) {
ReEvaluateValueIfInactiveValue[RMW->getPointerOperand()].insert(Val);
if (TmpOrig != Val) {
ReEvaluateValueIfInactiveValue[RMW->getPointerOperand()].insert(
TmpOrig);
}
}
} else if (auto op = dyn_cast<CallInst>(TmpOrig)) {
if (isInactiveCall(*op) || op->hasFnAttr("enzyme_inactive_val") ||
op->getAttributes().hasAttribute(llvm::AttributeList::ReturnIndex,
Expand Down Expand Up @@ -1940,7 +1972,7 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
isRefSet(AARes)) {
if (EnzymePrintActivity)
llvm::errs() << "potential active load: " << *I << "\n";
if (isa<LoadInst>(I) || isNVLoad(I)) {
if (isa<LoadInst>(I) || isNVLoad(I) || isa<AtomicRMWInst>(I)) {
// If the ref'ing value is a load check if the loaded value is
// active
if (!Hypothesis->isConstantValue(TR, I)) {
Expand Down Expand Up @@ -2696,6 +2728,11 @@ bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults const &TR,
if (AllocaSet.count(TmpOrig)) {
continue;
}
// We are literally storing our value into ourselves [or relevant
// derived pointer]
if (TmpOrig == val) {
continue;
}
if (isa<AllocaInst>(TmpOrig)) {
newAllocaSet.insert(TmpOrig);
continue;
Expand Down Expand Up @@ -2797,8 +2834,16 @@ bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults const &TR,
if (isa<AllocaInst>(TmpOrig) || isAllocationCall(TmpOrig, TLI)) {
done.insert(
std::make_tuple((User *)SI, SI->getPointerOperand(), UA));
// If we are capturing a variable v, we need to check any loads or
// stores into that variable, even if we are checking only for
// stores.
auto UA2 = UA;
if (UA == UseActivity::OnlyStores ||
UA == UseActivity::OnlyNonPointerStores ||
UA == UseActivity::AllStores)
UA2 = UseActivity::None;
for (const auto a : TmpOrig->users()) {
todo.push_back(std::make_tuple(a, TmpOrig, UA));
todo.push_back(std::make_tuple(a, TmpOrig, UA2));
}
AllocaSet.insert(TmpOrig);
if (EnzymePrintActivity)
Expand Down
46 changes: 24 additions & 22 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1380,31 +1380,33 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
ss << "Cannot deduce adding type (cast) of " << I;
EmitNoTypeError(str, I, gutils, Builder2);
}
assert(FT);

auto rule = [&](Value *dif) {
if (I.getOpcode() == CastInst::CastOps::FPTrunc ||
I.getOpcode() == CastInst::CastOps::FPExt) {
return Builder2.CreateFPCast(dif, op0->getType());
} else if (I.getOpcode() == CastInst::CastOps::BitCast) {
return Builder2.CreateBitCast(dif, op0->getType());
} else if (I.getOpcode() == CastInst::CastOps::Trunc) {
// TODO CHECK THIS
return Builder2.CreateZExt(dif, op0->getType());
} else {
std::string s;
llvm::raw_string_ostream ss(s);
ss << *I.getParent()->getParent() << "\n";
ss << "cannot handle above cast " << I << "\n";
EmitNoDerivativeError(ss.str(), I, gutils, Builder2);
return (llvm::Value *)UndefValue::get(op0->getType());
}
};
if (FT) {

auto rule = [&](Value *dif) {
if (I.getOpcode() == CastInst::CastOps::FPTrunc ||
I.getOpcode() == CastInst::CastOps::FPExt) {
return Builder2.CreateFPCast(dif, op0->getType());
} else if (I.getOpcode() == CastInst::CastOps::BitCast) {
return Builder2.CreateBitCast(dif, op0->getType());
} else if (I.getOpcode() == CastInst::CastOps::Trunc) {
// TODO CHECK THIS
return Builder2.CreateZExt(dif, op0->getType());
} else {
std::string s;
llvm::raw_string_ostream ss(s);
ss << *I.getParent()->getParent() << "\n";
ss << "cannot handle above cast " << I << "\n";
EmitNoDerivativeError(ss.str(), I, gutils, Builder2);
return (llvm::Value *)UndefValue::get(op0->getType());
}
};

Value *dif = diffe(&I, Builder2);
Value *diff = applyChainRule(op0->getType(), Builder2, rule, dif);
Value *dif = diffe(&I, Builder2);
Value *diff = applyChainRule(op0->getType(), Builder2, rule, dif);

addToDiffe(orig_op0, diff, Builder2, FT);
addToDiffe(orig_op0, diff, Builder2, FT);
}
}

Type *diffTy = gutils->getShadowType(I.getType());
Expand Down
31 changes: 28 additions & 3 deletions enzyme/Enzyme/CallDerivatives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
using namespace llvm;

extern "C" {
void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *) = nullptr;
void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *, LLVMValueRef, uint64_t,
LLVMValueRef, uint8_t) = nullptr;
}

void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
Expand Down Expand Up @@ -3014,6 +3015,9 @@ bool AdjointGenerator::handleKnownCallDerivatives(
bb, anti, getIndex(&call, CacheType::Shadow, BuilderZ));
} else {
bool zeroed = false;
uint64_t idx = 0;
Value *prev = nullptr;
;
auto rule = [&]() {
Value *anti =
bb.CreateCall(call.getFunctionType(), call.getCalledOperand(),
Expand Down Expand Up @@ -3058,8 +3062,12 @@ bool AdjointGenerator::handleKnownCallDerivatives(
if (funcName == "julia.gc_alloc_obj" ||
funcName == "jl_gc_alloc_typed" ||
funcName == "ijl_gc_alloc_typed") {
if (EnzymeShadowAllocRewrite)
EnzymeShadowAllocRewrite(wrap(anti), gutils);
if (EnzymeShadowAllocRewrite) {
bool used = unnecessaryInstructions.find(&call) ==
unnecessaryInstructions.end();
EnzymeShadowAllocRewrite(wrap(anti), gutils, wrap(&call),
idx, wrap(prev), used);
}
}
}
if (Mode == DerivativeMode::ReverseModeCombined ||
Expand All @@ -3075,6 +3083,8 @@ bool AdjointGenerator::handleKnownCallDerivatives(
zeroed = true;
}
}
idx++;
prev = anti;
return anti;
};

Expand Down Expand Up @@ -3224,6 +3234,8 @@ bool AdjointGenerator::handleKnownCallDerivatives(
args.push_back(gutils->getNewFromOriginal(arg));
}

uint64_t idx = 0;
Value *prev = gutils->getNewFromOriginal(&call);
auto rule = [&]() {
SmallVector<ValueType, 2> BundleTypes(args.size(), ValueType::Primal);

Expand All @@ -3236,6 +3248,19 @@ bool AdjointGenerator::handleKnownCallDerivatives(
CI->setCallingConv(call.getCallingConv());
CI->setTailCallKind(call.getTailCallKind());
CI->setDebugLoc(dbgLoc);

if (funcName == "julia.gc_alloc_obj" ||
funcName == "jl_gc_alloc_typed" ||
funcName == "ijl_gc_alloc_typed") {
if (EnzymeShadowAllocRewrite) {
bool used = unnecessaryInstructions.find(&call) ==
unnecessaryInstructions.end();
EnzymeShadowAllocRewrite(wrap(CI), gutils, wrap(&call), idx,
wrap(prev), used);
}
}
idx++;
prev = CI;
return CI;
};

Expand Down
10 changes: 7 additions & 3 deletions enzyme/Enzyme/DiffeGradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1179,9 +1179,13 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(
// the pointers and conditionally execute.
if ((!isa<AllocaInst>(basePtr) && !isAllocationCall(basePtr, TLI)) &&
runtimeActivity && !merge) {
Value *shadow = Builder2.CreateICmpNE(
lookupM(getNewFromOriginal(origptr), Builder2),
lookupM(invertPointerM(origptr, Builder2), Builder2));
Value *primal_val = lookupM(getNewFromOriginal(origptr), Builder2);
Value *shadow_val =
lookupM(invertPointerM(origptr, Builder2), Builder2);
if (getWidth() != 1) {
shadow_val = extractMeta(Builder2, shadow_val, 0);
}
Value *shadow = Builder2.CreateICmpNE(primal_val, shadow_val);

BasicBlock *current = Builder2.GetInsertBlock();
BasicBlock *conditional =
Expand Down
10 changes: 10 additions & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3822,6 +3822,9 @@ bool GradientUtils::legalRecompute(const Value *val,
}
}

if (isa<AtomicRMWInst>(val))
return false;

if (auto phi = dyn_cast<PHINode>(val)) {
if (auto uiv = hasUninverted(val)) {
if (auto dli = dyn_cast_or_null<LoadInst>(uiv)) {
Expand All @@ -3835,6 +3838,13 @@ bool GradientUtils::legalRecompute(const Value *val,
}
}

auto found = fictiousPHIs.find(const_cast<llvm::PHINode *>(phi));
if (found != fictiousPHIs.end()) {
auto orig = found->second;
if (isa<AtomicRMWInst>(orig))
return false;
}

if (phi->getNumIncomingValues() == 0) {
llvm::errs() << *oldFunc << "\n";
llvm::errs() << *newFunc << "\n";
Expand Down
16 changes: 16 additions & 0 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,20 @@ def GenericAdjointOp : Enzyme_Op<"genericAdjoint", [AttrSizedOperandSegments]> {

}

def BroadcastOp : Enzyme_Op<"broadcast"> {
let description = [{
Broadcast the operand by adding extra dimensions with sizes provided by the `shape` attribute to the front.
For scalar operands, ranked tensor is created.

NOTE: Only works for scalar and *ranked* tensor operands for now.
}];

let arguments = (ins AnyType:$input, DenseI64ArrayAttr:$shape);
let results = (outs AnyRankedTensor:$output);

let builders = [
OpBuilder<(ins "Value":$input, "ArrayRef<int64_t>":$shape)>
];
}

#endif // ENZYME_OPS
15 changes: 15 additions & 0 deletions enzyme/Enzyme/MLIR/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"

Expand Down Expand Up @@ -191,3 +192,17 @@ LogicalResult BatchOp::verifySymbolUses(SymbolTableCollection &symbolTable) {

return success();
}

//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//

void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input,
ArrayRef<int64_t> shape) {
auto shapeAttr = builder.getDenseI64ArrayAttr(shape);
auto resultTy = input.getType();
for (auto s : llvm::reverse(shape)) {
resultTy = resultTy.cast<AutoDiffTypeInterface>().getShadowType(s);
}
build(builder, result, resultTy, input, shapeAttr);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "Interfaces/GradientUtilsReverse.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"

Expand Down Expand Up @@ -69,3 +70,10 @@ void mlir::enzyme::registerArithDialectAutoDiffInterface(
arith::ConstantOp::attachInterface<ArithConstantOpBatchInterface>(*context);
});
}

void mlir::enzyme::registerTensorDialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, tensor::TensorDialect *) {
registerInterfaces(context);
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ void mlir::enzyme::detail::branchingForwardHandler(Operation *inst,
newVals.push_back(gutils->invertPointerM(op, builder));
} else {
Type retTy =
arg.getType().cast<AutoDiffTypeInterface>().getShadowType();
arg.getType().cast<AutoDiffTypeInterface>().getShadowType(
gutils->width);
auto toret = retTy.cast<AutoDiffTypeInterface>().createNullValue(
builder, op.getLoc());
newVals.push_back(toret);
Expand Down Expand Up @@ -146,7 +147,7 @@ LogicalResult mlir::enzyme::detail::memoryIdentityForwardHandler(
if (auto iface =
dyn_cast<AutoDiffTypeInterface>(operand.get().getType())) {
if (!iface.isMutable()) {
Type retTy = iface.getShadowType();
Type retTy = iface.getShadowType(gutils->width);
auto toret = retTy.cast<AutoDiffTypeInterface>().createNullValue(
builder, operand.get().getLoc());
newOperands.push_back(toret);
Expand Down Expand Up @@ -346,7 +347,7 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler(
<< result.getType() << "\n";
return failure();
}
newOpResultTypes.push_back(typeIface.getShadowType());
newOpResultTypes.push_back(typeIface.getShadowType(gutils->width));
}

SmallVector<Value> newOperands;
Expand Down Expand Up @@ -432,4 +433,5 @@ void mlir::enzyme::registerCoreDialectAutodiffInterfaces(
enzyme::registerCFDialectAutoDiffInterface(registry);
enzyme::registerLinalgDialectAutoDiffInterface(registry);
enzyme::registerFuncDialectAutoDiffInterface(registry);
enzyme::registerTensorDialectAutoDiffInterface(registry);
}
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ void registerCFDialectAutoDiffInterface(DialectRegistry &registry);
void registerLinalgDialectAutoDiffInterface(DialectRegistry &registry);
void registerMathDialectAutoDiffInterface(DialectRegistry &registry);
void registerFuncDialectAutoDiffInterface(DialectRegistry &registry);
void registerTensorDialectAutoDiffInterface(DialectRegistry &registry);

void registerCoreDialectAutodiffInterfaces(DialectRegistry &registry);

Expand Down
Loading

0 comments on commit cf998f8

Please sign in to comment.