-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalize broadcast #35
Conversation
Companion PR to EnzymeAD/Enzyme#1952 Untested, but should enable arbitrary broadcast support (if all scalar functions are overloaded) |
src/Reactant.jl
Outdated
@@ -429,7 +429,7 @@ function append_path(path, i) | |||
return (path..., i) | |||
end | |||
|
|||
@inline function make_tracer(seen::IdDict, prev::RT, path, mode, data) where {RT} | |||
@inline function make_tracer(seen::IdDict, prev::RT, path, mode, data; toscalar=false, tobatch=nothing) where {RT} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do these kwargs do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whether the tracer we create should automatically change all arrays to scalars, or scalars to arrays
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some fix requests and a couple of aesthetic suggestions.
Also, we should add a test for this.
src/overloads.jl
Outdated
@@ -457,6 +457,95 @@ for (jlop, hloop) in ( | |||
end | |||
end | |||
|
|||
|
|||
function elem_apply(f, args::VarArgs{Nargs}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need where
here:
function elem_apply(f, args::VarArgs{Nargs}) | |
function elem_apply(f, args::VarArgs{Nargs}) where {Nargs} |
src/utils.jl
Outdated
@@ -13,7 +13,7 @@ function apply(f, args...; kwargs...) | |||
return f(args...; kwargs...) | |||
end | |||
|
|||
function make_mlir_fn(mod, f, args, kwargs, name="main", concretein=true) | |||
function make_mlir_fn(mod, f, args, kwargs, name="main", concretein=true; toscalar=false) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we could turn name
and concretein
into kwargs?
function make_mlir_fn(mod, f, args, kwargs, name="main", concretein=true; toscalar=false) | |
function make_mlir_fn(mod, f, args, kwargs; name="main", concretein=true, toscalar=false) |
src/overloads.jl
Outdated
function elem_apply(f, args::VarArgs{Nargs}) | ||
primf = f.val | ||
|
||
mod = MLIR.IR.module() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mod = MLIR.IR.module() | |
mod = MLIR.IR.mmodule() |
Line 2 in 5134fe0
|
src/overloads.jl
Outdated
OutShape = nothing | ||
for (k, v) in seen_args | ||
invmap[v] = k | ||
OutShape size(k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OutShape size(k) | |
OutShape = size(k) |
No description provided.