From 6ca41f406c188899fc5ff9348f13eb28e7575411 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 28 Apr 2023 11:28:14 -0400 Subject: [PATCH] propertynames of CA from type --- Project.toml | 4 ++-- ext/LuxComponentArraysExt.jl | 7 +++++-- src/utils.jl | 11 ++++++----- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 6ab4882bd4..b35d69bb61 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "0.4.52" +version = "0.4.53" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -43,7 +43,7 @@ LuxZygoteExt = "Zygote" Adapt = "3" ChainRulesCore = "1" ComponentArrays = "0.13" -FillArrays = "0.13" +FillArrays = "0.13, 1" Flux = "0.13" Functors = "0.2, 0.3, 0.4" LuxCUDA = "0.1" diff --git a/ext/LuxComponentArraysExt.jl b/ext/LuxComponentArraysExt.jl index bc5502f399..c98d26b852 100644 --- a/ext/LuxComponentArraysExt.jl +++ b/ext/LuxComponentArraysExt.jl @@ -6,8 +6,11 @@ using Functors, Lux, Optimisers import TruncatedStacktraces: @truncate_stacktrace import ChainRulesCore as CRC -@inline function Lux._getproperty(x::ComponentArray, ::Val{prop}) where {prop} - return prop in propertynames(x) ? getproperty(x, prop) : nothing +@generated function Lux._getproperty(x::ComponentArray{T, N, A, Tuple{Ax}}, + ::Val{v}) where {v, T, N, A, + Ax <: ComponentArrays.AbstractAxis} + names = propertynames(ComponentArrays.indexmap(Ax)) + return v ∈ names ? :(x.$v) : :(nothing) end function Functors.functor(::Type{<:ComponentArray}, c) diff --git a/src/utils.jl b/src/utils.jl index 7c78e4c735..222891032a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -213,11 +213,12 @@ end # If doesn't have a property, return nothing @generated function _getproperty(x::NamedTuple{names}, ::Val{v}) where {names, v} - if v in names - return :(x.$v) - else - return :(nothing) - end + return v ∈ names ? :(x.$v) : :(nothing) +end + +## Slow-fallback +@inline function _getproperty(x, ::Val{v}) where {v} + return v ∈ propertynames(x) ? getproperty(x, v) : nothing end @inline function _eachslice(x::AbstractArray, ::Val{dims}) where {dims}