Skip to content
This repository has been archived by the owner on Jul 17, 2024. It is now read-only.

Commit

Permalink
feat: Add support for problem changes (#34)
Browse files Browse the repository at this point in the history
* feat: Add support for problem changes

- Since the methods in ProblemChangeDirector take
  interfaces, ProblemChange as a whole cannot be
  translated to pure Java (since that requires
  supporting casting an arbitary callable to any
  Java interface, which is not suported yet (nor
  planned to be supported)).

- Thus we goes for a more ugly approach: the
  ProblemChange runs in Python, and when a
  problem change director method is called,
  we compile/translate the supplied function
  to Java.

- We do a trick where we replace the Python
  working solution clone with the actual Java
  working solution in the closure before compiling
  the function so changes are applied to the right
  object. After the method is called, we then update
  the Python working solution from the java working
  solution so changes are reflected in it too.

- Users implement a ProblemChange by extending an
  abstract base class (which, among other things,
  raises an error if not all of its methods are
  implemented).
  • Loading branch information
Christopher-Chianelli authored Apr 16, 2024
1 parent 30a13b9 commit fb72d9c
Show file tree
Hide file tree
Showing 8 changed files with 379 additions and 67 deletions.
5 changes: 4 additions & 1 deletion jpyinterpreter/src/main/python/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ def is_c_native(item):
or module == '': # if we cannot find module, assume it is not native
return False

return is_native_module(importlib.import_module(module))
try:
return is_native_module(importlib.import_module(module))
except:
return True


def init_type_to_compiled_java_class():
Expand Down
117 changes: 61 additions & 56 deletions tests/test_solver_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Value:
@dataclass
class Entity:
code: Annotated[str, PlanningId]
value: Annotated[Value, PlanningVariable] = field(default=None)
value: Annotated[Value, PlanningVariable] = field(default=None, compare=False)

@constraint_provider
def my_constraints(constraint_factory: ConstraintFactory):
Expand Down Expand Up @@ -54,24 +54,22 @@ class Solution:
ValueRangeProvider]
score: Annotated[SimpleScore, PlanningScore] = field(default=None)

# TODO: Support problem changes
# @Problem_change
# class UseOnlyEntityAndValueProblemChange:
# def __init__(self, entity, value):
# self.entity = entity
# self.value = value
#
# def doChange(self, solution: Solution, problem_change_director: timefold.solver.types.ProblemChangeDirector):
# problem_facts_to_remove = solution.value_list.copy()
# entities_to_remove = solution.entity_list.copy()
# for problem_fact in problem_facts_to_remove:
# problem_change_director.removeProblemFact(problem_fact,
# lambda value: solution.value_list.remove(problem_fact))
# for removed_entity in entities_to_remove:
# problem_change_director.removeEntity(removed_entity,
# lambda entity: solution.entity_list.remove(removed_entity))
# problem_change_director.addEntity(self.entity, lambda entity: solution.entity_list.append(entity))
# problem_change_director.addProblemFact(self.value, lambda value: solution.value_list.append(value))
class UseOnlyEntityAndValueProblemChange(ProblemChange[Solution]):
def __init__(self, entity, value):
self.entity = entity
self.value = value

def do_change(self, solution: Solution, problem_change_director: ProblemChangeDirector):
problem_facts_to_remove = solution.value_list.copy()
entities_to_remove = solution.entity_list.copy()
for problem_fact in problem_facts_to_remove:
problem_change_director.remove_problem_fact(problem_fact,
lambda value: solution.value_list.remove(value))
for removed_entity in entities_to_remove:
problem_change_director.remove_entity(removed_entity,
lambda entity: solution.entity_list.remove(entity))
problem_change_director.add_entity(self.entity, lambda entity: solution.entity_list.append(entity))
problem_change_director.add_problem_fact(self.value, lambda value: solution.value_list.append(value))

