Skip to content
This repository has been archived by the owner on Jul 17, 2024. It is now read-only.

Commit

Permalink
chore: Implement VariableListener using base classes instead of decor…
Browse files Browse the repository at this point in the history
…ators

- Made all methods as optional, since 99% of the time, they do nothing
  • Loading branch information
Christopher-Chianelli committed Apr 24, 2024
1 parent 45132a1 commit bd011c1
Show file tree
Hide file tree
Showing 7 changed files with 208 additions and 146 deletions.
3 changes: 3 additions & 0 deletions tests/test_constraint_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,9 @@ def define_constraints(constraint_factory: ConstraintFactory):
'rewardConfigurableLong',
'rewardLong',
'_handler', # JPype handler field should be ignored
# Unimplemented
'toConnectedRanges',
'toConnectedTemporalRanges',
# These methods are deprecated
'from_',
'fromUnfiltered',
Expand Down
57 changes: 12 additions & 45 deletions tests/test_custom_shadow_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,14 @@


def test_custom_shadow_variable():
@variable_listener
class MyVariableListener:
def afterVariableChanged(self, score_director, entity):
score_director.beforeVariableChanged(entity, 'value_squared')
class MyVariableListener(VariableListener):
def after_variable_changed(self, score_director: ScoreDirector, entity):
score_director.before_variable_changed(entity, 'value_squared')
if entity.value is None:
entity.value_squared = None
else:
entity.value_squared = entity.value ** 2
score_director.afterVariableChanged(entity, 'value_squared')

def beforeVariableChanged(self, score_director, entity):
pass

def beforeEntityAdded(self, score_director, entity):
pass

def afterEntityAdded(self, score_director, entity):
pass

def beforeEntityRemoved(self, score_director, entity):
pass

def afterEntityRemoved(self, score_director, entity):
pass
score_director.after_variable_changed(entity, 'value_squared')

@planning_entity
@dataclass
Expand Down Expand Up @@ -79,34 +63,18 @@ class MySolution:


def test_custom_shadow_variable_with_variable_listener_ref():
@variable_listener
class MyVariableListener:
def afterVariableChanged(self, score_director, entity):
score_director.beforeVariableChanged(entity, 'twice_value')
score_director.beforeVariableChanged(entity, 'value_squared')
class MyVariableListener(VariableListener):
def after_variable_changed(self, score_director: ScoreDirector, entity):
score_director.before_variable_changed(entity, 'twice_value')
score_director.before_variable_changed(entity, 'value_squared')
if entity.value is None:
entity.twice_value = None
entity.value_squared = None
else:
entity.twice_value = 2 * entity.value
entity.value_squared = entity.value ** 2
score_director.afterVariableChanged(entity, 'value_squared')
score_director.afterVariableChanged(entity, 'twice_value')

def beforeVariableChanged(self, score_director, entity):
pass

def beforeEntityAdded(self, score_director, entity):
pass

def afterEntityAdded(self, score_director, entity):
pass

def beforeEntityRemoved(self, score_director, entity):
pass

def afterEntityRemoved(self, score_director, entity):
pass
score_director.after_variable_changed(entity, 'value_squared')
score_director.after_variable_changed(entity, 'twice_value')

@planning_entity
@dataclass
Expand All @@ -115,9 +83,8 @@ class MyPlanningEntity:
field(default=None)
value_squared: Annotated[Optional[int], ShadowVariable(
variable_listener_class=MyVariableListener, source_variable_name='value')] = field(default=None)
# TODO: Use PiggyBackShadowVariable
twice_value: Annotated[Optional[int], ShadowVariable(
variable_listener_class=MyVariableListener, source_variable_name='value')] = field(default=None)
twice_value: Annotated[Optional[int], PiggybackShadowVariable(shadow_variable_name='value_squared')] = (
field(default=None))

@constraint_provider
def my_constraints(constraint_factory: ConstraintFactory):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def get_class(python_class: Union[Type, Callable]) -> JClass:
from ai.timefold.jpyinterpreter.types.wrappers import OpaquePythonReference
from jpyinterpreter import is_c_native, get_java_type_for_python_type

if python_class is None:
return cast(JClass, None)
if isinstance(python_class, jpype.JClass):
return cast(JClass, python_class).class_
if isinstance(python_class, Class):
Expand Down
127 changes: 26 additions & 101 deletions timefold-solver-python-core/src/main/python/annotation/_annotations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jpype

from ..api import VariableListener
from ..constraint import ConstraintFactory
from .._timefold_java_interop import ensure_init, _generate_constraint_provider_class, register_java_class
from jpyinterpreter import JavaAnnotation
Expand All @@ -11,8 +12,7 @@
from ai.timefold.solver.core.api.score.stream import Constraint as _Constraint
from ai.timefold.solver.core.api.score import Score as _Score
from ai.timefold.solver.core.api.score.calculator import IncrementalScoreCalculator as _IncrementalScoreCalculator
from ai.timefold.solver.core.api.domain.variable import PlanningVariableGraphType as _PlanningVariableGraphType, \
VariableListener as _VariableListener
from ai.timefold.solver.core.api.domain.variable import PlanningVariableGraphType as _PlanningVariableGraphType


