-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add Ops.convolution
#428
base: main
Are you sure you want to change the base?
add Ops.convolution
#428
Conversation
featureGroupCount::Int, | ||
batchGroupCount::Int, | ||
) | ||
cp = ConvolutionParams( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable to me but yeah I am confused about the arg thing
Why do we need to add a function to the C-API? |
It's using the |
i know computing the size of the result of in the spec of
so, if you know all the input args and attributes of @wsmoses i guess we could open an issue in the StableHLO repo to ask for better docs on how to compute result shapes? it would be very useful for other ops too. at least some coherent naming would improve current situation.. |
Yes completely one can infer the result size from Julia, I started from Paul implementation in |
I could go either way on this. To be clear I’m fine with either approach. Advantage of having a Julia re implementation of the rules means it’ll be easier to debug shape mismatch issues on the Julia side. Con is it requires re implementation, and of course if things change that’s annoying. That said stablehlo shouldn’t really change much especially here. another thing you could do is you could export this function in the stablehlo c bindings upstream |
but this is the thing, you don't need to write the 34 constraints in Julia. actually, some of these constraints are not coded in the already available Ops as assertions but you already use them as ways to compute the input arguments/attributes. in the Ops i wrote myself, i didn't wrote all the constraints. what i did is write easy checks so that we can fast check in Julia any superficial problem (e.g. you're using different eltypes or shapes do not mismatch) and i left deep checks to XLA, who will raise an exception. now that we have propagation of XLA errors to Julia, we can also manage these XLA exceptions.
if there are changes, i'm confident it will be on these deep checks left to XLA with high probabily so we won't need to do modifications (probably). but not computing the result shape is gonna make debugging more difficult when it's sth we can do ourselves. having said this, i'm not against this PR and i think this is really needed, so thanks @glou-nes for pushing this forward. we can merge it and refactor it in the future if we see that it's a problem. @wsmoses one thing to keep in mind is that the XLA exceptions will be thrown on the verification step, not during MLIR emission, so unless you've used |
Thank @mofeing for the explanation! I just want to make a point : with this, one get the verification error, directly in MLIR emission, for instance: function bad_conv(x, ker)
pp = dimension_number(4, 3, 2, 3, 4)
Reactant.Ops.convolution(x, ker, pp; padding=[1 2; 3 4; 5 6]) #wrong padding here
end
x = Reactant.to_rarray(randn(Float64, 224, 224, 1, 1))
ker = Reactant.to_rarray(randn(Float64, 10, 10, 1, 1))
@code_hlo optimize=false bad_conv(x, ker) We have: loc("convolution"("/depot/dev/Reactant/src/Ops.jl":622:0)): error: expects padding-entries to have same dimension-size as size of window dimensions (2), but got: 3.
ERROR: AssertionError: cannot infer result type #assertion in Ops
[stacktrace] |
you're getting the location of the Ops.jl method, but not the location of the user code that calls for that, you need to use |
Sure adding stacktrace in every location make verifier hard to read for instance. Here we got the stacktrace in
|
Add convolution using StableHLO function to compute output tensor used by verification/interpreter. Add a structure because I find a limit with
ccall
:The second one causes segfault.
@wsmoses do you think this approach is viable?
Probably needs a new file when #421 land.