solver_config = SolverConfig(
solution_class=Solution,
Expand All @@ -97,27 +95,27 @@ def assert_solver_run(solver_manager, solver_job):
assert 3 in value_list
assert solver_manager.get_solver_status(1) == SolverStatus.NOT_SOLVING

# def assert_problem_change_solver_run(solver_manager, solver_job):
# assert solver_manager.get_solver_status(1) != SolverStatus.NOT_SOLVING
# solver_manager.addProblemChange(1, UseOnlyEntityAndValueProblemChange(Entity('D'), Value(6)))
# lock.release()
# solution = solver_job.get_final_best_solution()
# assert solution.score.score() == 6
# assert len(solution.entity_list) == 1
# assert len(solution.value_range) == 1
# assert solution.entity_list[0].code == 'D'
# assert solution.entity_list[0].value.value == 6
# assert solution.value_range[0].value == 6
# assert solver_manager.get_solver_status(1) == SolverStatus.NOT_SOLVING
def assert_problem_change_solver_run(solver_manager, solver_job):
assert solver_manager.get_solver_status(1) != SolverStatus.NOT_SOLVING
solver_manager.add_problem_change(1, UseOnlyEntityAndValueProblemChange(Entity('D'), Value(6)))
lock.release()
solution = solver_job.get_final_best_solution()
assert solution.score.score() == 6
assert len(solution.entity_list) == 1
assert len(solution.value_list) == 1
assert solution.entity_list[0].code == 'D'
assert solution.entity_list[0].value.value == 6
assert solution.value_list[0].value == 6
assert solver_manager.get_solver_status(1) == SolverStatus.NOT_SOLVING

with SolverManager.create(SolverFactory.create(solver_config)) as solver_manager:
lock.acquire()
solver_job = solver_manager.solve(1, problem)
assert_solver_run(solver_manager, solver_job)

# lock.acquire()
# solver_job = solver_manager.solve(1, problem)
# assert_problem_change_solver_run(solver_manager, solver_job)
lock.acquire()
solver_job = solver_manager.solve(1, problem)
assert_problem_change_solver_run(solver_manager, solver_job)

def get_problem(problem_id):
assert problem_id == 1
Expand All @@ -129,9 +127,11 @@ def get_problem(problem_id):
.with_problem_finder(get_problem)).run()
assert_solver_run(solver_manager, solver_job)

# lock.acquire()
#solver_job = solver_manager.solve(1, get_problem)
#assert_problem_change_solver_run(solver_manager, solver_job)
lock.acquire()
solver_job = (solver_manager.solve_builder()
.with_problem_id(1)
.with_problem_finder(get_problem)).run()
assert_problem_change_solver_run(solver_manager, solver_job)

solution_list = []
semaphore = Semaphore(0)
Expand All @@ -150,15 +150,16 @@ def on_best_solution_changed(solution):
assert semaphore.acquire(timeout=1)
assert len(solution_list) == 1

# solution_list = []
# lock.acquire()
# solver_job = (solver_manager.solve_builder()
# .with_problem_id(1)
# .with_problem_finder(get_problem)
# .with_best_solution_consumer(on_best_solution_changed)
# ).run()
#assert_problem_change_solver_run(solver_manager, solver_job)
# assert len(solution_list) == 1
solution_list = []
lock.acquire()
solver_job = (solver_manager.solve_builder()
.with_problem_id(1)
.with_problem_finder(get_problem)
.with_best_solution_consumer(on_best_solution_changed)
).run()
assert_problem_change_solver_run(solver_manager, solver_job)
assert semaphore.acquire(timeout=1)
assert len(solution_list) == 1

solution_list = []
lock.acquire()
Expand All @@ -175,16 +176,20 @@ def on_best_solution_changed(solution):
assert semaphore.acquire(timeout=1)
assert len(solution_list) == 2

# solution_list = []
# lock.acquire()
# solver_job = (solver_manager.solve_builder()
# .with_problem_id(1)
# .with_problem_finder(get_problem)
# .with_best_solution_consumer(on_best_solution_changed)
# .with_final_best_solution_consumer(on_best_solution_changed)
# ).run()
# assert_problem_change_solver_run(solver_manager, solver_job)
# assert len(solution_list) == 2
solution_list = []
lock.acquire()
solver_job = (solver_manager.solve_builder()
.with_problem_id(1)
.with_problem_finder(get_problem)
.with_best_solution_consumer(on_best_solution_changed)
.with_final_best_solution_consumer(on_best_solution_changed)
).run()
assert_problem_change_solver_run(solver_manager, solver_job)
# Wait for 2 acquires, one for best solution consumer,
# another for final best solution consumer
assert semaphore.acquire(timeout=1)
assert semaphore.acquire(timeout=1)
assert len(solution_list) == 2


@pytest.mark.filterwarnings("ignore:.*Exception in thread.*:pytest.PytestUnhandledThreadExceptionWarning")
Expand Down
135 changes: 135 additions & 0 deletions tests/test_solver_problem_change.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from timefold.solver.api import *
from timefold.solver.annotation import *
from timefold.solver.config import *
from timefold.solver.constraint import *
from timefold.solver.score import *

from dataclasses import dataclass, field
from typing import Annotated, List
from threading import Thread


