Skip to content

Commit

Permalink
Simplify set complements
Browse files Browse the repository at this point in the history
  • Loading branch information
dennisYatunin committed Oct 27, 2023
1 parent 4f73269 commit 59ce2ff
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 190 deletions.
2 changes: 1 addition & 1 deletion src/MatrixFields/field_name_dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ Base.Broadcast.broadcasted(
arg3,
args...,
) =
unrolled_foldl((arg1, arg2, arg3, args...)) do arg1′, arg2′
foldl((arg1, arg2, arg3, args...)) do arg1′, arg2′
Base.Broadcast.broadcasted(f, arg1′, arg2′)
end

Expand Down
97 changes: 45 additions & 52 deletions src/MatrixFields/field_name_set.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ end

set_string(set) = values_string(set.values)

set_complement(set) = setdiff(universal_set(eltype(set)), set)
set_complement(set) = setdiff(universal_set(set.name_tree, eltype(set)), set)

is_subset_that_covers_set(set1, set2) =
issubset(set1, set2) && isempty(setdiff(set2, set1))
Expand All @@ -116,9 +116,11 @@ function corresponding_matrix_keys(set::FieldVectorKeys)
return FieldMatrixKeys(result_values, set.name_tree)
end

function cartesian_product(set1::FieldVectorKeys, set2::FieldVectorKeys)
name_tree = combine_name_trees(set1.name_tree, set2.name_tree)
result_values = unrolled_product((set1.values, set2.values))
function cartesian_product(row_set::FieldVectorKeys, col_set::FieldVectorKeys)
name_tree = combine_name_trees(row_set.name_tree, col_set.name_tree)
result_values = unrolled_flatmap(row_set.values) do row_name
unrolled_map(col_name -> (row_name, col_name), col_set.values)
end
return FieldMatrixKeys(result_values, name_tree)
end

Expand Down Expand Up @@ -278,8 +280,16 @@ combine_name_trees(name_tree1, name_tree2) =
error("Mismatched FieldNameTrees: The ability to combine different \
FieldNameTrees has not been implemented")

universal_set(::Type{FieldName}) = FieldVectorKeys((@name(),))
universal_set(::Type{FieldNamePair}) = FieldMatrixKeys(((@name(), @name()),))
universal_set(::Nothing, ::Type{FieldName}) = error(
"Missing FieldNameTree: Cannot compute complement of FieldNameSet without \
a FieldNameTree",
)
universal_set(name_tree, ::Type{FieldName}) =
FieldVectorKeys(child_names(@name(), name_tree), name_tree)
function universal_set(name_tree, ::Type{FieldNamePair})
row_set = universal_set(name_tree, FieldName)
return cartesian_product(row_set, row_set)
end

is_valid_value(name::FieldName, name_tree) = is_valid_name(name, name_tree)
is_valid_value(name_pair::FieldNamePair, name_tree) =
Expand Down Expand Up @@ -329,49 +339,32 @@ function union_values(values, new_values, name_tree)
and $(values_string(overlapping_values)) without a FieldNameTree",
)

