Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@trace function calls #366

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open

@trace function calls #366

wants to merge 23 commits into from

Conversation

jumerckx
Copy link
Collaborator

no caching, multi-return, kwargs, ... yet.

@Pangoraw Pangoraw linked an issue Dec 12, 2024 that may be closed by this pull request
src/ControlFlow.jl Show resolved Hide resolved
src/Compiler.jl Outdated
@@ -287,12 +289,16 @@ function compile_mlir(f, args; kwargs...)
end
end

const callcache = ScopedValue{Dict}()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a simple way to store the cache. Fancier solutions (i.e. integrating with absint?) are probably possible as well but not sure whether it's worth it to investigate as of now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ironically we should do something slightly different here, and we can do something like we do with mmodule (to get the current mlir module):

function mmodule(; throw_error::Core.Bool=true)

If we define the cache scope during the outermost compile / module generation then we're certain that it doesn't get messed up and also can cache calls across subfunctions/etc nicely [which this does too to be clear, but maybe not quite as clean?]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added here: 588cebb (#366)

@@ -3,6 +3,9 @@ module ReactantCore
using ExpressionExplorer: ExpressionExplorer
using MacroTools: MacroTools

using Base.ScopedValues
const enable_tracing = ScopedValue{Bool}(false)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we simply want to trace function calls everytime in make_mlir_fn. I added this scopedvalue that is set in the beginning so it's just a simple check.
If it's okay to use ScopedValues in ReactantCore, ScopedValues probably needs to be added to project.toml for backwards compatibility (?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jumerckx jumerckx marked this pull request as ready for review December 14, 2024 15:39
seen_cache,
args,
(),
CallCache;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can call this mode TracedToTypes or something (since really its converting Traced to a Type expression)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also isn't sufficient, since the cache key will only consider the outermost object as the key.

For example

mutable struct T
   x
end

X = T(1.0)
cache_key1 = make_key(X)
X.x = 2.0
cache_key2 = make_key(X)

Presumably the two keys will be the same since the hash key will be based on the object ID, not the potentially recursive structure.

I think we probably want to make a version of make_tracer that returns a tuple of all of the leaf objects, and we can obviously replace all tracedvalues/numbers with just the type.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just the leaves is not enough I think. For example, foo and bar have the same type and the same leaves. But should be cached separately.

struct Foo
    x
end

struct Bar
    x
end

foo = Foo(Foo(1))
bar = Foo(Bar(1))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe

If leaf: the object
Otherwise: typeof(object), objects(children…)…

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, that could work. I'll experiment a bit.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

68559c5 (#366) What do you think of this approach. I created a Cached type with custom equality and hashing to be used as dictionary keys.
The equality and hashing recursively visit all fields and check type and actual values.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying out extending make_tracer to generate a tuple or array containing types and leaf values.
How should we handle self-recursive objects, i.e. a contains b contains a.
The other trace modes return early when they encounter an object for the second time (by consulting seen).
But for generating a cache key, that behavior doesn't seem desired because we want the same cache key for:

a = [1, 2, 3]
b = [1, 2, 3]
# (a, a) should have same cache key as (a, b)

@wsmoses the example you gave against deepcopy was file handles I believe, does deepcopy simply not work for those or are their other issues going on there?

src/Tracing.jl Outdated
@@ -382,6 +383,12 @@ function make_tracer(
if mode == ConcreteToTraced
throw("Cannot trace existing trace type")
end
if mode == CallCache
if !haskey(seen, prev)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit, started writing this, but I realize the design has a different issue we need to solve first that point moot (https://github.com/EnzymeAD/Reactant.jl/pull/366/files#r1885172358)

So this is going to present an issue for some code.

For example

struct Wrap
   x::AbstractArray
end

We will try to set x to the MLIR.Type which is busted.

I think what we want to do here is make a fake abstractarray or even rarray type that solely contains the size and eltype.

@jumerckx jumerckx force-pushed the jm/funccall branch 2 times, most recently from e3a201f to ef02755 Compare December 16, 2024 08:54
@jumerckx jumerckx changed the title WIP @trace function calls @trace function calls Dec 16, 2024
struct Cached
obj
end
Base.:(==)(a::Cached, b::Cached) = recursive_equal(a.obj, b.obj)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit worried about this key in the context of mutation. For example

x = [1.0]
cache something with x

x[1] = 2.0
check cache (the key in the set is still based off of the x object, and now would hash to a different thing) 

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I hadn't accounted for that. So the problem is that the original key is corrupted when it is mutated. This can lead to too many functions being generated but shouldn't lead to incorrect code generation, I believe.

struct Foo
    x
end

a = rand(10)
b = deepcopy(a)
foo = Cached(Foo(a))

cache = Dict(Cached(Foo(a))=>"cache entry")
cache[Cached(Foo(a))] # "cache entry"
a[1] = 0.
cache[Cached(Foo(a))] # KeyError (as it should)
cache[Cached(Foo(b))] # KeyError!

I think it should suffice to add a deepcopy, i.e.:
Base.:(==)(a::Cached, b::Cached) = recursive_equal(deepcopy(a.obj), deepcopy(b.obj))?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deepcopies themselves are prone to failues (and generally are advised against in julialang/julia). Is the recursive tuple/array generation potentially still on the table?

Copy link
Collaborator Author

@jumerckx jumerckx Dec 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have started implementing this but am not sure how to handle recursive datastructures. I had written about this in a previous comment chain here: #366

Trying out extending make_tracer to generate a tuple or array containing types and leaf values.
How should we handle self-recursive objects, i.e. a contains b contains a.
The other trace modes return early when they encounter an object for the second time (by consulting seen).
But for generating a cache key, that behavior doesn't seem desired because we want the same cache key for:

a = [1, 2, 3]
b = [1, 2, 3]
# (a, a) should have same cache key as (a, b)

@wsmoses the example you gave against deepcopy was file handles I believe, does deepcopy simply not work for those or are their other issues going on there?

I found this RE: file handles: https://discourse.julialang.org/t/deepcopy-dangers-and-drawbacks-for-nested-data-structures/111998/9 but I'm not certain that this is a problem in our case as the stuff we'd be deepcopying will only be used for hashing and equality checks later on?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ooh okay, what if the iddict mapped the structure to the size of the dict at time of insertion. Therefore we can represent the fact that it had a pointer to the 2nd object, or what not

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand what you're proposing, could you clarify?

Thinking out loud with a more problematic example:

a.field = a

# vs.

a′.field = b
b.field = a′

Assuming a, a′, and b have the same type: a and a′ should point to the same cache entry while the structure in memory isn't identical.
One way I see to implement equality checks between two objects is to keep track of visited pairs:

(a, a′)
(a, b )
(a, a′) <-- already seen, done checking equality!

I'm still not completely convinced on the reservation against deepcopy, but maybe if I understand what you're proposing I'll see the light :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in this case I think it's fine if they cache to different things

So my idea is essentially as follows:

Create a map of object to ID (which is just the index of the time we inserted the object

For each object:
If traced value, emit that value
If a primitive type (int etc), emit that object
Otherwise we are an object
Check if we've emitted that object before, if so emit a new type Object(id) where id is the value in the map
Otherwise
Map[object] = map.size()
Emit object(id)
For each subfield
Emit that field

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think it should be emit object(pointer_from_objref(object)) not the map of

Essentially because you could have something like

If x===global

Inside the code and obviously they won't structurally be equivalent (tho perhaps it's best to have two modes one where it's a structural cache with the id as the map index and the second one where it is fine with globals)

Project.toml Outdated Show resolved Hide resolved
jumerckx and others added 6 commits January 2, 2025 14:56
Repurposes the path argument of `make_tracer` and builds a vector containing:
* MLIR type for traced values
* Julia type for objects
* actual value for primitive types
* `VisitedObject(id)` for objects that where already encountered ( == stored in `seen`).
@jumerckx
Copy link
Collaborator Author

jumerckx commented Jan 2, 2025

I did a rebase.
I also added the functionality to make_tracer, abusing the path argument to store the final vector of types/values-cache key.
I don't quite get the comment about

If x===global

To me this would imply x and global are structurally equivalent, no?
So maybe I've missed something there. Depending on whether seen passed to make_tracer is an IdDict or Dict, caching behavior changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

@trace function_call() to introduce function barrier
3 participants