@planning_entity
@dataclass
class Entity:
code: Annotated[str, PlanningId]
value: Annotated[int, PlanningVariable] = field(default=None, compare=False)


@constraint_provider
def maximize_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.reward(SimpleScore.ONE, lambda entity: entity.value)
.as_constraint('Maximize value'),
]


@constraint_provider
def minimize_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.penalize(SimpleScore.ONE, lambda entity: entity.value)
.as_constraint('Minimize value'),
]


@planning_solution
@dataclass
class Solution:
entities: Annotated[List[Entity], PlanningEntityCollectionProperty]
value_range: Annotated[List[int], ValueRangeProvider]
score: Annotated[SimpleScore, PlanningScore] = field(default=None)

def __str__(self) -> str:
return str(self.entities)


class AddEntity(ProblemChange[Solution]):
entity: Entity

def __init__(self, entity: Entity):
self.entity = entity

def do_change(self, working_solution: Solution, problem_change_director: ProblemChangeDirector):
problem_change_director.add_entity(self.entity,
lambda working_entity: working_solution.entities.append(working_entity))


class RemoveEntity(ProblemChange[Solution]):
entity: Entity

def __init__(self, entity: Entity):
self.entity = entity

def do_change(self, working_solution: Solution, problem_change_director: ProblemChangeDirector):
problem_change_director.remove_entity(self.entity,
lambda working_entity: working_solution.entities.remove(working_entity))


def test_add_entity():
solver_config = SolverConfig(
solution_class=Solution,
entity_class_list=[Entity],
score_director_factory_config=ScoreDirectorFactoryConfig(
constraint_provider_function=maximize_constraints,
),
termination_config=TerminationConfig(
best_score_limit='6'
)
)

problem: Solution = Solution([Entity('A')], [1, 2, 3])
solver = SolverFactory.create(solver_config).build_solver()
result: Solution | None = None

def do_solve(problem: Solution):
nonlocal solver, result
result = solver.solve(problem)

thread = Thread(target=do_solve, args=(problem,), daemon=True)

thread.start()
solver.add_problem_change(AddEntity(Entity('B')))
thread.join(timeout=1)

if thread.is_alive():
raise AssertionError(f'Thread {thread} did not finish after 5 seconds')

assert result is not None
assert len(result.entities) == 2
assert result.score.score() == 6


def test_remove_entity():
solver_config = SolverConfig(
solution_class=Solution,
entity_class_list=[Entity],
score_director_factory_config=ScoreDirectorFactoryConfig(
constraint_provider_function=minimize_constraints,
),
termination_config=TerminationConfig(
best_score_limit='-1'
)
)

problem: Solution = Solution([Entity('A'), Entity('B')], [1, 2, 3])
solver = SolverFactory.create(solver_config).build_solver()
result: Solution | None = None

def do_solve(problem: Solution):
nonlocal solver, result
result = solver.solve(problem)

thread = Thread(target=do_solve, args=(problem,), daemon=True)

thread.start()
solver.add_problem_change(RemoveEntity(Entity('B')))
thread.join(timeout=1)

if thread.is_alive():
raise AssertionError(f'Thread {thread} did not finish after 5 seconds')

assert result is not None
assert len(result.entities) == 1
assert result.score.score() == -1
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .problem_change import *
from .solver import *
from .solver_factory import *
from .solver_manager import *
Expand Down
42 changes: 42 additions & 0 deletions timefold-solver-python-core/src/main/python/api/future.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from ..jpype_type_conversions import PythonBiFunction
from typing import Awaitable, TypeVar, TYPE_CHECKING
from asyncio import Future, get_event_loop, CancelledError

if TYPE_CHECKING:
from java.util.concurrent import (Future as JavaFuture,
CompletableFuture as JavaCompletableFuture)


Result = TypeVar('Result')


def wrap_future(future: 'JavaFuture[Result]') -> Awaitable[Result]:
async def get_result() -> Result:
nonlocal future
return future.get()

return get_result()


def wrap_completable_future(future: 'JavaCompletableFuture[Result]') -> Future[Result]:
loop = get_event_loop()
out = loop.create_future()

def result_handler(result, error):
nonlocal out
if error is not None:
out.set_exception(error)
else:
out.set_result(result)

def cancel_handler(python_future: Future):
nonlocal future
if isinstance(python_future.exception(), CancelledError):
future.cancel(True)

future.handle(PythonBiFunction(result_handler))
out.add_done_callback(cancel_handler)
return out


__all__ = ['wrap_future', 'wrap_completable_future']
Loading

0 comments on commit fb72d9c

Please sign in to comment.