Skip to content

Commit

Permalink
Switch to using Unrolled.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Oct 20, 2023
1 parent 9c22047 commit ae46d25
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 209 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8"

[compat]
Adapt = "3"
Expand All @@ -50,6 +51,7 @@ RootSolvers = "0.3, 0.4"
Static = "0.4, 0.5, 0.6, 0.7, 0.8"
StaticArrays = "1"
UnPack = "1"
Unrolled = "0.1"
julia = "1.8"

[extras]
Expand Down
10 changes: 8 additions & 2 deletions benchmarks/bickleyjet/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,10 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.5.5"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"]
path = "../.."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.10.53"
version = "0.10.54"

[[deps.CodecZlib]]
deps = ["TranscodingStreams", "Zlib_jll"]
Expand Down Expand Up @@ -1235,6 +1235,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728"
version = "1.6.3"

[[deps.Unrolled]]
deps = ["MacroTools"]
git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b"
uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
version = "0.1.5"

[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
Expand Down
10 changes: 8 additions & 2 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,10 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.5.5"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"]
path = ".."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.10.53"
version = "0.10.54"

[[deps.ClimaCoreMakie]]
deps = ["ClimaCore", "Makie"]
Expand Down Expand Up @@ -2480,6 +2480,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728"
version = "1.6.3"

[[deps.Unrolled]]
deps = ["MacroTools"]
git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b"
uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
version = "0.1.5"

[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
Expand Down
10 changes: 8 additions & 2 deletions examples/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,10 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.5.5"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"]
path = ".."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.10.53"
version = "0.10.54"

[[deps.ClimaCorePlots]]
deps = ["ClimaCore", "RecipesBase", "StaticArrays", "TriplotBase"]
Expand Down Expand Up @@ -2019,6 +2019,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728"
version = "1.6.3"

[[deps.Unrolled]]
deps = ["MacroTools"]
git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b"
uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
version = "0.1.5"

[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
Expand Down
10 changes: 8 additions & 2 deletions perf/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,10 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.5.5"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"]
path = ".."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.10.53"
version = "0.10.54"

[[deps.ClimaCorePlots]]
deps = ["ClimaCore", "RecipesBase", "StaticArrays", "TriplotBase"]
Expand Down Expand Up @@ -2077,6 +2077,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728"
version = "1.6.3"

[[deps.Unrolled]]
deps = ["MacroTools"]
git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b"
uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
version = "0.1.5"

