Skip to content

Commit

Permalink
Custom sret (#875)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Oct 3, 2022
1 parent dfa4088 commit 24774ad
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 10 deletions.
87 changes: 87 additions & 0 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -11594,6 +11594,7 @@ class AdjointGenerator
SmallVector<Value *, 8> args;
std::vector<DIFFE_TYPE> argsInverted;
std::map<int, Type *> gradByVal;
std::map<int, Attribute> structAttrs;

#if LLVM_VERSION_MAJOR >= 14
for (unsigned i = 0; i < orig->arg_size(); ++i)
Expand All @@ -11602,6 +11603,18 @@ class AdjointGenerator
#endif
{

if (orig->paramHasAttr(i, Attribute::StructRet)) {
structAttrs[args.size()] =
#if LLVM_VERSION_MAJOR >= 12
// TODO persist types
Attribute::get(orig->getContext(), "enzyme_sret");
// Attribute::get(orig->getContext(), "enzyme_sret",
// orig->getParamAttr(i, Attribute::StructRet).getValueAsType());
#else
Attribute::get(orig->getContext(), "enzyme_sret");
#endif
}

auto argi = gutils->getNewFromOriginal(orig->getArgOperand(i));

#if LLVM_VERSION_MAJOR >= 9
Expand Down Expand Up @@ -11646,6 +11659,33 @@ class AdjointGenerator
continue;
}

if (orig->paramHasAttr(i, Attribute::StructRet)) {
structAttrs[args.size()] =
Attribute::get(orig->getContext(), "enzyme_sret");
if (gutils->getWidth() == 1) {
structAttrs[args.size()] =
#if LLVM_VERSION_MAJOR >= 12
// TODO persist types
Attribute::get(orig->getContext(), "enzyme_sret");
// Attribute::get(orig->getContext(), "enzyme_sret",
// orig->getParamAttr(i, Attribute::StructRet).getValueAsType());
#else
Attribute::get(orig->getContext(), "enzyme_sret");
#endif
} else {
structAttrs[args.size()] =
#if LLVM_VERSION_MAJOR >= 12
// TODO persist types
Attribute::get(orig->getContext(), "enzyme_sret");
// Attribute::get(orig->getContext(), "enzyme_sret_v",
// gutils->getShadowType(orig->getParamAttr(ii,
// Attribute::StructRet).getValueAsType()));
#else
Attribute::get(orig->getContext(), "enzyme_sret_v");
#endif
}
}

assert(argTy == DIFFE_TYPE::DUP_ARG || argTy == DIFFE_TYPE::DUP_NONEED);

args.push_back(
Expand Down Expand Up @@ -11755,6 +11795,9 @@ class AdjointGenerator
Attribute::getWithByValType(diffes->getContext(), pair.second));
}
#endif
for (auto pair : structAttrs) {
diffes->addParamAttr(pair.first, pair.second);
}

auto newcall = gutils->getNewFromOriginal(orig);
auto ifound = gutils->invertedPointers.find(orig);
Expand Down Expand Up @@ -11814,6 +11857,7 @@ class AdjointGenerator
SmallVector<Instruction *, 4> userReplace;
std::map<int, Type *> preByVal;
std::map<int, Type *> gradByVal;
std::map<int, Attribute> structAttrs;

bool replaceFunction = false;

Expand All @@ -11839,6 +11883,19 @@ class AdjointGenerator
preByVal[pre_args.size()] = orig->getParamByValType(i);
}
#endif
if (orig->paramHasAttr(i, Attribute::StructRet)) {
structAttrs[pre_args.size()] =
#if LLVM_VERSION_MAJOR >= 12
// TODO persist types
Attribute::get(orig->getContext(), "enzyme_sret");
// Attribute::get(orig->getContext(), "enzyme_sret",
// orig->getParamAttr(ii, Attribute::StructRet).getValueAsType());
#else
// TODO persist types
Attribute::get(orig->getContext(), "enzyme_sret");
// Attribute::get(orig->getContext(), "enzyme_sret");
#endif
}

pre_args.push_back(argi);

Expand Down Expand Up @@ -11894,6 +11951,30 @@ class AdjointGenerator
auto argType = argi->getType();

if (argTy == DIFFE_TYPE::DUP_ARG || argTy == DIFFE_TYPE::DUP_NONEED) {
if (orig->paramHasAttr(i, Attribute::StructRet)) {
if (gutils->getWidth() == 1) {
structAttrs[pre_args.size()] =
#if LLVM_VERSION_MAJOR >= 12
// TODO persist types
Attribute::get(orig->getContext(), "enzyme_sret");
// Attribute::get(orig->getContext(), "enzyme_sret",
// orig->getParamAttr(ii, Attribute::StructRet).getValueAsType());
#else
Attribute::get(orig->getContext(), "enzyme_sret");
#endif
} else {
structAttrs[pre_args.size()] =
#if LLVM_VERSION_MAJOR >= 12
// TODO persist types
Attribute::get(orig->getContext(), "enzyme_sret_v");
// Attribute::get(orig->getContext(), "enzyme_sret_v",
// gutils->getShadowType(orig->getParamAttr(ii,
// Attribute::StructRet).getValueAsType()));
#else
Attribute::get(orig->getContext(), "enzyme_sret_v");
#endif
}
}
if (Mode != DerivativeMode::ReverseModePrimal) {
IRBuilder<> Builder2(call.getParent());
getReverseBuilder(Builder2);
Expand Down Expand Up @@ -12140,6 +12221,9 @@ class AdjointGenerator
pair.second));
}
#endif
for (auto pair : structAttrs) {
augmentcall->addParamAttr(pair.first, pair.second);
}

if (!augmentcall->getType()->isVoidTy())
augmentcall->setName(orig->getName() + "_augmented");
Expand Down Expand Up @@ -12508,6 +12592,9 @@ class AdjointGenerator
diffes->getContext(), pair.second));
}
#endif
for (auto pair : structAttrs) {
diffes->addParamAttr(pair.first, pair.second);
}

