Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

get_device/get_device_type doesn't support functions/closures #82

Closed
CarloLucibello opened this issue Oct 13, 2024 · 2 comments · Fixed by #87
Closed

get_device/get_device_type doesn't support functions/closures #82

CarloLucibello opened this issue Oct 13, 2024 · 2 comments · Fixed by #87

Comments

@CarloLucibello
Copy link
Contributor

CarloLucibello commented Oct 13, 2024

Ideally, the following example should return nothing instead of erroring

julia> using MLDataDevices

julia> get_device(sum)
ERROR: MethodError: no method matching get_device(::typeof(sum))
The function `get_device` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  get_device(::Symbol)
   @ MLDataDevices ~/.julia/packages/MLDataDevices/pD2je/src/internal.jl:153
  get_device(::String)
   @ MLDataDevices ~/.julia/packages/MLDataDevices/pD2je/src/internal.jl:153
  get_device(::Nothing)
   @ MLDataDevices ~/.julia/packages/MLDataDevices/pD2je/src/internal.jl:153
  ...

Stacktrace:
 [1] mapreduce_first(f::typeof(MLDataDevices.Internal.get_device), op::Function, x::Function)
   @ Base ./reduce.jl:421
 [2] _mapreduce(f::typeof(MLDataDevices.Internal.get_device), op::typeof(MLDataDevices.Internal.combine_devices), ::IndexLinear, A::Vector{…})
   @ Base ./reduce.jl:432
 [3] _mapreduce_dim(f::Function, op::Function, ::Base._InitialValue, A::Vector{typeof(sum)}, ::Colon)
   @ Base ./reducedim.jl:337
 [4] mapreduce(f::Function, op::Function, A::Vector{typeof(sum)})
   @ Base ./reducedim.jl:329
 [5] get_device(x::Function)
   @ MLDataDevices ~/.julia/packages/MLDataDevices/pD2je/src/public.jl:349
 [6] get_device(x::Function)
   @ Flux ~/juliadev/Flux/src/devices.jl:1
 [7] top-level scope
   @ REPL[12]:1
Some type information was truncated. Use `show(err)` to see complete types.

In this case instead, a CPUDevice should be returned

julia> using Flux

julia> get_device(Dense(3 => 3, tanh))
ERROR: MethodError: no method matching get_device(::typeof(tanh))
The function `get_device` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  get_device(::Symbol)
   @ MLDataDevices ~/.julia/packages/MLDataDevices/pD2je/src/internal.jl:153
  get_device(::String)
   @ MLDataDevices ~/.julia/packages/MLDataDevices/pD2je/src/internal.jl:153
  get_device(::Nothing)
   @ MLDataDevices ~/.julia/packages/MLDataDevices/pD2je/src/internal.jl:153
  ...

Stacktrace:
 [1] _mapreduce(f::typeof(MLDataDevices.Internal.get_device), op::typeof(MLDataDevices.Internal.combine_devices), ::IndexLinear, A::Vector{…})
   @ Base ./reduce.jl:440
 [2] _mapreduce_dim
   @ ./reducedim.jl:337 [inlined]
 [3] mapreduce
   @ ./reducedim.jl:329 [inlined]
 [4] get_device
   @ ~/.julia/packages/MLDataDevices/pD2je/src/public.jl:349 [inlined]
 [5] get_device(x::Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}})
   @ Flux ~/juliadev/Flux/src/devices.jl:1
 [6] top-level scope
   @ REPL[15]:1
@avik-pal
Copy link
Member

Right we need to make sure we support a closure correctly here. For example:

x = cu(rand(2))
myfn(y) = x .+ y

get_device(myfn) should return get_device(x).

@avik-pal avik-pal changed the title get_device errors on functions get_device/get_device_type doesn't support functions/closures Oct 13, 2024
@CarloLucibello
Copy link
Contributor Author

Should

get_device(x) = nothing

be defined as a generic fallback?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants