Skip to content

Commit

Permalink
Try #1506:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Nov 1, 2023
2 parents 28bbd23 + bf6c14a commit ed2a692
Show file tree
Hide file tree
Showing 11 changed files with 541 additions and 505 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8"

[compat]
Adapt = "3"
Expand All @@ -50,6 +51,7 @@ RootSolvers = "0.3, 0.4"
Static = "0.4, 0.5, 0.6, 0.7, 0.8"
StaticArrays = "1"
UnPack = "1"
Unrolled = "0.1"
julia = "1.8"

[extras]
Expand Down
8 changes: 7 additions & 1 deletion benchmarks/bickleyjet/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.5.5"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"]
path = "../.."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.10.55"
Expand Down Expand Up @@ -1238,6 +1238,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728"
version = "1.6.3"

[[deps.Unrolled]]
deps = ["MacroTools"]
git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b"
uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
version = "0.1.5"

[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
Expand Down
8 changes: 7 additions & 1 deletion docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.5.5"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"]
path = ".."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.10.55"
Expand Down Expand Up @@ -2492,6 +2492,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728"
version = "1.6.3"

[[deps.Unrolled]]
deps = ["MacroTools"]
git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b"
uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
version = "0.1.5"

[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
Expand Down
8 changes: 7 additions & 1 deletion examples/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.5.5"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"]
path = ".."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.10.55"
Expand Down Expand Up @@ -2031,6 +2031,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728"
version = "1.6.3"

[[deps.Unrolled]]
deps = ["MacroTools"]
git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b"
uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
version = "0.1.5"