unsigned structidx = 0;
if (replaceFunction) {
Expand Down
5 changes: 5 additions & 0 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2482,6 +2482,11 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
if (nf->hasParamAttribute(attrIndex, Attribute::NoAlias)) {
NewF->addParamAttr(attrIndex, Attribute::NoAlias);
}
for (auto name : {"enzyme_sret", "enzyme_sret_v"})
if (nf->getAttributes().hasParamAttr(attrIndex, name)) {
NewF->addParamAttr(attrIndex,
nf->getAttributes().getParamAttr(attrIndex, name));
}

j->setName(i->getName());
++j;
Expand Down
40 changes: 40 additions & 0 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2063,6 +2063,17 @@ Function *PreProcessCache::CloneFunctionWithReturns(
bool hasPtrInput = false;
unsigned ii = 0, jj = 0;
for (auto i = F->arg_begin(), j = NewF->arg_begin(); i != F->arg_end();) {
if (F->hasParamAttribute(ii, Attribute::StructRet)) {
// TODO persist types
NewF->addParamAttr(jj, Attribute::get(F->getContext(), "enzyme_sret"));
/*
#if LLVM_VERSION_MAJOR >= 12
NewF->addParamAttr(jj, Attribute::get(F->getContext(), "enzyme_sret",
F->getParamAttribute(ii, Attribute::StructRet).getValueAsType())); #else
NewF->addParamAttr(jj, Attribute::get(F->getContext(), "enzyme_sret"));
#endif
*/
}
if (constant_args[ii] == DIFFE_TYPE::CONSTANT) {
if (!i->hasByValAttr())
constants.insert(i);
Expand All @@ -2084,6 +2095,35 @@ Function *PreProcessCache::CloneFunctionWithReturns(
if (F->hasParamAttribute(ii, Attribute::NoCapture) && width == 1) {
NewF->addParamAttr(jj + 1, Attribute::NoCapture);
}
// TODO: find a way to keep sret for shadow
if (F->hasParamAttribute(ii, Attribute::StructRet)) {
if (width == 1) {
#if LLVM_VERSION_MAJOR >= 12
// TODO persist types
NewF->addParamAttr(jj + 1,
Attribute::get(F->getContext(), "enzyme_sret"));
// NewF->addParamAttr(jj + 1, Attribute::get(F->getContext(),
// "enzyme_sret", F->getParamAttribute(ii,
// Attribute::StructRet).getValueAsType()));
#else
NewF->addParamAttr(jj + 1,
Attribute::get(F->getContext(), "enzyme_sret"));
#endif
} else {
#if LLVM_VERSION_MAJOR >= 12
// TODO persist types
NewF->addParamAttr(jj + 1,
Attribute::get(F->getContext(), "enzyme_sret_v"));
// NewF->addParamAttr(jj + 1, Attribute::get(F->getContext(),
// "enzyme_sret_v",
// GradientUtils::getShadowType(F->getParamAttribute(ii,
// Attribute::StructRet).getValueAsType(), width)));
#else
NewF->addParamAttr(jj + 1,
Attribute::get(F->getContext(), "enzyme_sret_v"));
#endif
}
}

j->setName(i->getName());
++j;
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/Enzyme/ForwardMode/sret.ll
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ attributes #6 = { nounwind }
; CHECK-NEXT: }


; CHECK: define internal void @fwddiffe_Z6squared(%"struct.std::array"* noalias nocapture %agg.result, %"struct.std::array"* nocapture %"agg.result'", double %x, double %"x'")
; CHECK: define internal void @fwddiffe_Z6squared(%"struct.std::array"* noalias nocapture "enzyme_sret" %agg.result, %"struct.std::array"* nocapture "enzyme_sret" %"agg.result'", double %x, double %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %"arrayinit.begin'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 0
; CHECK-NEXT: %arrayinit.begin = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 0
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/ForwardMode/sret12.ll
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ attributes #6 = { nounwind }
; CHECK-NEXT: }


; CHECK: define internal void @fwddiffe_Z6squared(%"struct.std::array"* noalias nocapture align 8 %agg.result, %"struct.std::array"* nocapture %"agg.result'", double %x, double %"x'") #0 {
; CHECK: define internal void @fwddiffe_Z6squared(%"struct.std::array"* noalias nocapture align 8 "enzyme_sret" %agg.result, %"struct.std::array"* nocapture "enzyme_sret" %"agg.result'", double %x, double %"x'") #0 {
; CHECK-NEXT: entry:
; CHECK-NEXT: %"arrayinit.begin'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 0
; CHECK-NEXT: %arrayinit.begin = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 0
Expand All @@ -103,4 +103,4 @@ attributes #6 = { nounwind }
; CHECK-NEXT: store double %x, double* %arrayinit.element3, align 8
; CHECK-NEXT: store double %"x'", double* %"arrayinit.element3'ipg", align 8
; CHECK-NEXT: ret void
; CHECK-NEXT: }
; CHECK-NEXT: }
2 changes: 1 addition & 1 deletion enzyme/test/Enzyme/ForwardModeSplit/sret.ll
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ attributes #6 = { nounwind }
; CHECK-NEXT: }


; CHECK: define internal void @fwddiffe_Z6squared(%"struct.std::array"* noalias nocapture %agg.result, %"struct.std::array"* nocapture %"agg.result'", double %x, double %"x'", i8* %tapeArg)
; CHECK: define internal void @fwddiffe_Z6squared(%"struct.std::array"* noalias nocapture "enzyme_sret" %agg.result, %"struct.std::array"* nocapture "enzyme_sret" %"agg.result'", double %x, double %"x'", i8* %tapeArg)
; CHECK-NEXT: entry:
; CHECK-NEXT: %"arrayinit.begin'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 0
; CHECK-NEXT: %mul = fmul double %x, %x
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/Enzyme/ForwardModeSplit/sret12.ll
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ attributes #6 = { nounwind }
; CHECK-NEXT: }


; CHECK: define internal void @fwddiffe_Z6squared(%"struct.std::array"* noalias nocapture align 8 %agg.result, %"struct.std::array"* nocapture %"agg.result'", double %x, double %"x'", i8* %tapeArg)
; CHECK: define internal void @fwddiffe_Z6squared(%"struct.std::array"* noalias nocapture align 8 "enzyme_sret" %agg.result, %"struct.std::array"* nocapture "enzyme_sret" %"agg.result'", double %x, double %"x'", i8* %tapeArg)
; CHECK-NEXT: entry:
; CHECK-NEXT: %"arrayinit.begin'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 0
; CHECK-NEXT: %mul = fmul double %x, %x
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/Enzyme/ForwardModeVector/sret.ll
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ attributes #4 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disa
attributes #5 = { argmemonly nounwind }
attributes #6 = { nounwind }

