Skip to content

Commit

Permalink
propertynames of CA from type
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 28, 2023
1 parent b3c4188 commit 6ca41f4
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.4.52"
version = "0.4.53"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -43,7 +43,7 @@ LuxZygoteExt = "Zygote"
Adapt = "3"
ChainRulesCore = "1"
ComponentArrays = "0.13"
FillArrays = "0.13"
FillArrays = "0.13, 1"
Flux = "0.13"
Functors = "0.2, 0.3, 0.4"
LuxCUDA = "0.1"
Expand Down
7 changes: 5 additions & 2 deletions ext/LuxComponentArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ using Functors, Lux, Optimisers
import TruncatedStacktraces: @truncate_stacktrace
import ChainRulesCore as CRC

@inline function Lux._getproperty(x::ComponentArray, ::Val{prop}) where {prop}
return prop in propertynames(x) ? getproperty(x, prop) : nothing
@generated function Lux._getproperty(x::ComponentArray{T, N, A, Tuple{Ax}},
::Val{v}) where {v, T, N, A,
Ax <: ComponentArrays.AbstractAxis}
names = propertynames(ComponentArrays.indexmap(Ax))
return v names ? :(x.$v) : :(nothing)
end

function Functors.functor(::Type{<:ComponentArray}, c)
Expand Down
11 changes: 6 additions & 5 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,12 @@ end

# If doesn't have a property, return nothing
@generated function _getproperty(x::NamedTuple{names}, ::Val{v}) where {names, v}
if v in names
return :(x.$v)
else
return :(nothing)
end
return v names ? :(x.$v) : :(nothing)
end

## Slow-fallback
@inline function _getproperty(x, ::Val{v}) where {v}
return v propertynames(x) ? getproperty(x, v) : nothing
end

@inline function _eachslice(x::AbstractArray, ::Val{dims}) where {dims}
Expand Down

0 comments on commit 6ca41f4

Please sign in to comment.