[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
Expand Down
8 changes: 7 additions & 1 deletion perf/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"
version = "0.5.5"

[[deps.ClimaCore]]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"]
deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"]
path = ".."
uuid = "d414da3d-4745-48bb-8d80-42e94e092884"
version = "0.10.55"
Expand Down Expand Up @@ -2077,6 +2077,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd"
uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728"
version = "1.6.3"

[[deps.Unrolled]]
deps = ["MacroTools"]
git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b"
uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8"
version = "0.1.5"

[[deps.UnsafeAtomics]]
git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278"
uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f"
Expand Down
3 changes: 3 additions & 0 deletions src/MatrixFields/MatrixFields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ import StaticArrays: SMatrix, SVector
import BandedMatrices: BandedMatrix, band, _BandedMatrix
import ClimaComms

import Unrolled: unrolled_foreach, unrolled_map, unrolled_reduce
import Unrolled: unrolled_in, unrolled_any, unrolled_all, unrolled_filter

import ..Utilities: PlusHalf, half
import ..RecursiveApply:
rmap, rmaptype, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv
Expand Down
56 changes: 29 additions & 27 deletions src/MatrixFields/field_name.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,20 @@ is_child_name(
::FieldName{parent_name_chain},
) where {child_name_chain, parent_name_chain} =
length(child_name_chain) >= length(parent_name_chain) &&
child_name_chain[1:length(parent_name_chain)] == parent_name_chain
unrolled_take(child_name_chain, Val(length(parent_name_chain))) ==
parent_name_chain

names_are_overlapping(name1, name2) =
is_overlapping_name(name1, name2) =
is_child_name(name1, name2) || is_child_name(name2, name1)

extract_internal_name(
child_name::FieldName{child_name_chain},
parent_name::FieldName{parent_name_chain},
) where {child_name_chain, parent_name_chain} =
is_child_name(child_name, parent_name) ?
FieldName(child_name_chain[(length(parent_name_chain) + 1):end]...) :
error("$child_name is not a child name of $parent_name")
FieldName(
unrolled_drop(child_name_chain, Val(length(parent_name_chain)))...,
) : error("$child_name is not a child name of $parent_name")

append_internal_name(
::FieldName{name_chain},
Expand Down Expand Up @@ -118,41 +120,41 @@ struct FieldNameTreeNode{V <: FieldName, S <: NTuple{<:Any, FieldNameTree}} <:
subtrees::S
end

FieldNameTree(x) = make_subtree_at_name(x, @name())
function make_subtree_at_name(x, name)
FieldNameTree(x) = subtree_at_name(x, @name())
function subtree_at_name(x, name)
internal_names = top_level_names(get_field(x, name))
isempty(internal_names) && return FieldNameTreeLeaf(name)
subsubtrees = unrolled_map(internal_names) do internal_name
make_subtree_at_name(x, append_internal_name(name, internal_name))
return if isempty(internal_names)
FieldNameTreeLeaf(name)
else
subsubtrees_at_name =
recursively_unrolled_map(internal_names) do internal_name
subtree_at_name(x, append_internal_name(name, internal_name))
end
FieldNameTreeNode(name, subsubtrees_at_name)
end
return FieldNameTreeNode(name, subsubtrees)
end

is_valid_name(name, tree::FieldNameTreeLeaf) = name == tree.name
is_valid_name(name, tree::FieldNameTreeNode) =
is_valid_name(name, tree) =
name == tree.name ||
is_child_name(name, tree.name) &&
unrolled_any(subtree -> is_valid_name(name, subtree), tree.subtrees)
tree isa FieldNameTreeNode &&
recursively_unrolled_any(tree.subtrees) do subtree
is_valid_name(name, subtree)
end

function child_names(name, tree)
is_valid_name(name, tree) || error("$name is not a valid name")
subtree = get_subtree_at_name(name, tree)
subtree isa FieldNameTreeNode ||
error("FieldNameTree does not contain any child names for $name")
subtree isa FieldNameTreeNode || error("$name does not have child names")
return unrolled_map(subsubtree -> subsubtree.name, subtree.subtrees)
end
get_subtree_at_name(name, tree::FieldNameTreeLeaf) =
name == tree.name ? tree :
error("FieldNameTree does not contain the name $name")
get_subtree_at_name(name, tree::FieldNameTreeNode) =
get_subtree_at_name(name, tree) =
if name == tree.name
tree
elseif is_valid_name(name, tree)
subtree_that_contains_name = unrolled_findonly(tree.subtrees) do subtree
is_child_name(name, subtree.name)
end
get_subtree_at_name(name, subtree_that_contains_name)
else
error("FieldNameTree does not contain the name $name")
subtree = unrolled_findonly(tree.subtrees) do subtree
is_valid_name(name, subtree)
end
get_subtree_at_name(name, subtree)
end

################################################################################
Expand All @@ -175,7 +177,7 @@ if hasfield(Method, :recursion_relation)
for m in methods(wrapped_prop_names)
m.recursion_relation = dont_limit
end
for m in methods(make_subtree_at_name)
for m in methods(subtree_at_name)
m.recursion_relation = dont_limit
end
for m in methods(is_valid_name)
Expand Down
21 changes: 10 additions & 11 deletions src/MatrixFields/field_name_dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,18 +66,16 @@ const FieldMatrixBroadcasted = FieldNameDict{
dict_type(::FieldNameDict{T1, T2}) where {T1, T2} = FieldNameDict{T1, T2}

function Base.show(io::IO, dict::FieldNameDict)
strings = map((key, value) -> " $key => $value", pairs(dict))
print(io, "$(dict_type(dict))($(join(strings, ",\n")))")
strings = map(pair -> "\n$(pair[1]) => $(pair[2])", pairs(dict))
print(io, "$(dict_type(dict))($(join(strings, ","))\n)")
end

Base.keys(dict::FieldNameDict) = dict.keys

Base.values(dict::FieldNameDict) = dict.entries

Base.pairs(dict::FieldNameDict) =
unrolled_map(unrolled_zip(keys(dict).values, values(dict))) do key_entry_tup
key_entry_tup[1] => key_entry_tup[2]
end
unrolled_map((key, value) -> key => value, keys(dict).values, values(dict))

Base.length(dict::FieldNameDict) = length(keys(dict))

Expand Down Expand Up @@ -118,18 +116,19 @@ function get_internal_entry(
# See note above matrix_product_keys in field_name_set.jl for more details.
T = eltype(eltype(entry))
if name_pair == (@name(), @name())
# multiplication case 1, either argument
entry
elseif broadcasted_has_field(T, name_pair[1]) && name_pair[2] == @name()
elseif name_pair[1] == name_pair[2]
# multiplication case 3 or 4, first argument
@assert T <: SingleValue && !broadcasted_has_field(T, name_pair[1])
entry
elseif name_pair[2] == @name()
# multiplication case 2 or 4, second argument
@assert broadcasted_has_field(T, name_pair[1])
Base.broadcasted(entry) do matrix_row
map(matrix_row) do matrix_row_entry
broadcasted_get_field(matrix_row_entry, name_pair[1])
end
end # Note: This assumes that the entry is in a FieldMatrixBroadcasted.
elseif T <: SingleValue && name_pair[1] == name_pair[2]
# multiplication case 3 or 4, first argument
entry
else
unsupported_internal_entry_error(entry, name_pair)
end
Expand Down Expand Up @@ -249,7 +248,7 @@ Base.Broadcast.broadcasted(
arg3,
args...,
) =
unrolled_foldl((arg1, arg2, arg3, args...)) do arg1′, arg2′
foldl((arg1, arg2, arg3, args...)) do arg1′, arg2′
Base.Broadcast.broadcasted(f, arg1′, arg2′)
end

Expand Down
Loading

0 comments on commit ed2a692

Please sign in to comment.