; CHECK: define internal void @fwddiffe3_Z6squared(%"struct.std::array"* noalias nocapture align 8 %agg.result, [3 x %"struct.std::array"*] %"agg.result'", double %x, [3 x double] %"x'")
; CHECK: define internal void @fwddiffe3_Z6squared(%"struct.std::array"* noalias nocapture align 8 "enzyme_sret" %agg.result, [3 x %"struct.std::array"*] "enzyme_sret_v" %"agg.result'", double %x, [3 x double] %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = extractvalue [3 x %"struct.std::array"*] %"agg.result'", 0
; CHECK-NEXT: %"arrayinit.begin'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %0, i64 0, i32 0, i64 0
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/Enzyme/ForwardModeVector/sret12.ll
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ attributes #5 = { argmemonly nofree nosync nounwind willreturn }
attributes #6 = { nounwind uwtable willreturn mustprogress "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #7 = { nounwind }

; CHECK: define {{[^@]+}}@fwddiffe3_Z6squared(%"struct.std::array"* noalias nocapture align 8 [[AGG_RESULT:%.*]], [3 x %"struct.std::array"*] %"agg.result'", double [[X:%.*]], [3 x double] %"x'")
; CHECK: define {{[^@]+}}@fwddiffe3_Z6squared(%"struct.std::array"* noalias nocapture align 8 "enzyme_sret" [[AGG_RESULT:%.*]], [3 x %"struct.std::array"*] "enzyme_sret_v" %"agg.result'", double [[X:%.*]], [3 x double] %"x'")
; CHECK-NEXT: entry:
; CHECK-NEXT: [[TMP0:%.*]] = extractvalue [3 x %"struct.std::array"*] %"agg.result'", 0
; CHECK-NEXT: %"arrayinit.begin'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* [[TMP0]], i64 0, i32 0, i64 0
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/Enzyme/ReverseMode/sret2-12.ll
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ attributes #4 = { nounwind }
; CHECK-NEXT: }


; CHECK: define internal { double } @diffe_Z6squared(%"struct.std::array"* noalias nocapture align 8 %agg.result, %"struct.std::array"* nocapture %"agg.result'", double %x)
; CHECK: define internal { double } @diffe_Z6squared(%"struct.std::array"* noalias nocapture align 8 "enzyme_sret" %agg.result, %"struct.std::array"* nocapture "enzyme_sret" %"agg.result'", double %x)
; CHECK-NEXT: entry:
; CHECK-NEXT: %"arrayinit.begin'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 0
; CHECK-NEXT: %arrayinit.begin = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 0
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/ReverseMode/sret2.ll
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ attributes #4 = { nounwind }
; CHECK-NEXT: }


; CHECK: define internal { double } @diffe_Z6squared(%"struct.std::array"* noalias nocapture %agg.result, %"struct.std::array"* nocapture %"agg.result'", double %x)
; CHECK: define internal { double } @diffe_Z6squared(%"struct.std::array"* noalias nocapture "enzyme_sret" %agg.result, %"struct.std::array"* nocapture "enzyme_sret" %"agg.result'", double %x)
; CHECK-NEXT: entry:
; CHECK-NEXT: %"arrayinit.begin'ipg" = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %"agg.result'", i64 0, i32 0, i64 0
; CHECK-NEXT: %arrayinit.begin = getelementptr inbounds %"struct.std::array", %"struct.std::array"* %agg.result, i64 0, i32 0, i64 0
Expand Down Expand Up @@ -102,4 +102,4 @@ attributes #4 = { nounwind }
; CHECK-NEXT: %6 = fadd fast double %4, %5
; CHECK-NEXT: %7 = insertvalue { double } undef, double %6, 0
; CHECK-NEXT: ret { double } %7
; CHECK-NEXT: }
; CHECK-NEXT: }

0 comments on commit 24774ad

Please sign in to comment.