Skip to content

Commit

Permalink
Most CA patches are upstreamed (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal authored Aug 7, 2022
1 parent 6041cdb commit e7e64c1
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 45 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lux"
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.4.14"
version = "0.4.15"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -26,7 +26,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Adapt = "3"
CUDA = "3"
ChainRulesCore = "1"
ComponentArrays = "0.11 - 0.12.5"
ComponentArrays = "0.13"
FillArrays = "0.13"
Functors = "0.2, 0.3"
NNlib = "0.8"
Expand Down
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ AbstractDifferentiation = "0.4"
ArgParse = "1"
Augmentor = "0.6"
CUDA = "3"
ComponentArrays = "0.11, 0.12"
ComponentArrays = "0.13"
DataLoaders = "0.1"
DiffEqSensitivity = "6"
Flux = "0.13"
Expand Down
12 changes: 3 additions & 9 deletions src/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@ struct LuxCUDAAdaptor <: LuxDeviceAdaptor end
adapt_storage(::LuxCUDAAdaptor, x) = CUDA.cu(x)
adapt_storage(::LuxCUDAAdaptor, x::FillArrays.AbstractFill) = CUDA.cu(collect(x))
adapt_storage(::LuxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x))
function adapt_storage(to::LuxCUDAAdaptor, x::ComponentArray)
return ComponentArray(adapt_storage(to, getdata(x)), getaxes(x))
end
adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng

function adapt_storage(::LuxCPUAdaptor,
Expand All @@ -17,9 +14,6 @@ function adapt_storage(::LuxCPUAdaptor,
return x
end
adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = adapt(Array, x)
function adapt_storage(to::LuxCPUAdaptor, x::ComponentArray)
return ComponentArray(adapt_storage(to, getdata(x)), getaxes(x))
end
adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng
# TODO(@avik-pal): SparseArrays
function adapt_storage(::LuxCPUAdaptor,
Expand Down Expand Up @@ -55,12 +49,12 @@ function check_use_cuda()
if use_cuda[] === nothing
use_cuda[] = CUDA.functional()
if use_cuda[] && !CUDA.has_cudnn()
@warn """CUDA.jl found cuda, but did not find libcudnn. Some functionality
@warn """CUDA.jl found cuda, but did not find libcudnn. Some functionality
will not be available."""
end
if !(use_cuda[])
@info """The GPU function is being called but the GPU is not accessible.
Defaulting back to the CPU. (No action is required if you want
@info """The GPU function is being called but the GPU is not accessible.
Defaulting back to the CPU. (No action is required if you want
to run on the CPU).""" maxlog=1
end
end
Expand Down
29 changes: 0 additions & 29 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,40 +122,11 @@ function calc_padding(lt, ::SamePad, k::NTuple{N, T}, dilation, stride) where {N
end

# Handling ComponentArrays
# NOTE(@avik-pal): We should probably upsteam some of these
function Base.zero(c::ComponentArray{T, N, <:CuArray{T}}) where {T, N}
return ComponentArray(zero(getdata(c)), getaxes(c))
end

Base.vec(c::ComponentArray) = getdata(c)

Base.:-(x::ComponentArray) = ComponentArray(-getdata(x), getaxes(x))

function Base.similar(c::ComponentArray, l::Vararg{Union{Integer, AbstractUnitRange}})
return similar(getdata(c), l)
end

function Functors.functor(::Type{<:ComponentArray}, c)
return NamedTuple{propertynames(c)}(getproperty.((c,), propertynames(c))),
ComponentArray
end

function ComponentArrays.make_carray_args(nt::NamedTuple)
data, ax = ComponentArrays.make_carray_args(Vector, nt)
data = length(data) == 0 ? Float32[] :
(length(data) == 1 ? [data[1]] : reduce(vcat, data))
return (data, ax)
end

## For being able to print empty ComponentArrays
function ComponentArrays.last_index(f::FlatAxis)
nt = ComponentArrays.indexmap(f)
length(nt) == 0 && return 0
return ComponentArrays.last_index(last(nt))
end

ComponentArrays.recursive_length(nt::NamedTuple{(), Tuple{}}) = 0

Optimisers.setup(opt, ps::ComponentArray) = Optimisers.setup(opt, getdata(ps))

function Optimisers.update(tree, ps::ComponentArray, gs::ComponentArray)
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CUDA = "3"
ComponentArrays = "0.11, 0.12"
ComponentArrays = "0.13"
FiniteDifferences = "0.12"
Functors = "0.2, 0.3"
JET = "0.4, 0.5, 0.6"
Expand Down
6 changes: 3 additions & 3 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ end
@test ps_c.weight == ps.weight
@test ps_c.bias == ps.bias

@test p_flat == vec(ps_c)
@test -p_flat == vec(-ps_c)
@test zero(p_flat) == vec(zero(ps_c))
@test p_flat == getdata(ps_c)
@test -p_flat == getdata(-ps_c)
@test zero(p_flat) == getdata(zero(ps_c))

@test_nowarn similar(ps_c, 10)
@test_nowarn similar(ps_c)
Expand Down

0 comments on commit e7e64c1

Please sign in to comment.