From 5844bf4fdf32c069493a7c60864f42bb37b58d51 Mon Sep 17 00:00:00 2001 From: Dennis Yatunin Date: Thu, 19 Oct 2023 16:01:28 -0700 Subject: [PATCH 1/3] Switch to using Unrolled.jl --- Project.toml | 2 + benchmarks/bickleyjet/Manifest.toml | 8 ++- docs/Manifest.toml | 8 ++- examples/Manifest.toml | 8 ++- perf/Manifest.toml | 8 ++- src/MatrixFields/MatrixFields.jl | 3 + src/MatrixFields/field_name.jl | 52 +++++++------- src/MatrixFields/field_name_dict.jl | 19 +++-- src/MatrixFields/field_name_set.jl | 98 +++++++------------------- src/MatrixFields/unrolled_functions.jl | 94 +++++------------------- test/MatrixFields/field_names.jl | 20 ++---- 11 files changed, 115 insertions(+), 205 deletions(-) diff --git a/Project.toml b/Project.toml index 18acf21e44..413f9cb8f4 100644 --- a/Project.toml +++ b/Project.toml @@ -28,6 +28,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" +Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8" [compat] Adapt = "3" @@ -50,6 +51,7 @@ RootSolvers = "0.3, 0.4" Static = "0.4, 0.5, 0.6, 0.7, 0.8" StaticArrays = "1" UnPack = "1" +Unrolled = "0.1" julia = "1.8" [extras] diff --git a/benchmarks/bickleyjet/Manifest.toml b/benchmarks/bickleyjet/Manifest.toml index fc0a8268ad..48d90b1cd0 100644 --- a/benchmarks/bickleyjet/Manifest.toml +++ b/benchmarks/bickleyjet/Manifest.toml @@ -157,7 +157,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.5.5" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"] path = "../.." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" version = "0.10.55" @@ -1238,6 +1238,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd" uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" version = "1.6.3" +[[deps.Unrolled]] +deps = ["MacroTools"] +git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" +uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" +version = "0.1.5" + [[deps.UnsafeAtomics]] git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 622c3d83e4..cf667d45da 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -274,7 +274,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.5.5" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"] path = ".." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" version = "0.10.55" @@ -2492,6 +2492,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd" uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" version = "1.6.3" +[[deps.Unrolled]] +deps = ["MacroTools"] +git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" +uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" +version = "0.1.5" + [[deps.UnsafeAtomics]] git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" diff --git a/examples/Manifest.toml b/examples/Manifest.toml index 56b04e6b5d..b809ee2c27 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -238,7 +238,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.5.5" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"] path = ".." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" version = "0.10.55" @@ -2031,6 +2031,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd" uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" version = "1.6.3" +[[deps.Unrolled]] +deps = ["MacroTools"] +git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" +uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" +version = "0.1.5" + [[deps.UnsafeAtomics]] git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" diff --git a/perf/Manifest.toml b/perf/Manifest.toml index 89288bb4d7..3aea42a040 100644 --- a/perf/Manifest.toml +++ b/perf/Manifest.toml @@ -217,7 +217,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.5.5" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"] path = ".." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" version = "0.10.55" @@ -2077,6 +2077,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd" uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" version = "1.6.3" +[[deps.Unrolled]] +deps = ["MacroTools"] +git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" +uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" +version = "0.1.5" + [[deps.UnsafeAtomics]] git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" diff --git a/src/MatrixFields/MatrixFields.jl b/src/MatrixFields/MatrixFields.jl index 957c75de91..0f59f8d9cd 100644 --- a/src/MatrixFields/MatrixFields.jl +++ b/src/MatrixFields/MatrixFields.jl @@ -49,6 +49,9 @@ import StaticArrays: SMatrix, SVector import BandedMatrices: BandedMatrix, band, _BandedMatrix import ClimaComms +import Unrolled: unrolled_foreach, unrolled_map, unrolled_reduce +import Unrolled: unrolled_in, unrolled_any, unrolled_all, unrolled_filter + import ..Utilities: PlusHalf, half import ..RecursiveApply: rmap, rmaptype, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv diff --git a/src/MatrixFields/field_name.jl b/src/MatrixFields/field_name.jl index f027b1c1e1..a9bfbee87f 100644 --- a/src/MatrixFields/field_name.jl +++ b/src/MatrixFields/field_name.jl @@ -118,41 +118,42 @@ struct FieldNameTreeNode{V <: FieldName, S <: NTuple{<:Any, FieldNameTree}} <: subtrees::S end -FieldNameTree(x) = make_subtree_at_name(x, @name()) -function make_subtree_at_name(x, name) +FieldNameTree(x) = subtree_at_name(x, @name()) +function subtree_at_name(x, name) internal_names = top_level_names(get_field(x, name)) - isempty(internal_names) && return FieldNameTreeLeaf(name) - subsubtrees = unrolled_map(internal_names) do internal_name - make_subtree_at_name(x, append_internal_name(name, internal_name)) + return if isempty(internal_names) + FieldNameTreeLeaf(name) + else + FieldNameTreeNode(name, subtrees_at_names(x, name, internal_names)) end - return FieldNameTreeNode(name, subsubtrees) end - -is_valid_name(name, tree::FieldNameTreeLeaf) = name == tree.name -is_valid_name(name, tree::FieldNameTreeNode) = +subtrees_at_names(x, name, internal_names) = + isempty(internal_names) ? () : + ( + subtree_at_name(x, append_internal_name(name, internal_names[1])), + subtrees_at_names(x, name, internal_names[2:end])..., + ) + +is_valid_name(name, tree) = name == tree.name || - is_child_name(name, tree.name) && - unrolled_any(subtree -> is_valid_name(name, subtree), tree.subtrees) + tree isa FieldNameTreeNode && is_valid_name(name, tree.subtrees...) +is_valid_name(name, subtree, subtrees...) = + is_valid_name(name, subtree) || is_valid_name(name, subtrees...) function child_names(name, tree) + is_valid_name(name, tree) || error("$name is not a valid name") subtree = get_subtree_at_name(name, tree) - subtree isa FieldNameTreeNode || - error("FieldNameTree does not contain any child names for $name") + subtree isa FieldNameTreeNode || error("$name does not have child names") return unrolled_map(subsubtree -> subsubtree.name, subtree.subtrees) end -get_subtree_at_name(name, tree::FieldNameTreeLeaf) = - name == tree.name ? tree : - error("FieldNameTree does not contain the name $name") -get_subtree_at_name(name, tree::FieldNameTreeNode) = +get_subtree_at_name(name, tree) = if name == tree.name tree - elseif is_valid_name(name, tree) - subtree_that_contains_name = unrolled_findonly(tree.subtrees) do subtree - is_child_name(name, subtree.name) - end - get_subtree_at_name(name, subtree_that_contains_name) else - error("FieldNameTree does not contain the name $name") + subtree = unrolled_findonly(tree.subtrees) do subtree + is_valid_name(name, subtree) + end + get_subtree_at_name(name, subtree) end ################################################################################ @@ -175,7 +176,10 @@ if hasfield(Method, :recursion_relation) for m in methods(wrapped_prop_names) m.recursion_relation = dont_limit end - for m in methods(make_subtree_at_name) + for m in methods(subtree_at_name) + m.recursion_relation = dont_limit + end + for m in methods(subtrees_at_names) m.recursion_relation = dont_limit end for m in methods(is_valid_name) diff --git a/src/MatrixFields/field_name_dict.jl b/src/MatrixFields/field_name_dict.jl index ed26b40822..4a31cc56ac 100644 --- a/src/MatrixFields/field_name_dict.jl +++ b/src/MatrixFields/field_name_dict.jl @@ -66,8 +66,8 @@ const FieldMatrixBroadcasted = FieldNameDict{ dict_type(::FieldNameDict{T1, T2}) where {T1, T2} = FieldNameDict{T1, T2} function Base.show(io::IO, dict::FieldNameDict) - strings = map((key, value) -> " $key => $value", pairs(dict)) - print(io, "$(dict_type(dict))($(join(strings, ",\n")))") + strings = map(pair -> "\n$(pair[1]) => $(pair[2])", pairs(dict)) + print(io, "$(dict_type(dict))($(join(strings, ","))\n)") end Base.keys(dict::FieldNameDict) = dict.keys @@ -75,9 +75,7 @@ Base.keys(dict::FieldNameDict) = dict.keys Base.values(dict::FieldNameDict) = dict.entries Base.pairs(dict::FieldNameDict) = - unrolled_map(unrolled_zip(keys(dict).values, values(dict))) do key_entry_tup - key_entry_tup[1] => key_entry_tup[2] - end + unrolled_map((key, value) -> key => value, keys(dict).values, values(dict)) Base.length(dict::FieldNameDict) = length(keys(dict)) @@ -118,18 +116,19 @@ function get_internal_entry( # See note above matrix_product_keys in field_name_set.jl for more details. T = eltype(eltype(entry)) if name_pair == (@name(), @name()) - # multiplication case 1, either argument entry - elseif broadcasted_has_field(T, name_pair[1]) && name_pair[2] == @name() + elseif name_pair[1] == name_pair[2] + # multiplication case 3 or 4, first argument + @assert T <: SingleValue && !broadcasted_has_field(T, name_pair[1]) + entry + elseif name_pair[2] == @name() # multiplication case 2 or 4, second argument + @assert broadcasted_has_field(T, name_pair[1]) Base.broadcasted(entry) do matrix_row map(matrix_row) do matrix_row_entry broadcasted_get_field(matrix_row_entry, name_pair[1]) end end # Note: This assumes that the entry is in a FieldMatrixBroadcasted. - elseif T <: SingleValue && name_pair[1] == name_pair[2] - # multiplication case 3 or 4, first argument - entry else unsupported_internal_entry_error(entry, name_pair) end diff --git a/src/MatrixFields/field_name_set.jl b/src/MatrixFields/field_name_set.jl index 1f28960bff..a7743d79b4 100644 --- a/src/MatrixFields/field_name_set.jl +++ b/src/MatrixFields/field_name_set.jl @@ -79,46 +79,36 @@ end Base.:(==)(set1::FieldNameSet, set2::FieldNameSet) = issubset(set1, set2) && issubset(set2, set1) -function Base.intersect(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} +function Base.union(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} name_tree = combine_name_trees(set1.name_tree, set2.name_tree) values1′, values2′ = set1.values, set2.values values1, values2 = non_overlapping_values(values1′, values2′, name_tree) - result_values = unrolled_filter(values2) do value - unrolled_any(isequal(value), values1) - end - return FieldNameSet{T}(result_values, name_tree) + return FieldNameSet{T}(unrolled_union(values1, values2), name_tree) end -function Base.union(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} +function Base.intersect(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} name_tree = combine_name_trees(set1.name_tree, set2.name_tree) values1′, values2′ = set1.values, set2.values values1, values2 = non_overlapping_values(values1′, values2′, name_tree) - values2_minus_values1 = unrolled_filter(values2) do value - !unrolled_any(isequal(value), values1) - end - result_values = (values1..., values2_minus_values1...) - return FieldNameSet{T}(result_values, name_tree) + return FieldNameSet{T}(unrolled_intersect(values1, values2), name_tree) end function Base.setdiff(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - set2_complement_values = set_complement_values(T, set2.values, name_tree) - set2_complement = FieldNameSet{T}(set2_complement_values, name_tree) - return intersect(set1, set2_complement) + values1′, values2′ = set1.values, set2.values + values1, values2 = non_overlapping_values(values1′, values2′, name_tree) + return FieldNameSet{T}(unrolled_setdiff(values1, values2), name_tree) end set_string(set) = length(set) == 2 ? join(set.values, " and ") : join(set.values, ", ", ", and ") +set_complement(set) = setdiff(universal_set(eltype(set)), set) + is_subset_that_covers_set(set1, set2) = issubset(set1, set2) && isempty(setdiff(set2, set1)) -function set_complement(set::FieldNameSet{T}) where {T} - result_values = set_complement_values(T, set.values, set.name_tree) - return FieldNameSet{T}(result_values, set.name_tree) -end - function corresponding_matrix_keys(set::FieldVectorKeys) result_values = unrolled_map(name -> (name, name), set.values) return FieldMatrixKeys(result_values, set.name_tree) @@ -150,9 +140,13 @@ function matrix_diagonal_keys(set::FieldMatrixKeys) names_are_overlapping(name_pair[1], name_pair[2]) end result_values = unrolled_map(result_values′) do name_pair - name_pair[1] == name_pair[2] ? name_pair : - is_child_value(name_pair[1], name_pair[2]) ? - (name_pair[1], name_pair[1]) : (name_pair[2], name_pair[2]) + if name_pair[1] == name_pair[2] + name_pair + elseif is_child_value(name_pair[1], name_pair[2]) + (name_pair[1], name_pair[1]) + else + (name_pair[2], name_pair[2]) + end end return FieldMatrixKeys(result_values, set.name_tree) end @@ -302,7 +296,7 @@ is_child_value(name_pair1::FieldNamePair, name_pair2::FieldNamePair) = is_child_name(name_pair1[2], name_pair2[2]) is_value_in_set(value, values, name_tree) = - if unrolled_any(isequal(value), values) + if unrolled_in(value, values) true elseif unrolled_any(value′ -> is_child_value(value, value′), values) isnothing(name_tree) && error( @@ -313,6 +307,11 @@ is_value_in_set(value, values, name_tree) = false end +universal_set(::Type{FieldName}) = FieldVectorKeys((@name(),)) +universal_set(::Type{FieldNamePair}) = FieldMatrixKeys(((@name(), @name()),)) + +# TODO: Simplify the following code. + function non_overlapping_values(values1, values2, name_tree) new_values1 = unrolled_mapflatten(values1) do value value_or_non_overlapping_children(value, values2, name_tree) @@ -381,62 +380,13 @@ function value_or_non_overlapping_children( end end -set_complement_values(_, _, ::Nothing) = - error("Cannot compute complement of a FieldNameSet without a FieldNameTree") -set_complement_values(::Type{<:FieldName}, names, name_tree::FieldNameTree) = - complement_values_in_subtree(names, name_tree) -set_complement_values( - ::Type{<:FieldNamePair}, - name_pairs, - name_tree::FieldNameTree, -) = complement_values_in_subtree_pair(name_pairs, (name_tree, name_tree)) - -function complement_values_in_subtree(names, subtree) - name = subtree.name - unrolled_all(name′ -> !is_child_value(name, name′), names) || return () - unrolled_any(name′ -> is_child_value(name′, name), names) || return (name,) - return unrolled_mapflatten(subtree.subtrees) do subsubtree - complement_values_in_subtree(names, subsubtree) - end -end - -function complement_values_in_subtree_pair(name_pairs, subtree_pair) - name_pair = (subtree_pair[1].name, subtree_pair[2].name) - is_name_pair_in_complement = unrolled_all(name_pairs) do name_pair′ - !is_child_value(name_pair, name_pair′) - end - is_name_pair_in_complement || return () - need_row_subsubtrees = unrolled_any(name_pairs) do name_pair′ - is_child_value(name_pair′, name_pair) && name_pair′[1] != name_pair[1] - end - need_col_subsubtrees = unrolled_any(name_pairs) do name_pair′ - is_child_value(name_pair′, name_pair) && name_pair′[2] != name_pair[2] - end - need_row_subsubtrees || need_col_subsubtrees || return (name_pair,) - row_subsubtrees = - need_row_subsubtrees ? subtree_pair[1].subtrees : (subtree_pair[1],) - col_subsubtrees = - need_col_subsubtrees ? subtree_pair[2].subtrees : (subtree_pair[2],) - return unrolled_mapflatten(row_subsubtrees) do row_subsubtree - unrolled_mapflatten(col_subsubtrees) do col_subsubtree - subsubtree_pair = (row_subsubtree, col_subsubtree) - complement_values_in_subtree_pair(name_pairs, subsubtree_pair) - end - end -end - -################################################################################ - # This is required for type-stability as of Julia 1.9. if hasfield(Method, :recursion_relation) dont_limit = (args...) -> true - for m in methods(value_or_non_overlapping_children) - m.recursion_relation = dont_limit - end - for m in methods(complement_values_in_subtree) + for m in methods(unrolled_mapflatten) m.recursion_relation = dont_limit end - for m in methods(complement_values_in_subtree_pair) + for m in methods(value_or_non_overlapping_children) m.recursion_relation = dont_limit end end diff --git a/src/MatrixFields/unrolled_functions.jl b/src/MatrixFields/unrolled_functions.jl index 947a45084d..533aaf7785 100644 --- a/src/MatrixFields/unrolled_functions.jl +++ b/src/MatrixFields/unrolled_functions.jl @@ -1,70 +1,17 @@ -@inline unrolled_zip(values1, values2) = - isempty(values1) || isempty(values2) ? () : - ( - (first(values1), first(values2)), - unrolled_zip(Base.tail(values1), Base.tail(values2))..., - ) - -@inline unrolled_map(f::F, values) where {F} = - isempty(values) ? () : - (f(first(values)), unrolled_map(f, Base.tail(values))...) - -unrolled_foldl(f::F, values) where {F} = - isempty(values) ? - error("unrolled_foldl requires init for an empty collection of values") : - _unrolled_foldl(f, first(values), Base.tail(values)) -unrolled_foldl(f::F, values, init) where {F} = _unrolled_foldl(f, init, values) -@inline _unrolled_foldl(f::F, result, values) where {F} = - isempty(values) ? result : - _unrolled_foldl(f, f(result, first(values)), Base.tail(values)) - -# The @inline annotations are needed to avoid allocations when there are a lot -# of values. - -# Using first and tail instead of [1] and [2:end] restricts us to Tuples, but it -# also results in less compilation time. - -# This is required to make the unrolled functions type-stable, as of Julia 1.9. -if hasfield(Method, :recursion_relation) - dont_limit = (args...) -> true - for m in methods(unrolled_zip) - m.recursion_relation = dont_limit - end - for m in methods(unrolled_map) - m.recursion_relation = dont_limit - end - for m in methods(_unrolled_foldl) - m.recursion_relation = dont_limit - end -end - -################################################################################ - -unrolled_foreach(f::F, values) where {F} = (unrolled_map(f, values); nothing) - -unrolled_any(f::F, values) where {F} = - unrolled_foldl(|, unrolled_map(f, values), false) - -unrolled_all(f::F, values) where {F} = - unrolled_foldl(&, unrolled_map(f, values), true) - -unrolled_filter(f::F, values) where {F} = - unrolled_foldl(values, ()) do filtered_values, value - f(value) ? (filtered_values..., value) : filtered_values - end +# These functions are also defined in Unrolled.jl, but those versions use "in" +# instead of "unrolled_in". +unrolled_union(values1, values2) = + (values1..., unrolled_setdiff(values2, values1)...) +unrolled_intersect(values1, values2) = + unrolled_filter(x -> unrolled_in(x, values2), values1) +unrolled_setdiff(values1, values2) = + unrolled_filter(x -> !unrolled_in(x, values2), values1) unrolled_unique(values) = - unrolled_foldl(values, ()) do unique_values, value - unrolled_any(isequal(value), unique_values) ? unique_values : - (unique_values..., value) - end + unrolled_reduce(unrolled_union, (), unrolled_map(tuple, values)) unrolled_flatten(values) = - unrolled_foldl(values, ()) do flattened_values, value - (flattened_values..., value...) - end - -# Non-standard functions: + unrolled_reduce((tuple1, tuple2) -> (tuple1..., tuple2...), (), values) unrolled_mapflatten(f::F, values) where {F} = unrolled_flatten(unrolled_map(f, values)) @@ -73,19 +20,12 @@ function unrolled_findonly(f::F, values) where {F} filtered_values = unrolled_filter(f, values) length(filtered_values) == 1 || error("unrolled_findonly requires that exactly one value makes f true") - return first(filtered_values) + return filtered_values[1] end -# This is required to make functions defined elsewhere type-stable, as of Julia -# 1.9. Specifically, if an unrolled function is used to implement the recursion -# of another function, it needs to have its recursion limit disabled in order -# for that other function to be type-stable. -if hasfield(Method, :recursion_relation) - dont_limit = (args...) -> true - for m in methods(unrolled_any) - m.recursion_relation = dont_limit - end # for is_valid_name - for m in methods(unrolled_mapflatten) - m.recursion_relation = dont_limit - end # for complement_values_in_subtree and value_or_non_overlapping_children -end +# The way unrolled_reduce is defined in Unrolled.jl essentially makes it a +# backwards unrolled_foldl. +unrolled_foldl(f::F, values) where {F} = + isempty(values) ? + error("unrolled_foldl requires a nonempty collection of values") : + unrolled_reduce((x, y) -> f(y, x), values[1], values[2:end]) diff --git a/test/MatrixFields/field_names.jl b/test/MatrixFields/field_names.jl index 6834a9eef1..7724951e07 100644 --- a/test/MatrixFields/field_names.jl +++ b/test/MatrixFields/field_names.jl @@ -94,11 +94,11 @@ end (@name(a.b), @name(a.c)) @test_all MatrixFields.child_names(@name(a.c), name_tree) == (@name(a.c.:(1)), @name(a.c.:(2)), @name(a.c.:(3))) - @test_throws "does not contain any child names" MatrixFields.child_names( + @test_throws "does not have child names" MatrixFields.child_names( @name(a.c.:(2)), name_tree, ) - @test_throws "does not contain the name" MatrixFields.child_names( + @test_throws "is not a valid name" MatrixFields.child_names( @name(foo.invalid_name), name_tree, ) @@ -244,9 +244,11 @@ end (v_set1, v_set2), (v_set1, v_set2_no_tree), (v_set1_no_tree, v_set2), + (v_set1_no_tree, v_set2_no_tree), (m_set1, m_set2), (m_set1, m_set2_no_tree), (m_set1_no_tree, m_set2), + (m_set1_no_tree, m_set2_no_tree), ) @test_all set1 != set2 @test_all !issubset(set1, set2) @@ -257,20 +259,6 @@ end @test_all !MatrixFields.is_subset_that_covers_set(set2, set1) end - for (set1, set2) in - ((v_set1_no_tree, v_set2_no_tree), (m_set1_no_tree, m_set2_no_tree)) - @test_all set1 != set2 - @test_all !issubset(set1, set2) - @test_all issubset(set2, set1) - @test_all intersect(set1, set2) == set2 - @test_all union(set1, set2) == set1 - @test_all !MatrixFields.is_subset_that_covers_set(set1, set2) - @test_throws "FieldNameTree" MatrixFields.is_subset_that_covers_set( - set2, - set1, - ) - end - for (set1, set3) in ( (v_set1, v_set3), (v_set1, v_set3_no_tree), From ddd1d8a80f482ab92ded9c253c8cad7557a1b85d Mon Sep 17 00:00:00 2001 From: Dennis Yatunin Date: Thu, 26 Oct 2023 17:50:37 -0700 Subject: [PATCH 2/3] Extend unit tests and fix some bugs --- src/MatrixFields/field_name.jl | 24 +- src/MatrixFields/field_name_set.jl | 252 ++++++++------- src/MatrixFields/unrolled_functions.jl | 98 ++++-- test/MatrixFields/field_names.jl | 408 ++++++++++++++----------- 4 files changed, 456 insertions(+), 326 deletions(-) diff --git a/src/MatrixFields/field_name.jl b/src/MatrixFields/field_name.jl index a9bfbee87f..a389e5e816 100644 --- a/src/MatrixFields/field_name.jl +++ b/src/MatrixFields/field_name.jl @@ -75,7 +75,7 @@ is_child_name( length(child_name_chain) >= length(parent_name_chain) && child_name_chain[1:length(parent_name_chain)] == parent_name_chain -names_are_overlapping(name1, name2) = +is_overlapping_name(name1, name2) = is_child_name(name1, name2) || is_child_name(name2, name1) extract_internal_name( @@ -124,21 +124,20 @@ function subtree_at_name(x, name) return if isempty(internal_names) FieldNameTreeLeaf(name) else - FieldNameTreeNode(name, subtrees_at_names(x, name, internal_names)) + subsubtrees_at_name = + recursively_unrolled_map(internal_names) do internal_name + subtree_at_name(x, append_internal_name(name, internal_name)) + end + FieldNameTreeNode(name, subsubtrees_at_name) end end -subtrees_at_names(x, name, internal_names) = - isempty(internal_names) ? () : - ( - subtree_at_name(x, append_internal_name(name, internal_names[1])), - subtrees_at_names(x, name, internal_names[2:end])..., - ) is_valid_name(name, tree) = name == tree.name || - tree isa FieldNameTreeNode && is_valid_name(name, tree.subtrees...) -is_valid_name(name, subtree, subtrees...) = - is_valid_name(name, subtree) || is_valid_name(name, subtrees...) + tree isa FieldNameTreeNode && + recursively_unrolled_any(tree.subtrees) do subtree + is_valid_name(name, subtree) + end function child_names(name, tree) is_valid_name(name, tree) || error("$name is not a valid name") @@ -179,9 +178,6 @@ if hasfield(Method, :recursion_relation) for m in methods(subtree_at_name) m.recursion_relation = dont_limit end - for m in methods(subtrees_at_names) - m.recursion_relation = dont_limit - end for m in methods(is_valid_name) m.recursion_relation = dont_limit end diff --git a/src/MatrixFields/field_name_set.jl b/src/MatrixFields/field_name_set.jl index a7743d79b4..3dd16e52de 100644 --- a/src/MatrixFields/field_name_set.jl +++ b/src/MatrixFields/field_name_set.jl @@ -81,28 +81,30 @@ Base.:(==)(set1::FieldNameSet, set2::FieldNameSet) = function Base.union(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - values1′, values2′ = set1.values, set2.values - values1, values2 = non_overlapping_values(values1′, values2′, name_tree) - return FieldNameSet{T}(unrolled_union(values1, values2), name_tree) + result_values = union_values(set1.values, set2.values, name_tree) + return FieldNameSet{T}(result_values, name_tree) end function Base.intersect(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - values1′, values2′ = set1.values, set2.values - values1, values2 = non_overlapping_values(values1′, values2′, name_tree) - return FieldNameSet{T}(unrolled_intersect(values1, values2), name_tree) + all_values = union_values(set1.values, set2.values, name_tree) + result_values = unrolled_filter(all_values) do value + is_value_in_set(value, set1.values, name_tree) && + is_value_in_set(value, set2.values, name_tree) + end + return FieldNameSet{T}(result_values, name_tree) end function Base.setdiff(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - values1′, values2′ = set1.values, set2.values - values1, values2 = non_overlapping_values(values1′, values2′, name_tree) - return FieldNameSet{T}(unrolled_setdiff(values1, values2), name_tree) + all_values = union_values(set1.values, set2.values, name_tree) + result_values = unrolled_filter(all_values) do value + !is_value_in_set(value, set2.values, name_tree) + end + return FieldNameSet{T}(result_values, name_tree) end -set_string(set) = - length(set) == 2 ? join(set.values, " and ") : - join(set.values, ", ", ", and ") +set_string(set) = values_string(set.values) set_complement(set) = setdiff(universal_set(eltype(set)), set) @@ -116,16 +118,14 @@ end function cartesian_product(set1::FieldVectorKeys, set2::FieldVectorKeys) name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - result_values = unrolled_mapflatten(set1.values) do row_name - unrolled_map(col_name -> (row_name, col_name), set2.values) - end + result_values = unrolled_product((set1.values, set2.values)) return FieldMatrixKeys(result_values, name_tree) end function matrix_row_keys(set::FieldMatrixKeys) result_values′ = unrolled_map(name_pair -> name_pair[1], set.values) result_values = - unique_and_non_overlapping_values(result_values′, set.name_tree) + remove_duplicates_and_overlaps(result_values′, set.name_tree) return FieldVectorKeys(result_values, set.name_tree) end @@ -137,7 +137,7 @@ end function matrix_diagonal_keys(set::FieldMatrixKeys) result_values′ = unrolled_filter(set.values) do name_pair - names_are_overlapping(name_pair[1], name_pair[2]) + is_overlapping_name(name_pair[1], name_pair[2]) end result_values = unrolled_map(result_values′) do name_pair if name_pair[1] == name_pair[2] @@ -175,10 +175,10 @@ because we cannot extract internal columns from FieldNameDict entries. =# function matrix_product_keys(set1::FieldMatrixKeys, set2::FieldNameSet) name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - result_values′ = unrolled_mapflatten(set1.values) do name_pair1 + result_values′ = unrolled_flatmap(set1.values) do name_pair1 overlapping_set2_values = unrolled_filter(set2.values) do value2 row_name2 = eltype(set2) <: FieldName ? value2 : value2[1] - names_are_overlapping(name_pair1[2], row_name2) + is_overlapping_name(name_pair1[2], row_name2) end unrolled_map(overlapping_set2_values) do value2 row_name2 = eltype(set2) <: FieldName ? value2 : value2[1] @@ -194,8 +194,8 @@ function matrix_product_keys(set1::FieldMatrixKeys, set2::FieldNameSet) end end end - result_values = unique_and_non_overlapping_values(result_values′, name_tree) - # Note: the modification of result_values may trigger multiplication case 4. + # Note: Modifying result_values here can trigger multiplication case 4. + result_values = remove_duplicates_and_overlaps(result_values′, name_tree) return FieldNameSet{eltype(set2)}(result_values, name_tree) end function summand_names_for_matrix_product( @@ -206,14 +206,14 @@ function summand_names_for_matrix_product( product_row_name = eltype(set2) <: FieldName ? product_key : product_key[1] name_tree = combine_name_trees(set1.name_tree, set2.name_tree) overlapping_set1_values = unrolled_filter(set1.values) do name_pair1 - names_are_overlapping(product_row_name, name_pair1[1]) + is_overlapping_name(product_row_name, name_pair1[1]) end - result_values = unrolled_mapflatten(overlapping_set1_values) do name_pair1 + result_values = unrolled_flatmap(overlapping_set1_values) do name_pair1 overlapping_set2_values = unrolled_filter(set2.values) do value2 row_name2 = eltype(set2) <: FieldName ? value2 : value2[1] - names_are_overlapping(name_pair1[2], row_name2) && ( + is_overlapping_name(name_pair1[2], row_name2) && ( eltype(set2) <: FieldName || - names_are_overlapping(product_key[2], value2[2]) + is_overlapping_name(product_key[2], value2[2]) ) end unrolled_map(overlapping_set2_values) do value2 @@ -259,18 +259,17 @@ check_values(values, name_tree) = of $value have been passed to a FieldNameSet constructor", ) overlapping_values = unrolled_filter(values) do value′ - value != value′ && values_are_overlapping(value, value′) - end - if !isempty(overlapping_values) - overlapping_values_string = - length(overlapping_values) == 2 ? - join(overlapping_values, " or ") : - join(overlapping_values, ", ", ", or ") - error("Overlapping FieldNameSet values: $value cannot be in the \ - same FieldNameSet as $overlapping_values_string") + value′ != value && is_overlapping_value(value, value′) end + isempty(overlapping_values) || error( + "Overlapping FieldNameSet values: $value cannot be in the same \ + FieldNameSet as $(values_string(overlapping_values))", + ) end +values_string(values) = + length(values) == 2 ? join(values, " and ") : join(values, ", ", ", and ") + combine_name_trees(::Nothing, ::Nothing) = nothing combine_name_trees(name_tree1, ::Nothing) = name_tree1 combine_name_trees(::Nothing, name_tree2) = name_tree2 @@ -279,114 +278,145 @@ combine_name_trees(name_tree1, name_tree2) = error("Mismatched FieldNameTrees: The ability to combine different \ FieldNameTrees has not been implemented") +universal_set(::Type{FieldName}) = FieldVectorKeys((@name(),)) +universal_set(::Type{FieldNamePair}) = FieldMatrixKeys(((@name(), @name()),)) + is_valid_value(name::FieldName, name_tree) = is_valid_name(name, name_tree) is_valid_value(name_pair::FieldNamePair, name_tree) = is_valid_name(name_pair[1], name_tree) && is_valid_name(name_pair[2], name_tree) -values_are_overlapping(name1::FieldName, name2::FieldName) = - names_are_overlapping(name1, name2) -values_are_overlapping(name_pair1::FieldNamePair, name_pair2::FieldNamePair) = - names_are_overlapping(name_pair1[1], name_pair2[1]) && - names_are_overlapping(name_pair1[2], name_pair2[2]) - is_child_value(name1::FieldName, name2::FieldName) = is_child_name(name1, name2) is_child_value(name_pair1::FieldNamePair, name_pair2::FieldNamePair) = is_child_name(name_pair1[1], name_pair2[1]) && is_child_name(name_pair1[2], name_pair2[2]) +is_overlapping_value(name1::FieldName, name2::FieldName) = + is_overlapping_name(name1, name2) +is_overlapping_value(name_pair1::FieldNamePair, name_pair2::FieldNamePair) = + is_overlapping_name(name_pair1[1], name_pair2[1]) && + is_overlapping_name(name_pair1[2], name_pair2[2]) + is_value_in_set(value, values, name_tree) = - if unrolled_in(value, values) - true - elseif unrolled_any(value′ -> is_child_value(value, value′), values) - isnothing(name_tree) && error( - "Cannot check if $value is in FieldNameSet without a FieldNameTree", - ) - is_valid_value(value, name_tree) - else - false - end + unrolled_in(value, values) || + unrolled_any(value′ -> is_child_value(value, value′), values) && ( + isnothing(name_tree) ? + error( + "Missing FieldNameTree: Cannot check if $value is in FieldNameSet \ + without a FieldNameTree", + ) : is_valid_value(value, name_tree) + ) -universal_set(::Type{FieldName}) = FieldVectorKeys((@name(),)) -universal_set(::Type{FieldNamePair}) = FieldMatrixKeys(((@name(), @name()),)) +remove_duplicates_and_overlaps(values, name_tree) = + union_values((), values, name_tree) -# TODO: Simplify the following code. +function union_values(values, new_values, name_tree) + isempty(new_values) && return values -function non_overlapping_values(values1, values2, name_tree) - new_values1 = unrolled_mapflatten(values1) do value - value_or_non_overlapping_children(value, values2, name_tree) - end - new_values2 = unrolled_mapflatten(values2) do value - value_or_non_overlapping_children(value, values1, name_tree) + new_value = new_values[1] + more_new_values = new_values[2:end] + unrolled_in(new_value, values) && + return union_values(values, more_new_values, name_tree) + + overlapping_values, non_overlapping_values = unrolled_split(values) do value + is_overlapping_value(new_value, value) end - if eltype(values1) <: FieldName - new_values1, new_values2 - else - # Repeat the above operation to handle complex matrix key overlaps. - new_values1′ = unrolled_mapflatten(new_values1) do value - value_or_non_overlapping_children(value, new_values2, name_tree) + isempty(overlapping_values) && + return union_values((values..., new_value), more_new_values, name_tree) + + isnothing(name_tree) && error( + "Missing FieldNameTree: Cannot eliminate overlaps between $new_value \ + and $(values_string(overlapping_values)) without a FieldNameTree", + ) + + overlapping_values_that_are_children_of_value, other_overlapping_values = + unrolled_split(overlapping_values) do value + is_child_value(value, new_value) end - new_values2′ = unrolled_mapflatten(new_values2) do value - value_or_non_overlapping_children(value, new_values1, name_tree) + possible_union_values = if isempty(other_overlapping_values) + possible_children_of_value = available_sets_of_child_values( + new_value, + overlapping_values, + name_tree, + ) + recursively_unrolled_map( + possible_children_of_value, + ) do children_of_value + union_values( + values, + (children_of_value..., more_new_values...), + name_tree, + ) + end + else + possible_children_of_other_overlapping_values = unrolled_map( + unrolled_flatten, + unrolled_prodmap(other_overlapping_values) do value + available_sets_of_child_values(value, (new_value,), name_tree) + end, + ) + recursively_unrolled_map( + possible_children_of_other_overlapping_values, + ) do children_of_other_overlapping_values + values_and_children_of_values = ( + non_overlapping_values..., + overlapping_values_that_are_children_of_value..., + children_of_other_overlapping_values..., + ) + union_values(values_and_children_of_values, new_values, name_tree) end - return new_values1′, new_values2′ - end -end - -function unique_and_non_overlapping_values(values, name_tree) - new_values = unrolled_mapflatten(values) do value - value_or_non_overlapping_children(value, values, name_tree) end - return unrolled_unique(new_values) + return unrolled_argmin(length, possible_union_values) end -function value_or_non_overlapping_children(name::FieldName, names, name_tree) - need_child_names = unrolled_any(names) do name′ - is_child_value(name′, name) && name′ != name - end - need_child_names || return (name,) - isnothing(name_tree) && - error("Cannot compute child names of $name without a FieldNameTree") - return unrolled_mapflatten(child_names(name, name_tree)) do child_name - value_or_non_overlapping_children(child_name, names, name_tree) - end -end -function value_or_non_overlapping_children( +available_sets_of_child_values(name::FieldName, _, name_tree) = + (child_names(name, name_tree),) +function available_sets_of_child_values( name_pair::FieldNamePair, - name_pairs, + overlapping_name_pairs, name_tree, ) - need_row_child_names = unrolled_any(name_pairs) do name_pair′ - is_child_value(name_pair′, name_pair) && name_pair′[1] != name_pair[1] + row_name, col_name = name_pair + row_children_needed = unrolled_any(overlapping_name_pairs) do name_pair′ + name_pair′[1] != row_name && is_child_name(name_pair′[1], row_name) end - need_col_child_names = unrolled_any(name_pairs) do name_pair′ - is_child_value(name_pair′, name_pair) && name_pair′[2] != name_pair[2] + col_children_needed = unrolled_any(overlapping_name_pairs) do name_pair′ + name_pair′[2] != col_name && is_child_name(name_pair′[2], col_name) end - need_row_child_names || need_col_child_names || return (name_pair,) - isnothing(name_tree) && error( - "Cannot compute child name pairs of $name_pair without a FieldNameTree", - ) - row_name_children = - need_row_child_names ? child_names(name_pair[1], name_tree) : - (name_pair[1],) - col_name_children = - need_col_child_names ? child_names(name_pair[2], name_tree) : - (name_pair[2],) - return unrolled_mapflatten(row_name_children) do row_name_child - unrolled_mapflatten(col_name_children) do col_name_child - child_pair = (row_name_child, col_name_child) - value_or_non_overlapping_children(child_pair, name_pairs, name_tree) - end + row_children = + row_children_needed ? + unrolled_map(child_names(row_name, name_tree)) do child_of_row_name + (child_of_row_name, col_name) + end : nothing + col_children = + col_children_needed ? + unrolled_map(child_names(col_name, name_tree)) do child_of_col_name + (row_name, child_of_col_name) + end : nothing + n_row_children = row_children_needed ? length(row_children) : 0 + n_col_children = col_children_needed ? length(col_children) : 0 + # We are guaranteed that either n_row_children > 0 or n_col_children > 0. + return if n_row_children == n_col_children == 1 + only_child_of_name_pair = (row_children[1][1], col_children[1][2]) + ((only_child_of_name_pair,),) + elseif n_row_children > 0 && n_col_children <= 1 + (row_children,) + elseif n_col_children > 0 && n_row_children <= 1 + (col_children,) + else + @assert n_row_children > 1 && n_col_children > 1 + # If multiple row and column children are available, we might get + # different results depending on whether we expand the row or the column + # first, so we need to try both and pick the one that corresponds to the + # result with the shortest length. + (row_children, col_children) end end # This is required for type-stability as of Julia 1.9. if hasfield(Method, :recursion_relation) dont_limit = (args...) -> true - for m in methods(unrolled_mapflatten) - m.recursion_relation = dont_limit - end - for m in methods(value_or_non_overlapping_children) + for m in methods(union_values) m.recursion_relation = dont_limit end end diff --git a/src/MatrixFields/unrolled_functions.jl b/src/MatrixFields/unrolled_functions.jl index 533aaf7785..5442856008 100644 --- a/src/MatrixFields/unrolled_functions.jl +++ b/src/MatrixFields/unrolled_functions.jl @@ -1,31 +1,79 @@ -# These functions are also defined in Unrolled.jl, but those versions use "in" -# instead of "unrolled_in". -unrolled_union(values1, values2) = - (values1..., unrolled_setdiff(values2, values1)...) -unrolled_intersect(values1, values2) = - unrolled_filter(x -> unrolled_in(x, values2), values1) -unrolled_setdiff(values1, values2) = - unrolled_filter(x -> !unrolled_in(x, values2), values1) - -unrolled_unique(values) = - unrolled_reduce(unrolled_union, (), unrolled_map(tuple, values)) +# The following functions are extensions to the generated functions defined in +# Unrolled.jl. -unrolled_flatten(values) = - unrolled_reduce((tuple1, tuple2) -> (tuple1..., tuple2...), (), values) - -unrolled_mapflatten(f::F, values) where {F} = - unrolled_flatten(unrolled_map(f, values)) +unrolled_split(f::F, values) where {F} = + unrolled_filter(f, values), unrolled_filter(!f, values) function unrolled_findonly(f::F, values) where {F} filtered_values = unrolled_filter(f, values) - length(filtered_values) == 1 || - error("unrolled_findonly requires that exactly one value makes f true") - return filtered_values[1] + return length(filtered_values) == 1 ? filtered_values[1] : + error("unrolled_findonly requires that exactly 1 value makes f true") +end + +# The implementation of unrolled_reduce in Unrolled.jl makes it roughly the same +# as foldl, but with the order of arguments in every call to f flipped. +unrolled_foldl(f::F, values; init = nothing) where {F} = + if isnothing(init) + isempty(values) ? + error("unrolled_foldl requires init for empty collections of values") : + unrolled_reduce((x, y) -> f(y, x), values[1], values[2:end]) + else + unrolled_reduce((x, y) -> f(y, x), init, values) + end + +function unrolled_argmin(f::F, values) where {F} + values_and_fs = unrolled_map(value -> (value, f(value)), values) + min_value_and_f = + unrolled_foldl(values_and_fs) do (min_value, min_f), (new_value, new_f) + new_f < min_f ? (new_value, new_f) : (min_value, min_f) + end + return min_value_and_f[1] end -# The way unrolled_reduce is defined in Unrolled.jl essentially makes it a -# backwards unrolled_foldl. -unrolled_foldl(f::F, values) where {F} = - isempty(values) ? - error("unrolled_foldl requires a nonempty collection of values") : - unrolled_reduce((x, y) -> f(y, x), values[1], values[2:end]) +# This needs to use unrolled_reduce instead of unrolled_foldl in order for +# unrolled_product to be type-stable (otherwise, calling unrolled_flatmap inside +# of unrolled_foldl causes the compiler to hit a recursion limit). +unrolled_flatten(values) = + unrolled_reduce((), values) do value, flattened_values + (flattened_values..., value...) + end + +unrolled_flatmap(f::F, values) where {F} = + unrolled_flatten(unrolled_map(f, values)) + +unrolled_product(values) = + unrolled_foldl(values; init = ((),)) do product_values, value + unrolled_flatmap(product_values) do sub_values + unrolled_map(value) do sub_value + (sub_values..., sub_value) + end + end + end + +unrolled_prodmap(f::F, values) where {F} = + unrolled_product(unrolled_map(f, values)) + +# The following functions are recursion-based alternatives to the generated +# functions in Unrolled.jl. These should be used instead of their generated +# counterparts when implementing recursion in other functions. For example, if a +# function `func` needs to map over some values, and if the computation for each +# value recursively calls `func`, the map should be implemented using +# `recursively_unrolled_map` instead of `unrolled_map`. + +@inline recursively_unrolled_map(f::F, values) where {F} = + isempty(values) ? () : + (f(values[1]), recursively_unrolled_map(f, values[2:end])...) + +recursively_unrolled_any(f::F, values) where {F} = + unrolled_any(identity, recursively_unrolled_map(f, values)) + +# This is required for type-stability as of Julia 1.9. +if hasfield(Method, :recursion_relation) + dont_limit = (args...) -> true + for m in methods(recursively_unrolled_map) + m.recursion_relation = dont_limit + end + for m in methods(recursively_unrolled_any) + m.recursion_relation = dont_limit + end +end diff --git a/test/MatrixFields/field_names.jl b/test/MatrixFields/field_names.jl index 7724951e07..46b56ee069 100644 --- a/test/MatrixFields/field_names.jl +++ b/test/MatrixFields/field_names.jl @@ -1,4 +1,4 @@ -import ClimaCore.MatrixFields: @name +import ClimaCore.MatrixFields: @name, is_subset_that_covers_set include("matrix_field_test_utils.jl") @@ -53,10 +53,10 @@ const x = (; foo = Foo(0), a = (; b = 1, c = ((; d = 2), (;), ((), nothing)))) @test_all MatrixFields.is_child_name(@name(a.c.:(1).d), @name(a)) @test_all !MatrixFields.is_child_name(@name(a.c.:(1).d), @name(foo)) - @test_all MatrixFields.names_are_overlapping(@name(a), @name(a.c.:(1).d)) - @test_all MatrixFields.names_are_overlapping(@name(a.c.:(1).d), @name(a)) - @test_all !MatrixFields.names_are_overlapping(@name(foo), @name(a.c.:(1).d)) - @test_all !MatrixFields.names_are_overlapping(@name(a.c.:(1).d), @name(foo)) + @test_all MatrixFields.is_overlapping_name(@name(a), @name(a.c.:(1).d)) + @test_all MatrixFields.is_overlapping_name(@name(a.c.:(1).d), @name(a)) + @test_all !MatrixFields.is_overlapping_name(@name(foo), @name(a.c.:(1).d)) + @test_all !MatrixFields.is_overlapping_name(@name(a.c.:(1).d), @name(foo)) @test_all MatrixFields.extract_internal_name(@name(a.c.:(1).d), @name(a)) == @name(c.:(1).d) @@ -114,7 +114,7 @@ end matrix_keys_no_tree(name_pairs...) = MatrixFields.FieldMatrixKeys(name_pairs) - @testset "FieldNameSet Construction" begin + @testset "FieldNameSet Constructors" begin @test_throws "Invalid FieldNameSet value" vector_keys( @name(foo.invalid_name), ) @@ -145,23 +145,67 @@ end end end - @testset "FieldNameSet Iteration" begin - v_set1 = vector_keys(@name(foo), @name(a.c)) - v_set1_no_tree = vector_keys_no_tree(@name(foo), @name(a.c)) - m_set1 = matrix_keys((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) - m_set1_no_tree = matrix_keys_no_tree( - (@name(foo), @name(a.c)), - (@name(a.b), @name(foo)), - ) - - @test_all map(name -> (name, name), v_set1) == - ((@name(foo), @name(foo)), (@name(a.c), @name(a.c))) - @test_all map(name_pair -> name_pair[1], m_set1) == - (@name(foo), @name(a.b)) + v_set1 = vector_keys(@name(foo), @name(a.c)) + v_set1_no_tree = vector_keys_no_tree(@name(foo), @name(a.c)) + m_set1 = matrix_keys((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) + m_set1_no_tree = + matrix_keys_no_tree((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) + + # Proper subsets of v_set1 and m_set1. + v_set2 = vector_keys(@name(foo)) + v_set2_no_tree = vector_keys_no_tree(@name(foo)) + m_set2 = matrix_keys((@name(foo), @name(a.c))) + m_set2_no_tree = matrix_keys_no_tree((@name(foo), @name(a.c))) + + # Subsets that cover v_set1 and m_set1. + v_set3 = vector_keys( + @name(foo.value), + @name(a.c.:(1)), + @name(a.c.:(2)), + @name(a.c.:(3)), + ) + v_set3_no_tree = vector_keys_no_tree( + @name(foo.value), + @name(a.c.:(1)), + @name(a.c.:(2)), + @name(a.c.:(3)), + ) + m_set3 = matrix_keys( + (@name(foo.value), @name(a.c.:(1))), + (@name(foo), @name(a.c.:(2))), + (@name(foo), @name(a.c.:(3))), + (@name(a.b), @name(foo.value)), + ) + m_set3_no_tree = matrix_keys_no_tree( + (@name(foo.value), @name(a.c.:(1))), + (@name(foo), @name(a.c.:(2))), + (@name(foo), @name(a.c.:(3))), + (@name(a.b), @name(foo.value)), + ) - @test_all isnothing(foreach(name -> (name, name), v_set1)) - @test_all isnothing(foreach(name_pair -> name_pair[1], m_set1)) + # Sets that overlap with v_set1 and m_set1, but are neither subsets nor + # supersets of those sets. Some of the values in m_set4 overlap with + # those in m_set1, but are neither children nor parents of those values + # (this is only possible with matrix keys). + v_set4 = vector_keys(@name(a.b), @name(a.c.:(1)), @name(a.c.:(2))) + v_set4_no_tree = + vector_keys_no_tree(@name(a.b), @name(a.c.:(1)), @name(a.c.:(2))) + m_set4 = matrix_keys( + (@name(), @name(a.c.:(1))), + (@name(foo.value), @name(foo)), + (@name(foo.value), @name(a.c.:(2))), + (@name(a), @name(foo.value)), + (@name(a.c.:(3)), @name(a.c.:(3))), + ) + m_set4_no_tree = matrix_keys_no_tree( + (@name(), @name(a.c.:(1))), + (@name(foo.value), @name(foo)), + (@name(foo.value), @name(a.c.:(2))), + (@name(a), @name(foo.value)), + (@name(a.c.:(3)), @name(a.c.:(3))), + ) + @testset "FieldNameSet Basic Operations" begin @test string(v_set1) == "FieldVectorKeys(@name(foo), @name(a.c); )" @test string(v_set1_no_tree) == @@ -171,6 +215,14 @@ end @test string(m_set1_no_tree) == "FieldMatrixKeys((@name(foo), \ @name(a.c)), (@name(a.b), @name(foo)))" + @test_all map(name -> (name, name), v_set1) == + ((@name(foo), @name(foo)), (@name(a.c), @name(a.c))) + @test_all map(name_pair -> name_pair[1], m_set1) == + (@name(foo), @name(a.b)) + + @test_all isnothing(foreach(name -> (name, name), v_set1)) + @test_all isnothing(foreach(name_pair -> name_pair[1], m_set1)) + for set in (v_set1, v_set1_no_tree) @test_all @name(foo) in set @test_all !(@name(a.b) in set) @@ -193,52 +245,43 @@ end m_set1_no_tree @test_throws "FieldNameTree" (@name(foo.invalid_name), @name(a.c)) in m_set1_no_tree - end - @testset "FieldNameSet Operations for Addition/Subtraction" begin - v_set1 = vector_keys(@name(foo), @name(a.c)) - v_set1_no_tree = vector_keys_no_tree(@name(foo), @name(a.c)) - m_set1 = matrix_keys((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) - m_set1_no_tree = matrix_keys_no_tree( - (@name(foo), @name(a.c)), - (@name(a.b), @name(foo)), + @test_all MatrixFields.set_complement(v_set1) == + vector_keys_no_tree(@name(a.b)) + @test_throws "FieldNameTree" MatrixFields.set_complement(v_set1_no_tree) + + @test_all MatrixFields.set_complement(m_set1) == matrix_keys_no_tree( + (@name(foo), @name(foo)), + (@name(foo), @name(a.b)), + (@name(a.c), @name()), + (@name(a.b), @name(a)), ) + @test_throws "FieldNameTree" MatrixFields.set_complement(m_set1_no_tree) - v_set2 = vector_keys(@name(foo)) - v_set2_no_tree = vector_keys_no_tree(@name(foo)) - m_set2 = matrix_keys((@name(foo), @name(a.c))) - m_set2_no_tree = matrix_keys_no_tree((@name(foo), @name(a.c))) + @test_all MatrixFields.set_complement(v_set4) == + vector_keys_no_tree(@name(foo), @name(a.c.:(3))) + @test_throws "FieldNameTree" MatrixFields.set_complement(v_set4_no_tree) - v_set3 = vector_keys( - @name(foo.value), - @name(a.c.:(1)), - @name(a.c.:(2)), - @name(a.c.:(3)), - ) - v_set3_no_tree = vector_keys_no_tree( - @name(foo.value), - @name(a.c.:(1)), - @name(a.c.:(2)), - @name(a.c.:(3)), - ) - m_set3 = matrix_keys( - (@name(foo), @name(a.c.:(1))), - (@name(foo), @name(a.c.:(2))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo)), - ) - m_set3_no_tree = matrix_keys_no_tree( - (@name(foo), @name(a.c.:(1))), - (@name(foo), @name(a.c.:(2))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo)), - ) - m_set3_no_tree′ = matrix_keys_no_tree( - (@name(foo.value), @name(a.c.:(1))), - (@name(foo.value), @name(a.c.:(2))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo)), + @test_all MatrixFields.set_complement(m_set4) == matrix_keys_no_tree( + (@name(), @name(a.b)), + (@name(foo), @name(a.c.:(3))), + (@name(a), @name(a.c.:(2))), + (@name(a.b), @name(a.c.:(3))), + (@name(a.c.:(1)), @name(a.c.:(3))), + (@name(a.c.:(2)), @name(a.c.:(3))), ) + @test_throws "FieldNameTree" MatrixFields.set_complement(m_set4_no_tree) + end + + @testset "FieldNameSet Binary Set Operations" begin + for set1 in (v_set1, v_set1_no_tree, m_set1, m_set1_no_tree) + @test_all set1 == set1 + @test_all issubset(set1, set1) + @test_all is_subset_that_covers_set(set1, set1) + @test_all intersect(set1, set1) == set1 + @test_all union(set1, set1) == set1 + @test_all isempty(setdiff(set1, set1)) + end for (set1, set2) in ( (v_set1, v_set2), @@ -251,54 +294,124 @@ end (m_set1_no_tree, m_set2_no_tree), ) @test_all set1 != set2 - @test_all !issubset(set1, set2) - @test_all issubset(set2, set1) - @test_all intersect(set1, set2) == set2 - @test_all union(set1, set2) == set1 - @test_all !MatrixFields.is_subset_that_covers_set(set1, set2) - @test_all !MatrixFields.is_subset_that_covers_set(set2, set1) + @test_all !issubset(set1, set2) && issubset(set2, set1) + @test_all !is_subset_that_covers_set(set1, set2) && + !is_subset_that_covers_set(set2, set1) + @test_all intersect(set1, set2) == intersect(set2, set1) == set2 + @test_all union(set1, set2) == union(set2, set1) == set1 + if eltype(set1) <: MatrixFields.FieldName + @test_all setdiff(set1, set2) == vector_keys(@name(a.c)) + else + @test_all setdiff(set1, set2) == + matrix_keys((@name(a.b), @name(foo))) + end + @test_all isempty(setdiff(set2, set1)) end for (set1, set3) in ( (v_set1, v_set3), (v_set1, v_set3_no_tree), (v_set1_no_tree, v_set3), - ) - @test_all set1 != set3 - @test_all !issubset(set1, set3) - @test_all issubset(set3, set1) - @test_all intersect(set1, set3) == set3 - @test_all union(set1, set3) == set3 - @test_all !MatrixFields.is_subset_that_covers_set(set1, set3) - @test_all MatrixFields.is_subset_that_covers_set(set3, set1) - end - - for (set1, set3) in ( (m_set1, m_set3), (m_set1, m_set3_no_tree), (m_set1_no_tree, m_set3), ) @test_all set1 != set3 - @test_all !issubset(set1, set3) - @test_all issubset(set3, set1) - @test_all intersect(set1, set3) == m_set3_no_tree′ - @test_all union(set1, set3) == m_set3_no_tree′ - @test_all !MatrixFields.is_subset_that_covers_set(set1, set3) - @test_all MatrixFields.is_subset_that_covers_set(set3, set1) + @test_all !issubset(set1, set3) && issubset(set3, set1) + @test_all !is_subset_that_covers_set(set1, set3) && + is_subset_that_covers_set(set3, set1) + @test_all intersect(set1, set3) == intersect(set3, set1) == set3 + @test_all union(set1, set3) == union(set3, set1) == set3 + @test_all isempty(setdiff(set1, set3)) && + isempty(setdiff(set3, set1)) end for (set1, set3) in ((v_set1_no_tree, v_set3_no_tree), (m_set1_no_tree, m_set3_no_tree)) @test_all set1 != set3 @test_all !issubset(set1, set3) + @test_all !is_subset_that_covers_set(set1, set3) @test_throws "FieldNameTree" issubset(set3, set1) - @test_throws "FieldNameTree" intersect(set1, set3) == set3 - @test_throws "FieldNameTree" union(set1, set3) == set3 - @test_all !MatrixFields.is_subset_that_covers_set(set1, set3) - @test_throws "FieldNameTree" MatrixFields.is_subset_that_covers_set( - set3, - set1, - ) + @test_throws "FieldNameTree" is_subset_that_covers_set(set3, set1) + @test_throws "FieldNameTree" intersect(set1, set3) + @test_throws "FieldNameTree" intersect(set3, set1) + @test_throws "FieldNameTree" union(set1, set3) + @test_throws "FieldNameTree" union(set3, set1) + @test_throws "FieldNameTree" setdiff(set1, set3) + @test_throws "FieldNameTree" setdiff(set3, set1) + end + + for (set1, set4) in ( + (v_set1, v_set4), + (v_set1, v_set4_no_tree), + (v_set1_no_tree, v_set4), + (m_set1, m_set4), + (m_set1, m_set4_no_tree), + (m_set1_no_tree, m_set4), + ) + @test_all set1 != set4 + @test_all !issubset(set1, set4) && !issubset(set4, set1) + @test_all !is_subset_that_covers_set(set1, set4) && + !is_subset_that_covers_set(set4, set1) + if eltype(set1) <: MatrixFields.FieldName + @test_all intersect(set1, set4) == + intersect(set4, set1) == + vector_keys_no_tree(@name(a.c.:(1)), @name(a.c.:(2))) + @test_all union(set1, set4) == + union(set4, set1) == + vector_keys_no_tree( + @name(foo), + @name(a.b), + @name(a.c.:(1)), + @name(a.c.:(2)), + @name(a.c.:(3)), + ) + @test_all setdiff(set1, set4) == + vector_keys_no_tree(@name(foo), @name(a.c.:(3))) + @test_all setdiff(set4, set1) == vector_keys_no_tree(@name(a.b)) + else + @test_all intersect(set1, set4) == + intersect(set4, set1) == + matrix_keys_no_tree( + (@name(foo), @name(a.c.:(1))), + (@name(foo.value), @name(a.c.:(2))), + (@name(a.b), @name(foo.value)), + ) + @test_all union(set1, set4) == + union(set4, set1) == + matrix_keys_no_tree( + (@name(foo), @name(a.c.:(1))), + (@name(foo), @name(a.c.:(3))), + (@name(foo.value), @name(foo)), + (@name(foo.value), @name(a.c.:(2))), + (@name(a), @name(a.c.:(1))), + (@name(a.b), @name(foo.value)), + (@name(a.c), @name(foo.value)), + (@name(a.c.:(3)), @name(a.c.:(3))), + ) + @test_all setdiff(set1, set4) == + matrix_keys_no_tree((@name(foo), @name(a.c.:(3)))) + @test_all setdiff(set4, set1) == matrix_keys_no_tree( + (@name(foo.value), @name(foo)), + (@name(a), @name(a.c.:(1))), + (@name(a.c), @name(foo.value)), + (@name(a.c.:(3)), @name(a.c.:(3))), + ) + end + end + + for (set1, set4) in + ((v_set1_no_tree, v_set4_no_tree), (m_set1_no_tree, m_set4_no_tree)) + @test_all set1 != set4 + @test_all !issubset(set1, set4) && !issubset(set4, set1) + @test_all !is_subset_that_covers_set(set1, set4) && + !is_subset_that_covers_set(set4, set1) + @test_throws "FieldNameTree" intersect(set1, set4) + @test_throws "FieldNameTree" intersect(set4, set1) + @test_throws "FieldNameTree" union(set1, set4) + @test_throws "FieldNameTree" union(set4, set1) + @test_throws "FieldNameTree" setdiff(set1, set4) + @test_throws "FieldNameTree" setdiff(set4, set1) end end @@ -491,111 +604,54 @@ end end @testset "Other FieldNameSet Operations" begin - v_set1 = vector_keys(@name(foo), @name(a.c)) - v_set1_no_tree = vector_keys_no_tree(@name(foo), @name(a.c)) - m_set1 = matrix_keys((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) - m_set1_no_tree = matrix_keys_no_tree( - (@name(foo), @name(a.c)), - (@name(a.b), @name(foo)), - ) - - v_set2 = vector_keys(@name(foo.value), @name(a.c.:(1)), @name(a.c.:(3))) - v_set2_no_tree = vector_keys_no_tree( - @name(foo.value), - @name(a.c.:(1)), - @name(a.c.:(3)) - ) - m_set2 = matrix_keys( - (@name(foo), @name(foo)), - (@name(foo), @name(a.c.:(1))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo.value)), - (@name(a), @name(a.c)), - ) - m_set2_no_tree = matrix_keys_no_tree( - (@name(foo), @name(foo)), - (@name(foo), @name(a.c.:(1))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo.value)), - (@name(a), @name(a.c)), - ) - - @test_all MatrixFields.set_complement(v_set2) == - vector_keys(@name(a.b), @name(a.c.:(2))) - @test_throws "FieldNameTree" MatrixFields.set_complement(v_set2_no_tree) - - @test_all MatrixFields.set_complement(m_set2) == matrix_keys( - (@name(foo.value), @name(a.b)), - (@name(foo.value), @name(a.c.:(2))), - (@name(a.c), @name(foo.value)), - (@name(a), @name(a.b)), - ) - @test_throws "FieldNameTree" MatrixFields.set_complement(m_set2_no_tree) - - for (set1, set2) in ( - (v_set1, v_set2), - (v_set1, v_set2_no_tree), - (v_set1_no_tree, v_set2), - ) - @test_all setdiff(set1, set2) == vector_keys(@name(a.c.:(2))) - end - - for (set1, set2) in ( - (m_set1, m_set2), - (m_set1, m_set2_no_tree), - (m_set1_no_tree, m_set2), - ) - @test_all setdiff(set1, set2) == - matrix_keys((@name(foo.value), @name(a.c.:(2)))) - end - - for (set1, set2) in - ((v_set1_no_tree, v_set2_no_tree), (m_set1_no_tree, m_set2_no_tree)) - @test_throws "FieldNameTree" setdiff(set1, set2) - end - # With one exception, none of the following operations require a # FieldNameTree. @test_all MatrixFields.corresponding_matrix_keys(v_set1_no_tree) == - matrix_keys( + matrix_keys_no_tree( (@name(foo), @name(foo)), (@name(a.c), @name(a.c)), ) @test_all MatrixFields.cartesian_product( v_set1_no_tree, - v_set2_no_tree, - ) == matrix_keys( - (@name(foo), @name(foo.value)), + v_set4_no_tree, + ) == matrix_keys_no_tree( + (@name(foo), @name(a.b)), (@name(foo), @name(a.c.:(1))), - (@name(foo), @name(a.c.:(3))), - (@name(a.c), @name(foo.value)), + (@name(foo), @name(a.c.:(2))), + (@name(a.c), @name(a.b)), (@name(a.c), @name(a.c.:(1))), - (@name(a.c), @name(a.c.:(3))), + (@name(a.c), @name(a.c.:(2))), ) @test_all MatrixFields.matrix_row_keys(m_set1_no_tree) == - vector_keys(@name(foo), @name(a.b)) + vector_keys_no_tree(@name(foo), @name(a.b)) - @test_all MatrixFields.matrix_row_keys(m_set2) == - vector_keys(@name(foo.value), @name(a.b), @name(a.c)) + @test_all MatrixFields.matrix_row_keys(m_set4) == vector_keys_no_tree( + @name(foo.value), + @name(a.b), + @name(a.c.:(1)), + @name(a.c.:(2)), + @name(a.c.:(3)) + ) @test_throws "FieldNameTree" MatrixFields.matrix_row_keys( - m_set2_no_tree, + m_set4_no_tree, ) - @test_all MatrixFields.matrix_off_diagonal_keys(m_set2_no_tree) == - matrix_keys( - (@name(foo), @name(a.c.:(1))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo.value)), - (@name(a), @name(a.c)), + @test_all MatrixFields.matrix_off_diagonal_keys(m_set4_no_tree) == + matrix_keys_no_tree( + (@name(), @name(a.c.:(1))), + (@name(foo.value), @name(foo)), + (@name(foo.value), @name(a.c.:(2))), + (@name(a), @name(foo.value)), ) - @test_all MatrixFields.matrix_diagonal_keys(m_set2_no_tree) == - matrix_keys( - (@name(foo), @name(foo)), - (@name(a.c), @name(a.c)), + @test_all MatrixFields.matrix_diagonal_keys(m_set4_no_tree) == + matrix_keys_no_tree( + (@name(foo.value), @name(foo.value)), + (@name(a.c.:(1)), @name(a.c.:(1))), + (@name(a.c.:(3)), @name(a.c.:(3))), ) end end From bf6c14ae050b685baaa0c9c926426f8b0e01712b Mon Sep 17 00:00:00 2001 From: Dennis Yatunin Date: Fri, 27 Oct 2023 01:11:25 -0700 Subject: [PATCH 3/3] Simplify set complements --- src/MatrixFields/field_name.jl | 8 +- src/MatrixFields/field_name_dict.jl | 2 +- src/MatrixFields/field_name_set.jl | 221 ++++++++++++++----------- src/MatrixFields/unrolled_functions.jl | 78 ++++----- test/MatrixFields/field_names.jl | 167 +++++++++---------- 5 files changed, 236 insertions(+), 240 deletions(-) diff --git a/src/MatrixFields/field_name.jl b/src/MatrixFields/field_name.jl index a389e5e816..7f159339c7 100644 --- a/src/MatrixFields/field_name.jl +++ b/src/MatrixFields/field_name.jl @@ -73,7 +73,8 @@ is_child_name( ::FieldName{parent_name_chain}, ) where {child_name_chain, parent_name_chain} = length(child_name_chain) >= length(parent_name_chain) && - child_name_chain[1:length(parent_name_chain)] == parent_name_chain + unrolled_take(child_name_chain, Val(length(parent_name_chain))) == + parent_name_chain is_overlapping_name(name1, name2) = is_child_name(name1, name2) || is_child_name(name2, name1) @@ -83,8 +84,9 @@ extract_internal_name( parent_name::FieldName{parent_name_chain}, ) where {child_name_chain, parent_name_chain} = is_child_name(child_name, parent_name) ? - FieldName(child_name_chain[(length(parent_name_chain) + 1):end]...) : - error("$child_name is not a child name of $parent_name") + FieldName( + unrolled_drop(child_name_chain, Val(length(parent_name_chain)))..., + ) : error("$child_name is not a child name of $parent_name") append_internal_name( ::FieldName{name_chain}, diff --git a/src/MatrixFields/field_name_dict.jl b/src/MatrixFields/field_name_dict.jl index 4a31cc56ac..f540f1a0e9 100644 --- a/src/MatrixFields/field_name_dict.jl +++ b/src/MatrixFields/field_name_dict.jl @@ -248,7 +248,7 @@ Base.Broadcast.broadcasted( arg3, args..., ) = - unrolled_foldl((arg1, arg2, arg3, args...)) do arg1′, arg2′ + foldl((arg1, arg2, arg3, args...)) do arg1′, arg2′ Base.Broadcast.broadcasted(f, arg1′, arg2′) end diff --git a/src/MatrixFields/field_name_set.jl b/src/MatrixFields/field_name_set.jl index 3dd16e52de..6dc35224a4 100644 --- a/src/MatrixFields/field_name_set.jl +++ b/src/MatrixFields/field_name_set.jl @@ -106,7 +106,7 @@ end set_string(set) = values_string(set.values) -set_complement(set) = setdiff(universal_set(eltype(set)), set) +set_complement(set) = setdiff(universal_set(set.name_tree, eltype(set)), set) is_subset_that_covers_set(set1, set2) = issubset(set1, set2) && isempty(setdiff(set2, set1)) @@ -116,9 +116,9 @@ function corresponding_matrix_keys(set::FieldVectorKeys) return FieldMatrixKeys(result_values, set.name_tree) end -function cartesian_product(set1::FieldVectorKeys, set2::FieldVectorKeys) - name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - result_values = unrolled_product((set1.values, set2.values)) +function cartesian_product(row_set::FieldVectorKeys, col_set::FieldVectorKeys) + name_tree = combine_name_trees(row_set.name_tree, col_set.name_tree) + result_values = unrolled_product(row_set.values, col_set.values) return FieldMatrixKeys(result_values, name_tree) end @@ -253,10 +253,10 @@ check_values(values, name_tree) = (isnothing(name_tree) || is_valid_value(value, name_tree)) || error( "Invalid FieldNameSet value: $value is incompatible with name_tree", ) - duplicate_values = unrolled_filter(isequal(value), values) - length(duplicate_values) == 1 || error( - "Duplicate FieldNameSet values: $(length(duplicate_values)) copies \ - of $value have been passed to a FieldNameSet constructor", + num_duplicate_values = length(unrolled_filter(isequal(value), values)) + num_duplicate_values == 1 || error( + "Duplicate FieldNameSet values: $num_duplicate_values copies of \ + $value have been passed to a FieldNameSet constructor", ) overlapping_values = unrolled_filter(values) do value′ value′ != value && is_overlapping_value(value, value′) @@ -278,8 +278,16 @@ combine_name_trees(name_tree1, name_tree2) = error("Mismatched FieldNameTrees: The ability to combine different \ FieldNameTrees has not been implemented") -universal_set(::Type{FieldName}) = FieldVectorKeys((@name(),)) -universal_set(::Type{FieldNamePair}) = FieldMatrixKeys(((@name(), @name()),)) +universal_set(::Nothing, ::Type{FieldName}) = error( + "Missing FieldNameTree: Cannot compute complement of FieldNameSet without \ + a FieldNameTree", +) +universal_set(name_tree, ::Type{FieldName}) = + FieldVectorKeys(child_names(@name(), name_tree), name_tree) +function universal_set(name_tree, ::Type{FieldNamePair}) + row_set = universal_set(name_tree, FieldName) + return cartesian_product(row_set, row_set) +end is_valid_value(name::FieldName, name_tree) = is_valid_name(name, name_tree) is_valid_value(name_pair::FieldNamePair, name_tree) = @@ -307,115 +315,128 @@ is_value_in_set(value, values, name_tree) = ) : is_valid_value(value, name_tree) ) -remove_duplicates_and_overlaps(values, name_tree) = - union_values((), values, name_tree) - -function union_values(values, new_values, name_tree) - isempty(new_values) && return values - - new_value = new_values[1] - more_new_values = new_values[2:end] - unrolled_in(new_value, values) && - return union_values(values, more_new_values, name_tree) - - overlapping_values, non_overlapping_values = unrolled_split(values) do value - is_overlapping_value(new_value, value) +function remove_duplicates_and_overlaps(values, name_tree) + unique_values = unrolled_unique(values) + overlapping_values, non_overlapping_values = + unrolled_split(unique_values) do value + unrolled_any(unique_values) do value′ + value != value′ && is_overlapping_value(value, value′) + end + end + isempty(overlapping_values) && return unique_values + isnothing(name_tree) && + error("Missing FieldNameTree: Cannot eliminate overlaps among \ + $(values_string(overlapping_values)) without a FieldNameTree") + expanded_overlapping_values = unrolled_flatmap(overlapping_values) do value + values_overlapping_with_value = + unrolled_filter(overlapping_values) do value′ + value != value′ && is_overlapping_value(value, value′) + end + expand_child_values(value, values_overlapping_with_value, name_tree) end - isempty(overlapping_values) && - return union_values((values..., new_value), more_new_values, name_tree) + no_longer_overlapping_values = + remove_duplicates_and_overlaps(expanded_overlapping_values, name_tree) + return (non_overlapping_values..., no_longer_overlapping_values...) +end +# The function union_values(values1, values2, name_tree) generates the same +# result as remove_duplicates_and_overlaps((values1..., values2...), name_tree), +# but it is slightly more efficient (and faster to compile) because it makes use +# of the fact that values1 == remove_duplicates_and_overlaps(values1, name_tree) +# and values2 == remove_duplicates_and_overlaps(values2, name_tree). +function union_values(values1, values2, name_tree) + unique_values2 = + unrolled_filter(value2 -> !unrolled_in(value2, values1), values2) + overlapping_values1, non_overlapping_values1 = + unrolled_split(values1) do value1 + unrolled_any(unique_values2) do value2 + is_overlapping_value(value1, value2) + end + end + isempty(overlapping_values1) && return (values1..., unique_values2...) + overlapping_values2, non_overlapping_values2 = + unrolled_split(unique_values2) do value2 + unrolled_any(values1) do value1 + is_overlapping_value(value1, value2) + end + end isnothing(name_tree) && error( - "Missing FieldNameTree: Cannot eliminate overlaps between $new_value \ - and $(values_string(overlapping_values)) without a FieldNameTree", + "Missing FieldNameTree: Cannot eliminate overlaps between \ + $overlapping_values1 and $overlapping_values2 without a FieldNameTree", ) - - overlapping_values_that_are_children_of_value, other_overlapping_values = - unrolled_split(overlapping_values) do value - is_child_value(value, new_value) - end - possible_union_values = if isempty(other_overlapping_values) - possible_children_of_value = available_sets_of_child_values( - new_value, - overlapping_values, - name_tree, - ) - recursively_unrolled_map( - possible_children_of_value, - ) do children_of_value - union_values( - values, - (children_of_value..., more_new_values...), - name_tree, - ) + expanded_overlapping_values1 = + unrolled_flatmap(overlapping_values1) do value1 + values2_overlapping_value1 = + unrolled_filter(overlapping_values2) do value2 + is_overlapping_value(value1, value2) + end + expand_child_values(value1, values2_overlapping_value1, name_tree) end - else - possible_children_of_other_overlapping_values = unrolled_map( - unrolled_flatten, - unrolled_prodmap(other_overlapping_values) do value - available_sets_of_child_values(value, (new_value,), name_tree) - end, - ) - recursively_unrolled_map( - possible_children_of_other_overlapping_values, - ) do children_of_other_overlapping_values - values_and_children_of_values = ( - non_overlapping_values..., - overlapping_values_that_are_children_of_value..., - children_of_other_overlapping_values..., - ) - union_values(values_and_children_of_values, new_values, name_tree) + expanded_overlapping_values2 = + unrolled_flatmap(overlapping_values2) do value2 + values1_overlapping_value2 = + unrolled_filter(overlapping_values1) do value1 + is_overlapping_value(value1, value2) + end + expand_child_values(value2, values1_overlapping_value2, name_tree) end - end - return unrolled_argmin(length, possible_union_values) + union_of_overlapping_values = union_values( + expanded_overlapping_values1, + expanded_overlapping_values2, + name_tree, + ) + return ( + non_overlapping_values1..., + non_overlapping_values2..., + union_of_overlapping_values..., + ) end -available_sets_of_child_values(name::FieldName, _, name_tree) = - (child_names(name, name_tree),) -function available_sets_of_child_values( +expand_child_values(name::FieldName, overlapping_names, name_tree) = + unrolled_all(overlapping_names) do name′ + name′ != name && is_child_name(name′, name) + end ? child_names(name, name_tree) : (name,) +function expand_child_values( name_pair::FieldNamePair, overlapping_name_pairs, name_tree, ) row_name, col_name = name_pair - row_children_needed = unrolled_any(overlapping_name_pairs) do name_pair′ - name_pair′[1] != row_name && is_child_name(name_pair′[1], row_name) - end - col_children_needed = unrolled_any(overlapping_name_pairs) do name_pair′ - name_pair′[2] != col_name && is_child_name(name_pair′[2], col_name) - end - row_children = - row_children_needed ? - unrolled_map(child_names(row_name, name_tree)) do child_of_row_name - (child_of_row_name, col_name) - end : nothing - col_children = - col_children_needed ? - unrolled_map(child_names(col_name, name_tree)) do child_of_col_name - (row_name, child_of_col_name) - end : nothing - n_row_children = row_children_needed ? length(row_children) : 0 - n_col_children = col_children_needed ? length(col_children) : 0 - # We are guaranteed that either n_row_children > 0 or n_col_children > 0. - return if n_row_children == n_col_children == 1 - only_child_of_name_pair = (row_children[1][1], col_children[1][2]) - ((only_child_of_name_pair,),) - elseif n_row_children > 0 && n_col_children <= 1 - (row_children,) - elseif n_col_children > 0 && n_row_children <= 1 - (col_children,) - else - @assert n_row_children > 1 && n_col_children > 1 - # If multiple row and column children are available, we might get - # different results depending on whether we expand the row or the column - # first, so we need to try both and pick the one that corresponds to the - # result with the shortest length. - (row_children, col_children) + row_name_children_needed = + unrolled_all(overlapping_name_pairs) do name_pair′ + name_pair′[1] != row_name && is_child_name(name_pair′[1], row_name) + end + col_name_children_needed = + unrolled_all(overlapping_name_pairs) do name_pair′ + name_pair′[2] != col_name && is_child_name(name_pair′[2], col_name) + end + row_name_children = + row_name_children_needed ? child_names(row_name, name_tree) : () + col_name_children = + col_name_children_needed ? child_names(col_name, name_tree) : () + # Note: We need special cases for when either row_name or col_name only has + # one child name, since automatically expanding that name can generate + # results with more expansions than are necessary. For example, it can lead + # to a situation like issubset(set1, set2) && union(set1, set2) != set2, + # where union(set1, set2) has too many expanded values. + return if length(row_name_children) > 1 && length(col_name_children) > 1 || + length(row_name_children) == 1 && length(col_name_children) == 1 + unrolled_product(row_name_children, col_name_children) + elseif length(row_name_children) > 0 && length(col_name_children) <= 1 + unrolled_product(row_name_children, (col_name,)) + elseif length(row_name_children) <= 1 && length(col_name_children) > 0 + unrolled_product((row_name,), col_name_children) + else # length(row_name_children) == 0 && length(col_name_children) == 0 + (name_pair,) end end # This is required for type-stability as of Julia 1.9. if hasfield(Method, :recursion_relation) dont_limit = (args...) -> true + for m in methods(remove_duplicates_and_overlaps) + m.recursion_relation = dont_limit + end for m in methods(union_values) m.recursion_relation = dont_limit end diff --git a/src/MatrixFields/unrolled_functions.jl b/src/MatrixFields/unrolled_functions.jl index 5442856008..b2a6f564df 100644 --- a/src/MatrixFields/unrolled_functions.jl +++ b/src/MatrixFields/unrolled_functions.jl @@ -1,57 +1,40 @@ -# The following functions are extensions to the generated functions defined in -# Unrolled.jl. +# The following functions are more easily inferrable versions of `values[1:N]` +# and `values[(N + 1):end]`, named after `Iterators.take` and `Iterators.drop`. +# Note that `Base.tail` is roughly equivalent to `unrolled_drop` with `N == 1`. -unrolled_split(f::F, values) where {F} = - unrolled_filter(f, values), unrolled_filter(!f, values) - -function unrolled_findonly(f::F, values) where {F} - filtered_values = unrolled_filter(f, values) - return length(filtered_values) == 1 ? filtered_values[1] : - error("unrolled_findonly requires that exactly 1 value makes f true") -end +unrolled_take(values, ::Val{N}) where {N} = ntuple(i -> values[i], Val(N)) +unrolled_drop(values, ::Val{N}) where {N} = + ntuple(i -> values[N + i], Val(length(values) - N)) -# The implementation of unrolled_reduce in Unrolled.jl makes it roughly the same -# as foldl, but with the order of arguments in every call to f flipped. -unrolled_foldl(f::F, values; init = nothing) where {F} = - if isnothing(init) - isempty(values) ? - error("unrolled_foldl requires init for empty collections of values") : - unrolled_reduce((x, y) -> f(y, x), values[1], values[2:end]) - else - unrolled_reduce((x, y) -> f(y, x), init, values) - end - -function unrolled_argmin(f::F, values) where {F} - values_and_fs = unrolled_map(value -> (value, f(value)), values) - min_value_and_f = - unrolled_foldl(values_and_fs) do (min_value, min_f), (new_value, new_f) - new_f < min_f ? (new_value, new_f) : (min_value, min_f) - end - return min_value_and_f[1] -end +# The following functions build upon the generated functions in Unrolled.jl. +# The first three are alternatives to `Iterators.flatten`, `Iterators.flatmap`, +# and `Iterators.product`. -# This needs to use unrolled_reduce instead of unrolled_foldl in order for -# unrolled_product to be type-stable (otherwise, calling unrolled_flatmap inside -# of unrolled_foldl causes the compiler to hit a recursion limit). unrolled_flatten(values) = - unrolled_reduce((), values) do value, flattened_values - (flattened_values..., value...) - end + unrolled_reduce((tuple1, tuple2) -> (tuple1..., tuple2...), (), values) unrolled_flatmap(f::F, values) where {F} = unrolled_flatten(unrolled_map(f, values)) -unrolled_product(values) = - unrolled_foldl(values; init = ((),)) do product_values, value - unrolled_flatmap(product_values) do sub_values - unrolled_map(value) do sub_value - (sub_values..., sub_value) - end - end +unrolled_product(values1, values2) = + unrolled_flatmap(values1) do value1 + unrolled_map(value2 -> (value1, value2), values2) + end + +unrolled_unique(values) = + unrolled_reduce((), values) do value, unique_values + unrolled_in(value, unique_values) ? unique_values : + (unique_values..., value) end -unrolled_prodmap(f::F, values) where {F} = - unrolled_product(unrolled_map(f, values)) +unrolled_split(f::F, values) where {F} = + (unrolled_filter(f, values), unrolled_filter(value -> !f(value), values)) + +function unrolled_findonly(f::F, values) where {F} + filtered_values = unrolled_filter(f, values) + return length(filtered_values) == 1 ? filtered_values[1] : + error("unrolled_findonly requires that exactly 1 value makes f true") +end # The following functions are recursion-based alternatives to the generated # functions in Unrolled.jl. These should be used instead of their generated @@ -62,10 +45,11 @@ unrolled_prodmap(f::F, values) where {F} = @inline recursively_unrolled_map(f::F, values) where {F} = isempty(values) ? () : - (f(values[1]), recursively_unrolled_map(f, values[2:end])...) + (f(values[1]), recursively_unrolled_map(f, Base.tail(values))...) -recursively_unrolled_any(f::F, values) where {F} = - unrolled_any(identity, recursively_unrolled_map(f, values)) +@inline recursively_unrolled_any(f::F, values) where {F} = + isempty(values) ? false : + f(values[1]) || recursively_unrolled_any(f, Base.tail(values)) # This is required for type-stability as of Julia 1.9. if hasfield(Method, :recursion_relation) diff --git a/test/MatrixFields/field_names.jl b/test/MatrixFields/field_names.jl index 46b56ee069..c18ef0bbbb 100644 --- a/test/MatrixFields/field_names.jl +++ b/test/MatrixFields/field_names.jl @@ -75,9 +75,9 @@ const x = (; foo = Foo(0), a = (; b = 1, c = ((; d = 2), (;), ((), nothing)))) (@name(1), @name(2), @name(3)) end -@testset "FieldNameTree Unit Tests" begin - name_tree = MatrixFields.FieldNameTree(x) +const name_tree = MatrixFields.FieldNameTree(x) +@testset "FieldNameTree Unit Tests" begin @test_all MatrixFields.FieldNameTree(x) == name_tree @test_all MatrixFields.is_valid_name(@name(), name_tree) @@ -104,16 +104,18 @@ end ) end -@testset "FieldNameSet Unit Tests" begin - name_tree = MatrixFields.FieldNameTree(x) - vector_keys(names...) = MatrixFields.FieldVectorKeys(names, name_tree) - matrix_keys(name_pairs...) = - MatrixFields.FieldMatrixKeys(name_pairs, name_tree) +vector_keys(names...) = MatrixFields.FieldVectorKeys(names, name_tree) +matrix_keys(name_pairs...) = MatrixFields.FieldMatrixKeys(name_pairs, name_tree) - vector_keys_no_tree(names...) = MatrixFields.FieldVectorKeys(names) - matrix_keys_no_tree(name_pairs...) = - MatrixFields.FieldMatrixKeys(name_pairs) +vector_keys_no_tree(names...) = MatrixFields.FieldVectorKeys(names) +matrix_keys_no_tree(name_pairs...) = MatrixFields.FieldMatrixKeys(name_pairs) +drop_tree(set::MatrixFields.FieldVectorKeys) = + MatrixFields.FieldVectorKeys(set.values) +drop_tree(set::MatrixFields.FieldMatrixKeys) = + MatrixFields.FieldMatrixKeys(set.values) + +@testset "FieldNameSet Unit Tests" begin @testset "FieldNameSet Constructors" begin @test_throws "Invalid FieldNameSet value" vector_keys( @name(foo.invalid_name), @@ -146,16 +148,11 @@ end end v_set1 = vector_keys(@name(foo), @name(a.c)) - v_set1_no_tree = vector_keys_no_tree(@name(foo), @name(a.c)) m_set1 = matrix_keys((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) - m_set1_no_tree = - matrix_keys_no_tree((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) # Proper subsets of v_set1 and m_set1. v_set2 = vector_keys(@name(foo)) - v_set2_no_tree = vector_keys_no_tree(@name(foo)) m_set2 = matrix_keys((@name(foo), @name(a.c))) - m_set2_no_tree = matrix_keys_no_tree((@name(foo), @name(a.c))) # Subsets that cover v_set1 and m_set1. v_set3 = vector_keys( @@ -164,32 +161,18 @@ end @name(a.c.:(2)), @name(a.c.:(3)), ) - v_set3_no_tree = vector_keys_no_tree( - @name(foo.value), - @name(a.c.:(1)), - @name(a.c.:(2)), - @name(a.c.:(3)), - ) m_set3 = matrix_keys( (@name(foo.value), @name(a.c.:(1))), (@name(foo), @name(a.c.:(2))), (@name(foo), @name(a.c.:(3))), (@name(a.b), @name(foo.value)), ) - m_set3_no_tree = matrix_keys_no_tree( - (@name(foo.value), @name(a.c.:(1))), - (@name(foo), @name(a.c.:(2))), - (@name(foo), @name(a.c.:(3))), - (@name(a.b), @name(foo.value)), - ) # Sets that overlap with v_set1 and m_set1, but are neither subsets nor # supersets of those sets. Some of the values in m_set4 overlap with # those in m_set1, but are neither children nor parents of those values # (this is only possible with matrix keys). v_set4 = vector_keys(@name(a.b), @name(a.c.:(1)), @name(a.c.:(2))) - v_set4_no_tree = - vector_keys_no_tree(@name(a.b), @name(a.c.:(1)), @name(a.c.:(2))) m_set4 = matrix_keys( (@name(), @name(a.c.:(1))), (@name(foo.value), @name(foo)), @@ -197,22 +180,15 @@ end (@name(a), @name(foo.value)), (@name(a.c.:(3)), @name(a.c.:(3))), ) - m_set4_no_tree = matrix_keys_no_tree( - (@name(), @name(a.c.:(1))), - (@name(foo.value), @name(foo)), - (@name(foo.value), @name(a.c.:(2))), - (@name(a), @name(foo.value)), - (@name(a.c.:(3)), @name(a.c.:(3))), - ) @testset "FieldNameSet Basic Operations" begin @test string(v_set1) == "FieldVectorKeys(@name(foo), @name(a.c); )" - @test string(v_set1_no_tree) == + @test string(drop_tree(v_set1)) == "FieldVectorKeys(@name(foo), @name(a.c))" @test string(m_set1) == "FieldMatrixKeys((@name(foo), @name(a.c)), \ (@name(a.b), @name(foo)); )" - @test string(m_set1_no_tree) == "FieldMatrixKeys((@name(foo), \ + @test string(drop_tree(m_set1)) == "FieldMatrixKeys((@name(foo), \ @name(a.c)), (@name(a.b), @name(foo)))" @test_all map(name -> (name, name), v_set1) == @@ -223,12 +199,12 @@ end @test_all isnothing(foreach(name -> (name, name), v_set1)) @test_all isnothing(foreach(name_pair -> name_pair[1], m_set1)) - for set in (v_set1, v_set1_no_tree) + for set in (v_set1, drop_tree(v_set1)) @test_all @name(foo) in set @test_all !(@name(a.b) in set) @test_all !(@name(invalid_name) in set) end - for set in (m_set1, m_set1_no_tree) + for set in (m_set1, drop_tree(m_set1)) @test_all (@name(foo), @name(a.c)) in set @test_all !((@name(foo), @name(a.b)) in set) @test_all !((@name(foo), @name(invalid_name)) in set) @@ -236,62 +212,71 @@ end @test_all @name(foo.value) in v_set1 @test_all !(@name(foo.invalid_name) in v_set1) - @test_throws "FieldNameTree" @name(foo.value) in v_set1_no_tree - @test_throws "FieldNameTree" @name(foo.invalid_name) in v_set1_no_tree + @test_throws "FieldNameTree" @name(foo.value) in drop_tree(v_set1) + @test_throws "FieldNameTree" @name(foo.invalid_name) in + drop_tree(v_set1) @test_all (@name(foo.value), @name(a.c)) in m_set1 @test_all !((@name(foo.invalid_name), @name(a.c)) in m_set1) @test_throws "FieldNameTree" (@name(foo.value), @name(a.c)) in - m_set1_no_tree + drop_tree(m_set1) @test_throws "FieldNameTree" (@name(foo.invalid_name), @name(a.c)) in - m_set1_no_tree + drop_tree(m_set1) + end + @testset "FieldNameSet Complement Sets" begin @test_all MatrixFields.set_complement(v_set1) == vector_keys_no_tree(@name(a.b)) - @test_throws "FieldNameTree" MatrixFields.set_complement(v_set1_no_tree) + @test_all MatrixFields.set_complement(v_set2) == + vector_keys_no_tree(@name(a)) + @test_all MatrixFields.set_complement(v_set3) == + vector_keys_no_tree(@name(a.b)) + @test_all MatrixFields.set_complement(v_set4) == + vector_keys_no_tree(@name(foo), @name(a.c.:(3))) @test_all MatrixFields.set_complement(m_set1) == matrix_keys_no_tree( (@name(foo), @name(foo)), (@name(foo), @name(a.b)), - (@name(a.c), @name()), - (@name(a.b), @name(a)), + (@name(a), @name(a)), + (@name(a.c), @name(foo)), + ) + @test_all MatrixFields.set_complement(m_set2) == matrix_keys_no_tree( + (@name(foo), @name(foo)), + (@name(foo), @name(a.b)), + (@name(a), @name(foo)), + (@name(a), @name(a)), + ) + @test_all MatrixFields.set_complement(m_set3) == matrix_keys_no_tree( + (@name(foo), @name(foo)), + (@name(foo), @name(a.b)), + (@name(a), @name(a)), + (@name(a.c), @name(foo)), ) - @test_throws "FieldNameTree" MatrixFields.set_complement(m_set1_no_tree) - - @test_all MatrixFields.set_complement(v_set4) == - vector_keys_no_tree(@name(foo), @name(a.c.:(3))) - @test_throws "FieldNameTree" MatrixFields.set_complement(v_set4_no_tree) - @test_all MatrixFields.set_complement(m_set4) == matrix_keys_no_tree( - (@name(), @name(a.b)), + (@name(foo), @name(a.b)), (@name(foo), @name(a.c.:(3))), + (@name(a), @name(a.b)), (@name(a), @name(a.c.:(2))), (@name(a.b), @name(a.c.:(3))), (@name(a.c.:(1)), @name(a.c.:(3))), (@name(a.c.:(2)), @name(a.c.:(3))), ) - @test_throws "FieldNameTree" MatrixFields.set_complement(m_set4_no_tree) - end - @testset "FieldNameSet Binary Set Operations" begin - for set1 in (v_set1, v_set1_no_tree, m_set1, m_set1_no_tree) - @test_all set1 == set1 - @test_all issubset(set1, set1) - @test_all is_subset_that_covers_set(set1, set1) - @test_all intersect(set1, set1) == set1 - @test_all union(set1, set1) == set1 - @test_all isempty(setdiff(set1, set1)) + for set in (drop_tree(v_set1), drop_tree(m_set1)) + @test_throws "FieldNameTree" MatrixFields.set_complement(set) end + end + @testset "FieldNameSet Binary Set Operations" begin for (set1, set2) in ( (v_set1, v_set2), - (v_set1, v_set2_no_tree), - (v_set1_no_tree, v_set2), - (v_set1_no_tree, v_set2_no_tree), + (v_set1, drop_tree(v_set2)), + (drop_tree(v_set1), v_set2), + (drop_tree(v_set1), drop_tree(v_set2)), (m_set1, m_set2), - (m_set1, m_set2_no_tree), - (m_set1_no_tree, m_set2), - (m_set1_no_tree, m_set2_no_tree), + (m_set1, drop_tree(m_set2)), + (drop_tree(m_set1), m_set2), + (drop_tree(m_set1), drop_tree(m_set2)), ) @test_all set1 != set2 @test_all !issubset(set1, set2) && issubset(set2, set1) @@ -310,11 +295,11 @@ end for (set1, set3) in ( (v_set1, v_set3), - (v_set1, v_set3_no_tree), - (v_set1_no_tree, v_set3), + (v_set1, drop_tree(v_set3)), + (drop_tree(v_set1), v_set3), (m_set1, m_set3), - (m_set1, m_set3_no_tree), - (m_set1_no_tree, m_set3), + (m_set1, drop_tree(m_set3)), + (drop_tree(m_set1), m_set3), ) @test_all set1 != set3 @test_all !issubset(set1, set3) && issubset(set3, set1) @@ -326,8 +311,10 @@ end isempty(setdiff(set3, set1)) end - for (set1, set3) in - ((v_set1_no_tree, v_set3_no_tree), (m_set1_no_tree, m_set3_no_tree)) + for (set1, set3) in ( + (drop_tree(v_set1), drop_tree(v_set3)), + (drop_tree(m_set1), drop_tree(m_set3)), + ) @test_all set1 != set3 @test_all !issubset(set1, set3) @test_all !is_subset_that_covers_set(set1, set3) @@ -343,11 +330,11 @@ end for (set1, set4) in ( (v_set1, v_set4), - (v_set1, v_set4_no_tree), - (v_set1_no_tree, v_set4), + (v_set1, drop_tree(v_set4)), + (drop_tree(v_set1), v_set4), (m_set1, m_set4), - (m_set1, m_set4_no_tree), - (m_set1_no_tree, m_set4), + (m_set1, drop_tree(m_set4)), + (drop_tree(m_set1), m_set4), ) @test_all set1 != set4 @test_all !issubset(set1, set4) && !issubset(set4, set1) @@ -400,8 +387,10 @@ end end end - for (set1, set4) in - ((v_set1_no_tree, v_set4_no_tree), (m_set1_no_tree, m_set4_no_tree)) + for (set1, set4) in ( + (drop_tree(v_set1), drop_tree(v_set4)), + (drop_tree(m_set1), drop_tree(m_set4)), + ) @test_all set1 != set4 @test_all !issubset(set1, set4) && !issubset(set4, set1) @test_all !is_subset_that_covers_set(set1, set4) && @@ -607,15 +596,15 @@ end # With one exception, none of the following operations require a # FieldNameTree. - @test_all MatrixFields.corresponding_matrix_keys(v_set1_no_tree) == + @test_all MatrixFields.corresponding_matrix_keys(drop_tree(v_set1)) == matrix_keys_no_tree( (@name(foo), @name(foo)), (@name(a.c), @name(a.c)), ) @test_all MatrixFields.cartesian_product( - v_set1_no_tree, - v_set4_no_tree, + drop_tree(v_set1), + drop_tree(v_set4), ) == matrix_keys_no_tree( (@name(foo), @name(a.b)), (@name(foo), @name(a.c.:(1))), @@ -625,7 +614,7 @@ end (@name(a.c), @name(a.c.:(2))), ) - @test_all MatrixFields.matrix_row_keys(m_set1_no_tree) == + @test_all MatrixFields.matrix_row_keys(drop_tree(m_set1)) == vector_keys_no_tree(@name(foo), @name(a.b)) @test_all MatrixFields.matrix_row_keys(m_set4) == vector_keys_no_tree( @@ -636,10 +625,10 @@ end @name(a.c.:(3)) ) @test_throws "FieldNameTree" MatrixFields.matrix_row_keys( - m_set4_no_tree, + drop_tree(m_set4), ) - @test_all MatrixFields.matrix_off_diagonal_keys(m_set4_no_tree) == + @test_all MatrixFields.matrix_off_diagonal_keys(drop_tree(m_set4)) == matrix_keys_no_tree( (@name(), @name(a.c.:(1))), (@name(foo.value), @name(foo)), @@ -647,7 +636,7 @@ end (@name(a), @name(foo.value)), ) - @test_all MatrixFields.matrix_diagonal_keys(m_set4_no_tree) == + @test_all MatrixFields.matrix_diagonal_keys(drop_tree(m_set4)) == matrix_keys_no_tree( (@name(foo.value), @name(foo.value)), (@name(a.c.:(1)), @name(a.c.:(1))),