-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
@trace
function calls
#366
base: main
Are you sure you want to change the base?
Conversation
src/Compiler.jl
Outdated
@@ -287,12 +289,16 @@ function compile_mlir(f, args; kwargs...) | |||
end | |||
end | |||
|
|||
const callcache = ScopedValue{Dict}() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like a simple way to store the cache. Fancier solutions (i.e. integrating with absint?) are probably possible as well but not sure whether it's worth it to investigate as of now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think ironically we should do something slightly different here, and we can do something like we do with mmodule (to get the current mlir module):
Reactant.jl/src/mlir/IR/Module.jl
Line 81 in 73899f5
function mmodule(; throw_error::Core.Bool=true) |
If we define the cache scope during the outermost compile / module generation then we're certain that it doesn't get messed up and also can cache calls across subfunctions/etc nicely [which this does too to be clear, but maybe not quite as clean?]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added here: 588cebb
(#366)
lib/ReactantCore/src/ReactantCore.jl
Outdated
@@ -3,6 +3,9 @@ module ReactantCore | |||
using ExpressionExplorer: ExpressionExplorer | |||
using MacroTools: MacroTools | |||
|
|||
using Base.ScopedValues | |||
const enable_tracing = ScopedValue{Bool}(false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we simply want to trace function calls everytime in make_mlir_fn
. I added this scopedvalue that is set in the beginning so it's just a simple check.
If it's okay to use ScopedValues in ReactantCore, ScopedValues probably needs to be added to project.toml for backwards compatibility (?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with the interpreter landed, we can have a function within_tracing similar to https://github.com/EnzymeAD/Enzyme.jl/blob/7c0823fa64426a745dae8fc7a50980be0a0a8dc8/lib/EnzymeCore/src/EnzymeCore.jl#L561 https://github.com/EnzymeAD/Enzyme.jl/blob/7c0823fa64426a745dae8fc7a50980be0a0a8dc8/src/compiler/interpreter.jl#L916
which would be nicer than a global
de287e2
to
5922ef0
Compare
src/ControlFlow.jl
Outdated
seen_cache, | ||
args, | ||
(), | ||
CallCache; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can call this mode TracedToTypes or something (since really its converting Traced to a Type expression)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also isn't sufficient, since the cache key will only consider the outermost object as the key.
For example
mutable struct T
x
end
X = T(1.0)
cache_key1 = make_key(X)
X.x = 2.0
cache_key2 = make_key(X)
Presumably the two keys will be the same since the hash key will be based on the object ID, not the potentially recursive structure.
I think we probably want to make a version of make_tracer that returns a tuple of all of the leaf objects, and we can obviously replace all tracedvalues/numbers with just the type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just the leaves is not enough I think. For example, foo
and bar
have the same type and the same leaves. But should be cached separately.
struct Foo
x
end
struct Bar
x
end
foo = Foo(Foo(1))
bar = Foo(Bar(1))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe
If leaf: the object
Otherwise: typeof(object), objects(children…)…
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, that could work. I'll experiment a bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
68559c5
(#366) What do you think of this approach. I created a Cached
type with custom equality and hashing to be used as dictionary keys.
The equality and hashing recursively visit all fields and check type and actual values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying out extending make_tracer
to generate a tuple or array containing types and leaf values.
How should we handle self-recursive objects, i.e. a
contains b
contains a
.
The other trace modes return early when they encounter an object for the second time (by consulting seen
).
But for generating a cache key, that behavior doesn't seem desired because we want the same cache key for:
a = [1, 2, 3]
b = [1, 2, 3]
# (a, a) should have same cache key as (a, b)
@wsmoses the example you gave against deepcopy
was file handles I believe, does deepcopy
simply not work for those or are their other issues going on there?
src/Tracing.jl
Outdated
@@ -382,6 +383,12 @@ function make_tracer( | |||
if mode == ConcreteToTraced | |||
throw("Cannot trace existing trace type") | |||
end | |||
if mode == CallCache | |||
if !haskey(seen, prev) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Edit, started writing this, but I realize the design has a different issue we need to solve first that point moot (https://github.com/EnzymeAD/Reactant.jl/pull/366/files#r1885172358)
So this is going to present an issue for some code.
For example
struct Wrap
x::AbstractArray
end
We will try to set x to the MLIR.Type which is busted.
I think what we want to do here is make a fake abstractarray or even rarray type that solely contains the size and eltype.
e3a201f
to
ef02755
Compare
src/ControlFlow.jl
Outdated
struct Cached | ||
obj | ||
end | ||
Base.:(==)(a::Cached, b::Cached) = recursive_equal(a.obj, b.obj) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit worried about this key in the context of mutation. For example
x = [1.0]
cache something with x
x[1] = 2.0
check cache (the key in the set is still based off of the x object, and now would hash to a different thing)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I hadn't accounted for that. So the problem is that the original key is corrupted when it is mutated. This can lead to too many functions being generated but shouldn't lead to incorrect code generation, I believe.
struct Foo
x
end
a = rand(10)
b = deepcopy(a)
foo = Cached(Foo(a))
cache = Dict(Cached(Foo(a))=>"cache entry")
cache[Cached(Foo(a))] # "cache entry"
a[1] = 0.
cache[Cached(Foo(a))] # KeyError (as it should)
cache[Cached(Foo(b))] # KeyError!
I think it should suffice to add a deepcopy, i.e.:
Base.:(==)(a::Cached, b::Cached) = recursive_equal(deepcopy(a.obj), deepcopy(b.obj))
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deepcopies themselves are prone to failues (and generally are advised against in julialang/julia). Is the recursive tuple/array generation potentially still on the table?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have started implementing this but am not sure how to handle recursive datastructures. I had written about this in a previous comment chain here: #366
Trying out extending make_tracer to generate a tuple or array containing types and leaf values.
How should we handle self-recursive objects, i.e. a contains b contains a.
The other trace modes return early when they encounter an object for the second time (by consultingseen
).
But for generating a cache key, that behavior doesn't seem desired because we want the same cache key for:a = [1, 2, 3] b = [1, 2, 3] # (a, a) should have same cache key as (a, b)
@wsmoses the example you gave against deepcopy was file handles I believe, does deepcopy simply not work for those or are their other issues going on there?
I found this RE: file handles: https://discourse.julialang.org/t/deepcopy-dangers-and-drawbacks-for-nested-data-structures/111998/9 but I'm not certain that this is a problem in our case as the stuff we'd be deepcopy
ing will only be used for hashing and equality checks later on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ooh okay, what if the iddict mapped the structure to the size of the dict at time of insertion. Therefore we can represent the fact that it had a pointer to the 2nd object, or what not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I understand what you're proposing, could you clarify?
Thinking out loud with a more problematic example:
a.field = a
# vs.
a′.field = b
b.field = a′
Assuming a
, a′
, and b
have the same type: a
and a′
should point to the same cache entry while the structure in memory isn't identical.
One way I see to implement equality checks between two objects is to keep track of visited pairs:
(a, a′)
(a, b )
(a, a′) <-- already seen, done checking equality!
I'm still not completely convinced on the reservation against deepcopy
, but maybe if I understand what you're proposing I'll see the light :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So in this case I think it's fine if they cache to different things
So my idea is essentially as follows:
Create a map of object to ID (which is just the index of the time we inserted the object
For each object:
If traced value, emit that value
If a primitive type (int etc), emit that object
Otherwise we are an object
Check if we've emitted that object before, if so emit a new type Object(id) where id is the value in the map
Otherwise
Map[object] = map.size()
Emit object(id)
For each subfield
Emit that field
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I think it should be emit object(pointer_from_objref(object)) not the map of
Essentially because you could have something like
If x===global
Inside the code and obviously they won't structurally be equivalent (tho perhaps it's best to have two modes one where it's a structural cache with the id as the map index and the second one where it is fine with globals)
9b5cd96
to
566e809
Compare
…ed values to their corresponding mlir type. These transformed values can be used as keys in a dict (stored in ScopedValue for ease). Cache hits are detected but the cache is not yet used because there is not yet a way to replace the mlir data recursively in a traced object.
Repurposes the path argument of `make_tracer` and builds a vector containing: * MLIR type for traced values * Julia type for objects * actual value for primitive types * `VisitedObject(id)` for objects that where already encountered ( == stored in `seen`).
I did a rebase.
To me this would imply |
no caching, multi-return, kwargs, ... yet.