This repository has been archived by the owner on Jul 17, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add support for problem changes (#34)
* feat: Add support for problem changes - Since the methods in ProblemChangeDirector take interfaces, ProblemChange as a whole cannot be translated to pure Java (since that requires supporting casting an arbitary callable to any Java interface, which is not suported yet (nor planned to be supported)). - Thus we goes for a more ugly approach: the ProblemChange runs in Python, and when a problem change director method is called, we compile/translate the supplied function to Java. - We do a trick where we replace the Python working solution clone with the actual Java working solution in the closure before compiling the function so changes are applied to the right object. After the method is called, we then update the Python working solution from the java working solution so changes are reflected in it too. - Users implement a ProblemChange by extending an abstract base class (which, among other things, raises an error if not all of its methods are implemented).
- Loading branch information
1 parent
30a13b9
commit fb72d9c
Showing
8 changed files
with
379 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from timefold.solver.api import * | ||
from timefold.solver.annotation import * | ||
from timefold.solver.config import * | ||
from timefold.solver.constraint import * | ||
from timefold.solver.score import * | ||
|
||
from dataclasses import dataclass, field | ||
from typing import Annotated, List | ||
from threading import Thread | ||
|
||
|
||
@planning_entity | ||
@dataclass | ||
class Entity: | ||
code: Annotated[str, PlanningId] | ||
value: Annotated[int, PlanningVariable] = field(default=None, compare=False) | ||
|
||
|
||
@constraint_provider | ||
def maximize_constraints(constraint_factory: ConstraintFactory): | ||
return [ | ||
constraint_factory.for_each(Entity) | ||
.reward(SimpleScore.ONE, lambda entity: entity.value) | ||
.as_constraint('Maximize value'), | ||
] | ||
|
||
|
||
@constraint_provider | ||
def minimize_constraints(constraint_factory: ConstraintFactory): | ||
return [ | ||
constraint_factory.for_each(Entity) | ||
.penalize(SimpleScore.ONE, lambda entity: entity.value) | ||
.as_constraint('Minimize value'), | ||
] | ||
|
||
|
||
@planning_solution | ||
@dataclass | ||
class Solution: | ||
entities: Annotated[List[Entity], PlanningEntityCollectionProperty] | ||
value_range: Annotated[List[int], ValueRangeProvider] | ||
score: Annotated[SimpleScore, PlanningScore] = field(default=None) | ||
|
||
def __str__(self) -> str: | ||
return str(self.entities) | ||
|
||
|
||
class AddEntity(ProblemChange[Solution]): | ||
entity: Entity | ||
|
||
def __init__(self, entity: Entity): | ||
self.entity = entity | ||
|
||
def do_change(self, working_solution: Solution, problem_change_director: ProblemChangeDirector): | ||
problem_change_director.add_entity(self.entity, | ||
lambda working_entity: working_solution.entities.append(working_entity)) | ||
|
||
|
||
class RemoveEntity(ProblemChange[Solution]): | ||
entity: Entity | ||
|
||
def __init__(self, entity: Entity): | ||
self.entity = entity | ||
|
||
def do_change(self, working_solution: Solution, problem_change_director: ProblemChangeDirector): | ||
problem_change_director.remove_entity(self.entity, | ||
lambda working_entity: working_solution.entities.remove(working_entity)) | ||
|
||
|
||
def test_add_entity(): | ||
solver_config = SolverConfig( | ||
solution_class=Solution, | ||
entity_class_list=[Entity], | ||
score_director_factory_config=ScoreDirectorFactoryConfig( | ||
constraint_provider_function=maximize_constraints, | ||
), | ||
termination_config=TerminationConfig( | ||
best_score_limit='6' | ||
) | ||
) | ||
|
||
problem: Solution = Solution([Entity('A')], [1, 2, 3]) | ||
solver = SolverFactory.create(solver_config).build_solver() | ||
result: Solution | None = None | ||
|
||
def do_solve(problem: Solution): | ||
nonlocal solver, result | ||
result = solver.solve(problem) | ||
|
||
thread = Thread(target=do_solve, args=(problem,), daemon=True) | ||
|
||
thread.start() | ||
solver.add_problem_change(AddEntity(Entity('B'))) | ||
thread.join(timeout=1) | ||
|
||
if thread.is_alive(): | ||
raise AssertionError(f'Thread {thread} did not finish after 5 seconds') | ||
|
||
assert result is not None | ||
assert len(result.entities) == 2 | ||
assert result.score.score() == 6 | ||
|
||
|
||
def test_remove_entity(): | ||
solver_config = SolverConfig( | ||
solution_class=Solution, | ||
entity_class_list=[Entity], | ||
score_director_factory_config=ScoreDirectorFactoryConfig( | ||
constraint_provider_function=minimize_constraints, | ||
), | ||
termination_config=TerminationConfig( | ||
best_score_limit='-1' | ||
) | ||
) | ||
|
||
problem: Solution = Solution([Entity('A'), Entity('B')], [1, 2, 3]) | ||
solver = SolverFactory.create(solver_config).build_solver() | ||
result: Solution | None = None | ||
|
||
def do_solve(problem: Solution): | ||
nonlocal solver, result | ||
result = solver.solve(problem) | ||
|
||
thread = Thread(target=do_solve, args=(problem,), daemon=True) | ||
|
||
thread.start() | ||
solver.add_problem_change(RemoveEntity(Entity('B'))) | ||
thread.join(timeout=1) | ||
|
||
if thread.is_alive(): | ||
raise AssertionError(f'Thread {thread} did not finish after 5 seconds') | ||
|
||
assert result is not None | ||
assert len(result.entities) == 1 | ||
assert result.score.score() == -1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .problem_change import * | ||
from .solver import * | ||
from .solver_factory import * | ||
from .solver_manager import * | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from ..jpype_type_conversions import PythonBiFunction | ||
from typing import Awaitable, TypeVar, TYPE_CHECKING | ||
from asyncio import Future, get_event_loop, CancelledError | ||
|
||
if TYPE_CHECKING: | ||
from java.util.concurrent import (Future as JavaFuture, | ||
CompletableFuture as JavaCompletableFuture) | ||
|
||
|
||
Result = TypeVar('Result') | ||
|
||
|
||
def wrap_future(future: 'JavaFuture[Result]') -> Awaitable[Result]: | ||
async def get_result() -> Result: | ||
nonlocal future | ||
return future.get() | ||
|
||
return get_result() | ||
|
||
|
||
def wrap_completable_future(future: 'JavaCompletableFuture[Result]') -> Future[Result]: | ||
loop = get_event_loop() | ||
out = loop.create_future() | ||
|
||
def result_handler(result, error): | ||
nonlocal out | ||
if error is not None: | ||
out.set_exception(error) | ||
else: | ||
out.set_result(result) | ||
|
||
def cancel_handler(python_future: Future): | ||
nonlocal future | ||
if isinstance(python_future.exception(), CancelledError): | ||
future.cancel(True) | ||
|
||
future.handle(PythonBiFunction(result_handler)) | ||
out.add_done_callback(cancel_handler) | ||
return out | ||
|
||
|
||
__all__ = ['wrap_future', 'wrap_completable_future'] |
Oops, something went wrong.