Skip to content
This repository has been archived by the owner on Jul 17, 2024. It is now read-only.

Commit

Permalink
feat: Add SolverConfigOverride support
Browse files Browse the repository at this point in the history
- Added missing type info to Solver API Classes
- Unlike Java, the SolverConfig in Python is Generic,
  since:

  - Python users do not need to specify types of variables,
    and no warnings are emitted for using a raw type

  - Allows a smart enough type checker to deduce the generic type
    of a SolverFactory, SolverManager, etc. from the generic
    type of the SolverConfig
  • Loading branch information
Christopher-Chianelli committed Apr 18, 2024
1 parent 244e1d9 commit 196d76f
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 52 deletions.
57 changes: 57 additions & 0 deletions tests/test_solver_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from timefold.solver.api import *
from timefold.solver.annotation import *
from timefold.solver.config import *
from timefold.solver.constraint import *
from timefold.solver.score import *

from dataclasses import dataclass, field
from typing import Annotated, List


@planning_entity
@dataclass
class Entity:
code: Annotated[str, PlanningId]
value: Annotated[int, PlanningVariable] = field(default=None, compare=False)


@constraint_provider
def my_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.reward(SimpleScore.ONE, lambda entity: entity.value)
.as_constraint('Maximize value'),
]


@planning_solution
@dataclass
class Solution:
entities: Annotated[List[Entity], PlanningEntityCollectionProperty]
value_range: Annotated[List[int], ValueRangeProvider]
score: Annotated[SimpleScore, PlanningScore] = field(default=None)

def __str__(self) -> str:
return str(self.entities)


def test_solver_config_override():
solver_config = SolverConfig(
solution_class=Solution,
entity_class_list=[Entity],
score_director_factory_config=ScoreDirectorFactoryConfig(
constraint_provider_function=my_constraints,
),
termination_config=TerminationConfig(
best_score_limit='9'
)
)
solver_factory = SolverFactory.create(solver_config)
solver = solver_factory.build_solver(SolverConfigOverride(
termination_config=TerminationConfig(
best_score_limit='3'
)
))
problem = Solution([Entity('A')], [1, 2, 3])
solution = solver.solve(problem)
assert solution.score.score() == 3
56 changes: 56 additions & 0 deletions tests/test_solver_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,59 @@ def my_exception_handler(problem_id, exception):

assert the_problem_id == 1
assert the_exception is not None


def test_solver_config_override():
@dataclass
class Value:
value: Annotated[int, PlanningId]

@planning_entity
@dataclass
class Entity:
code: Annotated[str, PlanningId]
value: Annotated[Value, PlanningVariable] = field(default=None)

@constraint_provider
def my_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.reward(SimpleScore.ONE, lambda entity: entity.value.value)
.as_constraint('Maximize Value')
]

@planning_solution
@dataclass
class Solution:
entity_list: Annotated[List[Entity], PlanningEntityCollectionProperty]
value_list: Annotated[List[Value],
DeepPlanningClone,
ProblemFactCollectionProperty,
ValueRangeProvider]
score: Annotated[SimpleScore, PlanningScore] = field(default=None)

solver_config = SolverConfig(
solution_class=Solution,
entity_class_list=[Entity],
score_director_factory_config=ScoreDirectorFactoryConfig(
constraint_provider_function=my_constraints
),
termination_config=TerminationConfig(
best_score_limit='9'
)
)
problem: Solution = Solution([Entity('A')], [Value(1), Value(2), Value(3)],
SimpleScore.ONE)
with SolverManager.create(SolverFactory.create(solver_config)) as solver_manager:
solver_job = (solver_manager.solve_builder()
.with_problem_id(1)
.with_problem(problem)
.with_config_override(SolverConfigOverride(
termination_config=TerminationConfig(
best_score_limit='3'
)
))
.run())

solution = solver_job.get_final_best_solution()
assert solution.score.score() == 3
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ._solver_factory import SolverFactory
from .._timefold_java_interop import get_class

from typing import TypeVar, Union, TYPE_CHECKING
from typing import TypeVar, Generic, Union, TYPE_CHECKING


if TYPE_CHECKING:
Expand All @@ -15,41 +15,43 @@
ProblemId_ = TypeVar('ProblemId_')


class SolutionManager:
class SolutionManager(Generic[Solution_]):
_delegate: '_JavaSolutionManager'

def __init__(self, delegate: '_JavaSolutionManager'):
self._delegate = delegate

@staticmethod
def create(solver_factory: 'SolverFactory'):
def create(solver_factory: 'SolverFactory[Solution_]') -> 'SolutionManager[Solution_]':
from ai.timefold.solver.core.api.solver import SolutionManager as JavaSolutionManager
return SolutionManager(JavaSolutionManager.create(solver_factory._delegate))