Solution_ = TypeVar('Solution_')
Expand All @@ -35,7 +35,7 @@ class PlanningVariable(JavaAnnotation):
def __init__(self, *,
value_range_provider_refs: List[str] = None,
allows_unassigned: bool = False,
graph_type: '_PlanningVariableGraphType' = None):
graph_type=None):
ensure_init()
from ai.timefold.solver.core.api.domain.variable import PlanningVariable as JavaPlanningVariable
super().__init__(JavaPlanningVariable,
Expand Down Expand Up @@ -75,19 +75,37 @@ def __init__(self, *,

class ShadowVariable(JavaAnnotation):
def __init__(self, *,
variable_listener_class: Type['_VariableListener'] = None,
variable_listener_class: Type[VariableListener] = None,
source_variable_name: str,
source_entity_class: Type = None):
ensure_init()
from .._timefold_java_interop import get_class
from jpyinterpreter import get_java_type_for_python_type
from ai.timefold.jpyinterpreter import PythonClassTranslator
from ai.timefold.solver.core.api.domain.variable import (
ShadowVariable as JavaShadowVariable)
ShadowVariable as JavaShadowVariable, VariableListener as JavaVariableListener)

super().__init__(JavaShadowVariable,
{
'variableListenerClass': get_class(variable_listener_class),
'sourceVariableName': PythonClassTranslator.getJavaFieldName(source_variable_name),
'sourceEntityClass': source_entity_class,
'sourceEntityClass': get_class(source_entity_class),
})


class PiggybackShadowVariable(JavaAnnotation):
def __init__(self, *,
shadow_variable_name: str,
shadow_entity_class: Type = None):
ensure_init()
from .._timefold_java_interop import get_class
from ai.timefold.jpyinterpreter import PythonClassTranslator
from ai.timefold.solver.core.api.domain.variable import (
PiggybackShadowVariable as JavaPiggybackShadowVariable)
super().__init__(JavaPiggybackShadowVariable,
{
'shadowVariableName': PythonClassTranslator.getJavaFieldName(shadow_variable_name),
'shadowEntityClass': get_class(shadow_entity_class),
})


Expand Down Expand Up @@ -455,100 +473,6 @@ def resetWorkingSolution(self, workingSolution: Solution_, constraintMatchEnable
return register_java_class(incremental_score_calculator, java_class)


def variable_listener(variable_listener_class: Type['_VariableListener'] = None, /, *,
require_unique_entity_events: bool = False) -> Type['_VariableListener']:
"""Changes shadow variables when a genuine planning variable changes.
Important: it must only change the shadow variable(s) for which it's configured!
It should never change a genuine variable or a problem fact.
It can change its shadow variable(s) on multiple entity instances
(for example: an arrival_time change affects all trailing entities too).
It is recommended that implementations be kept stateless.
If state must be implemented, implementations may need to override the default methods
resetWorkingSolution(score_director: ScoreDirector) and close().
The following methods must exist:
def beforeEntityAdded(score_director: ScoreDirector[Solution_], entity: Entity_);
def afterEntityAdded(score_director: ScoreDirector[Solution_], entity: Entity_);
def beforeEntityRemoved(score_director: ScoreDirector[Solution_], entity: Entity_);
def afterEntityRemoved(score_director: ScoreDirector[Solution_], entity: Entity_);
def beforeVariableChanged(score_director: ScoreDirector[Solution_], entity: Entity_);
def afterVariableChanged(score_director: ScoreDirector[Solution_], entity: Entity_);
If the implementation is stateful, then the following methods should also be defined:
def resetWorkingSolution(score_director: ScoreDirector)
def close()
:param require_unique_entity_events: Set to True to guarantee that each of the before/after methods will only be
called once per entity instance per operation type (add, change or remove).
When set to True, this has a slight performance loss.
When set to False, it's often easier to make the listener implementation
correct and fast.
Defaults to False
:type variable_listener_class: '_VariableListener'
:type require_unique_entity_events: bool
:rtype: Type
"""
ensure_init()

def variable_listener_wrapper(the_variable_listener_class):
from jpyinterpreter import translate_python_class_to_java_class, generate_proxy_class_for_translated_class
from ai.timefold.solver.core.api.domain.variable import VariableListener
methods = ['beforeEntityAdded',
'afterEntityAdded',
'beforeVariableChanged',
'afterVariableChanged',
'beforeEntityRemoved',
'afterEntityRemoved']

missing_method_list = []
for method in methods:
if not callable(getattr(the_variable_listener_class, method, None)):
missing_method_list.append(method)
if len(missing_method_list) != 0:
raise ValueError(f'The following required methods are missing from @variable_listener class '
f'{the_variable_listener_class}: {missing_method_list}')

method_on_class = getattr(the_variable_listener_class, 'requiresUniqueEntityEvents', None)
if method_on_class is None:
def class_requires_unique_entity_events(self):
return require_unique_entity_events

setattr(the_variable_listener_class, 'requiresUniqueEntityEvents', class_requires_unique_entity_events)

method_on_class = getattr(the_variable_listener_class, 'close', None)
if method_on_class is None:
def close(self):
pass

setattr(the_variable_listener_class, 'close', close)

method_on_class = getattr(the_variable_listener_class, 'resetWorkingSolution', None)
if method_on_class is None:
def reset_working_solution(self, score_director):
pass

setattr(the_variable_listener_class, 'resetWorkingSolution', reset_working_solution)

translated_class = translate_python_class_to_java_class(the_variable_listener_class)
java_class = generate_proxy_class_for_translated_class(VariableListener, translated_class)
return register_java_class(the_variable_listener_class, java_class)

if variable_listener_class: # Called as @variable_listener
return variable_listener_wrapper(variable_listener_class)
else: # Called as @variable_listener(require_unique_entity_events=True)
return variable_listener_wrapper


def problem_change(problem_change_class: Type['_ProblemChange']) -> \
Type['_ProblemChange']:
"""A ProblemChange represents a change in 1 or more planning entities or problem facts of a PlanningSolution.
Expand Down Expand Up @@ -599,6 +523,7 @@ def wrapper_doChange(self, solution, problem_change_director):

__all__ = ['PlanningId', 'PlanningScore', 'PlanningPin', 'PlanningVariable',
'PlanningListVariable', 'PlanningVariableReference', 'ShadowVariable',
'PiggybackShadowVariable',
'IndexShadowVariable', 'AnchorShadowVariable', 'InverseRelationShadowVariable',
'ProblemFactProperty', 'ProblemFactCollectionProperty',
'PlanningEntityProperty', 'PlanningEntityCollectionProperty',
Expand All @@ -607,4 +532,4 @@ def wrapper_doChange(self, solution, problem_change_director):
'planning_entity', 'planning_solution', 'constraint_configuration',
'nearby_distance_meter',
'constraint_provider', 'easy_score_calculator', 'incremental_score_calculator',
'variable_listener', 'problem_change']
'problem_change']
2 changes: 2 additions & 0 deletions timefold-solver-python-core/src/main/python/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
from ._solver_factory import *
from ._solver_manager import *
from ._solution_manager import *
from ._score_director import *
from ._variable_listener import *
72 changes: 72 additions & 0 deletions timefold-solver-python-core/src/main/python/api/_score_director.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
class ScoreDirector:
def __init__(self, delegate):
self._delegate = delegate

def after_entity_added(self, entity) -> None:
self._delegate.afterEntityAdded(entity)

def after_entity_removed(self, entity) -> None:
self._delegate.afterEntityRemoved(entity)

def after_list_variable_changed(self, entity, variable_name: str, start: int, end: int) -> None:
self._delegate.afterListVariableChanged(entity, variable_name, start, end)

def after_list_variable_element_assigned(self, entity, variable_name: str, element) -> None:
self._delegate.afterListVariableElementAssigned(entity, variable_name, element)

def after_list_variable_element_unassigned(self, entity, variable_name: str, element) -> None:
self._delegate.afterListVariableElementUnassigned(entity, variable_name, element)

def after_problem_fact_added(self, entity) -> None:
self._delegate.afterProblemFactAdded(entity)

def after_problem_fact_removed(self, entity) -> None:
self._delegate.afterProblemFactRemoved(entity)

def after_problem_property_changed(self, entity) -> None:
self._delegate.afterProblemPropertyChanged(entity)

def after_variable_changed(self, entity, variable_name: str) -> None:
self._delegate.afterVariableChanged(entity, variable_name)

def before_entity_added(self, entity) -> None:
self._delegate.beforeEntityAdded(entity)

def before_entity_removed(self, entity) -> None:
self._delegate.beforeEntityRemoved(entity)

def before_list_variable_changed(self, entity, variable_name: str, start: int, end: int) -> None:
self._delegate.beforeListVariableChanged(entity, variable_name, start, end)

def before_list_variable_element_assigned(self, entity, variable_name: str, element) -> None:
self._delegate.beforeListVariableElementAssigned(entity, variable_name, element)

def before_list_variable_element_unassigned(self, entity, variable_name: str, element) -> None:
self._delegate.beforeListVariableElementUnassigned(entity, variable_name, element)

def before_problem_fact_added(self, entity) -> None:
self._delegate.beforeProblemFactAdded(entity)

def before_problem_fact_removed(self, entity) -> None:
self._delegate.beforeProblemFactRemoved(entity)

def before_problem_property_changed(self, entity) -> None:
self._delegate.beforeProblemPropertyChanged(entity)

def before_variable_changed(self, entity, variable_name: str) -> None:
self._delegate.beforeVariableChanged(entity, variable_name)

def get_working_solution(self):
return self._delegate.getWorkingSolution()

def look_up_working_object(self, working_object):
return self._delegate.lookUpWorkingObject(working_object)

def look_up_working_object_or_return_none(self, working_object):
return self._delegate.lookUpWorkingObject(working_object)

def trigger_variable_listeners(self) -> None:
self._delegate.triggerVariableListeners()


__all__ = ['ScoreDirector']
Loading

0 comments on commit bd011c1

Please sign in to comment.