Skip to content
This repository has been archived by the owner on Jul 17, 2024. It is now read-only.

feat: introduce fairness #94

Merged
merged 9 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions tests/test_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,24 +563,39 @@ def define_constraints(constraint_factory: ConstraintFactory):
assert score_manager.explain(problem).score == SimpleScore.of(4)


def test_flatten_last():
def test_load_balance():
triceo marked this conversation as resolved.
Show resolved Hide resolved
@constraint_provider
def define_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.map(lambda entity: (1, 2, 3))
.flatten_last(lambda the_tuple: the_tuple)
.reward(SimpleScore.ONE)
.as_constraint('Count')
.group_by(ConstraintCollectors.load_balance(
lambda entity: entity.value
))
.reward(SimpleScore.ONE,
lambda balance: balance.unfairness().multiply(BigDecimal.valueOf(1000)).intValue())
.as_constraint('Balanced value')
]

score_manager = create_score_manager(define_constraints)

entity_a: Entity = Entity('A')
entity_b: Entity = Entity('B')
entity_c: Entity = Entity('C')

value_1 = Value(1)
value_2 = Value(2)

problem = Solution([entity_a], [value_1])
problem = Solution([entity_a, entity_b, entity_c], [value_1, value_2])
entity_a.value = value_1
entity_b.value = value_1
entity_c.value = value_1

assert score_manager.explain(problem).score == SimpleScore.of(3)
assert score_manager.explain(problem).score == SimpleScore.of(0)

entity_c.value = value_2

assert score_manager.explain(problem).score == SimpleScore.of(2) // FIXME

entity_b.value = value_2

assert score_manager.explain(problem).score == SimpleScore.of(4) // FIXME
151 changes: 146 additions & 5 deletions tests/test_constraint_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,29 @@ def define_constraints(constraint_factory: ConstraintFactory):
assert score_manager.explain(problem).score.score == 1


def test_flatten_last():
@constraint_provider
def define_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.map(lambda entity: (1, 2, 3))
.flatten_last(lambda the_tuple: the_tuple)
.reward(SimpleScore.ONE)
.as_constraint('Count')
]

score_manager = create_score_manager(define_constraints)

entity_a: Entity = Entity('A')

value_1 = Value(1)

problem = Solution([entity_a], [value_1])
entity_a.value = value_1

assert score_manager.explain(problem).score == SimpleScore.of(3)


def test_join_uni():
@constraint_provider
def define_constraints(constraint_factory: ConstraintFactory):
Expand Down Expand Up @@ -265,6 +288,87 @@ def define_constraints(constraint_factory: ConstraintFactory):
assert score_manager.explain(problem).score.score == 8


def test_if_exists_uni():
@constraint_provider
def define_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.if_exists(Entity, Joiners.equal(lambda entity: entity.code))
.reward(SimpleScore.ONE, lambda e1: e1.value.number)
.as_constraint('Count')
]

score_manager = create_score_manager(define_constraints)
entity_a1: Entity = Entity('A')
entity_a2: Entity = Entity('A')
entity_b1: Entity = Entity('B')
entity_b2: Entity = Entity('B')

value_1 = Value(1)
value_2 = Value(2)

problem = Solution([entity_a1, entity_a2, entity_b1, entity_b2], [value_1, value_2])

entity_a1.value = value_1

# With itself
assert score_manager.explain(problem).score.score == 1

entity_a1.value = value_1
entity_a2.value = value_1

entity_b1.value = value_2
entity_b2.value = value_2

# 1 + 2 + 1 + 2
assert score_manager.explain(problem).score.score == 6

entity_a1.value = value_2
entity_b1.value = value_1

# 1 + 2 + 1 + 2
assert score_manager.explain(problem).score.score == 6


def test_if_not_exists_uni():
@constraint_provider
def define_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.if_not_exists(Entity, Joiners.equal(lambda entity: entity.code))
.reward(SimpleScore.ONE, lambda e1: e1.value.number)
.as_constraint('Count')
]

score_manager = create_score_manager(define_constraints)
entity_a1: Entity = Entity('A')
entity_a2: Entity = Entity('A')
entity_b1: Entity = Entity('B')
entity_b2: Entity = Entity('B')

value_1 = Value(1)
value_2 = Value(2)

problem = Solution([entity_a1, entity_a2, entity_b1, entity_b2], [value_1, value_2])

entity_a1.value = value_1

assert score_manager.explain(problem).score.score == 0

entity_a1.value = value_1
entity_a2.value = value_1

entity_b1.value = value_2
entity_b2.value = value_2

assert score_manager.explain(problem).score.score == 0

entity_a1.value = value_2
entity_b1.value = value_1

assert score_manager.explain(problem).score.score == 0


