diff --git a/tests/test_solver_factory.py b/tests/test_solver_factory.py new file mode 100644 index 0000000..b10b9a0 --- /dev/null +++ b/tests/test_solver_factory.py @@ -0,0 +1,57 @@ +from timefold.solver.api import * +from timefold.solver.annotation import * +from timefold.solver.config import * +from timefold.solver.constraint import * +from timefold.solver.score import * + +from dataclasses import dataclass, field +from typing import Annotated, List + + +@planning_entity +@dataclass +class Entity: + code: Annotated[str, PlanningId] + value: Annotated[int, PlanningVariable] = field(default=None, compare=False) + + +@constraint_provider +def my_constraints(constraint_factory: ConstraintFactory): + return [ + constraint_factory.for_each(Entity) + .reward(SimpleScore.ONE, lambda entity: entity.value) + .as_constraint('Maximize value'), + ] + + +@planning_solution +@dataclass +class Solution: + entities: Annotated[List[Entity], PlanningEntityCollectionProperty] + value_range: Annotated[List[int], ValueRangeProvider] + score: Annotated[SimpleScore, PlanningScore] = field(default=None) + + def __str__(self) -> str: + return str(self.entities) + + +def test_solver_config_override(): + solver_config = SolverConfig( + solution_class=Solution, + entity_class_list=[Entity], + score_director_factory_config=ScoreDirectorFactoryConfig( + constraint_provider_function=my_constraints, + ), + termination_config=TerminationConfig( + best_score_limit='9' + ) + ) + solver_factory = SolverFactory.create(solver_config) + solver = solver_factory.build_solver(SolverConfigOverride( + termination_config=TerminationConfig( + best_score_limit='3' + ) + )) + problem = Solution([Entity('A')], [1, 2, 3]) + solution = solver.solve(problem) + assert solution.score.score() == 3 diff --git a/tests/test_solver_manager.py b/tests/test_solver_manager.py index 58c5829..dda057b 100644 --- a/tests/test_solver_manager.py +++ b/tests/test_solver_manager.py @@ -272,3 +272,59 @@ def my_exception_handler(problem_id, exception): assert the_problem_id == 1 assert the_exception is not None + + +def test_solver_config_override(): + @dataclass + class Value: + value: Annotated[int, PlanningId] + + @planning_entity + @dataclass + class Entity: + code: Annotated[str, PlanningId] + value: Annotated[Value, PlanningVariable] = field(default=None) + + @constraint_provider + def my_constraints(constraint_factory: ConstraintFactory): + return [ + constraint_factory.for_each(Entity) + .reward(SimpleScore.ONE, lambda entity: entity.value.value) + .as_constraint('Maximize Value') + ] + + @planning_solution + @dataclass + class Solution: + entity_list: Annotated[List[Entity], PlanningEntityCollectionProperty] + value_list: Annotated[List[Value], + DeepPlanningClone, + ProblemFactCollectionProperty, + ValueRangeProvider] + score: Annotated[SimpleScore, PlanningScore] = field(default=None) + + solver_config = SolverConfig( + solution_class=Solution, + entity_class_list=[Entity], + score_director_factory_config=ScoreDirectorFactoryConfig( + constraint_provider_function=my_constraints + ), + termination_config=TerminationConfig( + best_score_limit='9' + ) + ) + problem: Solution = Solution([Entity('A')], [Value(1), Value(2), Value(3)], + SimpleScore.ONE) + with SolverManager.create(SolverFactory.create(solver_config)) as solver_manager: + solver_job = (solver_manager.solve_builder() + .with_problem_id(1) + .with_problem(problem) + .with_config_override(SolverConfigOverride( + termination_config=TerminationConfig( + best_score_limit='3' + ) + )) + .run()) + + solution = solver_job.get_final_best_solution() + assert solution.score.score() == 3 diff --git a/timefold-solver-python-core/src/main/python/api/_solution_manager.py b/timefold-solver-python-core/src/main/python/api/_solution_manager.py index fc2113a..c0b2fed 100644 --- a/timefold-solver-python-core/src/main/python/api/_solution_manager.py +++ b/timefold-solver-python-core/src/main/python/api/_solution_manager.py @@ -1,7 +1,7 @@ from ._solver_factory import SolverFactory from .._timefold_java_interop import get_class -from typing import TypeVar, Union, TYPE_CHECKING +from typing import TypeVar, Generic, Union, TYPE_CHECKING if TYPE_CHECKING: @@ -15,18 +15,18 @@ ProblemId_ = TypeVar('ProblemId_') -class SolutionManager: +class SolutionManager(Generic[Solution_]): _delegate: '_JavaSolutionManager' def __init__(self, delegate: '_JavaSolutionManager'): self._delegate = delegate @staticmethod - def create(solver_factory: 'SolverFactory'): + def create(solver_factory: 'SolverFactory[Solution_]') -> 'SolutionManager[Solution_]': from ai.timefold.solver.core.api.solver import SolutionManager as JavaSolutionManager return SolutionManager(JavaSolutionManager.create(solver_factory._delegate)) - def update(self, solution, solution_update_policy=None) -> 'Score': + def update(self, solution: Solution_, solution_update_policy=None) -> 'Score': # TODO handle solution_update_policy from jpyinterpreter import convert_to_java_python_like_object, update_python_object_from_java java_solution = convert_to_java_python_like_object(solution) @@ -34,22 +34,24 @@ def update(self, solution, solution_update_policy=None) -> 'Score': update_python_object_from_java(java_solution) return out - def analyze(self, solution, score_analysis_fetch_policy=None, solution_update_policy=None) -> 'ScoreAnalysis': + def analyze(self, solution: Solution_, score_analysis_fetch_policy=None, solution_update_policy=None) \ + -> 'ScoreAnalysis': # TODO handle policies from jpyinterpreter import convert_to_java_python_like_object return ScoreAnalysis(self._delegate.analyze(convert_to_java_python_like_object(solution))) - def explain(self, solution, solution_update_policy=None) -> 'ScoreExplanation': + def explain(self, solution: Solution_, solution_update_policy=None) -> 'ScoreExplanation': # TODO handle policies from jpyinterpreter import convert_to_java_python_like_object return ScoreExplanation(self._delegate.explain(convert_to_java_python_like_object(solution))) - def recommend_fit(self, solution, entity_or_element, proposition_function, score_analysis_fetch_policy=None): + def recommend_fit(self, solution: Solution_, entity_or_element, proposition_function, + score_analysis_fetch_policy=None): # TODO raise NotImplementedError -class ScoreExplanation: +class ScoreExplanation(Generic[Solution_]): _delegate: '_JavaScoreExplanation' def __init__(self, delegate: '_JavaScoreExplanation'): @@ -70,7 +72,7 @@ def get_justification_list(self, justification_type=None): def get_score(self) -> 'Score': return self._delegate.getScore() - def get_solution(self): + def get_solution(self) -> Solution_: from jpyinterpreter import unwrap_python_like_object return unwrap_python_like_object(self._delegate.getSolution()) diff --git a/timefold-solver-python-core/src/main/python/api/_solver.py b/timefold-solver-python-core/src/main/python/api/_solver.py index 4659950..6c04e5f 100644 --- a/timefold-solver-python-core/src/main/python/api/_solver.py +++ b/timefold-solver-python-core/src/main/python/api/_solver.py @@ -59,10 +59,10 @@ def terminate_early(self) -> bool: def is_terminate_early(self) -> bool: return self._delegate.isTerminateEarly() - def add_problem_change(self, problem_change: ProblemChange) -> None: + def add_problem_change(self, problem_change: ProblemChange[Solution_]) -> None: self._delegate.addProblemChange(ProblemChangeWrapper(problem_change)) # noqa - def add_problem_changes(self, problem_changes: List[ProblemChange]) -> None: + def add_problem_changes(self, problem_changes: List[ProblemChange[Solution_]]) -> None: self._delegate.addProblemChanges([ProblemChangeWrapper(problem_change) for problem_change in problem_changes]) # noqa def is_every_problem_change_processed(self) -> bool: diff --git a/timefold-solver-python-core/src/main/python/api/_solver_factory.py b/timefold-solver-python-core/src/main/python/api/_solver_factory.py index 124b9a3..ed2f444 100644 --- a/timefold-solver-python-core/src/main/python/api/_solver_factory.py +++ b/timefold-solver-python-core/src/main/python/api/_solver_factory.py @@ -1,7 +1,7 @@ from ._solver import Solver -from ..config import SolverConfig +from ..config import SolverConfig, SolverConfigOverride -from typing import TypeVar, TYPE_CHECKING +from typing import TypeVar, Generic, TYPE_CHECKING from jpype import JClass if TYPE_CHECKING: @@ -12,7 +12,7 @@ Solution_ = TypeVar('Solution_') -class SolverFactory: +class SolverFactory(Generic[Solution_]): _delegate: '_JavaSolverFactory' _solution_class: JClass @@ -21,14 +21,18 @@ def __init__(self, delegate: '_JavaSolverFactory', solution_class: JClass): self._solution_class = solution_class @staticmethod - def create(solver_config: SolverConfig): + def create(solver_config: SolverConfig[Solution_]) -> 'SolverFactory[Solution_]': from ai.timefold.solver.core.api.solver import SolverFactory as JavaSolverFactory solver_config = solver_config._to_java_solver_config() delegate = JavaSolverFactory.create(solver_config) # noqa return SolverFactory(delegate, solver_config.getSolutionClass()) # noqa - def build_solver(self): - return Solver(self._delegate.buildSolver(), self._solution_class) + def build_solver(self, solver_config_override: SolverConfigOverride = None) -> Solver[Solution_]: + if solver_config_override is None: + return Solver(self._delegate.buildSolver(), self._solution_class) + else: + return Solver(self._delegate.buildSolver(solver_config_override._to_java_solver_config_override()), + self._solution_class) __all__ = ['SolverFactory'] diff --git a/timefold-solver-python-core/src/main/python/api/_solver_manager.py b/timefold-solver-python-core/src/main/python/api/_solver_manager.py index 9ed9f2b..7ce6fe1 100644 --- a/timefold-solver-python-core/src/main/python/api/_solver_manager.py +++ b/timefold-solver-python-core/src/main/python/api/_solver_manager.py @@ -1,9 +1,10 @@ from ._problem_change import ProblemChange, ProblemChangeWrapper +from ..config import SolverConfigOverride from ._solver_factory import SolverFactory from ._future import wrap_completable_future from asyncio import Future -from typing import TypeVar, TYPE_CHECKING +from typing import TypeVar, Generic, Callable, TYPE_CHECKING from datetime import timedelta from enum import Enum, auto as auto_enum @@ -27,69 +28,68 @@ def _from_java_enum(enum_value): return getattr(SolverStatus, enum_value.name()) -class SolverJob: +class SolverJob(Generic[Solution_, ProblemId_]): _delegate: '_JavaSolverJob' def __init__(self, delegate: '_JavaSolverJob'): self._delegate = delegate - def get_problem_id(self): + def get_problem_id(self) -> ProblemId_: from jpyinterpreter import unwrap_python_like_object return unwrap_python_like_object(self._delegate.getProblemId()) - def get_solver_status(self): + def get_solver_status(self) -> SolverStatus: return SolverStatus._from_java_enum(self._delegate.getSolverStatus()) def get_solving_duration(self) -> timedelta: return timedelta(milliseconds=self._delegate.getSolvingDuration().toMillis()) - def get_final_best_solution(self): + def get_final_best_solution(self) -> Solution_: from jpyinterpreter import unwrap_python_like_object return unwrap_python_like_object(self._delegate.getFinalBestSolution()) - def terminate_early(self): + def terminate_early(self) -> None: self._delegate.terminateEarly() def is_terminated_early(self) -> bool: return self._delegate.isTerminatedEarly() - def add_problem_change(self, problem_change: ProblemChange) -> Future[None]: + def add_problem_change(self, problem_change: ProblemChange[Solution_]) -> Future[None]: return wrap_completable_future(self._delegate.addProblemChange(ProblemChangeWrapper(problem_change))) -class SolverJobBuilder: +class SolverJobBuilder(Generic[Solution_, ProblemId_]): _delegate: '_JavaSolverJobBuilder' def __init__(self, delegate: '_JavaSolverJobBuilder'): self._delegate = delegate - def with_problem_id(self, problem_id) -> 'SolverJobBuilder': + def with_problem_id(self, problem_id: ProblemId_) -> 'SolverJobBuilder': from jpyinterpreter import convert_to_java_python_like_object return SolverJobBuilder(self._delegate.withProblemId(convert_to_java_python_like_object(problem_id))) - def with_problem(self, problem) -> 'SolverJobBuilder': + def with_problem(self, problem: Solution_) -> 'SolverJobBuilder': from jpyinterpreter import convert_to_java_python_like_object return SolverJobBuilder(self._delegate.withProblem(convert_to_java_python_like_object(problem))) - def with_config_override(self, config_override) -> 'SolverJobBuilder': - # TODO: Create wrapper object for config override - raise NotImplementedError + def with_config_override(self, config_override: SolverConfigOverride) -> 'SolverJobBuilder': + return SolverJobBuilder(self._delegate.withConfigOverride(config_override._to_java_solver_config_override())) - def with_problem_finder(self, problem_finder) -> 'SolverJobBuilder': + def with_problem_finder(self, problem_finder: Callable[[ProblemId_], Solution_]) -> 'SolverJobBuilder': from java.util.function import Function from jpyinterpreter import convert_to_java_python_like_object, unwrap_python_like_object java_finder = Function @ (lambda problem_id: convert_to_java_python_like_object( problem_finder(unwrap_python_like_object(problem_id)))) return SolverJobBuilder(self._delegate.withProblemFinder(java_finder)) - def with_best_solution_consumer(self, best_solution_consumer) -> 'SolverJobBuilder': + def with_best_solution_consumer(self, best_solution_consumer: Callable[[Solution_], None]) -> 'SolverJobBuilder': from java.util.function import Consumer from jpyinterpreter import unwrap_python_like_object java_consumer = Consumer @ (lambda solution: best_solution_consumer(unwrap_python_like_object(solution))) return SolverJobBuilder(self._delegate.withBestSolutionConsumer(java_consumer)) - def with_final_best_solution_consumer(self, final_best_solution_consumer) -> 'SolverJobBuilder': + def with_final_best_solution_consumer(self, final_best_solution_consumer: Callable[[Solution_], None]) -> 'SolverJobBuilder': from java.util.function import Consumer from jpyinterpreter import unwrap_python_like_object @@ -97,7 +97,7 @@ def with_final_best_solution_consumer(self, final_best_solution_consumer) -> 'So return SolverJobBuilder( self._delegate.withFinalBestSolutionConsumer(java_consumer)) - def with_exception_handler(self, exception_handler) -> 'SolverJobBuilder': + def with_exception_handler(self, exception_handler: Callable[[ProblemId_, Exception], None]) -> 'SolverJobBuilder': from java.util.function import BiConsumer from jpyinterpreter import unwrap_python_like_object @@ -106,22 +106,23 @@ def with_exception_handler(self, exception_handler) -> 'SolverJobBuilder': return SolverJobBuilder( self._delegate.withExceptionHandler(java_consumer)) - def run(self) -> SolverJob: + def run(self) -> SolverJob[Solution_, ProblemId_]: return SolverJob(self._delegate.run()) -class SolverManager: +class SolverManager(Generic[Solution_, ProblemId_]): _delegate: '_JavaSolverManager' def __init__(self, delegate: '_JavaSolverManager'): self._delegate = delegate @staticmethod - def create(solver_factory: 'SolverFactory'): + def create(solver_factory: 'SolverFactory[Solution_]') -> 'SolverManager[Solution_, ProblemId_]': from ai.timefold.solver.core.api.solver import SolverManager as JavaSolverManager return SolverManager(JavaSolverManager.create(solver_factory._delegate)) # noqa - def solve(self, problem_id, problem, final_best_solution_listener=None): + def solve(self, problem_id: ProblemId_, problem: Solution_, + final_best_solution_listener: Callable[[Solution_], None] = None) -> SolverJob[Solution_, ProblemId_]: builder = (self.solve_builder() .with_problem_id(problem_id) .with_problem(problem)) @@ -131,37 +132,38 @@ def solve(self, problem_id, problem, final_best_solution_listener=None): return builder.run() - def solve_and_listen(self, problem_id, problem, listener): + def solve_and_listen(self, problem_id: ProblemId_, problem: Solution_, listener: Callable[[Solution_], None]) \ + -> SolverJob[Solution_, ProblemId_]: return (self.solve_builder() .with_problem_id(problem_id) .with_problem(problem) .with_best_solution_consumer(listener) .run()) - def solve_builder(self) -> SolverJobBuilder: + def solve_builder(self) -> SolverJobBuilder[Solution_, ProblemId_]: return SolverJobBuilder(self._delegate.solveBuilder()) - def get_solver_status(self, problem_id): + def get_solver_status(self, problem_id: ProblemId_) -> SolverStatus: from jpyinterpreter import convert_to_java_python_like_object return SolverStatus._from_java_enum(self._delegate.getSolverStatus( convert_to_java_python_like_object(problem_id))) - def terminate_early(self, problem_id): + def terminate_early(self, problem_id: ProblemId_) -> None: from jpyinterpreter import convert_to_java_python_like_object self._delegate.terminateEarly(convert_to_java_python_like_object(problem_id)) - def add_problem_change(self, problem_id, problem_change: ProblemChange) -> Future[None]: + def add_problem_change(self, problem_id: ProblemId_, problem_change: ProblemChange[Solution_]) -> Future[None]: from jpyinterpreter import convert_to_java_python_like_object return wrap_completable_future(self._delegate.addProblemChange(convert_to_java_python_like_object(problem_id), ProblemChangeWrapper(problem_change))) - def close(self): + def close(self) -> None: self._delegate.close() - def __enter__(self): + def __enter__(self) -> 'SolverManager[Solution_, ProblemId_]': return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb) -> None: self._delegate.close() diff --git a/timefold-solver-python-core/src/main/python/config/_config.py b/timefold-solver-python-core/src/main/python/config/_config.py index e8699c1..43c46b1 100644 --- a/timefold-solver-python-core/src/main/python/config/_config.py +++ b/timefold-solver-python-core/src/main/python/config/_config.py @@ -1,7 +1,7 @@ from ..constraint import ConstraintFactory from .._timefold_java_interop import is_enterprise_installed -from typing import Any, Optional, List, Type, Callable, TYPE_CHECKING +from typing import Any, Optional, List, Type, Callable, TypeVar, Generic, TYPE_CHECKING from dataclasses import dataclass, field from enum import Enum, auto from pathlib import Path @@ -97,9 +97,12 @@ def __init__(self, feature): f'install timefold-solver-enterprise.') +Solution_ = TypeVar('Solution_') + + @dataclass(kw_only=True) -class SolverConfig: - solution_class: Optional[Type] = field(default=None) +class SolverConfig(Generic[Solution_]): + solution_class: Optional[Type[Solution_]] = field(default=None) entity_class_list: Optional[List[Type]] = field(default=None) environment_mode: Optional[EnvironmentMode] = field(default=EnvironmentMode.REPRODUCIBLE) random_seed: Optional[int] = field(default=None) @@ -111,11 +114,11 @@ class SolverConfig: xml_source_file: Optional[Path] = field(default=None) @staticmethod - def create_from_xml_resource(path: Path): + def create_from_xml_resource(path: Path) -> 'SolverConfig': return SolverConfig(xml_source_file=path) @staticmethod - def create_from_xml_text(xml_text: str): + def create_from_xml_text(xml_text: str) -> 'SolverConfig': return SolverConfig(xml_source_text=xml_text) def _to_java_solver_config(self) -> '_JavaSolverConfig': @@ -272,6 +275,18 @@ def _to_java_termination_config(self, inherited_config: '_JavaTerminationConfig' return out +@dataclass(kw_only=True) +class SolverConfigOverride: + termination_config: Optional[TerminationConfig] = field(default=None) + + def _to_java_solver_config_override(self): + from ai.timefold.solver.core.api.solver import SolverConfigOverride + out = SolverConfigOverride() + if self.termination_config is not None: + out = out.withTerminationConfig(self.termination_config._to_java_termination_config()) + return out + + __all__ = ['Duration', 'EnvironmentMode', 'TerminationCompositionStyle', 'RequiresEnterpriseError', 'MoveThreadCount', - 'SolverConfig', 'ScoreDirectorFactoryConfig', 'TerminationConfig'] + 'SolverConfig', 'SolverConfigOverride', 'ScoreDirectorFactoryConfig', 'TerminationConfig']