From 55c3f850a4c536051c44a8c0fd95547ed16c8250 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Petrovick=C3=BD?= Date: Wed, 26 Jun 2024 11:38:50 +0200 Subject: [PATCH 1/9] feat: introduce fairness --- tests/test_constraint_verifier.py | 41 +++++ .../main/python/score/_constraint_stream.py | 144 +++++++++++++++++- .../src/main/python/score/_group_by.py | 134 ++++++++++++++++ 3 files changed, 317 insertions(+), 2 deletions(-) diff --git a/tests/test_constraint_verifier.py b/tests/test_constraint_verifier.py index 05b71c8e..b9bc1a4c 100644 --- a/tests/test_constraint_verifier.py +++ b/tests/test_constraint_verifier.py @@ -5,9 +5,16 @@ from timefold.solver.config import * from timefold.solver.test import * +import inspect +import re from typing import Annotated, List from dataclasses import dataclass, field +from ai.timefold.solver.test.api.score.stream import (ConstraintVerifier as JavaConstraintVerifier, + SingleConstraintAssertion as JavaSingleConstraintAssertion, + SingleConstraintVerification as JavaSingleConstraintVerification, + MultiConstraintAssertion as JavaMultiConstraintAssertion, + MultiConstraintVerification as JavaMultiConstraintVerification) def verifier_suite(verifier: ConstraintVerifier, same_value, is_value_one, solution, e1, e2, e3, v1, v2, v3): @@ -268,3 +275,37 @@ class Solution: verifier_suite(verifier, same_value, is_value_one, solution, e1, e2, e3, v1, v2, v3) + + +ignored_java_functions = { + 'equals', + 'getClass', + 'hashCode', + 'notify', + 'notifyAll', + 'toString', + 'wait', + 'with_constraint_stream_impl_type' +} + + +def test_has_all_methods(): + for python_type, java_type in ((ConstraintVerifier, JavaConstraintVerifier), + (SingleConstraintAssertion, JavaSingleConstraintAssertion), + (SingleConstraintVerification, JavaSingleConstraintVerification), + (MultiConstraintAssertion, JavaMultiConstraintAssertion), + (MultiConstraintVerification, JavaMultiConstraintVerification)): + missing = [] + for function_name, function_impl in inspect.getmembers(java_type, inspect.isfunction): + if function_name in ignored_java_functions: + continue + snake_case_name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', function_name) + # change h_t_t_p -> http + snake_case_name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake_case_name).lower() + if not hasattr(python_type, snake_case_name): + missing.append(snake_case_name) + + if missing: + raise AssertionError(f'{python_type} is missing methods ({missing}) ' + f'from java_type ({java_type}).)') + diff --git a/timefold-solver-python-core/src/main/python/score/_constraint_stream.py b/timefold-solver-python-core/src/main/python/score/_constraint_stream.py index d6395ebb..2afdcfd1 100644 --- a/timefold-solver-python-core/src/main/python/score/_constraint_stream.py +++ b/timefold-solver-python-core/src/main/python/score/_constraint_stream.py @@ -454,6 +454,20 @@ def concat(self, other): else: raise RuntimeError(f'Unhandled constraint stream type {type(other)}.') + def complement(self, cls: type[A]) -> 'UniConstraintStream[A]': + """ + Adds to the stream all instances of a given class which are not yet present in it. + These instances must be present in the solution, + which means the class needs to be either a planning entity or a problem fact. + + Parameters + ---------- + cls : Type[A] + the type of the instances to add to the stream. + """ + result = self.delegate.complement(get_class(cls)) + return TriConstraintCollector(result, self.package, self.a_type) + def penalize(self, constraint_weight: ScoreType, match_weigher: Callable[[A], int] = None) -> \ 'UniConstraintBuilder[A, ScoreType]': """ @@ -1000,6 +1014,40 @@ def concat(self, other): else: raise RuntimeError(f'Unhandled constraint stream type {type(other)}.') + @overload + def complement(self, cls: type[A]) -> 'BiConstraintStream[A, B]': + ... + + @overload + def complement(self, cls: type[A], padding: Callable[[A], B]) -> 'BiConstraintStream[A, B]': + ... + + def complement(self, cls: type[A], padding=None): + """ + Adds to the stream all instances of a given class which are not yet present in it. + These instances must be present in the solution, + which means the class needs to be either a planning entity or a problem fact. + + The instances will be read from the first element of the input tuple. + When an output tuple needs to be created for the newly inserted instances, + the first element will be the new instance. + The rest of the tuple will be padded with the result of the padding function. + + Parameters + ---------- + cls : Type[A] + the type of the instances to add to the stream. + + padding : Callable[[A], B] + a function that computes the padding value for the second fact in the new tuple. + """ + if None == padding: + result = self.delegate.complement(get_class(cls)) + return TriConstraintCollector(result, self.package, self.a_type, self.b_type) + java_padding = function_cast(padding, self.a_type) + result = self.delegate.complement(get_class(cls), java_padding) + return TriConstraintCollector(result, self.package, self.a_type, self.b_type) + def penalize(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B], int] = None) -> \ 'BiConstraintBuilder[A, B, ScoreType]': """ @@ -1544,6 +1592,51 @@ def concat(self, other): else: raise RuntimeError(f'Unhandled constraint stream type {type(other)}.') + @overload + def complement(self, cls: type[A]) -> 'TriConstraintStream[A, B, C]': + ... + + @overload + def complement(self, cls: type[A], padding_b: Callable[[A], B], padding_c: Callable[[A], C]) \ + -> 'TriConstraintStream[A, B, C]': + ... + + def complement(self, cls: type[A], padding_b=None, padding_c=None): + """ + Adds to the stream all instances of a given class which are not yet present in it. + These instances must be present in the solution, + which means the class needs to be either a planning entity or a problem fact. + + The instances will be read from the first element of the input tuple. + When an output tuple needs to be created for the newly inserted instances, + the first element will be the new instance. + The rest of the tuple will be padded with the result of the padding function, + applied on the new instance. + + Padding functions are optional, but if one is provided, then both must-be provided. + + Parameters + ---------- + cls : Type[A] + the type of the instances to add to the stream. + + padding_b : Callable[[A], B] + a function that computes the padding value for the second fact in the new tuple. + + padding_c : Callable[[A], C] + a function that computes the padding value for the third fact in the new tuple. + """ + if None == padding_b == padding_c: + result = self.delegate.complement(get_class(cls)) + return TriConstraintCollector(result, self.package, self.a_type, self.b_type, self.c_type) + specified_count = sum(x is not None for x in [padding_b, padding_c]) + if specified_count != 0: + raise ValueError(f'If a padding function is provided, both are expected, got {specified_count} instead.') + java_padding_b = function_cast(padding_b, self.a_type) + java_padding_c = function_cast(padding_c, self.a_type) + result = self.delegate.complement(get_class(cls), java_padding_b, java_padding_c) + return TriConstraintCollector(result, self.package, self.a_type, self.b_type, self.c_type) + def penalize(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B, C], int] = None) -> 'TriConstraintBuilder[A, B, C, ScoreType]': """ @@ -2016,7 +2109,6 @@ def map(self, *mapping_functions): JClass('java.lang.Object')) if len(mapping_functions) == 4: return QuadConstraintStream(self.delegate.map(*translated_functions), self.package, - JClass('java.lang.Object'), JClass('java.lang.Object'), JClass('java.lang.Object'), JClass('java.lang.Object')) raise RuntimeError(f'Impossible state: missing case for {len(mapping_functions)}.') @@ -2027,7 +2119,6 @@ def flatten_last(self, flattening_function) -> 'QuadConstraintStream[A,B,C,D]': """ translated_function = function_cast(flattening_function, self.d_type) return QuadConstraintStream(self.delegate.flattenLast(translated_function), self.package, - self.a_type, self.b_type, self.c_type, JClass('java.lang.Object')) def distinct(self) -> 'QuadConstraintStream[A,B,C,D]': @@ -2083,6 +2174,55 @@ def concat(self, other): else: raise RuntimeError(f'Unhandled constraint stream type {type(other)}.') + @overload + def complement(self, cls: type[A]) -> 'QuadConstraintStream[A, B, C, D]': + ... + + @overload + def complement(self, cls: type[A], padding_b: Callable[[A], B], padding_c: Callable[[A], C], + padding_d: Callable[[A], D]) -> 'QuadConstraintStream[A, B, C, D]': + ... + + def complement(self, cls: type[A], padding_b=None, padding_c=None, padding_d=None): + """ + Adds to the stream all instances of a given class which are not yet present in it. + These instances must be present in the solution, + which means the class needs to be either a planning entity or a problem fact. + + The instances will be read from the first element of the input tuple. + When an output tuple needs to be created for the newly inserted instances, + the first element will be the new instance. + The rest of the tuple will be padded with the result of the padding function, + applied on the new instance. + + Padding functions are optional, but if one is provided, then all three must-be provided. + + Parameters + ---------- + cls : Type[A] + the type of the instances to add to the stream. + + padding_b : Callable[[A], B] + a function that computes the padding value for the second fact in the new tuple. + + padding_c : Callable[[A], C] + a function that computes the padding value for the third fact in the new tuple. + + padding_d : Callable[[A], D] + a function that computes the padding value for the fourth fact in the new tuple. + """ + if None == padding_b == padding_c == padding_d: + result = self.delegate.complement(get_class(cls)) + return QuadConstraintCollector(result, self.package, self.a_type, self.b_type, self.c_type, self.d_type) + specified_count = sum(x is not None for x in [padding_b, padding_c, padding_d]) + if specified_count != 0: + raise ValueError(f'If a padding function is provided, all 3 are expected, got {specified_count} instead.') + java_padding_b = function_cast(padding_b, self.a_type) + java_padding_c = function_cast(padding_c, self.a_type) + java_padding_d = function_cast(padding_d, self.a_type) + result = self.delegate.complement(get_class(cls), java_padding_b, java_padding_c, java_padding_d) + return QuadConstraintCollector(result, self.package, self.a_type, self.b_type, self.c_type, self.d_type) + def penalize(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B, C, D], int] = None) -> 'QuadConstraintBuilder[A, B, C, D, ScoreType]': """ diff --git a/timefold-solver-python-core/src/main/python/score/_group_by.py b/timefold-solver-python-core/src/main/python/score/_group_by.py index 68c482c9..ba7727ea 100644 --- a/timefold-solver-python-core/src/main/python/score/_group_by.py +++ b/timefold-solver-python-core/src/main/python/score/_group_by.py @@ -3,6 +3,7 @@ from typing import Callable, Any, Sequence, TypeVar, List, Set, Dict, TYPE_CHECKING, overload if TYPE_CHECKING: from ai.timefold.solver.core.api.score.stream.common import SequenceChain + from ai.timefold.solver.core.api.score.stream.common import LoadBalance from ai.timefold.solver.core.api.score.stream.uni import UniConstraintCollector from ai.timefold.solver.core.api.score.stream.bi import BiConstraintCollector from ai.timefold.solver.core.api.score.stream.tri import TriConstraintCollector @@ -61,6 +62,14 @@ class CollectAndThenCollector: mapping_function: Callable +@dataclasses.dataclass +class LoadBalanceCollector: + collector_creator: Callable + balanced_item_function: Callable + load_function: Callable | None + initial_load_function: Callable | None + + def extract_collector(collector_info, *type_arguments): if isinstance(collector_info, NoArgsConstraintCollector): return collector_info.collector_creator() @@ -89,6 +98,15 @@ def extract_collector(collector_info, *type_arguments): delegate_collector = extract_collector(collector_info.delegate_collector, *type_arguments) mapping_function = function_cast(collector_info.mapping_function, JClass('java.lang.Object')) return collector_info.collector_creator(delegate_collector, mapping_function) + elif isinstance(collector_info, LoadBalanceCollector): + balanced_item_function = function_cast(collector_info.balanced_item_function, *type_arguments) + if collector_info.load_function is None: + return collector_info.collector_creator(balanced_item_function) + load_function = function_cast(collector_info.load_function, *type_arguments) + if collector_info.initial_load_function is None: + return collector_info.collector_creator(balanced_item_function, load_function) + initial_load_function = function_cast(collector_info.initial_load_function, *type_arguments) + return collector_info.collector_creator(balanced_item_function, load_function, initial_load_function) else: raise ValueError(f'Invalid Collector: {collector_info}. ' f'Create Collectors via timefold.solver.constraint.ConstraintCollectors.') @@ -135,6 +153,7 @@ class ConstraintCollectors: C = TypeVar('C') D = TypeVar('D') E = TypeVar('E') + Balanced = TypeVar('Balanced') # Method return type variables A_ = TypeVar('A_') @@ -142,6 +161,7 @@ class ConstraintCollectors: C_ = TypeVar('C_') D_ = TypeVar('D_') E_ = TypeVar('E_') + Balanced_ = TypeVar('Balanced_') @staticmethod def _delegate(): @@ -993,6 +1013,120 @@ def to_sorted_map(key_mapper, value_mapper, merge_function_or_set_creator=None): else: raise ValueError + @overload + @staticmethod + def load_balance(balanced_item_function: Callable[[A], Balanced_]) -> \ + 'UniConstraintCollector[A, Any, LoadBalance[Balanced_]]': + ... + + @overload + @staticmethod + def load_balance(balanced_item_function: Callable[[A], Balanced_], load_function: Callable[[A], int]) -> \ + 'UniConstraintCollector[A, Any, LoadBalance[Balanced_]]': + ... + + @overload + @staticmethod + def load_balance(balanced_item_function: Callable[[A], Balanced_], load_function: Callable[[A], int], + initial_load_function: Callable[[A], int]) -> \ + 'UniConstraintCollector[A, Any, LoadBalance[Balanced_]]': + ... + + @overload + @staticmethod + def load_balance(balanced_item_function: Callable[[A, B], Balanced_]) -> \ + 'BiConstraintCollector[A, B, Any, LoadBalance[Balanced_]]': + ... + + @overload + @staticmethod + def load_balance(balanced_item_function: Callable[[A, B], Balanced_], load_function: Callable[[A, B], int]) -> \ + 'BiConstraintCollector[A, B, Any, LoadBalance[Balanced_]]': + ... + + @overload + @staticmethod + def load_balance(balanced_item_function: Callable[[A, B], Balanced_], load_function: Callable[[A, B], int], + initial_load_function: Callable[[A, B], int]) -> \ + 'BiConstraintCollector[A, B, Any, LoadBalance[Balanced_]]': + ... + + @overload + @staticmethod + def load_balance(balanced_item_function: Callable[[A, B, C], Balanced_]) -> \ + 'TriConstraintCollector[A, B, C, Any, LoadBalance[Balanced_]]': + ... + + @overload + @staticmethod + def load_balance(balanced_item_function: Callable[[A, B, C], Balanced_], + load_function: Callable[[A, B, C], int]) -> \ + 'TriConstraintCollector[A, B, C, Any, LoadBalance[Balanced_]]': + ... + + @overload + @staticmethod + def load_balance(balanced_item_function: Callable[[A, B, C], Balanced_], load_function: Callable[[A, B, C], int], + initial_load_function: Callable[[A, B, C], int]) -> \ + 'TriConstraintCollector[A, B, C, Any, LoadBalance[Balanced_]]': + ... + + @overload + @staticmethod + def load_balance(balanced_item_function: Callable[[A, B, C, D], Balanced_]) -> \ + 'QuadConstraintCollector[A, B, C, D, Any, LoadBalance[Balanced_]]': + ... + + @overload + @staticmethod + def load_balance(balanced_item_function: Callable[[A, B, C, D], Balanced_], + load_function: Callable[[A, B, C, D], int]) -> \ + 'QuadConstraintCollector[A, B, C, D, Any, LoadBalance[Balanced_]]': + ... + + @overload + @staticmethod + def load_balance(balanced_item_function: Callable[[A, B, C, D], Balanced_], + load_function: Callable[[A, B, C, D], int], + initial_load_function: Callable[[A, B, C, D], int]) -> \ + 'QuadConstraintCollector[A, B, C, D, Any, LoadBalance[Balanced_]]': + ... + + @staticmethod + def load_balance(balanced_item_function, load_function=None, initial_load_function=None): + """ + Returns a collector that takes a stream of items and calculates the unfairness measure from them. + The load for every item is provided by the load_function, + with the starting load provided by the initial_load_function. + + When this collector is used in a constraint stream, + it is recommended to use a score type which supports real numbers. + This is so that the unfairness measure keeps its precision + without forcing the other constraints to be multiplied by a large constant, + which would otherwise be required to implement fixed-point arithmetic. + + Parameters + ---------- + balanced_item_function: + The function that returns the item which should be load-balanced. + load_function: + How much the item should count for in the formula. + initial_load_function: + The initial value of the metric, allowing to provide initial state + without requiring the entire previous planning windows in the working memory. + If this function is provided, load_function must be provided as well. + """ + if None == load_function == initial_load_function: + return LoadBalanceCollector(ConstraintCollectors._delegate().loadBalance, balanced_item_function, None, + None) + elif None == initial_load_function: + return LoadBalanceCollector(ConstraintCollectors._delegate().loadBalance, balanced_item_function, + load_function, None) + elif None == load_function: + raise ValueError("load_function cannot be None if initial_load_function is not None") + else: + return LoadBalanceCollector(ConstraintCollectors._delegate().loadBalance, balanced_item_function, + load_function, initial_load_function) # Must be at the bottom, constraint_stream depends on this module from ._constraint_stream import * From 48a7a6a4061fb5c17673a1f41d5ebb9b0c234ac9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Petrovick=C3=BD?= Date: Fri, 28 Jun 2024 09:49:37 +0200 Subject: [PATCH 2/9] chore: avoid IDE warnings --- .../src/main/python/score/_group_by.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/timefold-solver-python-core/src/main/python/score/_group_by.py b/timefold-solver-python-core/src/main/python/score/_group_by.py index ba7727ea..578705d9 100644 --- a/timefold-solver-python-core/src/main/python/score/_group_by.py +++ b/timefold-solver-python-core/src/main/python/score/_group_by.py @@ -1013,78 +1013,78 @@ def to_sorted_map(key_mapper, value_mapper, merge_function_or_set_creator=None): else: raise ValueError - @overload + @overload # noqa @staticmethod def load_balance(balanced_item_function: Callable[[A], Balanced_]) -> \ 'UniConstraintCollector[A, Any, LoadBalance[Balanced_]]': ... - @overload + @overload # noqa @staticmethod def load_balance(balanced_item_function: Callable[[A], Balanced_], load_function: Callable[[A], int]) -> \ 'UniConstraintCollector[A, Any, LoadBalance[Balanced_]]': ... - @overload + @overload # noqa @staticmethod def load_balance(balanced_item_function: Callable[[A], Balanced_], load_function: Callable[[A], int], initial_load_function: Callable[[A], int]) -> \ 'UniConstraintCollector[A, Any, LoadBalance[Balanced_]]': ... - @overload + @overload # noqa @staticmethod def load_balance(balanced_item_function: Callable[[A, B], Balanced_]) -> \ 'BiConstraintCollector[A, B, Any, LoadBalance[Balanced_]]': ... - @overload + @overload # noqa @staticmethod def load_balance(balanced_item_function: Callable[[A, B], Balanced_], load_function: Callable[[A, B], int]) -> \ 'BiConstraintCollector[A, B, Any, LoadBalance[Balanced_]]': ... - @overload + @overload # noqa @staticmethod def load_balance(balanced_item_function: Callable[[A, B], Balanced_], load_function: Callable[[A, B], int], initial_load_function: Callable[[A, B], int]) -> \ 'BiConstraintCollector[A, B, Any, LoadBalance[Balanced_]]': ... - @overload + @overload # noqa @staticmethod def load_balance(balanced_item_function: Callable[[A, B, C], Balanced_]) -> \ 'TriConstraintCollector[A, B, C, Any, LoadBalance[Balanced_]]': ... - @overload + @overload # noqa @staticmethod def load_balance(balanced_item_function: Callable[[A, B, C], Balanced_], load_function: Callable[[A, B, C], int]) -> \ 'TriConstraintCollector[A, B, C, Any, LoadBalance[Balanced_]]': ... - @overload + @overload # noqa @staticmethod def load_balance(balanced_item_function: Callable[[A, B, C], Balanced_], load_function: Callable[[A, B, C], int], initial_load_function: Callable[[A, B, C], int]) -> \ 'TriConstraintCollector[A, B, C, Any, LoadBalance[Balanced_]]': ... - @overload + @overload # noqa @staticmethod def load_balance(balanced_item_function: Callable[[A, B, C, D], Balanced_]) -> \ 'QuadConstraintCollector[A, B, C, D, Any, LoadBalance[Balanced_]]': ... - @overload + @overload # noqa @staticmethod def load_balance(balanced_item_function: Callable[[A, B, C, D], Balanced_], load_function: Callable[[A, B, C, D], int]) -> \ 'QuadConstraintCollector[A, B, C, D, Any, LoadBalance[Balanced_]]': ... - @overload + @overload # noqa @staticmethod def load_balance(balanced_item_function: Callable[[A, B, C, D], Balanced_], load_function: Callable[[A, B, C, D], int], From efecb6aa12fbcaaf36423fec5cfab1f8cde55067 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Petrovick=C3=BD?= Date: Fri, 28 Jun 2024 11:03:51 +0200 Subject: [PATCH 3/9] chore: add concat overloads --- tests/test_constraint_streams.py | 35 +++ .../main/python/score/_constraint_stream.py | 277 ++++++++++++++---- 2 files changed, 256 insertions(+), 56 deletions(-) diff --git a/tests/test_constraint_streams.py b/tests/test_constraint_streams.py index 488de865..e862b67a 100644 --- a/tests/test_constraint_streams.py +++ b/tests/test_constraint_streams.py @@ -436,6 +436,41 @@ def define_constraints(constraint_factory: ConstraintFactory): assert score_manager.explain(problem).score.score == 1 +def test_complement(): + @constraint_provider + def define_constraints(constraint_factory: ConstraintFactory): + return [ + constraint_factory.for_each(Entity) + .filter(lambda e: e.value.number == 1) + .complement(Entity) + .reward(SimpleScore.ONE) + .as_constraint('Count') + ] + + score_manager = create_score_manager(define_constraints) + entity_a: Entity = Entity('A') + entity_b: Entity = Entity('B') + + value_1 = Value(1) + value_2 = Value(2) + value_3 = Value(3) + + problem = Solution([entity_a, entity_b], [value_1, value_2, value_3]) + + assert score_manager.explain(problem).score.score == 0 + + entity_a.value = value_1 + + assert score_manager.explain(problem).score.score == 1 + + entity_b.value = value_2 + + assert score_manager.explain(problem).score.score == 2 + + entity_b.value = value_3 + + assert score_manager.explain(problem).score.score == 2 + def test_custom_indictments(): @dataclass(unsafe_hash=True) diff --git a/timefold-solver-python-core/src/main/python/score/_constraint_stream.py b/timefold-solver-python-core/src/main/python/score/_constraint_stream.py index 2afdcfd1..88be532f 100644 --- a/timefold-solver-python-core/src/main/python/score/_constraint_stream.py +++ b/timefold-solver-python-core/src/main/python/score/_constraint_stream.py @@ -418,15 +418,29 @@ def concat(self, other: 'UniConstraintStream[A]') -> 'UniConstraintStream[A]': def concat(self, other: 'BiConstraintStream[A, B_]') -> 'BiConstraintStream[A, B_]': ... + @overload + def concat(self, other: 'BiConstraintStream[A, B_]', padding_b: Callable[[A], B_]) -> 'BiConstraintStream[A, B_]': + ... + @overload def concat(self, other: 'TriConstraintStream[A, B_, C_]') -> 'TriConstraintStream[A, B_, C_]': ... + @overload + def concat(self, other: 'TriConstraintStream[A, B_, C_]', padding_b: Callable[[A], B_], + padding_c: Callable[[A], C_]) -> 'TriConstraintStream[A, B_, C_]': + ... + @overload def concat(self, other: 'QuadConstraintStream[A, B_, C_, D_]') -> 'QuadConstraintStream[A, B_, C_, D_]': ... - def concat(self, other): + @overload + def concat(self, other: 'QuadConstraintStream[A, B_, C_, D_]', padding_b: Callable[[A], B_], + padding_c: Callable[[A], C_], padding_d: Callable[[A], D_]) -> 'QuadConstraintStream[A, B_, C_, D_]': + ... + + def concat(self, other, padding_b=None, padding_c=None, padding_d=None): """ The concat building block allows you to create a constraint stream containing tuples of two other constraint streams. @@ -436,21 +450,46 @@ def concat(self, other): when they come from the same source of data, the tuples will be repeated downstream. If this is undesired, use the distinct building block. """ + specified_count = sum(x is not None for x in [padding_b, padding_c, padding_d]) if isinstance(other, UniConstraintStream): - return UniConstraintStream(self.delegate.concat(other.delegate), self.package, - self.a_type) + if specified_count == 0: + return UniConstraintStream(self.delegate.concat(other.delegate), self.package, + self.a_type) + else: + raise ValueError(f'Concatenating UniConstraintStreams requires no padding functions, ' + f'got {specified_count} instead.') elif isinstance(other, BiConstraintStream): - return BiConstraintStream(self.delegate.concat(other.delegate), self.package, - self.a_type, - other.b_type) + if specified_count == 0: + return BiConstraintStream(self.delegate.concat(other.delegate), self.package, + self.a_type, other.b_type) + elif specified_count > 1: + raise ValueError(f'Concatenating Uni and BiConstraintStream requires 1 padding function, ' + f'got {specified_count} instead.') + elif padding_b is None: + raise ValueError(f'Concatenating Uni and BiConstraintStream requires padding_b to be provided.') + return BiConstraintStream(self.delegate.concat(other.delegate, padding_b), self.package, + self.a_type, other.b_type) elif isinstance(other, TriConstraintStream): - return TriConstraintStream(self.delegate.concat(other.delegate), self.package, - self.a_type, - other.b_type, other.c_type) + if specified_count == 0: + return TriConstraintStream(self.delegate.concat(other.delegate), self.package, + self.a_type, other.b_type, other.c_type) + elif specified_count != 2: + raise ValueError(f'Concatenating Uni and TriConstraintStream requires 2 padding functions, ' + f'got {specified_count} instead.') + elif padding_d is not None: + raise ValueError(f'Concatenating Uni and TriConstraintStream requires ' + f'padding_b and padding_c to be provided.') + return TriConstraintStream(self.delegate.concat(other.delegate, padding_b, padding_c), self.package, + self.a_type, other.b_type, other.c_type) elif isinstance(other, QuadConstraintStream): - return QuadConstraintStream(self.delegate.concat(other.delegate), self.package, - self.a_type, - other.b_type, other.c_type, other.d_type) + if specified_count == 0: + return QuadConstraintStream(self.delegate.concat(other.delegate), + self.package, self.a_type, other.b_type, other.c_type, other.d_type) + elif specified_count != 3: + raise ValueError(f'Concatenating Uni and QuadConstraintStream requires 3 padding functions, ' + f'got {specified_count} instead.') + return QuadConstraintStream(self.delegate.concat(other.delegate, padding_b, padding_c, padding_d), + self.package, self.a_type, other.b_type, other.c_type, other.d_type) else: raise RuntimeError(f'Unhandled constraint stream type {type(other)}.') @@ -466,7 +505,7 @@ def complement(self, cls: type[A]) -> 'UniConstraintStream[A]': the type of the instances to add to the stream. """ result = self.delegate.complement(get_class(cls)) - return TriConstraintCollector(result, self.package, self.a_type) + return UniConstraintStream(result, self.package, self.a_type) def penalize(self, constraint_weight: ScoreType, match_weigher: Callable[[A], int] = None) -> \ 'UniConstraintBuilder[A, ScoreType]': @@ -974,6 +1013,10 @@ def distinct(self) -> 'BiConstraintStream[A,B]': def concat(self, other: 'UniConstraintStream[A]') -> 'BiConstraintStream[A, B]': ... + @overload + def concat(self, other: 'UniConstraintStream[A]', padding_b: Callable[[A], B]) -> 'BiConstraintStream[A, B]': + ... + @overload def concat(self, other: 'BiConstraintStream[A, B]') -> 'BiConstraintStream[A, B]': ... @@ -982,11 +1025,21 @@ def concat(self, other: 'BiConstraintStream[A, B]') -> 'BiConstraintStream[A, B] def concat(self, other: 'TriConstraintStream[A, B, C_]') -> 'TriConstraintStream[A, B, C_]': ... + @overload + def concat(self, other: 'TriConstraintStream[A, B, C_]', padding_c: Callable[[A, B], C_]) \ + -> 'TriConstraintStream[A, B, C_]': + ... + @overload def concat(self, other: 'QuadConstraintStream[A, B, C_, D_]') -> 'QuadConstraintStream[A, B, C_, D_]': ... - def concat(self, other): + @overload + def concat(self, other: 'QuadConstraintStream[A, B, C_, D_]', padding_c: Callable[[A, B], C_], + padding_d: Callable[[A, B], D_]) -> 'QuadConstraintStream[A, B, C_, D_]': + ... + + def concat(self, other, padding_b=None, padding_c=None, padding_d=None): """ The concat building block allows you to create a constraint stream containing tuples of two other constraint streams. @@ -996,21 +1049,48 @@ def concat(self, other): when they come from the same source of data, the tuples will be repeated downstream. If this is undesired, use the distinct building block. """ + specified_count = sum(x is not None for x in [padding_b, padding_c, padding_d]) if isinstance(other, UniConstraintStream): - return BiConstraintStream(self.delegate.concat(other.delegate), self.package, + if specified_count == 0: + return BiConstraintStream(self.delegate.concat(other.delegate), self.package, + self.a_type, self.b_type) + elif specified_count != 1: + raise ValueError(f'Concatenating Bi and UniConstraintStream requires one padding function, ' + f'got {specified_count} instead.') + elif padding_b is None: + raise ValueError(f'Concatenating Bi and UniConstraintStream requires padding_b to be provided.') + return BiConstraintStream(self.delegate.concat(other.delegate, padding_b), self.package, self.a_type, self.b_type) elif isinstance(other, BiConstraintStream): - return BiConstraintStream(self.delegate.concat(other.delegate), self.package, - self.a_type, - self.b_type) + if specified_count == 0: + return BiConstraintStream(self.delegate.concat(other.delegate), self.package, + self.a_type, self.b_type) + else: + raise ValueError(f'Concatenating BiConstraintStreams requires no padding function, ' + f'got {specified_count} instead.') elif isinstance(other, TriConstraintStream): - return TriConstraintStream(self.delegate.concat(other.delegate), self.package, - self.a_type, - self.b_type, other.c_type) + if specified_count == 0: + return TriConstraintStream(self.delegate.concat(other.delegate), self.package, + self.a_type, self.b_type, other.c_type) + elif specified_count != 1: + raise ValueError(f'Concatenating Bi and TriConstraintStream requires one padding function, ' + f'got {specified_count} instead.') + elif padding_c is None: + raise ValueError(f'Concatenating Bi and TriConstraintStream requires padding_c to be provided.') + return TriConstraintStream(self.delegate.concat(other.delegate, padding_c), self.package, + self.a_type, self.b_type, other.c_type) elif isinstance(other, QuadConstraintStream): - return QuadConstraintStream(self.delegate.concat(other.delegate), self.package, - self.a_type, - self.b_type, other.c_type, other.d_type) + if specified_count == 0: + return QuadConstraintStream(self.delegate.concat(other.delegate), self.package, + self.a_type, self.b_type, other.c_type, other.d_type) + elif specified_count != 2: + raise ValueError(f'Concatenating Bi and QuadConstraintStream requires two padding functions, ' + f'got {specified_count} instead.') + elif padding_b is not None: + raise ValueError(f'Concatenating Bi and QuadConstraintStream requires ' + f'padding_c and padding_d to be provided.') + return QuadConstraintStream(self.delegate.concat(other.delegate, padding_c, padding_d), self.package, + self.a_type, self.b_type, other.c_type, other.d_type) else: raise RuntimeError(f'Unhandled constraint stream type {type(other)}.') @@ -1043,10 +1123,10 @@ def complement(self, cls: type[A], padding=None): """ if None == padding: result = self.delegate.complement(get_class(cls)) - return TriConstraintCollector(result, self.package, self.a_type, self.b_type) + return BiConstraintStream(result, self.package, self.a_type, self.b_type) java_padding = function_cast(padding, self.a_type) result = self.delegate.complement(get_class(cls), java_padding) - return TriConstraintCollector(result, self.package, self.a_type, self.b_type) + return BiConstraintStream(result, self.package, self.a_type, self.b_type) def penalize(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B], int] = None) -> \ 'BiConstraintBuilder[A, B, ScoreType]': @@ -1551,10 +1631,20 @@ def distinct(self) -> 'TriConstraintStream[A, B, C]': def concat(self, other: 'UniConstraintStream[A]') -> 'TriConstraintStream[A, B, C]': ... + @overload + def concat(self, other: 'UniConstraintStream[A]', padding_b: Callable[[A], B], padding_c: Callable[[A], C]) \ + -> 'TriConstraintStream[A, B, C]': + ... + @overload def concat(self, other: 'BiConstraintStream[A, B]') -> 'TriConstraintStream[A, B, C]': ... + @overload + def concat(self, other: 'BiConstraintStream[A, B]', padding_c: Callable[[A, B], C]) \ + -> 'TriConstraintStream[A, B, C]': + ... + @overload def concat(self, other: 'TriConstraintStream[A, B, C]') -> 'TriConstraintStream[A, B, C]': ... @@ -1563,7 +1653,12 @@ def concat(self, other: 'TriConstraintStream[A, B, C]') -> 'TriConstraintStream[ def concat(self, other: 'QuadConstraintStream[A, B, C, D_]') -> 'QuadConstraintStream[A, B, C, D_]': ... - def concat(self, other): + @overload + def concat(self, other: 'QuadConstraintStream[A, B, C, D_]', padding_d: Callable[[A, B, C], D_]) \ + -> 'QuadConstraintStream[A, B, C, D_]': + ... + + def concat(self, other, padding_b=None, padding_c=None, padding_d=None): """ The concat building block allows you to create a constraint stream containing tuples of two other constraint streams. @@ -1573,22 +1668,48 @@ def concat(self, other): when they come from the same source of data, the tuples will be repeated downstream. If this is undesired, use the distinct building block. """ + specified_count = sum(x is not None for x in [padding_b, padding_c, padding_d]) if isinstance(other, UniConstraintStream): - return TriConstraintStream(self.delegate.concat(other.delegate), self.package, - self.a_type, - self.b_type, self.c_type) + if specified_count == 0: + return TriConstraintStream(self.delegate.concat(other.delegate), self.package, + self.a_type, self.b_type, self.c_type) + elif specified_count != 2: + raise ValueError(f'Concatenating Tri and UniConstraintStream requires 2 padding functions, ' + f'got {specified_count} instead.') + elif padding_d is not None: + raise ValueError(f'Concatenating Tri and UniConstraintStream requires ' + f'padding_b and padding_c to be provided.') + return TriConstraintStream(self.delegate.concat(other.delegate, padding_b, padding_c), self.package, + self.a_type, self.b_type, self.c_type) elif isinstance(other, BiConstraintStream): - return TriConstraintStream(self.delegate.concat(other.delegate), self.package, - self.a_type, - self.b_type, self.c_type) + if specified_count == 0: + return TriConstraintStream(self.delegate.concat(other.delegate), self.package, + self.a_type, self.b_type, self.c_type) + elif specified_count != 1: + raise ValueError(f'Concatenating Tri and BiConstraintStream requires 1 padding function, ' + f'got {specified_count} instead.') + elif padding_c is None: + raise ValueError(f'Concatenating Tri and BiConstraintStream requires padding_c to be provided.') + return TriConstraintStream(self.delegate.concat(other.delegate, padding_c), self.package, + self.a_type, self.b_type, self.c_type) elif isinstance(other, TriConstraintStream): - return TriConstraintStream(self.delegate.concat(other.delegate), self.package, - self.a_type, - self.b_type, self.c_type) + if specified_count == 0: + return TriConstraintStream(self.delegate.concat(other.delegate), self.package, + self.a_type, self.b_type, self.c_type) + else: + raise ValueError(f'Concatenating TriConstraintStreams requires no padding functions, ' + f'got {specified_count} instead.') elif isinstance(other, QuadConstraintStream): - return QuadConstraintStream(self.delegate.concat(other.delegate), self.package, - self.a_type, - self.b_type, self.c_type, other.d_type) + if specified_count == 0: + return QuadConstraintStream(self.delegate.concat(other.delegate), self.package, + self.a_type, self.b_type, self.c_type, other.d_type) + elif specified_count != 1: + raise ValueError(f'Concatenating Tri and QuadConstraintStream requires 1 padding function, ' + f'got {specified_count} instead.') + elif padding_d is None: + raise ValueError(f'Concatenating Tri and QuadConstraintStream requires padding_d to be provided.') + return QuadConstraintStream(self.delegate.concat(other.delegate, padding_d), self.package, + self.a_type, self.b_type, self.c_type, other.d_type) else: raise RuntimeError(f'Unhandled constraint stream type {type(other)}.') @@ -1628,14 +1749,14 @@ def complement(self, cls: type[A], padding_b=None, padding_c=None): """ if None == padding_b == padding_c: result = self.delegate.complement(get_class(cls)) - return TriConstraintCollector(result, self.package, self.a_type, self.b_type, self.c_type) + return TriConstraintStream(result, self.package, self.a_type, self.b_type, self.c_type) specified_count = sum(x is not None for x in [padding_b, padding_c]) if specified_count != 0: raise ValueError(f'If a padding function is provided, both are expected, got {specified_count} instead.') java_padding_b = function_cast(padding_b, self.a_type) java_padding_c = function_cast(padding_c, self.a_type) result = self.delegate.complement(get_class(cls), java_padding_b, java_padding_c) - return TriConstraintCollector(result, self.package, self.a_type, self.b_type, self.c_type) + return TriConstraintStream(result, self.package, self.a_type, self.b_type, self.c_type) def penalize(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B, C], int] = None) -> 'TriConstraintBuilder[A, B, C, ScoreType]': @@ -2133,6 +2254,20 @@ def distinct(self) -> 'QuadConstraintStream[A,B,C,D]': def concat(self, other: 'UniConstraintStream[A]') -> 'QuadConstraintStream[A, B, C, D]': ... + @overload + def concat(self, other: 'UniConstraintStream[A]', padding_b: Callable[[A], B], padding_c: Callable[[A], C], + padding_d: Callable[[A], D]) -> 'QuadConstraintStream[A, B, C, D]': + ... + + @overload + def concat(self, other: 'BiConstraintStream[A, B]') -> 'QuadConstraintStream[A, B, C, D]': + ... + + @overload + def concat(self, other: 'BiConstraintStream[A, B]', padding_c: Callable[[A, B], C], + padding_d: Callable[[A, B], D]) -> 'QuadConstraintStream[A, B, C, D]': + ... + @overload def concat(self, other: 'BiConstraintStream[A, B]') -> 'QuadConstraintStream[A, B, C, D]': ... @@ -2141,11 +2276,16 @@ def concat(self, other: 'BiConstraintStream[A, B]') -> 'QuadConstraintStream[A, def concat(self, other: 'TriConstraintStream[A, B, C]') -> 'QuadConstraintStream[A, B, C, D]': ... + @overload + def concat(self, other: 'TriConstraintStream[A, B, C]', padding_d: Callable[[A, B, C], D]) \ + -> 'QuadConstraintStream[A, B, C, D]': + ... + @overload def concat(self, other: 'QuadConstraintStream[A, B, C, D]') -> 'QuadConstraintStream[A, B, C, D]': ... - def concat(self, other): + def concat(self, other, padding_b=None, padding_c=None, padding_d=None): """ The concat building block allows you to create a constraint stream containing tuples of two other constraint streams. @@ -2155,22 +2295,47 @@ def concat(self, other): when they come from the same source of data, the tuples will be repeated downstream. If this is undesired, use the distinct building block. """ + specified_count = sum(x is not None for x in [padding_b, padding_c, padding_d]) if isinstance(other, UniConstraintStream): - return QuadConstraintStream(self.delegate.concat(other.delegate), self.package, - self.a_type, - self.b_type, self.c_type, self.d_type) + if specified_count == 0: + return QuadConstraintStream(self.delegate.concat(other.delegate), self.package, + self.a_type, self.b_type, self.c_type, self.d_type) + elif specified_count != 3: + raise ValueError(f'Concatenating Uni and QuadConstraintStream requires 3 padding functions, ' + f'got {specified_count} instead.') + return QuadConstraintStream(self.delegate.concat(other.delegate, padding_b, padding_c, padding_d), + self.package, + self.a_type, self.b_type, self.c_type, self.d_type) elif isinstance(other, BiConstraintStream): - return QuadConstraintStream(self.delegate.concat(other.delegate), self.package, - self.a_type, - self.b_type, self.c_type, self.d_type) + if specified_count == 0: + return QuadConstraintStream(self.delegate.concat(other.delegate), self.package, + self.a_type, self.b_type, self.c_type, self.d_type) + elif specified_count != 2: + raise ValueError(f'Concatenating Bi and QuadConstraintStream requires 2 padding functions, ' + f'got {specified_count} instead.') + elif padding_b is not None: + raise ValueError(f'Concatenating Bi and QuadConstraintStream requires ' + f'padding_c and padding_d to be provided.') + return QuadConstraintStream(self.delegate.concat(other.delegate, padding_c, padding_d), self.package, + self.a_type, self.b_type, self.c_type, self.d_type) elif isinstance(other, TriConstraintStream): - return QuadConstraintStream(self.delegate.concat(other.delegate), self.package, - self.a_type, - self.b_type, self.c_type, self.d_type) + if specified_count == 0: + return QuadConstraintStream(self.delegate.concat(other.delegate), self.package, + self.a_type, self.b_type, self.c_type, self.d_type) + elif specified_count != 1: + raise ValueError(f'Concatenating Tri and QuadConstraintStream requires 1 padding function, ' + f'got {specified_count} instead.') + elif padding_d is None: + raise ValueError(f'Concatenating Bi and QuadConstraintStream requires padding_d to be provided.') + return QuadConstraintStream(self.delegate.concat(other.delegate, padding_d), self.package, + self.a_type, self.b_type, self.c_type, self.d_type) elif isinstance(other, QuadConstraintStream): - return QuadConstraintStream(self.delegate.concat(other.delegate), self.package, - self.a_type, - self.b_type, self.c_type, self.d_type) + if specified_count == 0: + return QuadConstraintStream(self.delegate.concat(other.delegate), self.package, + self.a_type, self.b_type, self.c_type, self.d_type) + else: + raise ValueError(f'Concatenating QuadConstraintStreams requires no padding functions, ' + f'got {specified_count} instead.') else: raise RuntimeError(f'Unhandled constraint stream type {type(other)}.') @@ -2213,7 +2378,7 @@ def complement(self, cls: type[A], padding_b=None, padding_c=None, padding_d=Non """ if None == padding_b == padding_c == padding_d: result = self.delegate.complement(get_class(cls)) - return QuadConstraintCollector(result, self.package, self.a_type, self.b_type, self.c_type, self.d_type) + return QuadConstraintStream(result, self.package, self.a_type, self.b_type, self.c_type, self.d_type) specified_count = sum(x is not None for x in [padding_b, padding_c, padding_d]) if specified_count != 0: raise ValueError(f'If a padding function is provided, all 3 are expected, got {specified_count} instead.') @@ -2221,7 +2386,7 @@ def complement(self, cls: type[A], padding_b=None, padding_c=None, padding_d=Non java_padding_c = function_cast(padding_c, self.a_type) java_padding_d = function_cast(padding_d, self.a_type) result = self.delegate.complement(get_class(cls), java_padding_b, java_padding_c, java_padding_d) - return QuadConstraintCollector(result, self.package, self.a_type, self.b_type, self.c_type, self.d_type) + return QuadConstraintStream(result, self.package, self.a_type, self.b_type, self.c_type, self.d_type) def penalize(self, constraint_weight: ScoreType, match_weigher: Callable[[A, B, C, D], int] = None) -> 'QuadConstraintBuilder[A, B, C, D, ScoreType]': From e60cd2881a69c937a5bb57d64c3a14d586ab39ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Petrovick=C3=BD?= Date: Fri, 28 Jun 2024 12:20:09 +0200 Subject: [PATCH 4/9] chore: add if(Not)Exists overloads --- tests/test_constraint_streams.py | 81 +++++ .../main/python/score/_constraint_stream.py | 342 +++++++++++------- 2 files changed, 302 insertions(+), 121 deletions(-) diff --git a/tests/test_constraint_streams.py b/tests/test_constraint_streams.py index e862b67a..555906e8 100644 --- a/tests/test_constraint_streams.py +++ b/tests/test_constraint_streams.py @@ -265,6 +265,87 @@ def define_constraints(constraint_factory: ConstraintFactory): assert score_manager.explain(problem).score.score == 8 +def test_if_exists_uni(): + @constraint_provider + def define_constraints(constraint_factory: ConstraintFactory): + return [ + constraint_factory.for_each(Entity) + .if_exists(Entity, Joiners.equal(lambda entity: entity.code)) + .reward(SimpleScore.ONE, lambda e1: e1.value.number) + .as_constraint('Count') + ] + + score_manager = create_score_manager(define_constraints) + entity_a1: Entity = Entity('A') + entity_a2: Entity = Entity('A') + entity_b1: Entity = Entity('B') + entity_b2: Entity = Entity('B') + + value_1 = Value(1) + value_2 = Value(2) + + problem = Solution([entity_a1, entity_a2, entity_b1, entity_b2], [value_1, value_2]) + + entity_a1.value = value_1 + + # With itself + assert score_manager.explain(problem).score.score == 1 + + entity_a1.value = value_1 + entity_a2.value = value_1 + + entity_b1.value = value_2 + entity_b2.value = value_2 + + # 1 + 2 + 1 + 2 + assert score_manager.explain(problem).score.score == 6 + + entity_a1.value = value_2 + entity_b1.value = value_1 + + # 1 + 2 + 1 + 2 + assert score_manager.explain(problem).score.score == 6 + + +def test_if_not_exists_uni(): + @constraint_provider + def define_constraints(constraint_factory: ConstraintFactory): + return [ + constraint_factory.for_each(Entity) + .if_not_exists(Entity, Joiners.equal(lambda entity: entity.code)) + .reward(SimpleScore.ONE, lambda e1: e1.value.number) + .as_constraint('Count') + ] + + score_manager = create_score_manager(define_constraints) + entity_a1: Entity = Entity('A') + entity_a2: Entity = Entity('A') + entity_b1: Entity = Entity('B') + entity_b2: Entity = Entity('B') + + value_1 = Value(1) + value_2 = Value(2) + + problem = Solution([entity_a1, entity_a2, entity_b1, entity_b2], [value_1, value_2]) + + entity_a1.value = value_1 + + assert score_manager.explain(problem).score.score == 0 + + entity_a1.value = value_1 + entity_a2.value = value_1 + + entity_b1.value = value_2 + entity_b2.value = value_2 + + assert score_manager.explain(problem).score.score == 0 + + entity_a1.value = value_2 + entity_b1.value = value_1 + + assert score_manager.explain(problem).score.score == 0 + + def test_map(): @constraint_provider def define_constraints(constraint_factory: ConstraintFactory): diff --git a/timefold-solver-python-core/src/main/python/score/_constraint_stream.py b/timefold-solver-python-core/src/main/python/score/_constraint_stream.py index 88be532f..32461644 100644 --- a/timefold-solver-python-core/src/main/python/score/_constraint_stream.py +++ b/timefold-solver-python-core/src/main/python/score/_constraint_stream.py @@ -58,7 +58,7 @@ def filter(self, predicate: Callable[[A], bool]) -> 'UniConstraintStream[A]': def join(self, unistream_or_type: Union['UniConstraintStream[B_]', Type[B_]], *joiners: 'BiJoiner[A, B_]') -> \ 'BiConstraintStream[A,B_]': """ - Create a new `BiConstraintStream` for every combination of A and B that satisfy all specified joiners. + Create a new `BiConstraintStream` for every combination of A and B that satisfies all specified joiners. """ b_type = None if isinstance(unistream_or_type, UniConstraintStream): @@ -72,26 +72,42 @@ def join(self, unistream_or_type: Union['UniConstraintStream[B_]', Type[B_]], *j return BiConstraintStream(join_result, self.package, self.a_type, b_type) + @overload def if_exists(self, item_type: Type[B_], *joiners: 'BiJoiner[A, B_]') -> 'UniConstraintStream[A]': + ... + + @overload + def if_exists(self, other_stream: 'UniConstraintStream[B_]', *joiners: 'BiJoiner[A, B_]') \ + -> 'UniConstraintStream[A]': + ... + + def if_exists(self, unistream_or_type: Union['UniConstraintStream[B_]', Type[B_]], + *joiners: 'BiJoiner[A, B_]') -> 'UniConstraintStream[A]': """ - Create a new UniConstraintStream for every A where B exists that satisfy all specified joiners. + Create a new `UniConstraintStream` for every A where B exists that satisfies all specified joiners. """ - item_type = get_class(item_type) - return UniConstraintStream(self.delegate.ifExists(item_type, - extract_joiners(joiners, self.a_type, item_type)), + b_type = None + if isinstance(unistream_or_type, UniConstraintStream): + b_type = unistream_or_type.a_type + unistream_or_type = unistream_or_type.delegate + else: + b_type = get_class(unistream_or_type) + unistream_or_type = b_type + return UniConstraintStream(self.delegate.ifExists(unistream_or_type, + extract_joiners(joiners, + self.a_type, b_type)), self.package, self.a_type) def if_exists_including_unassigned(self, item_type: Type[B_], *joiners: 'BiJoiner[A, B_]') -> \ 'UniConstraintStream[A]': """ - Create a new `UniConstraintStream` for every A where B exists that satisfy all specified joiners. + Create a new `UniConstraintStream` for every A where B exists that satisfies all specified joiners. """ item_type = get_class(item_type) return UniConstraintStream(self.delegate.ifExistsIncludingUnassigned(item_type, - extract_joiners(joiners, self.a_type, - item_type)), - self.package, - self.a_type) + extract_joiners(joiners, + self.a_type, item_type)), + self.package, self.a_type) def if_exists_other(self, item_type: Type[B_], *joiners: 'BiJoiner[A, B_]') -> 'UniConstraintStream[A]': """ @@ -101,8 +117,7 @@ def if_exists_other(self, item_type: Type[B_], *joiners: 'BiJoiner[A, B_]') -> ' item_type = get_class(item_type) return UniConstraintStream(self.delegate.ifExistsOther(cast(Type['A_'], item_type), extract_joiners(joiners, - self.a_type, - item_type)), + self.a_type, item_type)), self.package, self.a_type) def if_exists_other_including_unassigned(self, item_type: Type, *joiners: 'BiJoiner') -> \ @@ -115,65 +130,69 @@ def if_exists_other_including_unassigned(self, item_type: Type, *joiners: 'BiJoi item_type = get_class(item_type) return UniConstraintStream(self.delegate.ifExistsOtherIncludingUnassigned(cast(Type['A_'], item_type), extract_joiners(joiners, - self.a_type, - item_type)), - self.package, - - self.a_type) + self.a_type, item_type)), + self.package, self.a_type) + @overload def if_not_exists(self, item_type: Type[B_], *joiners: 'BiJoiner[A, B_]') -> 'UniConstraintStream[A]': + ... + + @overload + def if_not_exists(self, other_stream: 'UniConstraintStream[B_]', *joiners: 'BiJoiner[A, B_]') \ + -> 'UniConstraintStream[A]': + ... + + def if_not_exists(self, unistream_or_type: Union['UniConstraintStream[B_]', Type[B_]], + *joiners: 'BiJoiner[A, B_]') -> 'UniConstraintStream[A]': """ - Create a new `UniConstraintStream` for every A where there does not exist a B where all specified joiners - are satisfied. + Create a new `UniConstraintStream` for every A where B does not exist that satisfies all specified joiners. """ - item_type = get_class(item_type) - return UniConstraintStream(self.delegate.ifNotExists(item_type, extract_joiners(joiners, self.a_type, - item_type)), + b_type = None + if isinstance(unistream_or_type, UniConstraintStream): + b_type = unistream_or_type.a_type + unistream_or_type = unistream_or_type.delegate + else: + b_type = get_class(unistream_or_type) + unistream_or_type = b_type + return UniConstraintStream(self.delegate.ifNotExists(unistream_or_type, + extract_joiners(joiners, + self.a_type, b_type)), self.package, self.a_type) def if_not_exists_including_unassigned(self, item_type: Type[B_], *joiners: 'BiJoiner[A, B_]') -> \ 'UniConstraintStream[A]': """ - Create a new `UniConstraintStream` for every A where there does not exist a B where all specified joiners are - satisfied. - """ + Create a new `UniConstraintStream` for every A where B does not exist that satisfies all specified joiners. + """ item_type = get_class(item_type) return UniConstraintStream(self.delegate.ifNotExistsIncludingUnassigned(item_type, extract_joiners(joiners, - self.a_type, - item_type)), - self.package, - - self.a_type) + self.a_type, item_type)), + self.package, self.a_type) def if_not_exists_other(self, item_type: Type[B_], *joiners: 'BiJoiner[A, B_]') -> \ 'UniConstraintStream[A]': """ - Create a new `UniConstraintStream` for every A where there does not exist a different A where all specified - joiners are satisfied. + Create a new `UniConstraintStream` for every A where B does not exist that satisfies all specified joiners. """ item_type = get_class(item_type) return UniConstraintStream(self.delegate.ifNotExistsOther(cast(Type['A_'], item_type), extract_joiners(joiners, self.a_type, item_type)), - self.package, - - self.a_type) + self.package, self.a_type) def if_not_exists_other_including_unassigned(self, item_type: Type[B_], *joiners: 'BiJoiner[A, B_]') -> \ 'UniConstraintStream[A]': """ - Create a new `UniConstraintStream` for every A where there does not exist a different A where all specified - joiners are satisfied. + Create a new `UniConstraintStream` for every A where a different A does not exist + that satisfies all specified joiners. """ item_type = get_class(item_type) return UniConstraintStream(self.delegate.ifNotExistsOtherIncludingUnassigned(cast(Type['A_'], item_type), extract_joiners(joiners, - self.a_type, - item_type)), - self.package, - self.a_type) + self.a_type, item_type)), + self.package, self.a_type) @overload def group_by(self, key_mapping: Callable[[A], A_]) -> 'UniConstraintStream[A_]': @@ -714,7 +733,7 @@ def filter(self, predicate: Callable[[A, B], bool]) -> 'BiConstraintStream[A,B]' def join(self, unistream_or_type: Union[UniConstraintStream[C_], Type[C_]], *joiners: 'TriJoiner[A,B,C_]') -> 'TriConstraintStream[A,B,C_]': """ - Create a new `TriConstraintStream` for every combination of A, B and C that satisfy all specified joiners. + Create a new `TriConstraintStream` for every combination of A, B and C that satisfies all specified joiners. """ c_type = None if isinstance(unistream_or_type, UniConstraintStream): @@ -724,62 +743,86 @@ def join(self, unistream_or_type: Union[UniConstraintStream[C_], Type[C_]], c_type = get_class(unistream_or_type) unistream_or_type = c_type - join_result = self.delegate.join(unistream_or_type, extract_joiners(joiners, self.a_type, self.b_type, c_type)) + join_result = self.delegate.join(unistream_or_type, extract_joiners(joiners, + self.a_type, self.b_type, c_type)) return TriConstraintStream(join_result, self.package, - self.a_type, self.b_type, c_type) + @overload def if_exists(self, item_type: Type[C_], *joiners: 'TriJoiner[A, B, C_]') -> 'BiConstraintStream[A,B]': + ... + + @overload + def if_exists(self, other_stream: 'UniConstraintStream[C_]', *joiners: 'TriJoiner[A, B, C_]') \ + -> 'BiConstraintStream[A,B]': + ... + + def if_exists(self, unistream_or_type: Union['UniConstraintStream[C_]', Type[C_]], + *joiners: 'TriJoiner[A, B, C_]') -> 'BiConstraintStream[A,B]': """ - Create a new `BiConstraintStream` for every A, B where C exists that satisfy all specified joiners. + Create a new `BiConstraintStream` for every A, B where C exists that satisfies all specified joiners. """ - item_type = get_class(item_type) - return BiConstraintStream(self.delegate.ifExists(item_type, - extract_joiners(joiners, self.a_type, - self.b_type, item_type)), - self.package, - self.a_type, self.b_type) + c_type = None + if isinstance(unistream_or_type, UniConstraintStream): + c_type = unistream_or_type.a_type + unistream_or_type = unistream_or_type.delegate + else: + c_type = get_class(unistream_or_type) + unistream_or_type = c_type + return BiConstraintStream(self.delegate.ifExists(unistream_or_type, + extract_joiners(joiners, + self.a_type, self.b_type, c_type)), + self.package, self.a_type, self.b_type) def if_exists_including_unassigned(self, item_type: Type[C_], *joiners: 'TriJoiner[A, B, C_]') -> \ 'BiConstraintStream[A,B]': """ - Create a new `BiConstraintStream` for every A, B where C exists that satisfy all specified joiners. + Create a new `BiConstraintStream` for every A, B where C exists that satisfies all specified joiners. """ item_type = get_class(item_type) return BiConstraintStream(self.delegate.ifExistsIncludingUnassigned(item_type, extract_joiners(joiners, self.a_type, self.b_type, item_type)), - self.package, + self.package, self.a_type, self.b_type) - self.a_type, self.b_type) + @overload + def if_not_exists(self, item_type: Type[C_], *joiners: 'TriJoiner[A, B, C_]') -> 'BiConstraintStream[A,B]': + ... - def if_not_exists(self, item_type: Type[C_], *joiners: 'TriJoiner[A, B, C_]') -> \ - 'BiConstraintStream[A,B]': + @overload + def if_not_exists(self, other_stream: 'UniConstraintStream[C_]', *joiners: 'TriJoiner[A, B, C_]')\ + -> 'BiConstraintStream[A,B]': + ... + + def if_not_exists(self, unistream_or_type: Union['UniConstraintStream[C_]', Type[C_]], + *joiners: 'TriJoiner[A, B, C_]') -> 'BiConstraintStream[A,B]': """ - Create a new `BiConstraintStream` for every A, B, where there does not exist a C where all specified joiners - are satisfied. - """ - item_type = get_class(item_type) - return BiConstraintStream(self.delegate.ifNotExists(item_type, extract_joiners(joiners, self.a_type, - self.b_type, item_type)), - self.package, - self.a_type, self.b_type) + Create a new `BiConstraintStream` for every A, B where C does not exist that satisfies all specified joiners. + """ + c_type = None + if isinstance(unistream_or_type, UniConstraintStream): + c_type = unistream_or_type.a_type + unistream_or_type = unistream_or_type.delegate + else: + c_type = get_class(unistream_or_type) + unistream_or_type = c_type + return BiConstraintStream(self.delegate.ifNotExists(unistream_or_type, + extract_joiners(joiners, + self.a_type, self.b_type, c_type)), + self.package, self.a_type, self.b_type) def if_not_exists_including_unassigned(self, item_type: Type[C_], *joiners: 'TriJoiner[A, B, C_]') -> \ 'BiConstraintStream[A,B]': """ - Create a new `BiConstraintStream` for every A, B, where there does not exist a C where all specified joiners - are satisfied. + Create a new `BiConstraintStream` for every A, B where C does not exist that satisfies all specified joiners. """ item_type = get_class(item_type) return BiConstraintStream(self.delegate.ifNotExistsIncludingUnassigned(item_type, extract_joiners(joiners, - self.a_type, - self.b_type, + self.a_type, self.b_type, item_type)), - self.package, - self.a_type, self.b_type) + self.package, self.a_type, self.b_type) @overload def group_by(self, key_mapping: Callable[[A, B], A_]) -> 'UniConstraintStream[A_]': @@ -1349,7 +1392,7 @@ def filter(self, predicate: Callable[[A, B, C], bool]) -> 'TriConstraintStream[A def join(self, unistream_or_type: Union[UniConstraintStream[D_], Type[D_]], *joiners: 'QuadJoiner[A, B, C, D_]') -> 'QuadConstraintStream[A,B,C,D_]': """ - Create a new `QuadConstraintStream` for every combination of A, B and C that satisfy all specified joiners. + Create a new `QuadConstraintStream` for every combination of A, B and C that satisfies all specified joiners. """ d_type = None if isinstance(unistream_or_type, UniConstraintStream): @@ -1359,54 +1402,85 @@ def join(self, unistream_or_type: Union[UniConstraintStream[D_], Type[D_]], d_type = get_class(unistream_or_type) unistream_or_type = d_type - join_result = self.delegate.join(unistream_or_type, extract_joiners(joiners, self.a_type, self.b_type, - self.c_type, d_type)) + join_result = self.delegate.join(unistream_or_type, extract_joiners(joiners, + self.a_type, self.b_type, self.c_type, + d_type)) return QuadConstraintStream(join_result, self.package, self.a_type, self.b_type, self.c_type, d_type) + @overload def if_exists(self, item_type: Type[D_], *joiners: 'QuadJoiner[A, B, C, D_]') -> \ 'TriConstraintStream[A,B,C]': + ... + + @overload + def if_exists(self, other_stream: 'UniConstraintStream[D_]', *joiners: 'QuadJoiner[A, B, C, D_]') -> \ + 'TriConstraintStream[A,B,C]': + ... + + def if_exists(self, unistream_or_type: Union['UniConstraintStream[D_]', Type[D_]], + *joiners: 'QuadJoiner[A, B, C, D_]') -> 'TriConstraintStream[A,B,C]': """ - Create a new `TriConstraintStream` for every A, B, C where D exists that satisfy all specified joiners. + Create a new `TriConstraintStream` for every A, B, C where D exists that satisfies all specified joiners. """ - item_type = get_class(item_type) - return TriConstraintStream(self.delegate.ifExists(item_type, extract_joiners(joiners, self.a_type, - self.b_type, self.c_type, - item_type)), self.package, - self.a_type, self.b_type, self.c_type) + d_type = None + if isinstance(unistream_or_type, UniConstraintStream): + d_type = unistream_or_type.a_type + unistream_or_type = unistream_or_type.delegate + else: + d_type = get_class(unistream_or_type) + unistream_or_type = d_type + return TriConstraintStream(self.delegate.ifExists(unistream_or_type, + extract_joiners(joiners, + self.a_type, self.b_type, self.c_type, + d_type)), + self.package, self.a_type, self.b_type, self.c_type) def if_exists_including_unassigned(self, item_type: Type[D_], *joiners: 'QuadJoiner[A, B, C, D_]') -> \ 'TriConstraintStream[A,B,C]': """ - Create a new `TriConstraintStream` for every A, B where D exists that satisfy all specified joiners. + Create a new `TriConstraintStream` for every A, B where D exists that satisfies all specified joiners. """ item_type = get_class(item_type) - return TriConstraintStream(self.delegate.ifExistsIncludingUnassigned(item_type, extract_joiners(joiners, - self.a_type, - self.b_type, - self.c_type, - item_type)), + return TriConstraintStream(self.delegate.ifExistsIncludingUnassigned(item_type, + extract_joiners(joiners, + self.a_type, self.b_type, + self.c_type, item_type)), self.package, self.a_type, self.b_type, self.c_type) + @overload def if_not_exists(self, item_type: Type[D_], *joiners: 'QuadJoiner[A, B, C, D_]') -> \ 'TriConstraintStream[A,B,C]': + ... + + @overload + def if_not_exists(self, other_stream: 'UniConstraintStream[D_]', *joiners: 'QuadJoiner[A, B, C, D_]') -> \ + 'TriConstraintStream[A,B,C]': + ... + + def if_not_exists(self, unistream_or_type: Union['UniConstraintStream[D_]', Type[D_]], + *joiners: 'QuadJoiner[A, B, C, D_]') -> 'TriConstraintStream[A,B,C]': """ - Create a new `TriConstraintStream` for every A, B, C where there does not exist a D where all specified joiners - are satisfied. + Create a new `TriConstraintStream` for every A, B, C where D does not exist + that satisfies all specified joiners. """ - item_type = get_class(item_type) - return TriConstraintStream(self.delegate.ifNotExists(item_type, extract_joiners(joiners, - self.a_type, - self.b_type, - self.c_type, - item_type)), + d_type = None + if isinstance(unistream_or_type, UniConstraintStream): + d_type = unistream_or_type.a_type + unistream_or_type = unistream_or_type.delegate + else: + d_type = get_class(unistream_or_type) + unistream_or_type = d_type + return TriConstraintStream(self.delegate.ifNotExists(unistream_or_type, + extract_joiners(joiners, + self.a_type, self.b_type, self.c_type, + d_type)), self.package, self.a_type, self.b_type, self.c_type) def if_not_exists_including_unassigned(self, item_type: Type[D_], *joiners: 'QuadJoiner[A, B, C, D_]') -> \ 'TriConstraintStream[A,B,C]': """ - Create a new `TriConstraintStream` for every A, B, C where there does not exist a D where all specified joiners - are satisfied. + Create a new `TriConstraintStream` for every A, B, C where D does not exist that satisfies all specified joiners. """ item_type = get_class(item_type) return TriConstraintStream(self.delegate.ifNotExistsIncludingUnassigned(item_type, @@ -1575,8 +1649,8 @@ def map(self, *mapping_functions): raise ValueError(f'At least one mapping function is required for map.') if len(mapping_functions) > 4: raise ValueError(f'At most four mapping functions can be passed to map (got {len(mapping_functions)}).') - translated_functions = tuple(map(lambda mapping_function: function_cast(mapping_function, self.a_type, - self.b_type, self.c_type), + translated_functions = tuple(map(lambda mapping_function: function_cast(mapping_function, + self.a_type, self.b_type, self.c_type), mapping_functions)) if len(mapping_functions) == 1: return UniConstraintStream(self.delegate.map(*translated_functions), self.package, @@ -1989,25 +2063,38 @@ def filter(self, predicate: Callable[[A, B, C, D], bool]) -> 'QuadConstraintStre self.a_type, self.b_type, self.c_type, self.d_type) + @overload def if_exists(self, item_type: Type[E_], *joiners: 'PentaJoiner[A, B, C, D, E_]') -> \ 'QuadConstraintStream[A,B,C,D]': + ... + + @overload + def if_exists(self, other_stream: 'UniConstraintCollector[E_]', *joiners: 'PentaJoiner[A, B, C, D, E_]') -> \ + 'QuadConstraintStream[A,B,C,D]': + ... + + def if_exists(self, unistream_or_type: Union['UniConstraintStream[E_]', Type[E_]], + *joiners: 'PentaJoiner[A, B, C, D, E_]') -> 'QuadConstraintStream[A,B,C,D]': """ - Create a new `QuadConstraintStream` for every A, B, C, D where E exists that satisfy all specified joiners. + Create a new `QuadConstraintStream` for every A, B, C, D where E exists that satisfies all specified joiners. """ - item_type = get_class(item_type) - return QuadConstraintStream(self.delegate.ifExists(item_type, extract_joiners(joiners, - self.a_type, - self.b_type, - self.c_type, - self.d_type, - item_type)), - self.package, - self.a_type, self.b_type, self.c_type, self.d_type) + e_type = None + if isinstance(unistream_or_type, UniConstraintStream): + e_type = unistream_or_type.a_type + unistream_or_type = unistream_or_type.delegate + else: + e_type = get_class(unistream_or_type) + unistream_or_type = e_type + return QuadConstraintStream(self.delegate.ifExists(unistream_or_type, + extract_joiners(joiners, + self.a_type, self.b_type, self.c_type, + self.d_type, e_type)), + self.package, self.a_type, self.b_type, self.c_type, self.d_type) def if_exists_including_unassigned(self, item_type: Type[E_], *joiners: 'PentaJoiner[A, B, C, D, E_]') -> \ 'QuadConstraintStream[A,B,C,D]': """ - Create a new `QuadConstraintStream` for every A, B, C, D where E exists that satisfy all specified joiners. + Create a new `QuadConstraintStream` for every A, B, C, D where E exists that satisfies all specified joiners. """ item_type = get_class(item_type) return QuadConstraintStream(self.delegate.ifExistsIncludingUnassigned(item_type, @@ -2020,27 +2107,40 @@ def if_exists_including_unassigned(self, item_type: Type[E_], *joiners: 'PentaJo self.package, self.a_type, self.b_type, self.c_type, self.d_type) + @overload def if_not_exists(self, item_type: Type[E_], *joiners: 'PentaJoiner[A, B, C, D, E_]') -> \ 'QuadConstraintStream[A,B,C,D]': + ... + + @overload + def if_not_exists(self, other_stream: 'UniConstraintCollector[E_]', *joiners: 'PentaJoiner[A, B, C, D, E_]') -> \ + 'QuadConstraintStream[A,B,C,D]': + ... + + def if_not_exists(self, unistream_or_type: Union['UniConstraintStream[E_]', Type[E_]], + *joiners: 'PentaJoiner[A, B, C, D, E_]') -> 'QuadConstraintStream[A,B,C,D]': """ - Create a new `QuadConstraintStream` for every A, B, C, D where there does not exist an E where all specified - joiners are satisfied. + Create a new `QuadConstraintStream` for every A, B, C, D where E does not exist + that satisfies all specified joiners. """ - item_type = get_class(item_type) - return QuadConstraintStream(self.delegate.ifNotExists(item_type, extract_joiners(joiners, - self.a_type, - self.b_type, - self.c_type, - self.d_type, - item_type)), - self.package, - self.a_type, self.b_type, self.c_type, self.d_type) + e_type = None + if isinstance(unistream_or_type, UniConstraintStream): + e_type = unistream_or_type.a_type + unistream_or_type = unistream_or_type.delegate + else: + e_type = get_class(unistream_or_type) + unistream_or_type = e_type + return QuadConstraintStream(self.delegate.ifNotExists(unistream_or_type, + extract_joiners(joiners, + self.a_type, self.b_type, self.c_type, + self.d_type, e_type)), + self.package, self.a_type, self.b_type, self.c_type, self.d_type) def if_not_exists_including_unassigned(self, item_type: Type[E_], *joiners: 'PentaJoiner[A, B, C, D, E_]') -> \ 'QuadConstraintStream[A,B,C,D]': """ - Create a new `QuadConstraintStream` for every A, B, C, D where there does not exist an E where all specified - joiners are satisfied. + Create a new `QuadConstraintStream` for every A, B, C, + D where E does not exist that satisfies all specified joiners. """ item_type = get_class(item_type) return QuadConstraintStream(self.delegate.ifNotExistsIncludingUnassigned(item_type, From 552b30738b91b156677190c1e4198fb4e3bd903b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Petrovick=C3=BD?= Date: Fri, 28 Jun 2024 12:20:09 +0200 Subject: [PATCH 5/9] chore: add SingleConstraintAssertion overloads --- tests/test_constraint_streams.py | 12 +- tests/test_constraint_verifier.py | 12 +- .../src/main/python/test/__init__.py | 350 +++++++++++++++++- 3 files changed, 363 insertions(+), 11 deletions(-) diff --git a/tests/test_constraint_streams.py b/tests/test_constraint_streams.py index 555906e8..86088008 100644 --- a/tests/test_constraint_streams.py +++ b/tests/test_constraint_streams.py @@ -721,6 +721,7 @@ def define_constraints(constraint_factory: ConstraintFactory): def test_has_all_methods(): + missing = [] for python_type, java_type in ((UniConstraintStream, JavaUniConstraintStream), (BiConstraintStream, JavaBiConstraintStream), (TriConstraintStream, JavaTriConstraintStream), @@ -732,7 +733,6 @@ def test_has_all_methods(): (Joiners, JavaJoiners), (ConstraintCollectors, JavaConstraintCollectors), (ConstraintFactory, JavaConstraintFactory)): - missing = [] for function_name, function_impl in inspect.getmembers(java_type, inspect.isfunction): if function_name in ignored_java_functions: continue @@ -745,8 +745,10 @@ def test_has_all_methods(): # change h_t_t_p -> http snake_case_name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake_case_name).lower() if not hasattr(python_type, snake_case_name): - missing.append(snake_case_name) + missing.append((java_type, python_type, snake_case_name)) - if missing: - raise AssertionError(f'{python_type} is missing methods ({missing}) ' - f'from java_type ({java_type}).)') + if missing: + assertion_msg = '' + for java_type, python_type, snake_case_name in missing: + assertion_msg += f'{python_type} is missing a method ({snake_case_name}) from java_type ({java_type}).)\n' + raise AssertionError(assertion_msg) diff --git a/tests/test_constraint_verifier.py b/tests/test_constraint_verifier.py index b9bc1a4c..bddaff5f 100644 --- a/tests/test_constraint_verifier.py +++ b/tests/test_constraint_verifier.py @@ -285,17 +285,17 @@ class Solution: 'notifyAll', 'toString', 'wait', - 'with_constraint_stream_impl_type' + 'withConstraintStreamImplType' } def test_has_all_methods(): + missing = [] for python_type, java_type in ((ConstraintVerifier, JavaConstraintVerifier), (SingleConstraintAssertion, JavaSingleConstraintAssertion), (SingleConstraintVerification, JavaSingleConstraintVerification), (MultiConstraintAssertion, JavaMultiConstraintAssertion), (MultiConstraintVerification, JavaMultiConstraintVerification)): - missing = [] for function_name, function_impl in inspect.getmembers(java_type, inspect.isfunction): if function_name in ignored_java_functions: continue @@ -303,9 +303,11 @@ def test_has_all_methods(): # change h_t_t_p -> http snake_case_name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake_case_name).lower() if not hasattr(python_type, snake_case_name): - missing.append(snake_case_name) + missing.append((java_type, python_type, snake_case_name)) if missing: - raise AssertionError(f'{python_type} is missing methods ({missing}) ' - f'from java_type ({java_type}).)') + assertion_msg = '' + for java_type, python_type, snake_case_name in missing: + assertion_msg += f'{python_type} is missing a method ({snake_case_name}) from java_type ({java_type}).)\n' + raise AssertionError(assertion_msg) diff --git a/timefold-solver-python-core/src/main/python/test/__init__.py b/timefold-solver-python-core/src/main/python/test/__init__.py index cc036b77..e5056bb8 100644 --- a/timefold-solver-python-core/src/main/python/test/__init__.py +++ b/timefold-solver-python-core/src/main/python/test/__init__.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: # These imports require a JVM to be running, so only import if type checking - from ai.timefold.solver.core.api.score.stream import Constraint, ConstraintFactory + from ai.timefold.solver.core.api.score.stream import Constraint, ConstraintFactory, ConstraintJustification from ai.timefold.solver.core.config.solver import SolverConfig from ai.timefold.solver.core.api.score import Score @@ -173,6 +173,112 @@ class SingleConstraintAssertion: def __init__(self, delegate): self.delegate = delegate + def justifies_with(self, message: str = None, *justifications: 'ConstraintJustification') \ + -> 'SingleConstraintAssertion': + """ + Asserts that the constraint being tested, given a set of facts, results in given justifications. + + Parameters + ---------- + justifications : ConstraintVerifier + zero or more justification to check for + + message : str, optional + description of the scenario being asserted + + Raises + ------ + AssertionError + when the expected justifications are not observed + """ + from java.lang import AssertionError as JavaAssertionError # noqa + try: + if message is None: + return self.delegate.justifiesWith(justifications) + else: + return self.delegate.justifiesWith(message, justifications) + except JavaAssertionError as e: + raise AssertionError(e.getMessage()) + + def justifies_with_exactly(self, message: str = None, *justifications: 'ConstraintJustification') \ + -> 'SingleConstraintAssertion': + """ + Asserts that the constraint being tested, given a set of facts, results in given justifications an no others. + + Parameters + ---------- + justifications : ConstraintVerifier + zero or more justification to check for + + message : str, optional + description of the scenario being asserted + + Raises + ------ + AssertionError + when the expected justifications are not observed + """ + from java.lang import AssertionError as JavaAssertionError # noqa + try: + if message is None: + return self.delegate.justifiesWithExactly(justifications) + else: + return self.delegate.justifiesWithExactly(message, justifications) + except JavaAssertionError as e: + raise AssertionError(e.getMessage()) + + def indicts_with(self, message: str = None, *indictments) -> 'SingleConstraintAssertion': + """ + Asserts that the constraint being tested, given a set of facts, results in given indictments. + + Parameters + ---------- + indictments : ConstraintVerifier + zero or more indictments to check for + + message : str, optional + description of the scenario being asserted + + Raises + ------ + AssertionError + when the expected indictments are not observed + """ + from java.lang import AssertionError as JavaAssertionError # noqa + try: + if message is None: + return self.delegate.indictsWith(indictments) + else: + return self.delegate.indictsWith(message, indictments) + except JavaAssertionError as e: + raise AssertionError(e.getMessage()) + + def indicts_with_exactly(self, message: str = None, *indictments) -> 'SingleConstraintAssertion': + """ + Asserts that the constraint being tested, given a set of facts, results in given indictments an no others. + + Parameters + ---------- + indictments : ConstraintVerifier + zero or more justification to check for + + message : str, optional + description of the scenario being asserted + + Raises + ------ + AssertionError + when the expected indictments are not observed + """ + from java.lang import AssertionError as JavaAssertionError # noqa + try: + if message is None: + return self.delegate.indictsWithExactly(indictments) + else: + return self.delegate.indictsWithExactly(message, indictments) + except JavaAssertionError as e: + raise AssertionError(e.getMessage()) + def penalizes(self, times: int = None, message: str = None) -> None: """ Asserts that the Constraint being tested, given a set of facts, results in a given number of penalties. @@ -208,6 +314,66 @@ def penalizes(self, times: int = None, message: str = None) -> None: except JavaAssertionError as e: raise AssertionError(e.getMessage()) + def penalizes_less_than(self, times: int, message: str = None) -> None: + """ + Asserts that the Constraint being tested, given a set of facts, + results in less than a given number of penalties. + + Ignores the constraint and match weights: it only asserts the number of matches + For example: if there are two matches with weight of 10 each, this assertion will check for 2 matches. + + Parameters + ---------- + times : int + the expected number of penalties. + + message : str, optional + description of the scenario being asserted + + Raises + ------ + AssertionError + when the expected penalty is not observed if `times` is provided + """ + from java.lang import AssertionError as JavaAssertionError # noqa + try: + if times is not None and message is None: + self.delegate.penalizesLessThan(times) + else: + self.delegate.penalizesLessThan(times, message) + except JavaAssertionError as e: + raise AssertionError(e.getMessage()) + + def penalizes_more_than(self, times: int, message: str = None) -> None: + """ + Asserts that the Constraint being tested, given a set of facts, + results in more than a given number of penalties. + + Ignores the constraint and match weights: it only asserts the number of matches + For example: if there are two matches with weight of 10 each, this assertion will check for 2 matches. + + Parameters + ---------- + times : int + the expected number of penalties. + + message : str, optional + description of the scenario being asserted + + Raises + ------ + AssertionError + when the expected penalty is not observed if `times` is provided + """ + from java.lang import AssertionError as JavaAssertionError # noqa + try: + if times is not None and message is None: + self.delegate.penalizesMoreThan(times) + else: + self.delegate.penalizesMoreThan(times, message) + except JavaAssertionError as e: + raise AssertionError(e.getMessage()) + def penalizes_by(self, match_weight_total: int, message: str = None): """ Asserts that the `Constraint` being tested, given a set of facts, results in a specific penalty. @@ -238,6 +404,68 @@ def penalizes_by(self, match_weight_total: int, message: str = None): except JavaAssertionError as e: raise AssertionError(e.getMessage()) + def penalizes_by_less_than(self, match_weight_total: int, message: str = None): + """ + Asserts that the `Constraint` being tested, given a set of facts, results in less than a specific penalty. + + Ignores the constraint weight: it only asserts the match weights. + For example: a match with a match weight of 10 on a constraint with a constraint weight of -2hard reduces the + score by -20hard. + In that case, this assertion checks for 10. + + Parameters + ---------- + match_weight_total : int + the expected penalty + + message : str, optional + description of the scenario being asserted + + Raises + ------ + AssertionError + when the expected penalty is not observed + """ + from java.lang import AssertionError as JavaAssertionError # noqa + try: + if message is None: + self.delegate.penalizesByLessThan(match_weight_total) + else: + self.delegate.penalizesByLessThan(match_weight_total, message) + except JavaAssertionError as e: + raise AssertionError(e.getMessage()) + + def penalizes_by_more_than(self, match_weight_total: int, message: str = None): + """ + Asserts that the `Constraint` being tested, given a set of facts, results in more than a specific penalty. + + Ignores the constraint weight: it only asserts the match weights. + For example: a match with a match weight of 10 on a constraint with a constraint weight of -2hard reduces the + score by -20hard. + In that case, this assertion checks for 10. + + Parameters + ---------- + match_weight_total : int + the expected penalty + + message : str, optional + description of the scenario being asserted + + Raises + ------ + AssertionError + when the expected penalty is not observed + """ + from java.lang import AssertionError as JavaAssertionError # noqa + try: + if message is None: + self.delegate.penalizesByMoreThan(match_weight_total) + else: + self.delegate.penalizesByMoreThan(match_weight_total, message) + except JavaAssertionError as e: + raise AssertionError(e.getMessage()) + def rewards(self, times: int = None, message: str = None): """ Asserts that the Constraint being tested, given a set of facts, results in a given number of rewards. @@ -273,6 +501,66 @@ def rewards(self, times: int = None, message: str = None): except JavaAssertionError as e: raise AssertionError(e.getMessage()) + def rewards_less_than(self, times: int, message: str = None): + """ + Asserts that the Constraint being tested, given a set of facts, + results in a less than a given number of rewards. + + Ignores the constraint and match weights: it only asserts the number of matches + For example: if there are two matches with weight of 10 each, this assertion will check for 2 matches. + + Parameters + ---------- + times : int + the expected number of rewards. + + message : str, optional + description of the scenario being asserted + + Raises + ------ + AssertionError + when the expected reward is not observed if times is provided + """ + from java.lang import AssertionError as JavaAssertionError # noqa + try: + if times is not None and message is None: + self.delegate.rewardsLessThan(times) + else: + self.delegate.rewardsLessThan(times, message) + except JavaAssertionError as e: + raise AssertionError(e.getMessage()) + + def rewards_more_than(self, times: int, message: str = None): + """ + Asserts that the Constraint being tested, given a set of facts, + results in more than a given number of rewards. + + Ignores the constraint and match weights: it only asserts the number of matches + For example: if there are two matches with weight of 10 each, this assertion will check for 2 matches. + + Parameters + ---------- + times : int + the expected number of rewards. + + message : str, optional + description of the scenario being asserted + + Raises + ------ + AssertionError + when the expected reward is not observed if times is provided + """ + from java.lang import AssertionError as JavaAssertionError # noqa + try: + if times is not None and message is None: + self.delegate.rewardsMoreThan(times) + else: + self.delegate.rewardsMoreThan(times, message) + except JavaAssertionError as e: + raise AssertionError(e.getMessage()) + def rewards_with(self, match_weight_total: int, message: str = None): """ Asserts that the Constraint being tested, given a set of facts, results in a specific reward. @@ -303,6 +591,66 @@ def rewards_with(self, match_weight_total: int, message: str = None): except JavaAssertionError as e: raise AssertionError(e.getMessage()) + def rewards_with_less_than(self, match_weight_total: int, message: str = None): + """ + Asserts that the Constraint being tested, given a set of facts, results in less than a specific reward. + Ignores the constraint weight: it only asserts the match weights. + For example: a match with a match weight of 10 on a constraint with a constraint weight of + -2hard reduces the score by -20hard. + In that case, this assertion checks for 10. + + Parameters + ---------- + match_weight_total : int + at least 0, expected sum of match weights of matches of the constraint. + + message : str, optional + description of the scenario being asserted + + Raises + ------ + AssertionError + when the expected reward is not observed + """ + from java.lang import AssertionError as JavaAssertionError # noqa + try: + if message is None: + self.delegate.rewardsWithLessThan(match_weight_total) + else: + self.delegate.rewardsWithLessThan(match_weight_total, message) + except JavaAssertionError as e: + raise AssertionError(e.getMessage()) + + def rewards_with_more_than(self, match_weight_total: int, message: str = None): + """ + Asserts that the Constraint being tested, given a set of facts, results in more than a specific reward. + Ignores the constraint weight: it only asserts the match weights. + For example: a match with a match weight of 10 on a constraint with a constraint weight of + -2hard reduces the score by -20hard. + In that case, this assertion checks for 10. + + Parameters + ---------- + match_weight_total : int + at least 0, expected sum of match weights of matches of the constraint. + + message : str, optional + description of the scenario being asserted + + Raises + ------ + AssertionError + when the expected reward is not observed + """ + from java.lang import AssertionError as JavaAssertionError # noqa + try: + if message is None: + self.delegate.rewardsWithMoreThan(match_weight_total) + else: + self.delegate.rewardsWithMoreThan(match_weight_total, message) + except JavaAssertionError as e: + raise AssertionError(e.getMessage()) + class MultiConstraintAssertion: def __init__(self, delegate): From 97d496bb5b2f56f9d3cbd48d4e201f78969eb569 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Petrovick=C3=BD?= Date: Wed, 3 Jul 2024 07:56:24 +0200 Subject: [PATCH 6/9] wip: add load_balance test --- tests/test_collectors.py | 29 ++++++++++++++----- tests/test_constraint_streams.py | 23 +++++++++++++++ .../src/main/python/score/_group_by.py | 21 +++++++------- 3 files changed, 56 insertions(+), 17 deletions(-) diff --git a/tests/test_collectors.py b/tests/test_collectors.py index f581e7b5..d0b26b96 100644 --- a/tests/test_collectors.py +++ b/tests/test_collectors.py @@ -563,24 +563,39 @@ def define_constraints(constraint_factory: ConstraintFactory): assert score_manager.explain(problem).score == SimpleScore.of(4) -def test_flatten_last(): +def test_load_balance(): @constraint_provider def define_constraints(constraint_factory: ConstraintFactory): return [ constraint_factory.for_each(Entity) - .map(lambda entity: (1, 2, 3)) - .flatten_last(lambda the_tuple: the_tuple) - .reward(SimpleScore.ONE) - .as_constraint('Count') + .group_by(ConstraintCollectors.load_balance( + lambda entity: entity.value + )) + .reward(SimpleScore.ONE, + lambda balance: balance.unfairness().multiply(BigDecimal.valueOf(1000)).intValue()) + .as_constraint('Balanced value') ] score_manager = create_score_manager(define_constraints) entity_a: Entity = Entity('A') + entity_b: Entity = Entity('B') + entity_c: Entity = Entity('C') value_1 = Value(1) + value_2 = Value(2) - problem = Solution([entity_a], [value_1]) + problem = Solution([entity_a, entity_b, entity_c], [value_1, value_2]) entity_a.value = value_1 + entity_b.value = value_1 + entity_c.value = value_1 - assert score_manager.explain(problem).score == SimpleScore.of(3) + assert score_manager.explain(problem).score == SimpleScore.of(0) + + entity_c.value = value_2 + + assert score_manager.explain(problem).score == SimpleScore.of(2) // FIXME + + entity_b.value = value_2 + + assert score_manager.explain(problem).score == SimpleScore.of(4) // FIXME diff --git a/tests/test_constraint_streams.py b/tests/test_constraint_streams.py index 86088008..f22c284d 100644 --- a/tests/test_constraint_streams.py +++ b/tests/test_constraint_streams.py @@ -223,6 +223,29 @@ def define_constraints(constraint_factory: ConstraintFactory): assert score_manager.explain(problem).score.score == 1 +def test_flatten_last(): + @constraint_provider + def define_constraints(constraint_factory: ConstraintFactory): + return [ + constraint_factory.for_each(Entity) + .map(lambda entity: (1, 2, 3)) + .flatten_last(lambda the_tuple: the_tuple) + .reward(SimpleScore.ONE) + .as_constraint('Count') + ] + + score_manager = create_score_manager(define_constraints) + + entity_a: Entity = Entity('A') + + value_1 = Value(1) + + problem = Solution([entity_a], [value_1]) + entity_a.value = value_1 + + assert score_manager.explain(problem).score == SimpleScore.of(3) + + def test_join_uni(): @constraint_provider def define_constraints(constraint_factory: ConstraintFactory): diff --git a/timefold-solver-python-core/src/main/python/score/_group_by.py b/timefold-solver-python-core/src/main/python/score/_group_by.py index 578705d9..f2626c90 100644 --- a/timefold-solver-python-core/src/main/python/score/_group_by.py +++ b/timefold-solver-python-core/src/main/python/score/_group_by.py @@ -122,18 +122,18 @@ def perform_group_by(constraint_stream, package, group_by_args, *type_arguments) created_collector = extract_collector(collector_info, *type_arguments) actual_group_by_args.append(created_collector) - if len(group_by_args) == 1: + if len(group_by_args) is 1: return UniConstraintStream(constraint_stream.groupBy(*actual_group_by_args), package, JClass('java.lang.Object')) - elif len(group_by_args) == 2: + elif len(group_by_args) is 2: return BiConstraintStream(constraint_stream.groupBy(*actual_group_by_args), package, JClass('java.lang.Object'), JClass('java.lang.Object')) - elif len(group_by_args) == 3: + elif len(group_by_args) is 3: return TriConstraintStream(constraint_stream.groupBy(*actual_group_by_args), package, JClass('java.lang.Object'), JClass('java.lang.Object'), JClass('java.lang.Object')) - elif len(group_by_args) == 4: + elif len(group_by_args) is 4: return QuadConstraintStream(constraint_stream.groupBy(*actual_group_by_args), package, JClass('java.lang.Object'), JClass('java.lang.Object'), JClass('java.lang.Object'), @@ -1107,27 +1107,28 @@ def load_balance(balanced_item_function, load_function=None, initial_load_functi Parameters ---------- - balanced_item_function: + balanced_item_function: Callable[[ParameterTypes, ...], Balanced_] The function that returns the item which should be load-balanced. - load_function: + load_function: Callable[[ParameterTypes, ...], int] How much the item should count for in the formula. - initial_load_function: + initial_load_function: Callable[[ParameterTypes, ...], int] The initial value of the metric, allowing to provide initial state without requiring the entire previous planning windows in the working memory. If this function is provided, load_function must be provided as well. """ - if None == load_function == initial_load_function: + if load_function is None and initial_load_function is None: return LoadBalanceCollector(ConstraintCollectors._delegate().loadBalance, balanced_item_function, None, None) - elif None == initial_load_function: + elif initial_load_function is None: return LoadBalanceCollector(ConstraintCollectors._delegate().loadBalance, balanced_item_function, load_function, None) - elif None == load_function: + elif load_function is None: raise ValueError("load_function cannot be None if initial_load_function is not None") else: return LoadBalanceCollector(ConstraintCollectors._delegate().loadBalance, balanced_item_function, load_function, initial_load_function) + # Must be at the bottom, constraint_stream depends on this module from ._constraint_stream import * from ._function_translator import * From 18ee7cda1ee46280d644f9e22a1417fa382c4855 Mon Sep 17 00:00:00 2001 From: Christopher Chianelli Date: Mon, 8 Jul 2024 00:11:16 -0400 Subject: [PATCH 7/9] test: unfairness test --- tests/test_collectors.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_collectors.py b/tests/test_collectors.py index d0b26b96..fbb23d9a 100644 --- a/tests/test_collectors.py +++ b/tests/test_collectors.py @@ -572,7 +572,7 @@ def define_constraints(constraint_factory: ConstraintFactory): lambda entity: entity.value )) .reward(SimpleScore.ONE, - lambda balance: balance.unfairness().multiply(BigDecimal.valueOf(1000)).intValue()) + lambda balance: balance.unfairness().movePointRight(3).intValue()) .as_constraint('Balanced value') ] @@ -594,8 +594,8 @@ def define_constraints(constraint_factory: ConstraintFactory): entity_c.value = value_2 - assert score_manager.explain(problem).score == SimpleScore.of(2) // FIXME + assert score_manager.explain(problem).score == SimpleScore.of(707) entity_b.value = value_2 - assert score_manager.explain(problem).score == SimpleScore.of(4) // FIXME + assert score_manager.explain(problem).score == SimpleScore.of(707) From d454af1c6f9bf02808cb79bd357052feb5110dca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luk=C3=A1=C5=A1=20Petrovick=C3=BD?= Date: Mon, 8 Jul 2024 07:33:17 +0200 Subject: [PATCH 8/9] test: better test data --- tests/test_collectors.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_collectors.py b/tests/test_collectors.py index fbb23d9a..2c589d29 100644 --- a/tests/test_collectors.py +++ b/tests/test_collectors.py @@ -585,17 +585,17 @@ def define_constraints(constraint_factory: ConstraintFactory): value_1 = Value(1) value_2 = Value(2) - problem = Solution([entity_a, entity_b, entity_c], [value_1, value_2]) + problem = Solution([entity_a, entity_b], [value_1]) entity_a.value = value_1 entity_b.value = value_1 entity_c.value = value_1 assert score_manager.explain(problem).score == SimpleScore.of(0) - entity_c.value = value_2 + problem = Solution([entity_a, entity_b, entity_c], [value_1, value_2]) - assert score_manager.explain(problem).score == SimpleScore.of(707) + assert score_manager.explain(problem).score == SimpleScore.of(0) - entity_b.value = value_2 + entity_c.value = value_2 assert score_manager.explain(problem).score == SimpleScore.of(707) From 1643b2316cf8677a231db67d3e572fb075dae4b2 Mon Sep 17 00:00:00 2001 From: Christopher Chianelli Date: Mon, 8 Jul 2024 09:35:06 -0400 Subject: [PATCH 9/9] fix: Add tests for new ConstraintVerifier methods --- tests/test_constraint_verifier.py | 110 +++++++++++++++++- .../src/main/python/test/__init__.py | 52 +++++++-- 2 files changed, 146 insertions(+), 16 deletions(-) diff --git a/tests/test_constraint_verifier.py b/tests/test_constraint_verifier.py index bddaff5f..e0c300e9 100644 --- a/tests/test_constraint_verifier.py +++ b/tests/test_constraint_verifier.py @@ -17,6 +17,7 @@ MultiConstraintVerification as JavaMultiConstraintVerification) def verifier_suite(verifier: ConstraintVerifier, same_value, is_value_one, + EntityValueIndictment, EntityValueJustification, EntityValuePairJustification, solution, e1, e2, e3, v1, v2, v3): verifier.verify_that(same_value) \ .given(e1, e2) \ @@ -37,6 +38,11 @@ def verifier_suite(verifier: ConstraintVerifier, same_value, is_value_one, .given(e1, e2) \ .penalizes(1) + with pytest.raises(AssertionError): + verifier.verify_that(same_value) \ + .given(e1, e2) \ + .indicts_with_exactly(EntityValueIndictment(e1, v1)) + e1.value = v1 e2.value = v1 e3.value = v1 @@ -47,7 +53,47 @@ def verifier_suite(verifier: ConstraintVerifier, same_value, is_value_one, verifier.verify_that(same_value) \ .given(e1, e2) \ - .penalizes() + .penalizes(1) + + verifier.verify_that(same_value) \ + .given(e1, e2) \ + .indicts_with(EntityValueIndictment(e1, e1.value), EntityValueIndictment(e2, v1)) \ + .penalizes_by(1) + + verifier.verify_that(same_value) \ + .given(e1, e2) \ + .indicts_with_exactly(EntityValueIndictment(e1, e1.value), EntityValueIndictment(e2, v1)) \ + .penalizes_by(1) + + verifier.verify_that(same_value) \ + .given(e1, e2) \ + .indicts_with(EntityValueIndictment(e1, v1)) + + verifier.verify_that(same_value) \ + .given(e1, e2) \ + .justifies_with(EntityValuePairJustification((e1, e2), v1, SimpleScore(-1))) \ + .penalizes_by(1) + + verifier.verify_that(same_value) \ + .given(e1, e2) \ + .justifies_with_exactly(EntityValuePairJustification((e1, e2), v1, SimpleScore(-1))) \ + .penalizes_by(1) + + with pytest.raises(AssertionError): + verifier.verify_that(same_value) \ + .given(e1, e2) \ + .indicts_with_exactly(EntityValueIndictment(e1, v1)) + + + with pytest.raises(AssertionError): + verifier.verify_that(same_value) \ + .given(e1, e2) \ + .justifies_with(EntityValuePairJustification((e1, e2), v1, SimpleScore(1))) + + with pytest.raises(AssertionError): + verifier.verify_that(same_value) \ + .given(e1, e2, e3) \ + .justifies_with_exactly(EntityValuePairJustification((e1, e2), v1, SimpleScore(1))) with pytest.raises(AssertionError): verifier.verify_that(same_value) \ @@ -68,10 +114,28 @@ def verifier_suite(verifier: ConstraintVerifier, same_value, is_value_one, .given(e1, e2, e3) \ .penalizes(3) + verifier.verify_that(same_value) \ + .given(e1, e2, e3) \ + .penalizes_more_than(2) + + verifier.verify_that(same_value) \ + .given(e1, e2, e3) \ + .penalizes_less_than(4) + verifier.verify_that(same_value) \ .given(e1, e2, e3) \ .penalizes() + with pytest.raises(AssertionError): + verifier.verify_that(same_value) \ + .given(e1, e2, e3) \ + .penalizes_more_than(3) + + with pytest.raises(AssertionError): + verifier.verify_that(same_value) \ + .given(e1, e2, e3) \ + .penalizes_less_than(3) + with pytest.raises(AssertionError): verifier.verify_that(same_value) \ .given(e1, e2, e3) \ @@ -199,21 +263,55 @@ def verifier_suite(verifier: ConstraintVerifier, same_value, is_value_one, def test_constraint_verifier_create(): - @dataclass + @dataclass(unsafe_hash=True) class Value: code: str + def __str__(self): + return f'Value({self.code})' + @planning_entity - @dataclass + @dataclass(unsafe_hash=True) class Entity: code: str - value: Annotated[Value, PlanningVariable] = field(default=None) + value: Annotated[Value | None, PlanningVariable] = field(default=None) + + def __str__(self): + return f'Entity({self.code}, {self.value})' + + @dataclass(unsafe_hash=True) + class EntityValueIndictment: + entity: Entity + value: Value + + def __str__(self): + return f'EntityValueIndictment({self.entity}, {self.value})' + + @dataclass(unsafe_hash=True) + class EntityValueJustification(ConstraintJustification): + entity: Entity + value: Value + score: SimpleScore + + def __str__(self): + return f'EntityValueJustification({self.entity}, {self.value}, {self.score})' + + @dataclass(unsafe_hash=True) + class EntityValuePairJustification(ConstraintJustification): + entities: tuple[Entity] + value: Value + score: SimpleScore + + def __str__(self): + return f'EntityValuePairJustification({self.entities}, {self.value}, {self.score})' def same_value(constraint_factory: ConstraintFactory): return (constraint_factory.for_each(Entity) .join(Entity, Joiners.less_than(lambda e: e.code), Joiners.equal(lambda e: e.value)) .penalize(SimpleScore.ONE) + .indict_with(lambda e1, e2: [EntityValueIndictment(e1, e1.value), EntityValueIndictment(e2, e2.value)]) + .justify_with(lambda e1, e2, score: EntityValuePairJustification((e1, e2), e1.value, score)) .as_constraint('Same Value') ) @@ -221,6 +319,8 @@ def is_value_one(constraint_factory: ConstraintFactory): return (constraint_factory.for_each(Entity) .filter(lambda e: e.value.code == 'v1') .reward(SimpleScore.ONE) + .indict_with(lambda e: [EntityValueIndictment(e, e.value)]) + .justify_with(lambda e, score: EntityValueJustification(e, e.value, score)) .as_constraint('Value 1') ) @@ -259,6 +359,7 @@ class Solution: solution = Solution([e1, e2, e3], [v1, v2, v3]) verifier_suite(verifier, same_value, is_value_one, + EntityValueIndictment, EntityValueJustification, EntityValuePairJustification, solution, e1, e2, e3, v1, v2, v3) verifier = ConstraintVerifier.build(my_constraints, Solution, Entity) @@ -274,6 +375,7 @@ class Solution: solution = Solution([e1, e2, e3], [v1, v2, v3]) verifier_suite(verifier, same_value, is_value_one, + EntityValueIndictment, EntityValueJustification, EntityValuePairJustification, solution, e1, e2, e3, v1, v2, v3) diff --git a/timefold-solver-python-core/src/main/python/test/__init__.py b/timefold-solver-python-core/src/main/python/test/__init__.py index e5056bb8..e3fd4333 100644 --- a/timefold-solver-python-core/src/main/python/test/__init__.py +++ b/timefold-solver-python-core/src/main/python/test/__init__.py @@ -173,7 +173,7 @@ class SingleConstraintAssertion: def __init__(self, delegate): self.delegate = delegate - def justifies_with(self, message: str = None, *justifications: 'ConstraintJustification') \ + def justifies_with(self, *justifications: 'ConstraintJustification', message: str = None) \ -> 'SingleConstraintAssertion': """ Asserts that the constraint being tested, given a set of facts, results in given justifications. @@ -192,15 +192,22 @@ def justifies_with(self, message: str = None, *justifications: 'ConstraintJustif when the expected justifications are not observed """ from java.lang import AssertionError as JavaAssertionError # noqa + from _jpyinterpreter import convert_to_java_python_like_object + from java.util import HashMap + reference_map = HashMap() + wrapped_justifications = [] + for justification in justifications: + wrapped_justification = convert_to_java_python_like_object(justification, reference_map) + wrapped_justifications.append(wrapped_justification) try: if message is None: - return self.delegate.justifiesWith(justifications) + return SingleConstraintAssertion(self.delegate.justifiesWith(*wrapped_justifications)) else: - return self.delegate.justifiesWith(message, justifications) + return SingleConstraintAssertion(self.delegate.justifiesWith(message, *wrapped_justifications)) except JavaAssertionError as e: raise AssertionError(e.getMessage()) - def justifies_with_exactly(self, message: str = None, *justifications: 'ConstraintJustification') \ + def justifies_with_exactly(self, *justifications: 'ConstraintJustification', message: str = None) \ -> 'SingleConstraintAssertion': """ Asserts that the constraint being tested, given a set of facts, results in given justifications an no others. @@ -219,15 +226,22 @@ def justifies_with_exactly(self, message: str = None, *justifications: 'Constrai when the expected justifications are not observed """ from java.lang import AssertionError as JavaAssertionError # noqa + from _jpyinterpreter import convert_to_java_python_like_object + from java.util import HashMap + reference_map = HashMap() + wrapped_justifications = [] + for justification in justifications: + wrapped_justification = convert_to_java_python_like_object(justification, reference_map) + wrapped_justifications.append(wrapped_justification) try: if message is None: - return self.delegate.justifiesWithExactly(justifications) + return SingleConstraintAssertion(self.delegate.justifiesWithExactly(*wrapped_justifications)) else: - return self.delegate.justifiesWithExactly(message, justifications) + return SingleConstraintAssertion(self.delegate.justifiesWithExactly(message, *wrapped_justifications)) except JavaAssertionError as e: raise AssertionError(e.getMessage()) - def indicts_with(self, message: str = None, *indictments) -> 'SingleConstraintAssertion': + def indicts_with(self, *indictments, message: str = None) -> 'SingleConstraintAssertion': """ Asserts that the constraint being tested, given a set of facts, results in given indictments. @@ -245,15 +259,22 @@ def indicts_with(self, message: str = None, *indictments) -> 'SingleConstraintAs when the expected indictments are not observed """ from java.lang import AssertionError as JavaAssertionError # noqa + from _jpyinterpreter import convert_to_java_python_like_object + from java.util import HashMap + reference_map = HashMap() + wrapped_indictments = [] + for indictment in indictments: + wrapped_indictment = convert_to_java_python_like_object(indictment, reference_map) + wrapped_indictments.append(wrapped_indictment) try: if message is None: - return self.delegate.indictsWith(indictments) + return SingleConstraintAssertion(self.delegate.indictsWith(*wrapped_indictments)) else: - return self.delegate.indictsWith(message, indictments) + return SingleConstraintAssertion(self.delegate.indictsWith(message, *wrapped_indictments)) except JavaAssertionError as e: raise AssertionError(e.getMessage()) - def indicts_with_exactly(self, message: str = None, *indictments) -> 'SingleConstraintAssertion': + def indicts_with_exactly(self, *indictments, message: str = None) -> 'SingleConstraintAssertion': """ Asserts that the constraint being tested, given a set of facts, results in given indictments an no others. @@ -271,11 +292,18 @@ def indicts_with_exactly(self, message: str = None, *indictments) -> 'SingleCons when the expected indictments are not observed """ from java.lang import AssertionError as JavaAssertionError # noqa + from _jpyinterpreter import convert_to_java_python_like_object + from java.util import HashMap + reference_map = HashMap() + wrapped_indictments = [] + for indictment in indictments: + wrapped_indictment = convert_to_java_python_like_object(indictment, reference_map) + wrapped_indictments.append(wrapped_indictment) try: if message is None: - return self.delegate.indictsWithExactly(indictments) + return SingleConstraintAssertion(self.delegate.indictsWithExactly(*wrapped_indictments)) else: - return self.delegate.indictsWithExactly(message, indictments) + return SingleConstraintAssertion(self.delegate.indictsWithExactly(message, *wrapped_indictments)) except JavaAssertionError as e: raise AssertionError(e.getMessage())