Skip to content

Commit

Permalink
Fix shadow remat bug (#959)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Dec 25, 2022
1 parent 4cc3434 commit e0bb22c
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 23 deletions.
24 changes: 23 additions & 1 deletion enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7432,7 +7432,7 @@ void GradientUtils::computeForwardingProperties(Instruction *V) {
if (shadowpromotable && !isConstantValue(V)) {
for (auto LI : shadowPointerLoads) {
// Is there a store which could occur after the load.
// In other words
// This subsequent store would invalidate any loads being re-performed.
SmallVector<Instruction *, 2> results;
mayExecuteAfter(results, LI, storingOps, outer);
for (auto res : results) {
Expand All @@ -7447,6 +7447,28 @@ void GradientUtils::computeForwardingProperties(Instruction *V) {
}
}
}
// If there is a store not reproduced in the reverse pass (e.g. as part
// of a write in a call), and this store is necessary to a pointer load of
// the shadow, this is not materializable since the load will not return
// the same value.
{
SmallVector<Instruction *, 2> nonReproducedStores;
for (auto S : storingOps)
if (!stores.count(S)) {
SmallVector<Instruction *, 2> results;
SmallPtrSet<Instruction *, 2> shadowPtrLoadSet(
shadowPointerLoads.begin(), shadowPointerLoads.end());
mayExecuteAfter(results, S, shadowPtrLoadSet, outer);
if (results.size()) {
EmitWarning("NotPromotable", *results[0],
" Could not promote shadow allocation ", *V,
" due to non-reproduced store ", *S,
" which may impact pointer load ", *results[0]);
shadowpromotable = false;
goto exitL;
}
}
}
exitL:;
if (shadowpromotable) {
backwardsOnlyShadows[V] = ShadowRematerializer(
Expand Down
5 changes: 1 addition & 4 deletions enzyme/benchmarks/ReverseMode/hand/Makefile.make
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# RUN: cd %S && LD_LIBRARY_PATH="%bldpath:$LD_LIBRARY_PATH" BENCH="%bench" BENCHLINK="%blink" LOAD="%loadEnzyme" make -B hand-raw.ll results.json -f %s

# This test is broken
# XFAIL: *

.PHONY: clean

clean:
Expand All @@ -22,4 +19,4 @@ hand.o: hand-opt.ll
clang++ $^ -o $@ -lblas $(BENCHLINK)

results.json: hand.o
./$^
./$^
55 changes: 55 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/nonpromoteshadow.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s

declare i8* @malloc(i64)

define void @set(double** writeonly nocapture %p) {
entry:
%m = call i8* @malloc(i64 8)
%ptr = bitcast i8* %m to double*
store double* %ptr, double** %p, align 8
ret void
}

define double @square(double %x) {
entry:
%a = alloca double*, align 8
call void @set(double** %a)
%m = load double*, double** %a, align 8
store double %x, double* %m, align 8
%ld = load double, double* %m, align 8
%mul = fmul double %ld, %ld
ret double %mul
}

declare dso_local i8* @__enzyme_virtualreverse(i8*)

define i8* @dsquare(double %x) local_unnamed_addr {
entry:
%call = tail call i8* @__enzyme_virtualreverse(i8* bitcast (double (double)* @square to i8*))
ret i8* %call
}

; CHECK: define internal { double } @diffesquare(double %x, double %differeturn, i8* %tapeArg)
; CHECK-NEXT: entry:
; CHECK-NEXT: %0 = bitcast i8* %tapeArg to { { i8*, i8* }, i8*, double }*
; CHECK-NEXT: %truetape = load { { i8*, i8* }, i8*, double }, { { i8*, i8* }, i8*, double }* %0
; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg)
; CHECK-NEXT: %"malloccall'mi" = extractvalue { { i8*, i8* }, i8*, double } %truetape, 1
; CHECK-NEXT: %"a'ipc" = bitcast i8* %"malloccall'mi" to double**
; CHECK-NEXT: %tapeArg1 = extractvalue { { i8*, i8* }, i8*, double } %truetape, 0
; CHECK-NEXT: %"m'ipl" = load double*, double** %"a'ipc", align 8
; CHECK-NEXT: %ld = extractvalue { { i8*, i8* }, i8*, double } %truetape, 2
; CHECK-NEXT: %m0diffeld = fmul fast double %differeturn, %ld
; CHECK-NEXT: %m1diffeld = fmul fast double %differeturn, %ld
; CHECK-NEXT: %1 = fadd fast double %m0diffeld, %m1diffeld
; CHECK-NEXT: %2 = load double, double* %"m'ipl", align 8
; CHECK-NEXT: %3 = fadd fast double %2, %1
; CHECK-NEXT: store double %3, double* %"m'ipl", align 8
; CHECK-NEXT: %4 = load double, double* %"m'ipl", align 8
; CHECK-NEXT: store double 0.000000e+00, double* %"m'ipl", align 8
; CHECK-NEXT: call void @diffeset(double** undef, double** undef, { i8*, i8* } %tapeArg1)
; CHECK-NEXT: tail call void @free(i8* nonnull %"malloccall'mi")
; CHECK-NEXT: %5 = insertvalue { double } undef, double %4, 0
; CHECK-NEXT: ret { double } %5
; CHECK-NEXT: }

49 changes: 31 additions & 18 deletions enzyme/test/Enzyme/ReverseMode/writeonlyretjlptr.ll
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@ define double @dsquare({} addrspace(10)* %x, {} addrspace(10)* %dx) {

; CHECK: define internal void @diffesquare({} addrspace(10)* %x, {} addrspace(10)* %"x'", double %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %m = call fast double @augmented_mid({} addrspace(10)* %x, {} addrspace(10)* %"x'")
; CHECK-NEXT: %m_augmented = call { {} addrspace(10)*, double } @augmented_mid({} addrspace(10)* %x, {} addrspace(10)* %"x'")
; CHECK-NEXT: %subcache = extractvalue { {} addrspace(10)*, double } %m_augmented, 0
; CHECK-NEXT: %m = extractvalue { {} addrspace(10)*, double } %m_augmented, 1
; CHECK-NEXT: %m0diffem = fmul fast double %differeturn, %m
; CHECK-NEXT: %m1diffem = fmul fast double %differeturn, %m
; CHECK-NEXT: %0 = fadd fast double %m0diffem, %m1diffem
; CHECK-NEXT: call void @diffemid({} addrspace(10)* %x, {} addrspace(10)* %"x'", double %0)
; CHECK-NEXT: call void @diffemid({} addrspace(10)* %x, {} addrspace(10)* %"x'", double %0, {} addrspace(10)* %subcache)
; CHECK-NEXT: ret void
; CHECK-NEXT: }

Expand All @@ -58,32 +60,43 @@ define double @dsquare({} addrspace(10)* %x, {} addrspace(10)* %dx) {
; CHECK-NEXT: ret void
; CHECK-NEXT: }

; CHECK: define internal double @augmented_mid({} addrspace(10)* %x, {} addrspace(10)* %"x'")
; CHECK: define internal { {} addrspace(10)*, double } @augmented_mid({} addrspace(10)* %x, {} addrspace(10)* %"x'")
; CHECK-NEXT: %1 = alloca { {} addrspace(10)*, double }
; CHECK-NEXT: %2 = getelementptr inbounds { {} addrspace(10)*, double }, { {} addrspace(10)*, double }* %1, i32 0, i32 0
; CHECK-NEXT: store {} addrspace(10)* null, {} addrspace(10)** %2
; CHECK-NEXT: %r = alloca {} addrspace(10)*, i64 1, align 8
; CHECK-NEXT: %pg = call {}*** @julia.get_pgcstack()
; CHECK-NEXT: %"r'ai" = alloca {} addrspace(10)*, i64 1, align 8
; CHECK-NEXT: %1 = bitcast {} addrspace(10)** %"r'ai" to {}*
; CHECK-NEXT: %2 = bitcast {}* %1 to i8*
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull dereferenceable(8) dereferenceable_or_null(8) %2, i8 0, i64 8, i1 false)
; CHECK-NEXT: call void @augmented_subsq({} addrspace(10)** %r, {} addrspace(10)** %"r'ai", {} addrspace(10)* %x, {} addrspace(10)* %"x'")
; CHECK-NEXT: %p3 = bitcast {}*** %pg to {}**
; CHECK-NEXT: %p4 = getelementptr inbounds {}*, {}** %p3, i64 -12
; CHECK-NEXT: %p5 = getelementptr inbounds {}*, {}** %p4, i64 14
; CHECK-NEXT: %p6 = bitcast {}** %p5 to i8**
; CHECK-NEXT: %p7 = load i8*, i8** %p6, align 8
; CHECK-NEXT: %"al'mi" = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) {} addrspace(10)* @jl_gc_alloc_typed(i8* %p7, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 139806792221568 to {}*) to {} addrspace(10)*))
; CHECK-NEXT: store {} addrspace(10)* %"al'mi", {} addrspace(10)** %2
; CHECK-NEXT: %3 = bitcast {} addrspace(10)* %"al'mi" to i8 addrspace(10)*
; CHECK-NEXT: call void @llvm.memset.p10i8.i64(i8 addrspace(10)* nonnull dereferenceable(8) dereferenceable_or_null(8) %3, i8 0, i64 8, i1 false)
; CHECK-NEXT: %"r'ipc" = bitcast {} addrspace(10)* %"al'mi" to {} addrspace(10)* addrspace(10)*
; CHECK-NEXT: %"addr'ipc" = addrspacecast {} addrspace(10)* addrspace(10)* %"r'ipc" to {} addrspace(10)**
; CHECK-NEXT: call void @augmented_subsq({} addrspace(10)** %r, {} addrspace(10)** %"addr'ipc", {} addrspace(10)* %x, {} addrspace(10)* %"x'")
; CHECK-NEXT: %l = load {} addrspace(10)*, {} addrspace(10)** %r, align 8
; CHECK-NEXT: %bc = bitcast {} addrspace(10)* %l to double addrspace(10)*
; CHECK-NEXT: %ld = load double, double addrspace(10)* %bc
; CHECK-NEXT: ret double %ld
; CHECK-NEXT: %4 = getelementptr inbounds { {} addrspace(10)*, double }, { {} addrspace(10)*, double }* %1, i32 0, i32 1
; CHECK-NEXT: store double %ld, double* %4
; CHECK-NEXT: %5 = load { {} addrspace(10)*, double }, { {} addrspace(10)*, double }* %1
; CHECK-NEXT: ret { {} addrspace(10)*, double } %5
; CHECK-NEXT: }

