diff --git a/Project.toml b/Project.toml index 18acf21e44..413f9cb8f4 100644 --- a/Project.toml +++ b/Project.toml @@ -28,6 +28,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" +Unrolled = "9602ed7d-8fef-5bc8-8597-8f21381861e8" [compat] Adapt = "3" @@ -50,6 +51,7 @@ RootSolvers = "0.3, 0.4" Static = "0.4, 0.5, 0.6, 0.7, 0.8" StaticArrays = "1" UnPack = "1" +Unrolled = "0.1" julia = "1.8" [extras] diff --git a/benchmarks/bickleyjet/Manifest.toml b/benchmarks/bickleyjet/Manifest.toml index fc0a8268ad..48d90b1cd0 100644 --- a/benchmarks/bickleyjet/Manifest.toml +++ b/benchmarks/bickleyjet/Manifest.toml @@ -157,7 +157,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.5.5" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"] path = "../.." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" version = "0.10.55" @@ -1238,6 +1238,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd" uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" version = "1.6.3" +[[deps.Unrolled]] +deps = ["MacroTools"] +git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" +uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" +version = "0.1.5" + [[deps.UnsafeAtomics]] git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 622c3d83e4..cf667d45da 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -274,7 +274,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.5.5" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"] path = ".." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" version = "0.10.55" @@ -2492,6 +2492,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd" uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" version = "1.6.3" +[[deps.Unrolled]] +deps = ["MacroTools"] +git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" +uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" +version = "0.1.5" + [[deps.UnsafeAtomics]] git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" diff --git a/examples/Manifest.toml b/examples/Manifest.toml index 56b04e6b5d..b809ee2c27 100644 --- a/examples/Manifest.toml +++ b/examples/Manifest.toml @@ -238,7 +238,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.5.5" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"] path = ".." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" version = "0.10.55" @@ -2031,6 +2031,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd" uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" version = "1.6.3" +[[deps.Unrolled]] +deps = ["MacroTools"] +git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" +uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" +version = "0.1.5" + [[deps.UnsafeAtomics]] git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" diff --git a/perf/Manifest.toml b/perf/Manifest.toml index 89288bb4d7..3aea42a040 100644 --- a/perf/Manifest.toml +++ b/perf/Manifest.toml @@ -217,7 +217,7 @@ uuid = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d" version = "0.5.5" [[deps.ClimaCore]] -deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack"] +deps = ["Adapt", "BandedMatrices", "BlockArrays", "CUDA", "ClimaComms", "CubedSphere", "DataStructures", "DocStringExtensions", "ForwardDiff", "GaussQuadrature", "GilbertCurves", "HDF5", "InteractiveUtils", "IntervalSets", "LinearAlgebra", "PkgVersion", "RecursiveArrayTools", "Requires", "RootSolvers", "SparseArrays", "Static", "StaticArrays", "Statistics", "UnPack", "Unrolled"] path = ".." uuid = "d414da3d-4745-48bb-8d80-42e94e092884" version = "0.10.55" @@ -2077,6 +2077,12 @@ git-tree-sha1 = "e2d817cc500e960fdbafcf988ac8436ba3208bfd" uuid = "45397f5d-5981-4c77-b2b3-fc36d6e9b728" version = "1.6.3" +[[deps.Unrolled]] +deps = ["MacroTools"] +git-tree-sha1 = "6cc9d682755680e0f0be87c56392b7651efc2c7b" +uuid = "9602ed7d-8fef-5bc8-8597-8f21381861e8" +version = "0.1.5" + [[deps.UnsafeAtomics]] git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" diff --git a/src/MatrixFields/MatrixFields.jl b/src/MatrixFields/MatrixFields.jl index 957c75de91..0f59f8d9cd 100644 --- a/src/MatrixFields/MatrixFields.jl +++ b/src/MatrixFields/MatrixFields.jl @@ -49,6 +49,9 @@ import StaticArrays: SMatrix, SVector import BandedMatrices: BandedMatrix, band, _BandedMatrix import ClimaComms +import Unrolled: unrolled_foreach, unrolled_map, unrolled_reduce +import Unrolled: unrolled_in, unrolled_any, unrolled_all, unrolled_filter + import ..Utilities: PlusHalf, half import ..RecursiveApply: rmap, rmaptype, rpromote_type, rzero, rconvert, radd, rsub, rmul, rdiv diff --git a/src/MatrixFields/field_name.jl b/src/MatrixFields/field_name.jl index f027b1c1e1..7f159339c7 100644 --- a/src/MatrixFields/field_name.jl +++ b/src/MatrixFields/field_name.jl @@ -73,9 +73,10 @@ is_child_name( ::FieldName{parent_name_chain}, ) where {child_name_chain, parent_name_chain} = length(child_name_chain) >= length(parent_name_chain) && - child_name_chain[1:length(parent_name_chain)] == parent_name_chain + unrolled_take(child_name_chain, Val(length(parent_name_chain))) == + parent_name_chain -names_are_overlapping(name1, name2) = +is_overlapping_name(name1, name2) = is_child_name(name1, name2) || is_child_name(name2, name1) extract_internal_name( @@ -83,8 +84,9 @@ extract_internal_name( parent_name::FieldName{parent_name_chain}, ) where {child_name_chain, parent_name_chain} = is_child_name(child_name, parent_name) ? - FieldName(child_name_chain[(length(parent_name_chain) + 1):end]...) : - error("$child_name is not a child name of $parent_name") + FieldName( + unrolled_drop(child_name_chain, Val(length(parent_name_chain)))..., + ) : error("$child_name is not a child name of $parent_name") append_internal_name( ::FieldName{name_chain}, @@ -118,41 +120,41 @@ struct FieldNameTreeNode{V <: FieldName, S <: NTuple{<:Any, FieldNameTree}} <: subtrees::S end -FieldNameTree(x) = make_subtree_at_name(x, @name()) -function make_subtree_at_name(x, name) +FieldNameTree(x) = subtree_at_name(x, @name()) +function subtree_at_name(x, name) internal_names = top_level_names(get_field(x, name)) - isempty(internal_names) && return FieldNameTreeLeaf(name) - subsubtrees = unrolled_map(internal_names) do internal_name - make_subtree_at_name(x, append_internal_name(name, internal_name)) + return if isempty(internal_names) + FieldNameTreeLeaf(name) + else + subsubtrees_at_name = + recursively_unrolled_map(internal_names) do internal_name + subtree_at_name(x, append_internal_name(name, internal_name)) + end + FieldNameTreeNode(name, subsubtrees_at_name) end - return FieldNameTreeNode(name, subsubtrees) end -is_valid_name(name, tree::FieldNameTreeLeaf) = name == tree.name -is_valid_name(name, tree::FieldNameTreeNode) = +is_valid_name(name, tree) = name == tree.name || - is_child_name(name, tree.name) && - unrolled_any(subtree -> is_valid_name(name, subtree), tree.subtrees) + tree isa FieldNameTreeNode && + recursively_unrolled_any(tree.subtrees) do subtree + is_valid_name(name, subtree) + end function child_names(name, tree) + is_valid_name(name, tree) || error("$name is not a valid name") subtree = get_subtree_at_name(name, tree) - subtree isa FieldNameTreeNode || - error("FieldNameTree does not contain any child names for $name") + subtree isa FieldNameTreeNode || error("$name does not have child names") return unrolled_map(subsubtree -> subsubtree.name, subtree.subtrees) end -get_subtree_at_name(name, tree::FieldNameTreeLeaf) = - name == tree.name ? tree : - error("FieldNameTree does not contain the name $name") -get_subtree_at_name(name, tree::FieldNameTreeNode) = +get_subtree_at_name(name, tree) = if name == tree.name tree - elseif is_valid_name(name, tree) - subtree_that_contains_name = unrolled_findonly(tree.subtrees) do subtree - is_child_name(name, subtree.name) - end - get_subtree_at_name(name, subtree_that_contains_name) else - error("FieldNameTree does not contain the name $name") + subtree = unrolled_findonly(tree.subtrees) do subtree + is_valid_name(name, subtree) + end + get_subtree_at_name(name, subtree) end ################################################################################ @@ -175,7 +177,7 @@ if hasfield(Method, :recursion_relation) for m in methods(wrapped_prop_names) m.recursion_relation = dont_limit end - for m in methods(make_subtree_at_name) + for m in methods(subtree_at_name) m.recursion_relation = dont_limit end for m in methods(is_valid_name) diff --git a/src/MatrixFields/field_name_dict.jl b/src/MatrixFields/field_name_dict.jl index ed26b40822..f540f1a0e9 100644 --- a/src/MatrixFields/field_name_dict.jl +++ b/src/MatrixFields/field_name_dict.jl @@ -66,8 +66,8 @@ const FieldMatrixBroadcasted = FieldNameDict{ dict_type(::FieldNameDict{T1, T2}) where {T1, T2} = FieldNameDict{T1, T2} function Base.show(io::IO, dict::FieldNameDict) - strings = map((key, value) -> " $key => $value", pairs(dict)) - print(io, "$(dict_type(dict))($(join(strings, ",\n")))") + strings = map(pair -> "\n$(pair[1]) => $(pair[2])", pairs(dict)) + print(io, "$(dict_type(dict))($(join(strings, ","))\n)") end Base.keys(dict::FieldNameDict) = dict.keys @@ -75,9 +75,7 @@ Base.keys(dict::FieldNameDict) = dict.keys Base.values(dict::FieldNameDict) = dict.entries Base.pairs(dict::FieldNameDict) = - unrolled_map(unrolled_zip(keys(dict).values, values(dict))) do key_entry_tup - key_entry_tup[1] => key_entry_tup[2] - end + unrolled_map((key, value) -> key => value, keys(dict).values, values(dict)) Base.length(dict::FieldNameDict) = length(keys(dict)) @@ -118,18 +116,19 @@ function get_internal_entry( # See note above matrix_product_keys in field_name_set.jl for more details. T = eltype(eltype(entry)) if name_pair == (@name(), @name()) - # multiplication case 1, either argument entry - elseif broadcasted_has_field(T, name_pair[1]) && name_pair[2] == @name() + elseif name_pair[1] == name_pair[2] + # multiplication case 3 or 4, first argument + @assert T <: SingleValue && !broadcasted_has_field(T, name_pair[1]) + entry + elseif name_pair[2] == @name() # multiplication case 2 or 4, second argument + @assert broadcasted_has_field(T, name_pair[1]) Base.broadcasted(entry) do matrix_row map(matrix_row) do matrix_row_entry broadcasted_get_field(matrix_row_entry, name_pair[1]) end end # Note: This assumes that the entry is in a FieldMatrixBroadcasted. - elseif T <: SingleValue && name_pair[1] == name_pair[2] - # multiplication case 3 or 4, first argument - entry else unsupported_internal_entry_error(entry, name_pair) end @@ -249,7 +248,7 @@ Base.Broadcast.broadcasted( arg3, args..., ) = - unrolled_foldl((arg1, arg2, arg3, args...)) do arg1′, arg2′ + foldl((arg1, arg2, arg3, args...)) do arg1′, arg2′ Base.Broadcast.broadcasted(f, arg1′, arg2′) end diff --git a/src/MatrixFields/field_name_set.jl b/src/MatrixFields/field_name_set.jl index 1f28960bff..6dc35224a4 100644 --- a/src/MatrixFields/field_name_set.jl +++ b/src/MatrixFields/field_name_set.jl @@ -79,63 +79,53 @@ end Base.:(==)(set1::FieldNameSet, set2::FieldNameSet) = issubset(set1, set2) && issubset(set2, set1) -function Base.intersect(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} +function Base.union(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - values1′, values2′ = set1.values, set2.values - values1, values2 = non_overlapping_values(values1′, values2′, name_tree) - result_values = unrolled_filter(values2) do value - unrolled_any(isequal(value), values1) - end + result_values = union_values(set1.values, set2.values, name_tree) return FieldNameSet{T}(result_values, name_tree) end -function Base.union(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} +function Base.intersect(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - values1′, values2′ = set1.values, set2.values - values1, values2 = non_overlapping_values(values1′, values2′, name_tree) - values2_minus_values1 = unrolled_filter(values2) do value - !unrolled_any(isequal(value), values1) + all_values = union_values(set1.values, set2.values, name_tree) + result_values = unrolled_filter(all_values) do value + is_value_in_set(value, set1.values, name_tree) && + is_value_in_set(value, set2.values, name_tree) end - result_values = (values1..., values2_minus_values1...) return FieldNameSet{T}(result_values, name_tree) end function Base.setdiff(set1::FieldNameSet{T}, set2::FieldNameSet{T}) where {T} name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - set2_complement_values = set_complement_values(T, set2.values, name_tree) - set2_complement = FieldNameSet{T}(set2_complement_values, name_tree) - return intersect(set1, set2_complement) + all_values = union_values(set1.values, set2.values, name_tree) + result_values = unrolled_filter(all_values) do value + !is_value_in_set(value, set2.values, name_tree) + end + return FieldNameSet{T}(result_values, name_tree) end -set_string(set) = - length(set) == 2 ? join(set.values, " and ") : - join(set.values, ", ", ", and ") +set_string(set) = values_string(set.values) + +set_complement(set) = setdiff(universal_set(set.name_tree, eltype(set)), set) is_subset_that_covers_set(set1, set2) = issubset(set1, set2) && isempty(setdiff(set2, set1)) -function set_complement(set::FieldNameSet{T}) where {T} - result_values = set_complement_values(T, set.values, set.name_tree) - return FieldNameSet{T}(result_values, set.name_tree) -end - function corresponding_matrix_keys(set::FieldVectorKeys) result_values = unrolled_map(name -> (name, name), set.values) return FieldMatrixKeys(result_values, set.name_tree) end -function cartesian_product(set1::FieldVectorKeys, set2::FieldVectorKeys) - name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - result_values = unrolled_mapflatten(set1.values) do row_name - unrolled_map(col_name -> (row_name, col_name), set2.values) - end +function cartesian_product(row_set::FieldVectorKeys, col_set::FieldVectorKeys) + name_tree = combine_name_trees(row_set.name_tree, col_set.name_tree) + result_values = unrolled_product(row_set.values, col_set.values) return FieldMatrixKeys(result_values, name_tree) end function matrix_row_keys(set::FieldMatrixKeys) result_values′ = unrolled_map(name_pair -> name_pair[1], set.values) result_values = - unique_and_non_overlapping_values(result_values′, set.name_tree) + remove_duplicates_and_overlaps(result_values′, set.name_tree) return FieldVectorKeys(result_values, set.name_tree) end @@ -147,12 +137,16 @@ end function matrix_diagonal_keys(set::FieldMatrixKeys) result_values′ = unrolled_filter(set.values) do name_pair - names_are_overlapping(name_pair[1], name_pair[2]) + is_overlapping_name(name_pair[1], name_pair[2]) end result_values = unrolled_map(result_values′) do name_pair - name_pair[1] == name_pair[2] ? name_pair : - is_child_value(name_pair[1], name_pair[2]) ? - (name_pair[1], name_pair[1]) : (name_pair[2], name_pair[2]) + if name_pair[1] == name_pair[2] + name_pair + elseif is_child_value(name_pair[1], name_pair[2]) + (name_pair[1], name_pair[1]) + else + (name_pair[2], name_pair[2]) + end end return FieldMatrixKeys(result_values, set.name_tree) end @@ -181,10 +175,10 @@ because we cannot extract internal columns from FieldNameDict entries. =# function matrix_product_keys(set1::FieldMatrixKeys, set2::FieldNameSet) name_tree = combine_name_trees(set1.name_tree, set2.name_tree) - result_values′ = unrolled_mapflatten(set1.values) do name_pair1 + result_values′ = unrolled_flatmap(set1.values) do name_pair1 overlapping_set2_values = unrolled_filter(set2.values) do value2 row_name2 = eltype(set2) <: FieldName ? value2 : value2[1] - names_are_overlapping(name_pair1[2], row_name2) + is_overlapping_name(name_pair1[2], row_name2) end unrolled_map(overlapping_set2_values) do value2 row_name2 = eltype(set2) <: FieldName ? value2 : value2[1] @@ -200,8 +194,8 @@ function matrix_product_keys(set1::FieldMatrixKeys, set2::FieldNameSet) end end end - result_values = unique_and_non_overlapping_values(result_values′, name_tree) - # Note: the modification of result_values may trigger multiplication case 4. + # Note: Modifying result_values here can trigger multiplication case 4. + result_values = remove_duplicates_and_overlaps(result_values′, name_tree) return FieldNameSet{eltype(set2)}(result_values, name_tree) end function summand_names_for_matrix_product( @@ -212,14 +206,14 @@ function summand_names_for_matrix_product( product_row_name = eltype(set2) <: FieldName ? product_key : product_key[1] name_tree = combine_name_trees(set1.name_tree, set2.name_tree) overlapping_set1_values = unrolled_filter(set1.values) do name_pair1 - names_are_overlapping(product_row_name, name_pair1[1]) + is_overlapping_name(product_row_name, name_pair1[1]) end - result_values = unrolled_mapflatten(overlapping_set1_values) do name_pair1 + result_values = unrolled_flatmap(overlapping_set1_values) do name_pair1 overlapping_set2_values = unrolled_filter(set2.values) do value2 row_name2 = eltype(set2) <: FieldName ? value2 : value2[1] - names_are_overlapping(name_pair1[2], row_name2) && ( + is_overlapping_name(name_pair1[2], row_name2) && ( eltype(set2) <: FieldName || - names_are_overlapping(product_key[2], value2[2]) + is_overlapping_name(product_key[2], value2[2]) ) end unrolled_map(overlapping_set2_values) do value2 @@ -259,24 +253,23 @@ check_values(values, name_tree) = (isnothing(name_tree) || is_valid_value(value, name_tree)) || error( "Invalid FieldNameSet value: $value is incompatible with name_tree", ) - duplicate_values = unrolled_filter(isequal(value), values) - length(duplicate_values) == 1 || error( - "Duplicate FieldNameSet values: $(length(duplicate_values)) copies \ - of $value have been passed to a FieldNameSet constructor", + num_duplicate_values = length(unrolled_filter(isequal(value), values)) + num_duplicate_values == 1 || error( + "Duplicate FieldNameSet values: $num_duplicate_values copies of \ + $value have been passed to a FieldNameSet constructor", ) overlapping_values = unrolled_filter(values) do value′ - value != value′ && values_are_overlapping(value, value′) - end - if !isempty(overlapping_values) - overlapping_values_string = - length(overlapping_values) == 2 ? - join(overlapping_values, " or ") : - join(overlapping_values, ", ", ", or ") - error("Overlapping FieldNameSet values: $value cannot be in the \ - same FieldNameSet as $overlapping_values_string") + value′ != value && is_overlapping_value(value, value′) end + isempty(overlapping_values) || error( + "Overlapping FieldNameSet values: $value cannot be in the same \ + FieldNameSet as $(values_string(overlapping_values))", + ) end +values_string(values) = + length(values) == 2 ? join(values, " and ") : join(values, ", ", ", and ") + combine_name_trees(::Nothing, ::Nothing) = nothing combine_name_trees(name_tree1, ::Nothing) = name_tree1 combine_name_trees(::Nothing, name_tree2) = name_tree2 @@ -285,158 +278,166 @@ combine_name_trees(name_tree1, name_tree2) = error("Mismatched FieldNameTrees: The ability to combine different \ FieldNameTrees has not been implemented") +universal_set(::Nothing, ::Type{FieldName}) = error( + "Missing FieldNameTree: Cannot compute complement of FieldNameSet without \ + a FieldNameTree", +) +universal_set(name_tree, ::Type{FieldName}) = + FieldVectorKeys(child_names(@name(), name_tree), name_tree) +function universal_set(name_tree, ::Type{FieldNamePair}) + row_set = universal_set(name_tree, FieldName) + return cartesian_product(row_set, row_set) +end + is_valid_value(name::FieldName, name_tree) = is_valid_name(name, name_tree) is_valid_value(name_pair::FieldNamePair, name_tree) = is_valid_name(name_pair[1], name_tree) && is_valid_name(name_pair[2], name_tree) -values_are_overlapping(name1::FieldName, name2::FieldName) = - names_are_overlapping(name1, name2) -values_are_overlapping(name_pair1::FieldNamePair, name_pair2::FieldNamePair) = - names_are_overlapping(name_pair1[1], name_pair2[1]) && - names_are_overlapping(name_pair1[2], name_pair2[2]) - is_child_value(name1::FieldName, name2::FieldName) = is_child_name(name1, name2) is_child_value(name_pair1::FieldNamePair, name_pair2::FieldNamePair) = is_child_name(name_pair1[1], name_pair2[1]) && is_child_name(name_pair1[2], name_pair2[2]) +is_overlapping_value(name1::FieldName, name2::FieldName) = + is_overlapping_name(name1, name2) +is_overlapping_value(name_pair1::FieldNamePair, name_pair2::FieldNamePair) = + is_overlapping_name(name_pair1[1], name_pair2[1]) && + is_overlapping_name(name_pair1[2], name_pair2[2]) + is_value_in_set(value, values, name_tree) = - if unrolled_any(isequal(value), values) - true - elseif unrolled_any(value′ -> is_child_value(value, value′), values) - isnothing(name_tree) && error( - "Cannot check if $value is in FieldNameSet without a FieldNameTree", - ) - is_valid_value(value, name_tree) - else - false - end + unrolled_in(value, values) || + unrolled_any(value′ -> is_child_value(value, value′), values) && ( + isnothing(name_tree) ? + error( + "Missing FieldNameTree: Cannot check if $value is in FieldNameSet \ + without a FieldNameTree", + ) : is_valid_value(value, name_tree) + ) -function non_overlapping_values(values1, values2, name_tree) - new_values1 = unrolled_mapflatten(values1) do value - value_or_non_overlapping_children(value, values2, name_tree) - end - new_values2 = unrolled_mapflatten(values2) do value - value_or_non_overlapping_children(value, values1, name_tree) - end - if eltype(values1) <: FieldName - new_values1, new_values2 - else - # Repeat the above operation to handle complex matrix key overlaps. - new_values1′ = unrolled_mapflatten(new_values1) do value - value_or_non_overlapping_children(value, new_values2, name_tree) - end - new_values2′ = unrolled_mapflatten(new_values2) do value - value_or_non_overlapping_children(value, new_values1, name_tree) +function remove_duplicates_and_overlaps(values, name_tree) + unique_values = unrolled_unique(values) + overlapping_values, non_overlapping_values = + unrolled_split(unique_values) do value + unrolled_any(unique_values) do value′ + value != value′ && is_overlapping_value(value, value′) + end end - return new_values1′, new_values2′ + isempty(overlapping_values) && return unique_values + isnothing(name_tree) && + error("Missing FieldNameTree: Cannot eliminate overlaps among \ + $(values_string(overlapping_values)) without a FieldNameTree") + expanded_overlapping_values = unrolled_flatmap(overlapping_values) do value + values_overlapping_with_value = + unrolled_filter(overlapping_values) do value′ + value != value′ && is_overlapping_value(value, value′) + end + expand_child_values(value, values_overlapping_with_value, name_tree) end + no_longer_overlapping_values = + remove_duplicates_and_overlaps(expanded_overlapping_values, name_tree) + return (non_overlapping_values..., no_longer_overlapping_values...) end -function unique_and_non_overlapping_values(values, name_tree) - new_values = unrolled_mapflatten(values) do value - value_or_non_overlapping_children(value, values, name_tree) - end - return unrolled_unique(new_values) +# The function union_values(values1, values2, name_tree) generates the same +# result as remove_duplicates_and_overlaps((values1..., values2...), name_tree), +# but it is slightly more efficient (and faster to compile) because it makes use +# of the fact that values1 == remove_duplicates_and_overlaps(values1, name_tree) +# and values2 == remove_duplicates_and_overlaps(values2, name_tree). +function union_values(values1, values2, name_tree) + unique_values2 = + unrolled_filter(value2 -> !unrolled_in(value2, values1), values2) + overlapping_values1, non_overlapping_values1 = + unrolled_split(values1) do value1 + unrolled_any(unique_values2) do value2 + is_overlapping_value(value1, value2) + end + end + isempty(overlapping_values1) && return (values1..., unique_values2...) + overlapping_values2, non_overlapping_values2 = + unrolled_split(unique_values2) do value2 + unrolled_any(values1) do value1 + is_overlapping_value(value1, value2) + end + end + isnothing(name_tree) && error( + "Missing FieldNameTree: Cannot eliminate overlaps between \ + $overlapping_values1 and $overlapping_values2 without a FieldNameTree", + ) + expanded_overlapping_values1 = + unrolled_flatmap(overlapping_values1) do value1 + values2_overlapping_value1 = + unrolled_filter(overlapping_values2) do value2 + is_overlapping_value(value1, value2) + end + expand_child_values(value1, values2_overlapping_value1, name_tree) + end + expanded_overlapping_values2 = + unrolled_flatmap(overlapping_values2) do value2 + values1_overlapping_value2 = + unrolled_filter(overlapping_values1) do value1 + is_overlapping_value(value1, value2) + end + expand_child_values(value2, values1_overlapping_value2, name_tree) + end + union_of_overlapping_values = union_values( + expanded_overlapping_values1, + expanded_overlapping_values2, + name_tree, + ) + return ( + non_overlapping_values1..., + non_overlapping_values2..., + union_of_overlapping_values..., + ) end -function value_or_non_overlapping_children(name::FieldName, names, name_tree) - need_child_names = unrolled_any(names) do name′ - is_child_value(name′, name) && name′ != name - end - need_child_names || return (name,) - isnothing(name_tree) && - error("Cannot compute child names of $name without a FieldNameTree") - return unrolled_mapflatten(child_names(name, name_tree)) do child_name - value_or_non_overlapping_children(child_name, names, name_tree) - end -end -function value_or_non_overlapping_children( +expand_child_values(name::FieldName, overlapping_names, name_tree) = + unrolled_all(overlapping_names) do name′ + name′ != name && is_child_name(name′, name) + end ? child_names(name, name_tree) : (name,) +function expand_child_values( name_pair::FieldNamePair, - name_pairs, + overlapping_name_pairs, name_tree, ) - need_row_child_names = unrolled_any(name_pairs) do name_pair′ - is_child_value(name_pair′, name_pair) && name_pair′[1] != name_pair[1] - end - need_col_child_names = unrolled_any(name_pairs) do name_pair′ - is_child_value(name_pair′, name_pair) && name_pair′[2] != name_pair[2] - end - need_row_child_names || need_col_child_names || return (name_pair,) - isnothing(name_tree) && error( - "Cannot compute child name pairs of $name_pair without a FieldNameTree", - ) - row_name_children = - need_row_child_names ? child_names(name_pair[1], name_tree) : - (name_pair[1],) - col_name_children = - need_col_child_names ? child_names(name_pair[2], name_tree) : - (name_pair[2],) - return unrolled_mapflatten(row_name_children) do row_name_child - unrolled_mapflatten(col_name_children) do col_name_child - child_pair = (row_name_child, col_name_child) - value_or_non_overlapping_children(child_pair, name_pairs, name_tree) + row_name, col_name = name_pair + row_name_children_needed = + unrolled_all(overlapping_name_pairs) do name_pair′ + name_pair′[1] != row_name && is_child_name(name_pair′[1], row_name) end - end -end - -set_complement_values(_, _, ::Nothing) = - error("Cannot compute complement of a FieldNameSet without a FieldNameTree") -set_complement_values(::Type{<:FieldName}, names, name_tree::FieldNameTree) = - complement_values_in_subtree(names, name_tree) -set_complement_values( - ::Type{<:FieldNamePair}, - name_pairs, - name_tree::FieldNameTree, -) = complement_values_in_subtree_pair(name_pairs, (name_tree, name_tree)) - -function complement_values_in_subtree(names, subtree) - name = subtree.name - unrolled_all(name′ -> !is_child_value(name, name′), names) || return () - unrolled_any(name′ -> is_child_value(name′, name), names) || return (name,) - return unrolled_mapflatten(subtree.subtrees) do subsubtree - complement_values_in_subtree(names, subsubtree) - end -end - -function complement_values_in_subtree_pair(name_pairs, subtree_pair) - name_pair = (subtree_pair[1].name, subtree_pair[2].name) - is_name_pair_in_complement = unrolled_all(name_pairs) do name_pair′ - !is_child_value(name_pair, name_pair′) - end - is_name_pair_in_complement || return () - need_row_subsubtrees = unrolled_any(name_pairs) do name_pair′ - is_child_value(name_pair′, name_pair) && name_pair′[1] != name_pair[1] - end - need_col_subsubtrees = unrolled_any(name_pairs) do name_pair′ - is_child_value(name_pair′, name_pair) && name_pair′[2] != name_pair[2] - end - need_row_subsubtrees || need_col_subsubtrees || return (name_pair,) - row_subsubtrees = - need_row_subsubtrees ? subtree_pair[1].subtrees : (subtree_pair[1],) - col_subsubtrees = - need_col_subsubtrees ? subtree_pair[2].subtrees : (subtree_pair[2],) - return unrolled_mapflatten(row_subsubtrees) do row_subsubtree - unrolled_mapflatten(col_subsubtrees) do col_subsubtree - subsubtree_pair = (row_subsubtree, col_subsubtree) - complement_values_in_subtree_pair(name_pairs, subsubtree_pair) + col_name_children_needed = + unrolled_all(overlapping_name_pairs) do name_pair′ + name_pair′[2] != col_name && is_child_name(name_pair′[2], col_name) end + row_name_children = + row_name_children_needed ? child_names(row_name, name_tree) : () + col_name_children = + col_name_children_needed ? child_names(col_name, name_tree) : () + # Note: We need special cases for when either row_name or col_name only has + # one child name, since automatically expanding that name can generate + # results with more expansions than are necessary. For example, it can lead + # to a situation like issubset(set1, set2) && union(set1, set2) != set2, + # where union(set1, set2) has too many expanded values. + return if length(row_name_children) > 1 && length(col_name_children) > 1 || + length(row_name_children) == 1 && length(col_name_children) == 1 + unrolled_product(row_name_children, col_name_children) + elseif length(row_name_children) > 0 && length(col_name_children) <= 1 + unrolled_product(row_name_children, (col_name,)) + elseif length(row_name_children) <= 1 && length(col_name_children) > 0 + unrolled_product((row_name,), col_name_children) + else # length(row_name_children) == 0 && length(col_name_children) == 0 + (name_pair,) end end -################################################################################ - # This is required for type-stability as of Julia 1.9. if hasfield(Method, :recursion_relation) dont_limit = (args...) -> true - for m in methods(value_or_non_overlapping_children) - m.recursion_relation = dont_limit - end - for m in methods(complement_values_in_subtree) + for m in methods(remove_duplicates_and_overlaps) m.recursion_relation = dont_limit end - for m in methods(complement_values_in_subtree_pair) + for m in methods(union_values) m.recursion_relation = dont_limit end end diff --git a/src/MatrixFields/unrolled_functions.jl b/src/MatrixFields/unrolled_functions.jl index 947a45084d..b2a6f564df 100644 --- a/src/MatrixFields/unrolled_functions.jl +++ b/src/MatrixFields/unrolled_functions.jl @@ -1,91 +1,63 @@ -@inline unrolled_zip(values1, values2) = - isempty(values1) || isempty(values2) ? () : - ( - (first(values1), first(values2)), - unrolled_zip(Base.tail(values1), Base.tail(values2))..., - ) +# The following functions are more easily inferrable versions of `values[1:N]` +# and `values[(N + 1):end]`, named after `Iterators.take` and `Iterators.drop`. +# Note that `Base.tail` is roughly equivalent to `unrolled_drop` with `N == 1`. -@inline unrolled_map(f::F, values) where {F} = - isempty(values) ? () : - (f(first(values)), unrolled_map(f, Base.tail(values))...) - -unrolled_foldl(f::F, values) where {F} = - isempty(values) ? - error("unrolled_foldl requires init for an empty collection of values") : - _unrolled_foldl(f, first(values), Base.tail(values)) -unrolled_foldl(f::F, values, init) where {F} = _unrolled_foldl(f, init, values) -@inline _unrolled_foldl(f::F, result, values) where {F} = - isempty(values) ? result : - _unrolled_foldl(f, f(result, first(values)), Base.tail(values)) - -# The @inline annotations are needed to avoid allocations when there are a lot -# of values. - -# Using first and tail instead of [1] and [2:end] restricts us to Tuples, but it -# also results in less compilation time. - -# This is required to make the unrolled functions type-stable, as of Julia 1.9. -if hasfield(Method, :recursion_relation) - dont_limit = (args...) -> true - for m in methods(unrolled_zip) - m.recursion_relation = dont_limit - end - for m in methods(unrolled_map) - m.recursion_relation = dont_limit - end - for m in methods(_unrolled_foldl) - m.recursion_relation = dont_limit - end -end +unrolled_take(values, ::Val{N}) where {N} = ntuple(i -> values[i], Val(N)) +unrolled_drop(values, ::Val{N}) where {N} = + ntuple(i -> values[N + i], Val(length(values) - N)) -################################################################################ +# The following functions build upon the generated functions in Unrolled.jl. +# The first three are alternatives to `Iterators.flatten`, `Iterators.flatmap`, +# and `Iterators.product`. -unrolled_foreach(f::F, values) where {F} = (unrolled_map(f, values); nothing) - -unrolled_any(f::F, values) where {F} = - unrolled_foldl(|, unrolled_map(f, values), false) +unrolled_flatten(values) = + unrolled_reduce((tuple1, tuple2) -> (tuple1..., tuple2...), (), values) -unrolled_all(f::F, values) where {F} = - unrolled_foldl(&, unrolled_map(f, values), true) +unrolled_flatmap(f::F, values) where {F} = + unrolled_flatten(unrolled_map(f, values)) -unrolled_filter(f::F, values) where {F} = - unrolled_foldl(values, ()) do filtered_values, value - f(value) ? (filtered_values..., value) : filtered_values +unrolled_product(values1, values2) = + unrolled_flatmap(values1) do value1 + unrolled_map(value2 -> (value1, value2), values2) end unrolled_unique(values) = - unrolled_foldl(values, ()) do unique_values, value - unrolled_any(isequal(value), unique_values) ? unique_values : + unrolled_reduce((), values) do value, unique_values + unrolled_in(value, unique_values) ? unique_values : (unique_values..., value) end -unrolled_flatten(values) = - unrolled_foldl(values, ()) do flattened_values, value - (flattened_values..., value...) - end - -# Non-standard functions: - -unrolled_mapflatten(f::F, values) where {F} = - unrolled_flatten(unrolled_map(f, values)) +unrolled_split(f::F, values) where {F} = + (unrolled_filter(f, values), unrolled_filter(value -> !f(value), values)) function unrolled_findonly(f::F, values) where {F} filtered_values = unrolled_filter(f, values) - length(filtered_values) == 1 || - error("unrolled_findonly requires that exactly one value makes f true") - return first(filtered_values) + return length(filtered_values) == 1 ? filtered_values[1] : + error("unrolled_findonly requires that exactly 1 value makes f true") end -# This is required to make functions defined elsewhere type-stable, as of Julia -# 1.9. Specifically, if an unrolled function is used to implement the recursion -# of another function, it needs to have its recursion limit disabled in order -# for that other function to be type-stable. +# The following functions are recursion-based alternatives to the generated +# functions in Unrolled.jl. These should be used instead of their generated +# counterparts when implementing recursion in other functions. For example, if a +# function `func` needs to map over some values, and if the computation for each +# value recursively calls `func`, the map should be implemented using +# `recursively_unrolled_map` instead of `unrolled_map`. + +@inline recursively_unrolled_map(f::F, values) where {F} = + isempty(values) ? () : + (f(values[1]), recursively_unrolled_map(f, Base.tail(values))...) + +@inline recursively_unrolled_any(f::F, values) where {F} = + isempty(values) ? false : + f(values[1]) || recursively_unrolled_any(f, Base.tail(values)) + +# This is required for type-stability as of Julia 1.9. if hasfield(Method, :recursion_relation) dont_limit = (args...) -> true - for m in methods(unrolled_any) + for m in methods(recursively_unrolled_map) m.recursion_relation = dont_limit - end # for is_valid_name - for m in methods(unrolled_mapflatten) + end + for m in methods(recursively_unrolled_any) m.recursion_relation = dont_limit - end # for complement_values_in_subtree and value_or_non_overlapping_children + end end diff --git a/test/MatrixFields/field_names.jl b/test/MatrixFields/field_names.jl index 6834a9eef1..c18ef0bbbb 100644 --- a/test/MatrixFields/field_names.jl +++ b/test/MatrixFields/field_names.jl @@ -1,4 +1,4 @@ -import ClimaCore.MatrixFields: @name +import ClimaCore.MatrixFields: @name, is_subset_that_covers_set include("matrix_field_test_utils.jl") @@ -53,10 +53,10 @@ const x = (; foo = Foo(0), a = (; b = 1, c = ((; d = 2), (;), ((), nothing)))) @test_all MatrixFields.is_child_name(@name(a.c.:(1).d), @name(a)) @test_all !MatrixFields.is_child_name(@name(a.c.:(1).d), @name(foo)) - @test_all MatrixFields.names_are_overlapping(@name(a), @name(a.c.:(1).d)) - @test_all MatrixFields.names_are_overlapping(@name(a.c.:(1).d), @name(a)) - @test_all !MatrixFields.names_are_overlapping(@name(foo), @name(a.c.:(1).d)) - @test_all !MatrixFields.names_are_overlapping(@name(a.c.:(1).d), @name(foo)) + @test_all MatrixFields.is_overlapping_name(@name(a), @name(a.c.:(1).d)) + @test_all MatrixFields.is_overlapping_name(@name(a.c.:(1).d), @name(a)) + @test_all !MatrixFields.is_overlapping_name(@name(foo), @name(a.c.:(1).d)) + @test_all !MatrixFields.is_overlapping_name(@name(a.c.:(1).d), @name(foo)) @test_all MatrixFields.extract_internal_name(@name(a.c.:(1).d), @name(a)) == @name(c.:(1).d) @@ -75,9 +75,9 @@ const x = (; foo = Foo(0), a = (; b = 1, c = ((; d = 2), (;), ((), nothing)))) (@name(1), @name(2), @name(3)) end -@testset "FieldNameTree Unit Tests" begin - name_tree = MatrixFields.FieldNameTree(x) +const name_tree = MatrixFields.FieldNameTree(x) +@testset "FieldNameTree Unit Tests" begin @test_all MatrixFields.FieldNameTree(x) == name_tree @test_all MatrixFields.is_valid_name(@name(), name_tree) @@ -94,27 +94,29 @@ end (@name(a.b), @name(a.c)) @test_all MatrixFields.child_names(@name(a.c), name_tree) == (@name(a.c.:(1)), @name(a.c.:(2)), @name(a.c.:(3))) - @test_throws "does not contain any child names" MatrixFields.child_names( + @test_throws "does not have child names" MatrixFields.child_names( @name(a.c.:(2)), name_tree, ) - @test_throws "does not contain the name" MatrixFields.child_names( + @test_throws "is not a valid name" MatrixFields.child_names( @name(foo.invalid_name), name_tree, ) end -@testset "FieldNameSet Unit Tests" begin - name_tree = MatrixFields.FieldNameTree(x) - vector_keys(names...) = MatrixFields.FieldVectorKeys(names, name_tree) - matrix_keys(name_pairs...) = - MatrixFields.FieldMatrixKeys(name_pairs, name_tree) +vector_keys(names...) = MatrixFields.FieldVectorKeys(names, name_tree) +matrix_keys(name_pairs...) = MatrixFields.FieldMatrixKeys(name_pairs, name_tree) + +vector_keys_no_tree(names...) = MatrixFields.FieldVectorKeys(names) +matrix_keys_no_tree(name_pairs...) = MatrixFields.FieldMatrixKeys(name_pairs) - vector_keys_no_tree(names...) = MatrixFields.FieldVectorKeys(names) - matrix_keys_no_tree(name_pairs...) = - MatrixFields.FieldMatrixKeys(name_pairs) +drop_tree(set::MatrixFields.FieldVectorKeys) = + MatrixFields.FieldVectorKeys(set.values) +drop_tree(set::MatrixFields.FieldMatrixKeys) = + MatrixFields.FieldMatrixKeys(set.values) - @testset "FieldNameSet Construction" begin +@testset "FieldNameSet Unit Tests" begin + @testset "FieldNameSet Constructors" begin @test_throws "Invalid FieldNameSet value" vector_keys( @name(foo.invalid_name), ) @@ -145,38 +147,64 @@ end end end - @testset "FieldNameSet Iteration" begin - v_set1 = vector_keys(@name(foo), @name(a.c)) - v_set1_no_tree = vector_keys_no_tree(@name(foo), @name(a.c)) - m_set1 = matrix_keys((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) - m_set1_no_tree = matrix_keys_no_tree( - (@name(foo), @name(a.c)), - (@name(a.b), @name(foo)), - ) + v_set1 = vector_keys(@name(foo), @name(a.c)) + m_set1 = matrix_keys((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) - @test_all map(name -> (name, name), v_set1) == - ((@name(foo), @name(foo)), (@name(a.c), @name(a.c))) - @test_all map(name_pair -> name_pair[1], m_set1) == - (@name(foo), @name(a.b)) + # Proper subsets of v_set1 and m_set1. + v_set2 = vector_keys(@name(foo)) + m_set2 = matrix_keys((@name(foo), @name(a.c))) - @test_all isnothing(foreach(name -> (name, name), v_set1)) - @test_all isnothing(foreach(name_pair -> name_pair[1], m_set1)) + # Subsets that cover v_set1 and m_set1. + v_set3 = vector_keys( + @name(foo.value), + @name(a.c.:(1)), + @name(a.c.:(2)), + @name(a.c.:(3)), + ) + m_set3 = matrix_keys( + (@name(foo.value), @name(a.c.:(1))), + (@name(foo), @name(a.c.:(2))), + (@name(foo), @name(a.c.:(3))), + (@name(a.b), @name(foo.value)), + ) + # Sets that overlap with v_set1 and m_set1, but are neither subsets nor + # supersets of those sets. Some of the values in m_set4 overlap with + # those in m_set1, but are neither children nor parents of those values + # (this is only possible with matrix keys). + v_set4 = vector_keys(@name(a.b), @name(a.c.:(1)), @name(a.c.:(2))) + m_set4 = matrix_keys( + (@name(), @name(a.c.:(1))), + (@name(foo.value), @name(foo)), + (@name(foo.value), @name(a.c.:(2))), + (@name(a), @name(foo.value)), + (@name(a.c.:(3)), @name(a.c.:(3))), + ) + + @testset "FieldNameSet Basic Operations" begin @test string(v_set1) == "FieldVectorKeys(@name(foo), @name(a.c); )" - @test string(v_set1_no_tree) == + @test string(drop_tree(v_set1)) == "FieldVectorKeys(@name(foo), @name(a.c))" @test string(m_set1) == "FieldMatrixKeys((@name(foo), @name(a.c)), \ (@name(a.b), @name(foo)); )" - @test string(m_set1_no_tree) == "FieldMatrixKeys((@name(foo), \ + @test string(drop_tree(m_set1)) == "FieldMatrixKeys((@name(foo), \ @name(a.c)), (@name(a.b), @name(foo)))" - for set in (v_set1, v_set1_no_tree) + @test_all map(name -> (name, name), v_set1) == + ((@name(foo), @name(foo)), (@name(a.c), @name(a.c))) + @test_all map(name_pair -> name_pair[1], m_set1) == + (@name(foo), @name(a.b)) + + @test_all isnothing(foreach(name -> (name, name), v_set1)) + @test_all isnothing(foreach(name_pair -> name_pair[1], m_set1)) + + for set in (v_set1, drop_tree(v_set1)) @test_all @name(foo) in set @test_all !(@name(a.b) in set) @test_all !(@name(invalid_name) in set) end - for set in (m_set1, m_set1_no_tree) + for set in (m_set1, drop_tree(m_set1)) @test_all (@name(foo), @name(a.c)) in set @test_all !((@name(foo), @name(a.b)) in set) @test_all !((@name(foo), @name(invalid_name)) in set) @@ -184,133 +212,195 @@ end @test_all @name(foo.value) in v_set1 @test_all !(@name(foo.invalid_name) in v_set1) - @test_throws "FieldNameTree" @name(foo.value) in v_set1_no_tree - @test_throws "FieldNameTree" @name(foo.invalid_name) in v_set1_no_tree + @test_throws "FieldNameTree" @name(foo.value) in drop_tree(v_set1) + @test_throws "FieldNameTree" @name(foo.invalid_name) in + drop_tree(v_set1) @test_all (@name(foo.value), @name(a.c)) in m_set1 @test_all !((@name(foo.invalid_name), @name(a.c)) in m_set1) @test_throws "FieldNameTree" (@name(foo.value), @name(a.c)) in - m_set1_no_tree + drop_tree(m_set1) @test_throws "FieldNameTree" (@name(foo.invalid_name), @name(a.c)) in - m_set1_no_tree + drop_tree(m_set1) end - @testset "FieldNameSet Operations for Addition/Subtraction" begin - v_set1 = vector_keys(@name(foo), @name(a.c)) - v_set1_no_tree = vector_keys_no_tree(@name(foo), @name(a.c)) - m_set1 = matrix_keys((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) - m_set1_no_tree = matrix_keys_no_tree( - (@name(foo), @name(a.c)), - (@name(a.b), @name(foo)), - ) - - v_set2 = vector_keys(@name(foo)) - v_set2_no_tree = vector_keys_no_tree(@name(foo)) - m_set2 = matrix_keys((@name(foo), @name(a.c))) - m_set2_no_tree = matrix_keys_no_tree((@name(foo), @name(a.c))) + @testset "FieldNameSet Complement Sets" begin + @test_all MatrixFields.set_complement(v_set1) == + vector_keys_no_tree(@name(a.b)) + @test_all MatrixFields.set_complement(v_set2) == + vector_keys_no_tree(@name(a)) + @test_all MatrixFields.set_complement(v_set3) == + vector_keys_no_tree(@name(a.b)) + @test_all MatrixFields.set_complement(v_set4) == + vector_keys_no_tree(@name(foo), @name(a.c.:(3))) - v_set3 = vector_keys( - @name(foo.value), - @name(a.c.:(1)), - @name(a.c.:(2)), - @name(a.c.:(3)), - ) - v_set3_no_tree = vector_keys_no_tree( - @name(foo.value), - @name(a.c.:(1)), - @name(a.c.:(2)), - @name(a.c.:(3)), + @test_all MatrixFields.set_complement(m_set1) == matrix_keys_no_tree( + (@name(foo), @name(foo)), + (@name(foo), @name(a.b)), + (@name(a), @name(a)), + (@name(a.c), @name(foo)), ) - m_set3 = matrix_keys( - (@name(foo), @name(a.c.:(1))), - (@name(foo), @name(a.c.:(2))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo)), + @test_all MatrixFields.set_complement(m_set2) == matrix_keys_no_tree( + (@name(foo), @name(foo)), + (@name(foo), @name(a.b)), + (@name(a), @name(foo)), + (@name(a), @name(a)), ) - m_set3_no_tree = matrix_keys_no_tree( - (@name(foo), @name(a.c.:(1))), - (@name(foo), @name(a.c.:(2))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo)), + @test_all MatrixFields.set_complement(m_set3) == matrix_keys_no_tree( + (@name(foo), @name(foo)), + (@name(foo), @name(a.b)), + (@name(a), @name(a)), + (@name(a.c), @name(foo)), ) - m_set3_no_tree′ = matrix_keys_no_tree( - (@name(foo.value), @name(a.c.:(1))), - (@name(foo.value), @name(a.c.:(2))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo)), + @test_all MatrixFields.set_complement(m_set4) == matrix_keys_no_tree( + (@name(foo), @name(a.b)), + (@name(foo), @name(a.c.:(3))), + (@name(a), @name(a.b)), + (@name(a), @name(a.c.:(2))), + (@name(a.b), @name(a.c.:(3))), + (@name(a.c.:(1)), @name(a.c.:(3))), + (@name(a.c.:(2)), @name(a.c.:(3))), ) + for set in (drop_tree(v_set1), drop_tree(m_set1)) + @test_throws "FieldNameTree" MatrixFields.set_complement(set) + end + end + + @testset "FieldNameSet Binary Set Operations" begin for (set1, set2) in ( (v_set1, v_set2), - (v_set1, v_set2_no_tree), - (v_set1_no_tree, v_set2), + (v_set1, drop_tree(v_set2)), + (drop_tree(v_set1), v_set2), + (drop_tree(v_set1), drop_tree(v_set2)), (m_set1, m_set2), - (m_set1, m_set2_no_tree), - (m_set1_no_tree, m_set2), + (m_set1, drop_tree(m_set2)), + (drop_tree(m_set1), m_set2), + (drop_tree(m_set1), drop_tree(m_set2)), ) @test_all set1 != set2 - @test_all !issubset(set1, set2) - @test_all issubset(set2, set1) - @test_all intersect(set1, set2) == set2 - @test_all union(set1, set2) == set1 - @test_all !MatrixFields.is_subset_that_covers_set(set1, set2) - @test_all !MatrixFields.is_subset_that_covers_set(set2, set1) - end - - for (set1, set2) in - ((v_set1_no_tree, v_set2_no_tree), (m_set1_no_tree, m_set2_no_tree)) - @test_all set1 != set2 - @test_all !issubset(set1, set2) - @test_all issubset(set2, set1) - @test_all intersect(set1, set2) == set2 - @test_all union(set1, set2) == set1 - @test_all !MatrixFields.is_subset_that_covers_set(set1, set2) - @test_throws "FieldNameTree" MatrixFields.is_subset_that_covers_set( - set2, - set1, - ) + @test_all !issubset(set1, set2) && issubset(set2, set1) + @test_all !is_subset_that_covers_set(set1, set2) && + !is_subset_that_covers_set(set2, set1) + @test_all intersect(set1, set2) == intersect(set2, set1) == set2 + @test_all union(set1, set2) == union(set2, set1) == set1 + if eltype(set1) <: MatrixFields.FieldName + @test_all setdiff(set1, set2) == vector_keys(@name(a.c)) + else + @test_all setdiff(set1, set2) == + matrix_keys((@name(a.b), @name(foo))) + end + @test_all isempty(setdiff(set2, set1)) end for (set1, set3) in ( (v_set1, v_set3), - (v_set1, v_set3_no_tree), - (v_set1_no_tree, v_set3), + (v_set1, drop_tree(v_set3)), + (drop_tree(v_set1), v_set3), + (m_set1, m_set3), + (m_set1, drop_tree(m_set3)), + (drop_tree(m_set1), m_set3), ) @test_all set1 != set3 - @test_all !issubset(set1, set3) - @test_all issubset(set3, set1) - @test_all intersect(set1, set3) == set3 - @test_all union(set1, set3) == set3 - @test_all !MatrixFields.is_subset_that_covers_set(set1, set3) - @test_all MatrixFields.is_subset_that_covers_set(set3, set1) + @test_all !issubset(set1, set3) && issubset(set3, set1) + @test_all !is_subset_that_covers_set(set1, set3) && + is_subset_that_covers_set(set3, set1) + @test_all intersect(set1, set3) == intersect(set3, set1) == set3 + @test_all union(set1, set3) == union(set3, set1) == set3 + @test_all isempty(setdiff(set1, set3)) && + isempty(setdiff(set3, set1)) end for (set1, set3) in ( - (m_set1, m_set3), - (m_set1, m_set3_no_tree), - (m_set1_no_tree, m_set3), + (drop_tree(v_set1), drop_tree(v_set3)), + (drop_tree(m_set1), drop_tree(m_set3)), ) @test_all set1 != set3 @test_all !issubset(set1, set3) - @test_all issubset(set3, set1) - @test_all intersect(set1, set3) == m_set3_no_tree′ - @test_all union(set1, set3) == m_set3_no_tree′ - @test_all !MatrixFields.is_subset_that_covers_set(set1, set3) - @test_all MatrixFields.is_subset_that_covers_set(set3, set1) + @test_all !is_subset_that_covers_set(set1, set3) + @test_throws "FieldNameTree" issubset(set3, set1) + @test_throws "FieldNameTree" is_subset_that_covers_set(set3, set1) + @test_throws "FieldNameTree" intersect(set1, set3) + @test_throws "FieldNameTree" intersect(set3, set1) + @test_throws "FieldNameTree" union(set1, set3) + @test_throws "FieldNameTree" union(set3, set1) + @test_throws "FieldNameTree" setdiff(set1, set3) + @test_throws "FieldNameTree" setdiff(set3, set1) end - for (set1, set3) in - ((v_set1_no_tree, v_set3_no_tree), (m_set1_no_tree, m_set3_no_tree)) - @test_all set1 != set3 - @test_all !issubset(set1, set3) - @test_throws "FieldNameTree" issubset(set3, set1) - @test_throws "FieldNameTree" intersect(set1, set3) == set3 - @test_throws "FieldNameTree" union(set1, set3) == set3 - @test_all !MatrixFields.is_subset_that_covers_set(set1, set3) - @test_throws "FieldNameTree" MatrixFields.is_subset_that_covers_set( - set3, - set1, - ) + for (set1, set4) in ( + (v_set1, v_set4), + (v_set1, drop_tree(v_set4)), + (drop_tree(v_set1), v_set4), + (m_set1, m_set4), + (m_set1, drop_tree(m_set4)), + (drop_tree(m_set1), m_set4), + ) + @test_all set1 != set4 + @test_all !issubset(set1, set4) && !issubset(set4, set1) + @test_all !is_subset_that_covers_set(set1, set4) && + !is_subset_that_covers_set(set4, set1) + if eltype(set1) <: MatrixFields.FieldName + @test_all intersect(set1, set4) == + intersect(set4, set1) == + vector_keys_no_tree(@name(a.c.:(1)), @name(a.c.:(2))) + @test_all union(set1, set4) == + union(set4, set1) == + vector_keys_no_tree( + @name(foo), + @name(a.b), + @name(a.c.:(1)), + @name(a.c.:(2)), + @name(a.c.:(3)), + ) + @test_all setdiff(set1, set4) == + vector_keys_no_tree(@name(foo), @name(a.c.:(3))) + @test_all setdiff(set4, set1) == vector_keys_no_tree(@name(a.b)) + else + @test_all intersect(set1, set4) == + intersect(set4, set1) == + matrix_keys_no_tree( + (@name(foo), @name(a.c.:(1))), + (@name(foo.value), @name(a.c.:(2))), + (@name(a.b), @name(foo.value)), + ) + @test_all union(set1, set4) == + union(set4, set1) == + matrix_keys_no_tree( + (@name(foo), @name(a.c.:(1))), + (@name(foo), @name(a.c.:(3))), + (@name(foo.value), @name(foo)), + (@name(foo.value), @name(a.c.:(2))), + (@name(a), @name(a.c.:(1))), + (@name(a.b), @name(foo.value)), + (@name(a.c), @name(foo.value)), + (@name(a.c.:(3)), @name(a.c.:(3))), + ) + @test_all setdiff(set1, set4) == + matrix_keys_no_tree((@name(foo), @name(a.c.:(3)))) + @test_all setdiff(set4, set1) == matrix_keys_no_tree( + (@name(foo.value), @name(foo)), + (@name(a), @name(a.c.:(1))), + (@name(a.c), @name(foo.value)), + (@name(a.c.:(3)), @name(a.c.:(3))), + ) + end + end + + for (set1, set4) in ( + (drop_tree(v_set1), drop_tree(v_set4)), + (drop_tree(m_set1), drop_tree(m_set4)), + ) + @test_all set1 != set4 + @test_all !issubset(set1, set4) && !issubset(set4, set1) + @test_all !is_subset_that_covers_set(set1, set4) && + !is_subset_that_covers_set(set4, set1) + @test_throws "FieldNameTree" intersect(set1, set4) + @test_throws "FieldNameTree" intersect(set4, set1) + @test_throws "FieldNameTree" union(set1, set4) + @test_throws "FieldNameTree" union(set4, set1) + @test_throws "FieldNameTree" setdiff(set1, set4) + @test_throws "FieldNameTree" setdiff(set4, set1) end end @@ -503,111 +593,54 @@ end end @testset "Other FieldNameSet Operations" begin - v_set1 = vector_keys(@name(foo), @name(a.c)) - v_set1_no_tree = vector_keys_no_tree(@name(foo), @name(a.c)) - m_set1 = matrix_keys((@name(foo), @name(a.c)), (@name(a.b), @name(foo))) - m_set1_no_tree = matrix_keys_no_tree( - (@name(foo), @name(a.c)), - (@name(a.b), @name(foo)), - ) - - v_set2 = vector_keys(@name(foo.value), @name(a.c.:(1)), @name(a.c.:(3))) - v_set2_no_tree = vector_keys_no_tree( - @name(foo.value), - @name(a.c.:(1)), - @name(a.c.:(3)) - ) - m_set2 = matrix_keys( - (@name(foo), @name(foo)), - (@name(foo), @name(a.c.:(1))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo.value)), - (@name(a), @name(a.c)), - ) - m_set2_no_tree = matrix_keys_no_tree( - (@name(foo), @name(foo)), - (@name(foo), @name(a.c.:(1))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo.value)), - (@name(a), @name(a.c)), - ) - - @test_all MatrixFields.set_complement(v_set2) == - vector_keys(@name(a.b), @name(a.c.:(2))) - @test_throws "FieldNameTree" MatrixFields.set_complement(v_set2_no_tree) - - @test_all MatrixFields.set_complement(m_set2) == matrix_keys( - (@name(foo.value), @name(a.b)), - (@name(foo.value), @name(a.c.:(2))), - (@name(a.c), @name(foo.value)), - (@name(a), @name(a.b)), - ) - @test_throws "FieldNameTree" MatrixFields.set_complement(m_set2_no_tree) - - for (set1, set2) in ( - (v_set1, v_set2), - (v_set1, v_set2_no_tree), - (v_set1_no_tree, v_set2), - ) - @test_all setdiff(set1, set2) == vector_keys(@name(a.c.:(2))) - end - - for (set1, set2) in ( - (m_set1, m_set2), - (m_set1, m_set2_no_tree), - (m_set1_no_tree, m_set2), - ) - @test_all setdiff(set1, set2) == - matrix_keys((@name(foo.value), @name(a.c.:(2)))) - end - - for (set1, set2) in - ((v_set1_no_tree, v_set2_no_tree), (m_set1_no_tree, m_set2_no_tree)) - @test_throws "FieldNameTree" setdiff(set1, set2) - end - # With one exception, none of the following operations require a # FieldNameTree. - @test_all MatrixFields.corresponding_matrix_keys(v_set1_no_tree) == - matrix_keys( + @test_all MatrixFields.corresponding_matrix_keys(drop_tree(v_set1)) == + matrix_keys_no_tree( (@name(foo), @name(foo)), (@name(a.c), @name(a.c)), ) @test_all MatrixFields.cartesian_product( - v_set1_no_tree, - v_set2_no_tree, - ) == matrix_keys( - (@name(foo), @name(foo.value)), + drop_tree(v_set1), + drop_tree(v_set4), + ) == matrix_keys_no_tree( + (@name(foo), @name(a.b)), (@name(foo), @name(a.c.:(1))), - (@name(foo), @name(a.c.:(3))), - (@name(a.c), @name(foo.value)), + (@name(foo), @name(a.c.:(2))), + (@name(a.c), @name(a.b)), (@name(a.c), @name(a.c.:(1))), - (@name(a.c), @name(a.c.:(3))), + (@name(a.c), @name(a.c.:(2))), ) - @test_all MatrixFields.matrix_row_keys(m_set1_no_tree) == - vector_keys(@name(foo), @name(a.b)) + @test_all MatrixFields.matrix_row_keys(drop_tree(m_set1)) == + vector_keys_no_tree(@name(foo), @name(a.b)) - @test_all MatrixFields.matrix_row_keys(m_set2) == - vector_keys(@name(foo.value), @name(a.b), @name(a.c)) + @test_all MatrixFields.matrix_row_keys(m_set4) == vector_keys_no_tree( + @name(foo.value), + @name(a.b), + @name(a.c.:(1)), + @name(a.c.:(2)), + @name(a.c.:(3)) + ) @test_throws "FieldNameTree" MatrixFields.matrix_row_keys( - m_set2_no_tree, + drop_tree(m_set4), ) - @test_all MatrixFields.matrix_off_diagonal_keys(m_set2_no_tree) == - matrix_keys( - (@name(foo), @name(a.c.:(1))), - (@name(foo.value), @name(a.c.:(3))), - (@name(a.b), @name(foo.value)), - (@name(a), @name(a.c)), + @test_all MatrixFields.matrix_off_diagonal_keys(drop_tree(m_set4)) == + matrix_keys_no_tree( + (@name(), @name(a.c.:(1))), + (@name(foo.value), @name(foo)), + (@name(foo.value), @name(a.c.:(2))), + (@name(a), @name(foo.value)), ) - @test_all MatrixFields.matrix_diagonal_keys(m_set2_no_tree) == - matrix_keys( - (@name(foo), @name(foo)), - (@name(a.c), @name(a.c)), + @test_all MatrixFields.matrix_diagonal_keys(drop_tree(m_set4)) == + matrix_keys_no_tree( + (@name(foo.value), @name(foo.value)), + (@name(a.c.:(1)), @name(a.c.:(1))), + (@name(a.c.:(3)), @name(a.c.:(3))), ) end end