[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
Expand Down
3 changes: 3 additions & 0 deletions src/MatrixFields/MatrixFields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ import StaticArrays: SMatrix, SVector
import BandedMatrices: BandedMatrix, band, _BandedMatrix
import ClimaComms

import Unrolled: unrolled_foreach, unrolled_map, unrolled_reduce
import Unrolled: unrolled_in, unrolled_any, unrolled_all, unrolled_filter

import ..Utilities: PlusHalf, half
import ..RecursiveApply:
rmap, rmaptype, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv
Expand Down
52 changes: 28 additions & 24 deletions src/MatrixFields/field_name.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,41 +118,42 @@ struct FieldNameTreeNode{V <: FieldName, S <: NTuple{<:Any, FieldNameTree}} <:
subtrees::S
end

FieldNameTree(x) = make_subtree_at_name(x, @name())
function make_subtree_at_name(x, name)
FieldNameTree(x) = subtree_at_name(x, @name())
function subtree_at_name(x, name)
internal_names = top_level_names(get_field(x, name))
isempty(internal_names) && return FieldNameTreeLeaf(name)
subsubtrees = unrolled_map(internal_names) do internal_name
make_subtree_at_name(x, append_internal_name(name, internal_name))
return if isempty(internal_names)
FieldNameTreeLeaf(name)
else
FieldNameTreeNode(name, subtrees_at_names(x, name, internal_names))
end
return FieldNameTreeNode(name, subsubtrees)
end

is_valid_name(name, tree::FieldNameTreeLeaf) = name == tree.name
is_valid_name(name, tree::FieldNameTreeNode) =
subtrees_at_names(x, name, internal_names) =
isempty(internal_names) ? () :
(
subtree_at_name(x, append_internal_name(name, internal_names[1])),
subtrees_at_names(x, name, internal_names[2:end])...,
)

is_valid_name(name, tree) =
name == tree.name ||
is_child_name(name, tree.name) &&
unrolled_any(subtree -> is_valid_name(name, subtree), tree.subtrees)
tree isa FieldNameTreeNode && is_valid_name(name, tree.subtrees...)
is_valid_name(name, subtree, subtrees...) =
is_valid_name(name, subtree) || is_valid_name(name, subtrees...)

function child_names(name, tree)
is_valid_name(name, tree) || error("$name is not a valid name")
subtree = get_subtree_at_name(name, tree)
subtree isa FieldNameTreeNode ||
error("FieldNameTree does not contain any child names for $name")
subtree isa FieldNameTreeNode || error("$name does not have child names")
return unrolled_map(subsubtree -> subsubtree.name, subtree.subtrees)
end
get_subtree_at_name(name, tree::FieldNameTreeLeaf) =
name == tree.name ? tree :
error("FieldNameTree does not contain the name $name")
get_subtree_at_name(name, tree::FieldNameTreeNode) =
get_subtree_at_name(name, tree) =
if name == tree.name
tree
elseif is_valid_name(name, tree)
subtree_that_contains_name = unrolled_findonly(tree.subtrees) do subtree
is_child_name(name, subtree.name)
end
get_subtree_at_name(name, subtree_that_contains_name)
else
error("FieldNameTree does not contain the name $name")
subtree = unrolled_findonly(tree.subtrees) do subtree
is_valid_name(name, subtree)
end
get_subtree_at_name(name, subtree)
end

################################################################################
Expand All @@ -175,7 +176,10 @@ if hasfield(Method, :recursion_relation)
for m in methods(wrapped_prop_names)
m.recursion_relation = dont_limit
end
for m in methods(make_subtree_at_name)
for m in methods(subtree_at_name)
m.recursion_relation = dont_limit
end
for m in methods(subtrees_at_names)
m.recursion_relation = dont_limit
end
for m in methods(is_valid_name)
Expand Down
19 changes: 9 additions & 10 deletions src/MatrixFields/field_name_dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,16 @@ const FieldMatrixBroadcasted = FieldNameDict{
dict_type(::FieldNameDict{T1, T2}) where {T1, T2} = FieldNameDict{T1, T2}

function Base.show(io::IO, dict::FieldNameDict)
strings = map((key, value) -> " $key => $value", pairs(dict))
print(io, "$(dict_type(dict))($(join(strings, ",\n")))")
strings = map(pair -> "\n$(pair[1]) => $(pair[2])", pairs(dict))
print(io, "$(dict_type(dict))($(join(strings, ","))\n)")
end

Base.keys(dict::FieldNameDict) = dict.keys

Base.values(dict::FieldNameDict) = dict.entries

Base.pairs(dict::FieldNameDict) =
unrolled_map(unrolled_zip(keys(dict).values, values(dict))) do key_entry_tup
key_entry_tup[1] => key_entry_tup[2]
end
unrolled_map((key, value) -> key => value, keys(dict).values, values(dict))

Base.length(dict::FieldNameDict) = length(keys(dict))

Expand Down Expand Up @@ -118,18 +116,19 @@ function get_internal_entry(
# See note above matrix_product_keys in field_name_set.jl for more details.
T = eltype(eltype(entry))
if name_pair == (@name(), @name())
# multiplication case 1, either argument
entry
elseif broadcasted_has_field(T, name_pair[1]) && name_pair[2] == @name()
elseif name_pair[1] == name_pair[2]
# multiplication case 3 or 4, first argument
@assert T <: SingleValue && !broadcasted_has_field(T, name_pair[1])
entry
elseif name_pair[2] == @name()
# multiplication case 2 or 4, second argument
@assert broadcasted_has_field(T, name_pair[1])
Base.broadcasted(entry) do matrix_row
map(matrix_row) do matrix_row_entry
broadcasted_get_field(matrix_row_entry, name_pair[1])
end
end # Note: This assumes that the entry is in a FieldMatrixBroadcasted.
elseif T <: SingleValue && name_pair[1] == name_pair[2]
# multiplication case 3 or 4, first argument
entry
else
unsupported_internal_entry_error(entry, name_pair)
end
Expand Down
Loading

0 comments on commit ae46d25

Please sign in to comment.