def update(self, solution, solution_update_policy=None) -> 'Score':
def update(self, solution: Solution_, solution_update_policy=None) -> 'Score':
# TODO handle solution_update_policy
from jpyinterpreter import convert_to_java_python_like_object, update_python_object_from_java
java_solution = convert_to_java_python_like_object(solution)
out = self._delegate.update(java_solution)
update_python_object_from_java(java_solution)
return out

def analyze(self, solution, score_analysis_fetch_policy=None, solution_update_policy=None) -> 'ScoreAnalysis':
def analyze(self, solution: Solution_, score_analysis_fetch_policy=None, solution_update_policy=None) \
-> 'ScoreAnalysis':
# TODO handle policies
from jpyinterpreter import convert_to_java_python_like_object
return ScoreAnalysis(self._delegate.analyze(convert_to_java_python_like_object(solution)))

def explain(self, solution, solution_update_policy=None) -> 'ScoreExplanation':
def explain(self, solution: Solution_, solution_update_policy=None) -> 'ScoreExplanation':
# TODO handle policies
from jpyinterpreter import convert_to_java_python_like_object
return ScoreExplanation(self._delegate.explain(convert_to_java_python_like_object(solution)))

def recommend_fit(self, solution, entity_or_element, proposition_function, score_analysis_fetch_policy=None):
def recommend_fit(self, solution: Solution_, entity_or_element, proposition_function,
score_analysis_fetch_policy=None):
# TODO
raise NotImplementedError


class ScoreExplanation:
class ScoreExplanation(Generic[Solution_]):
_delegate: '_JavaScoreExplanation'

def __init__(self, delegate: '_JavaScoreExplanation'):
Expand All @@ -70,7 +72,7 @@ def get_justification_list(self, justification_type=None):
def get_score(self) -> 'Score':
return self._delegate.getScore()

def get_solution(self):
def get_solution(self) -> Solution_:
from jpyinterpreter import unwrap_python_like_object
return unwrap_python_like_object(self._delegate.getSolution())

Expand Down
4 changes: 2 additions & 2 deletions timefold-solver-python-core/src/main/python/api/_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def terminate_early(self) -> bool:
def is_terminate_early(self) -> bool:
return self._delegate.isTerminateEarly()

def add_problem_change(self, problem_change: ProblemChange) -> None:
def add_problem_change(self, problem_change: ProblemChange[Solution_]) -> None:
self._delegate.addProblemChange(ProblemChangeWrapper(problem_change)) # noqa

def add_problem_changes(self, problem_changes: List[ProblemChange]) -> None:
def add_problem_changes(self, problem_changes: List[ProblemChange[Solution_]]) -> None:
self._delegate.addProblemChanges([ProblemChangeWrapper(problem_change) for problem_change in problem_changes]) # noqa

def is_every_problem_change_processed(self) -> bool:
Expand Down
16 changes: 10 additions & 6 deletions timefold-solver-python-core/src/main/python/api/_solver_factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ._solver import Solver
from ..config import SolverConfig
from ..config import SolverConfig, SolverConfigOverride

from typing import TypeVar, TYPE_CHECKING
from typing import TypeVar, Generic, TYPE_CHECKING
from jpype import JClass

if TYPE_CHECKING:
Expand All @@ -12,7 +12,7 @@
Solution_ = TypeVar('Solution_')


class SolverFactory:
class SolverFactory(Generic[Solution_]):
_delegate: '_JavaSolverFactory'
_solution_class: JClass

Expand All @@ -21,14 +21,18 @@ def __init__(self, delegate: '_JavaSolverFactory', solution_class: JClass):
self._solution_class = solution_class

@staticmethod
def create(solver_config: SolverConfig):
def create(solver_config: SolverConfig[Solution_]) -> 'SolverFactory[Solution_]':
from ai.timefold.solver.core.api.solver import SolverFactory as JavaSolverFactory
solver_config = solver_config._to_java_solver_config()
delegate = JavaSolverFactory.create(solver_config) # noqa
return SolverFactory(delegate, solver_config.getSolutionClass()) # noqa

def build_solver(self):
return Solver(self._delegate.buildSolver(), self._solution_class)
def build_solver(self, solver_config_override: SolverConfigOverride = None) -> Solver[Solution_]:
if solver_config_override is None:
return Solver(self._delegate.buildSolver(), self._solution_class)
else:
return Solver(self._delegate.buildSolver(solver_config_override._to_java_solver_config_override()),
self._solution_class)


__all__ = ['SolverFactory']
Loading

0 comments on commit 196d76f

Please sign in to comment.