def test_map():
@constraint_provider
def define_constraints(constraint_factory: ConstraintFactory):
Expand Down Expand Up @@ -436,6 +540,41 @@ def define_constraints(constraint_factory: ConstraintFactory):

assert score_manager.explain(problem).score.score == 1

def test_complement():
@constraint_provider
def define_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.filter(lambda e: e.value.number == 1)
.complement(Entity)
.reward(SimpleScore.ONE)
.as_constraint('Count')
]

score_manager = create_score_manager(define_constraints)
entity_a: Entity = Entity('A')
entity_b: Entity = Entity('B')

value_1 = Value(1)
value_2 = Value(2)
value_3 = Value(3)

problem = Solution([entity_a, entity_b], [value_1, value_2, value_3])

assert score_manager.explain(problem).score.score == 0

entity_a.value = value_1

assert score_manager.explain(problem).score.score == 1

entity_b.value = value_2

assert score_manager.explain(problem).score.score == 2

entity_b.value = value_3

assert score_manager.explain(problem).score.score == 2


def test_custom_indictments():
@dataclass(unsafe_hash=True)
Expand Down Expand Up @@ -605,6 +744,7 @@ def define_constraints(constraint_factory: ConstraintFactory):


def test_has_all_methods():
missing = []
for python_type, java_type in ((UniConstraintStream, JavaUniConstraintStream),
(BiConstraintStream, JavaBiConstraintStream),
(TriConstraintStream, JavaTriConstraintStream),
Expand All @@ -616,7 +756,6 @@ def test_has_all_methods():
(Joiners, JavaJoiners),
(ConstraintCollectors, JavaConstraintCollectors),
(ConstraintFactory, JavaConstraintFactory)):
missing = []
for function_name, function_impl in inspect.getmembers(java_type, inspect.isfunction):
if function_name in ignored_java_functions:
continue
Expand All @@ -629,8 +768,10 @@ def test_has_all_methods():
# change h_t_t_p -> http
snake_case_name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake_case_name).lower()
if not hasattr(python_type, snake_case_name):
missing.append(snake_case_name)
missing.append((java_type, python_type, snake_case_name))

if missing:
raise AssertionError(f'{python_type} is missing methods ({missing}) '
f'from java_type ({java_type}).)')
if missing:
assertion_msg = ''
for java_type, python_type, snake_case_name in missing:
assertion_msg += f'{python_type} is missing a method ({snake_case_name}) from java_type ({java_type}).)\n'
raise AssertionError(assertion_msg)
43 changes: 43 additions & 0 deletions tests/test_constraint_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@
from timefold.solver.config import *
from timefold.solver.test import *

import inspect
import re
from typing import Annotated, List
from dataclasses import dataclass, field

from ai.timefold.solver.test.api.score.stream import (ConstraintVerifier as JavaConstraintVerifier,
SingleConstraintAssertion as JavaSingleConstraintAssertion,
SingleConstraintVerification as JavaSingleConstraintVerification,
MultiConstraintAssertion as JavaMultiConstraintAssertion,
MultiConstraintVerification as JavaMultiConstraintVerification)

def verifier_suite(verifier: ConstraintVerifier, same_value, is_value_one,
solution, e1, e2, e3, v1, v2, v3):
Expand Down Expand Up @@ -268,3 +275,39 @@ class Solution:

verifier_suite(verifier, same_value, is_value_one,
solution, e1, e2, e3, v1, v2, v3)

Christopher-Chianelli marked this conversation as resolved.
Show resolved Hide resolved

ignored_java_functions = {
'equals',
'getClass',
'hashCode',
'notify',
'notifyAll',
'toString',
'wait',
'withConstraintStreamImplType'
}


def test_has_all_methods():
missing = []
for python_type, java_type in ((ConstraintVerifier, JavaConstraintVerifier),
(SingleConstraintAssertion, JavaSingleConstraintAssertion),
(SingleConstraintVerification, JavaSingleConstraintVerification),
(MultiConstraintAssertion, JavaMultiConstraintAssertion),
(MultiConstraintVerification, JavaMultiConstraintVerification)):
for function_name, function_impl in inspect.getmembers(java_type, inspect.isfunction):
if function_name in ignored_java_functions:
continue
snake_case_name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', function_name)
# change h_t_t_p -> http
snake_case_name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake_case_name).lower()
if not hasattr(python_type, snake_case_name):
missing.append((java_type, python_type, snake_case_name))

if missing:
assertion_msg = ''
for java_type, python_type, snake_case_name in missing:
assertion_msg += f'{python_type} is missing a method ({snake_case_name}) from java_type ({java_type}).)\n'
raise AssertionError(assertion_msg)

Loading
Loading