Skip to content
This repository has been archived by the owner on Jul 17, 2024. It is now read-only.

Commit

Permalink
feat: add support for fairness constraints (#94)
Browse files Browse the repository at this point in the history
Co-authored-by: Christopher Chianelli <christopher@timefold.ai>
  • Loading branch information
triceo and Christopher-Chianelli authored Jul 8, 2024
1 parent 4396e2d commit e1dac95
Show file tree
Hide file tree
Showing 6 changed files with 1,410 additions and 193 deletions.
29 changes: 22 additions & 7 deletions tests/test_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,24 +563,39 @@ def define_constraints(constraint_factory: ConstraintFactory):
assert score_manager.explain(problem).score == SimpleScore.of(4)


def test_flatten_last():
def test_load_balance():
@constraint_provider
def define_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.map(lambda entity: (1, 2, 3))
.flatten_last(lambda the_tuple: the_tuple)
.reward(SimpleScore.ONE)
.as_constraint('Count')
.group_by(ConstraintCollectors.load_balance(
lambda entity: entity.value
))
.reward(SimpleScore.ONE,
lambda balance: balance.unfairness().movePointRight(3).intValue())
.as_constraint('Balanced value')
]

score_manager = create_score_manager(define_constraints)

entity_a: Entity = Entity('A')
entity_b: Entity = Entity('B')
entity_c: Entity = Entity('C')

value_1 = Value(1)
value_2 = Value(2)

problem = Solution([entity_a], [value_1])
problem = Solution([entity_a, entity_b], [value_1])
entity_a.value = value_1
entity_b.value = value_1
entity_c.value = value_1

assert score_manager.explain(problem).score == SimpleScore.of(3)
assert score_manager.explain(problem).score == SimpleScore.of(0)

problem = Solution([entity_a, entity_b, entity_c], [value_1, value_2])

assert score_manager.explain(problem).score == SimpleScore.of(0)

entity_c.value = value_2

assert score_manager.explain(problem).score == SimpleScore.of(707)
151 changes: 146 additions & 5 deletions tests/test_constraint_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,29 @@ def define_constraints(constraint_factory: ConstraintFactory):
assert score_manager.explain(problem).score.score == 1


def test_flatten_last():
@constraint_provider
def define_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.map(lambda entity: (1, 2, 3))
.flatten_last(lambda the_tuple: the_tuple)
.reward(SimpleScore.ONE)
.as_constraint('Count')
]

score_manager = create_score_manager(define_constraints)

entity_a: Entity = Entity('A')

value_1 = Value(1)

problem = Solution([entity_a], [value_1])
entity_a.value = value_1

assert score_manager.explain(problem).score == SimpleScore.of(3)


def test_join_uni():
@constraint_provider
def define_constraints(constraint_factory: ConstraintFactory):
Expand Down Expand Up @@ -265,6 +288,87 @@ def define_constraints(constraint_factory: ConstraintFactory):
assert score_manager.explain(problem).score.score == 8


def test_if_exists_uni():
@constraint_provider
def define_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.if_exists(Entity, Joiners.equal(lambda entity: entity.code))
.reward(SimpleScore.ONE, lambda e1: e1.value.number)
.as_constraint('Count')
]

score_manager = create_score_manager(define_constraints)
entity_a1: Entity = Entity('A')
entity_a2: Entity = Entity('A')
entity_b1: Entity = Entity('B')
entity_b2: Entity = Entity('B')

value_1 = Value(1)
value_2 = Value(2)

problem = Solution([entity_a1, entity_a2, entity_b1, entity_b2], [value_1, value_2])

entity_a1.value = value_1

# With itself
assert score_manager.explain(problem).score.score == 1

entity_a1.value = value_1
entity_a2.value = value_1

entity_b1.value = value_2
entity_b2.value = value_2

# 1 + 2 + 1 + 2
assert score_manager.explain(problem).score.score == 6

entity_a1.value = value_2
entity_b1.value = value_1

# 1 + 2 + 1 + 2
assert score_manager.explain(problem).score.score == 6


def test_if_not_exists_uni():
@constraint_provider
def define_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.if_not_exists(Entity, Joiners.equal(lambda entity: entity.code))
.reward(SimpleScore.ONE, lambda e1: e1.value.number)
.as_constraint('Count')
]

score_manager = create_score_manager(define_constraints)
entity_a1: Entity = Entity('A')
entity_a2: Entity = Entity('A')
entity_b1: Entity = Entity('B')
entity_b2: Entity = Entity('B')

value_1 = Value(1)
value_2 = Value(2)

problem = Solution([entity_a1, entity_a2, entity_b1, entity_b2], [value_1, value_2])

entity_a1.value = value_1

assert score_manager.explain(problem).score.score == 0

entity_a1.value = value_1
entity_a2.value = value_1

entity_b1.value = value_2
entity_b2.value = value_2

assert score_manager.explain(problem).score.score == 0

entity_a1.value = value_2
entity_b1.value = value_1

assert score_manager.explain(problem).score.score == 0


def test_map():
@constraint_provider
def define_constraints(constraint_factory: ConstraintFactory):
Expand Down Expand Up @@ -436,6 +540,41 @@ def define_constraints(constraint_factory: ConstraintFactory):

assert score_manager.explain(problem).score.score == 1

def test_complement():
@constraint_provider
def define_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.filter(lambda e: e.value.number == 1)
.complement(Entity)
.reward(SimpleScore.ONE)
.as_constraint('Count')
]

score_manager = create_score_manager(define_constraints)
entity_a: Entity = Entity('A')
entity_b: Entity = Entity('B')

value_1 = Value(1)
value_2 = Value(2)
value_3 = Value(3)

problem = Solution([entity_a, entity_b], [value_1, value_2, value_3])

assert score_manager.explain(problem).score.score == 0

entity_a.value = value_1

assert score_manager.explain(problem).score.score == 1

entity_b.value = value_2

assert score_manager.explain(problem).score.score == 2

entity_b.value = value_3

assert score_manager.explain(problem).score.score == 2


def test_custom_indictments():
@dataclass(unsafe_hash=True)
Expand Down Expand Up @@ -630,6 +769,7 @@ def define_constraints(constraint_factory: ConstraintFactory):


def test_has_all_methods():
missing = []
for python_type, java_type in ((UniConstraintStream, JavaUniConstraintStream),
(BiConstraintStream, JavaBiConstraintStream),
(TriConstraintStream, JavaTriConstraintStream),
Expand All @@ -641,7 +781,6 @@ def test_has_all_methods():
(Joiners, JavaJoiners),
(ConstraintCollectors, JavaConstraintCollectors),
(ConstraintFactory, JavaConstraintFactory)):
missing = []
for function_name, function_impl in inspect.getmembers(java_type, inspect.isfunction):
if function_name in ignored_java_functions:
continue
Expand All @@ -654,8 +793,10 @@ def test_has_all_methods():
# change h_t_t_p -> http
snake_case_name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake_case_name).lower()
if not hasattr(python_type, snake_case_name):
missing.append(snake_case_name)
missing.append((java_type, python_type, snake_case_name))

if missing:
raise AssertionError(f'{python_type} is missing methods ({missing}) '
f'from java_type ({java_type}).)')
if missing:
assertion_msg = ''
for java_type, python_type, snake_case_name in missing:
assertion_msg += f'{python_type} is missing a method ({snake_case_name}) from java_type ({java_type}).)\n'
raise AssertionError(assertion_msg)
Loading

0 comments on commit e1dac95

Please sign in to comment.