; CHECK: define internal void @diffemid({} addrspace(10)* %x, {} addrspace(10)* %"x'", double %differeturn)
; CHECK: define internal void @diffemid({} addrspace(10)* %x, {} addrspace(10)* %"x'", double %differeturn, {} addrspace(10)* %"al'mi")
; CHECK-NEXT: invert:
; CHECK-NEXT: %pg = call {}*** @julia.get_pgcstack()
; CHECK-NEXT: %"r'ai" = alloca {} addrspace(10)*, i64 1, align 8
; CHECK-NEXT: %0 = bitcast {} addrspace(10)** %"r'ai" to {}*
; CHECK-NEXT: %1 = bitcast {}* %0 to i8*
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull dereferenceable(8) dereferenceable_or_null(8) %1, i8 0, i64 8, i1 false)
; CHECK-NEXT: %"l'ipl" = load {} addrspace(10)*, {} addrspace(10)** %"r'ai", align 8
; CHECK-NEXT: %"r'ipc" = bitcast {} addrspace(10)* %"al'mi" to {} addrspace(10)* addrspace(10)*
; CHECK-NEXT: %"addr'ipc" = addrspacecast {} addrspace(10)* addrspace(10)* %"r'ipc" to {} addrspace(10)**
; CHECK-NEXT: %"l'ipl" = load {} addrspace(10)*, {} addrspace(10)** %"addr'ipc", align 8
; CHECK-NEXT: %"bc'ipc" = bitcast {} addrspace(10)* %"l'ipl" to double addrspace(10)*
; CHECK-NEXT: %2 = load double, double addrspace(10)* %"bc'ipc"
; CHECK-NEXT: %3 = fadd fast double %2, %differeturn
; CHECK-NEXT: store double %3, double addrspace(10)* %"bc'ipc"
; CHECK-NEXT: %0 = load double, double addrspace(10)* %"bc'ipc"
; CHECK-NEXT: %1 = fadd fast double %0, %differeturn
; CHECK-NEXT: store double %1, double addrspace(10)* %"bc'ipc"
; CHECK-NEXT: call void @diffesubsq({} addrspace(10)** null, {} addrspace(10)** null, {} addrspace(10)* %x, {} addrspace(10)* %"x'")
; CHECK-NEXT: ret void
; CHECK-NEXT: }
Expand Down

0 comments on commit e0bb22c

Please sign in to comment.