overlapping_values_that_are_children_of_value, other_overlapping_values =
overlapping_children_of_new_value, other_overlapping_values =
unrolled_split(overlapping_values) do value
is_child_value(value, new_value)
end
possible_union_values = if isempty(other_overlapping_values)
possible_children_of_value = available_sets_of_child_values(
new_value,
overlapping_values,
name_tree,
)
recursively_unrolled_map(
possible_children_of_value,
) do children_of_value
union_values(
values,
(children_of_value..., more_new_values...),
name_tree,
)
end
return if isempty(other_overlapping_values)
children_of_new_value =
expand_along_row_or_column(new_value, overlapping_values, name_tree)
expanded_new_values = (children_of_new_value..., more_new_values...)
union_values(values, expanded_new_values, name_tree)
else
possible_children_of_other_overlapping_values = unrolled_map(
unrolled_flatten,
unrolled_prodmap(other_overlapping_values) do value
available_sets_of_child_values(value, (new_value,), name_tree)
end,
children_of_other_overlapping_values =
unrolled_flatmap(other_overlapping_values) do value
expand_along_row_or_column(value, (new_value,), name_tree)
end
expanded_values = (
non_overlapping_values...,
overlapping_children_of_new_value...,
children_of_other_overlapping_values...,
)
recursively_unrolled_map(
possible_children_of_other_overlapping_values,
) do children_of_other_overlapping_values
values_and_children_of_values = (
non_overlapping_values...,
overlapping_values_that_are_children_of_value...,
children_of_other_overlapping_values...,
)
union_values(values_and_children_of_values, new_values, name_tree)
end
union_values(expanded_values, new_values, name_tree)
end
return unrolled_argmin(length, possible_union_values)
end

available_sets_of_child_values(name::FieldName, _, name_tree) =
(child_names(name, name_tree),)
function available_sets_of_child_values(
expand_along_row_or_column(name::FieldName, _, name_tree) =
child_names(name, name_tree)
function expand_along_row_or_column(
name_pair::FieldNamePair,
overlapping_name_pairs,
name_tree,
Expand All @@ -397,19 +390,19 @@ function available_sets_of_child_values(
n_col_children = col_children_needed ? length(col_children) : 0
# We are guaranteed that either n_row_children > 0 or n_col_children > 0.
return if n_row_children == n_col_children == 1
only_child_of_name_pair = (row_children[1][1], col_children[1][2])
((only_child_of_name_pair,),)
((row_children[1][1], col_children[1][2]),)
elseif n_row_children > 0 && n_col_children <= 1
(row_children,)
row_children
elseif n_col_children > 0 && n_row_children <= 1
(col_children,)
else
@assert n_row_children > 1 && n_col_children > 1
# If multiple row and column children are available, we might get
# different results depending on whether we expand the row or the column
# first, so we need to try both and pick the one that corresponds to the
# result with the shortest length.
(row_children, col_children)
col_children
else # n_row_children > 1 && n_col_children > 1
# If multiple row and column children are needed, the result of
# union_values only depends on whether we expand the row or the column
# first when name_pair == (@name(), @name()). In this function, we make
# the arbitrary choice to expand the row first. Since we do not expect
# users to construct matrices that contain (@name(), @name()), this
# choice should not have any noticeable effect.
row_children
end
end

Expand Down
59 changes: 11 additions & 48 deletions src/MatrixFields/unrolled_functions.jl
Original file line number Diff line number Diff line change
@@ -1,58 +1,20 @@
# The following functions are extensions to the generated functions defined in
# Unrolled.jl.
# The following functions build upon the generated functions in Unrolled.jl.

unrolled_flatten(values) =
unrolled_reduce((tuple1, tuple2) -> (tuple1..., tuple2...), (), values)

unrolled_flatmap(f::F, values) where {F} =
unrolled_flatten(unrolled_map(f, values))

unrolled_split(f::F, values) where {F} =
unrolled_filter(f, values), unrolled_filter(!f, values)
(unrolled_filter(f, values), unrolled_filter(!f, values))

function unrolled_findonly(f::F, values) where {F}
filtered_values = unrolled_filter(f, values)
return length(filtered_values) == 1 ? filtered_values[1] :
error("unrolled_findonly requires that exactly 1 value makes f true")
end

# The implementation of unrolled_reduce in Unrolled.jl makes it roughly the same
# as foldl, but with the order of arguments in every call to f flipped.
unrolled_foldl(f::F, values; init = nothing) where {F} =
if isnothing(init)
isempty(values) ?
error("unrolled_foldl requires init for empty collections of values") :
unrolled_reduce((x, y) -> f(y, x), values[1], values[2:end])
else
unrolled_reduce((x, y) -> f(y, x), init, values)
end

function unrolled_argmin(f::F, values) where {F}
values_and_fs = unrolled_map(value -> (value, f(value)), values)
min_value_and_f =
unrolled_foldl(values_and_fs) do (min_value, min_f), (new_value, new_f)
new_f < min_f ? (new_value, new_f) : (min_value, min_f)
end
return min_value_and_f[1]
end

# This needs to use unrolled_reduce instead of unrolled_foldl in order for
# unrolled_product to be type-stable (otherwise, calling unrolled_flatmap inside
# of unrolled_foldl causes the compiler to hit a recursion limit).
unrolled_flatten(values) =
unrolled_reduce((), values) do value, flattened_values
(flattened_values..., value...)
end

unrolled_flatmap(f::F, values) where {F} =
unrolled_flatten(unrolled_map(f, values))

unrolled_product(values) =
unrolled_foldl(values; init = ((),)) do product_values, value
unrolled_flatmap(product_values) do sub_values
unrolled_map(value) do sub_value
(sub_values..., sub_value)
end
end
end

unrolled_prodmap(f::F, values) where {F} =
unrolled_product(unrolled_map(f, values))

# The following functions are recursion-based alternatives to the generated
# functions in Unrolled.jl. These should be used instead of their generated
# counterparts when implementing recursion in other functions. For example, if a
Expand All @@ -64,8 +26,9 @@ unrolled_prodmap(f::F, values) where {F} =
isempty(values) ? () :
(f(values[1]), recursively_unrolled_map(f, values[2:end])...)

recursively_unrolled_any(f::F, values) where {F} =
unrolled_any(identity, recursively_unrolled_map(f, values))
@inline recursively_unrolled_any(f::F, values) where {F} =
isempty(values) ? false :
f(values[1]) || recursively_unrolled_any(f, values[2:end])

# This is required for type-stability as of Julia 1.9.
if hasfield(Method, :recursion_relation)
Expand Down
Loading

0 comments on commit 59ce2ff

Please sign in to comment.