Skip to content
This repository has been archived by the owner on Jul 17, 2024. It is now read-only.

feat: Add support for problem changes #34

Merged
merged 3 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion jpyinterpreter/src/main/python/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,10 @@ def is_c_native(item):
or module == '': # if we cannot find module, assume it is not native
return False

return is_native_module(importlib.import_module(module))
try:
return is_native_module(importlib.import_module(module))
except:
return True


def init_type_to_compiled_java_class():
Expand Down
117 changes: 61 additions & 56 deletions tests/test_solver_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Value:
@dataclass
class Entity:
code: Annotated[str, PlanningId]
value: Annotated[Value, PlanningVariable] = field(default=None)
value: Annotated[Value, PlanningVariable] = field(default=None, compare=False)
Christopher-Chianelli marked this conversation as resolved.
Show resolved Hide resolved

@constraint_provider
def my_constraints(constraint_factory: ConstraintFactory):
Expand Down Expand Up @@ -54,24 +54,22 @@ class Solution:
ValueRangeProvider]
score: Annotated[SimpleScore, PlanningScore] = field(default=None)

# TODO: Support problem changes
# @Problem_change
# class UseOnlyEntityAndValueProblemChange:
# def __init__(self, entity, value):
# self.entity = entity
# self.value = value
#
# def doChange(self, solution: Solution, problem_change_director: timefold.solver.types.ProblemChangeDirector):
# problem_facts_to_remove = solution.value_list.copy()
# entities_to_remove = solution.entity_list.copy()
# for problem_fact in problem_facts_to_remove:
# problem_change_director.removeProblemFact(problem_fact,
# lambda value: solution.value_list.remove(problem_fact))
# for removed_entity in entities_to_remove:
# problem_change_director.removeEntity(removed_entity,
# lambda entity: solution.entity_list.remove(removed_entity))
# problem_change_director.addEntity(self.entity, lambda entity: solution.entity_list.append(entity))
# problem_change_director.addProblemFact(self.value, lambda value: solution.value_list.append(value))
class UseOnlyEntityAndValueProblemChange(ProblemChange[Solution]):
def __init__(self, entity, value):
self.entity = entity
self.value = value

def do_change(self, solution: Solution, problem_change_director: ProblemChangeDirector):
problem_facts_to_remove = solution.value_list.copy()
entities_to_remove = solution.entity_list.copy()
for problem_fact in problem_facts_to_remove:
problem_change_director.remove_problem_fact(problem_fact,
lambda value: solution.value_list.remove(value))
for removed_entity in entities_to_remove:
problem_change_director.remove_entity(removed_entity,
lambda entity: solution.entity_list.remove(entity))
problem_change_director.add_entity(self.entity, lambda entity: solution.entity_list.append(entity))
problem_change_director.add_problem_fact(self.value, lambda value: solution.value_list.append(value))

solver_config = SolverConfig(
solution_class=Solution,
Expand All @@ -97,27 +95,27 @@ def assert_solver_run(solver_manager, solver_job):
assert 3 in value_list
assert solver_manager.get_solver_status(1) == SolverStatus.NOT_SOLVING

# def assert_problem_change_solver_run(solver_manager, solver_job):
# assert solver_manager.get_solver_status(1) != SolverStatus.NOT_SOLVING
# solver_manager.addProblemChange(1, UseOnlyEntityAndValueProblemChange(Entity('D'), Value(6)))
# lock.release()
# solution = solver_job.get_final_best_solution()
# assert solution.score.score() == 6
# assert len(solution.entity_list) == 1
# assert len(solution.value_range) == 1
# assert solution.entity_list[0].code == 'D'
# assert solution.entity_list[0].value.value == 6
# assert solution.value_range[0].value == 6
# assert solver_manager.get_solver_status(1) == SolverStatus.NOT_SOLVING
def assert_problem_change_solver_run(solver_manager, solver_job):
assert solver_manager.get_solver_status(1) != SolverStatus.NOT_SOLVING
solver_manager.add_problem_change(1, UseOnlyEntityAndValueProblemChange(Entity('D'), Value(6)))
lock.release()
solution = solver_job.get_final_best_solution()
assert solution.score.score() == 6
assert len(solution.entity_list) == 1
assert len(solution.value_list) == 1
assert solution.entity_list[0].code == 'D'
assert solution.entity_list[0].value.value == 6
assert solution.value_list[0].value == 6
assert solver_manager.get_solver_status(1) == SolverStatus.NOT_SOLVING

with SolverManager.create(SolverFactory.create(solver_config)) as solver_manager:
lock.acquire()
solver_job = solver_manager.solve(1, problem)
assert_solver_run(solver_manager, solver_job)

# lock.acquire()
# solver_job = solver_manager.solve(1, problem)
# assert_problem_change_solver_run(solver_manager, solver_job)
lock.acquire()
solver_job = solver_manager.solve(1, problem)
assert_problem_change_solver_run(solver_manager, solver_job)

def get_problem(problem_id):
assert problem_id == 1
Expand All @@ -129,9 +127,11 @@ def get_problem(problem_id):
.with_problem_finder(get_problem)).run()
assert_solver_run(solver_manager, solver_job)

# lock.acquire()
#solver_job = solver_manager.solve(1, get_problem)
#assert_problem_change_solver_run(solver_manager, solver_job)
lock.acquire()
solver_job = (solver_manager.solve_builder()
.with_problem_id(1)
.with_problem_finder(get_problem)).run()
assert_problem_change_solver_run(solver_manager, solver_job)

solution_list = []
semaphore = Semaphore(0)
Expand All @@ -150,15 +150,16 @@ def on_best_solution_changed(solution):
assert semaphore.acquire(timeout=1)
assert len(solution_list) == 1

# solution_list = []
# lock.acquire()
# solver_job = (solver_manager.solve_builder()
# .with_problem_id(1)
# .with_problem_finder(get_problem)
# .with_best_solution_consumer(on_best_solution_changed)
# ).run()
#assert_problem_change_solver_run(solver_manager, solver_job)
# assert len(solution_list) == 1
solution_list = []
lock.acquire()
solver_job = (solver_manager.solve_builder()
.with_problem_id(1)
.with_problem_finder(get_problem)
.with_best_solution_consumer(on_best_solution_changed)
).run()
assert_problem_change_solver_run(solver_manager, solver_job)
assert semaphore.acquire(timeout=1)
assert len(solution_list) == 1

solution_list = []
lock.acquire()
Expand All @@ -175,16 +176,20 @@ def on_best_solution_changed(solution):
assert semaphore.acquire(timeout=1)
assert len(solution_list) == 2

# solution_list = []
# lock.acquire()
# solver_job = (solver_manager.solve_builder()
# .with_problem_id(1)
# .with_problem_finder(get_problem)
# .with_best_solution_consumer(on_best_solution_changed)
# .with_final_best_solution_consumer(on_best_solution_changed)
# ).run()
# assert_problem_change_solver_run(solver_manager, solver_job)
# assert len(solution_list) == 2
solution_list = []
lock.acquire()
solver_job = (solver_manager.solve_builder()
.with_problem_id(1)
.with_problem_finder(get_problem)
.with_best_solution_consumer(on_best_solution_changed)
.with_final_best_solution_consumer(on_best_solution_changed)
).run()
assert_problem_change_solver_run(solver_manager, solver_job)
# Wait for 2 acquires, one for best solution consumer,
# another for final best solution consumer
assert semaphore.acquire(timeout=1)
assert semaphore.acquire(timeout=1)
assert len(solution_list) == 2


@pytest.mark.filterwarnings("ignore:.*Exception in thread.*:pytest.PytestUnhandledThreadExceptionWarning")
Expand Down
135 changes: 135 additions & 0 deletions tests/test_solver_problem_change.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from timefold.solver.api import *
from timefold.solver.annotation import *
from timefold.solver.config import *
from timefold.solver.constraint import *
from timefold.solver.score import *

from dataclasses import dataclass, field
from typing import Annotated, List
from threading import Thread


@planning_entity
@dataclass
class Entity:
code: Annotated[str, PlanningId]
value: Annotated[int, PlanningVariable] = field(default=None, compare=False)


@constraint_provider
def maximize_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.reward(SimpleScore.ONE, lambda entity: entity.value)
.as_constraint('Maximize value'),
]


@constraint_provider
def minimize_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.penalize(SimpleScore.ONE, lambda entity: entity.value)
.as_constraint('Minimize value'),
]


@planning_solution
@dataclass
class Solution:
entities: Annotated[List[Entity], PlanningEntityCollectionProperty]
value_range: Annotated[List[int], ValueRangeProvider]
score: Annotated[SimpleScore, PlanningScore] = field(default=None)

def __str__(self) -> str:
return str(self.entities)


class AddEntity(ProblemChange[Solution]):
entity: Entity

def __init__(self, entity: Entity):
self.entity = entity

def do_change(self, working_solution: Solution, problem_change_director: ProblemChangeDirector):
problem_change_director.add_entity(self.entity,
lambda working_entity: working_solution.entities.append(working_entity))


class RemoveEntity(ProblemChange[Solution]):
entity: Entity

def __init__(self, entity: Entity):
self.entity = entity

def do_change(self, working_solution: Solution, problem_change_director: ProblemChangeDirector):
problem_change_director.remove_entity(self.entity,
lambda working_entity: working_solution.entities.remove(working_entity))


def test_add_entity():
solver_config = SolverConfig(
solution_class=Solution,
entity_class_list=[Entity],
score_director_factory_config=ScoreDirectorFactoryConfig(
constraint_provider_function=maximize_constraints,
),
termination_config=TerminationConfig(
best_score_limit='6'
)
)

problem: Solution = Solution([Entity('A')], [1, 2, 3])
solver = SolverFactory.create(solver_config).build_solver()
result: Solution | None = None

def do_solve(problem: Solution):
nonlocal solver, result
result = solver.solve(problem)

thread = Thread(target=do_solve, args=(problem,), daemon=True)

thread.start()
solver.add_problem_change(AddEntity(Entity('B')))
thread.join(timeout=1)

if thread.is_alive():
raise AssertionError(f'Thread {thread} did not finish after 5 seconds')

assert result is not None
assert len(result.entities) == 2
assert result.score.score() == 6


def test_remove_entity():
solver_config = SolverConfig(
solution_class=Solution,
entity_class_list=[Entity],
score_director_factory_config=ScoreDirectorFactoryConfig(
constraint_provider_function=minimize_constraints,
),
termination_config=TerminationConfig(
best_score_limit='-1'
)
)

problem: Solution = Solution([Entity('A'), Entity('B')], [1, 2, 3])
solver = SolverFactory.create(solver_config).build_solver()
result: Solution | None = None

def do_solve(problem: Solution):
nonlocal solver, result
result = solver.solve(problem)

thread = Thread(target=do_solve, args=(problem,), daemon=True)

thread.start()
solver.add_problem_change(RemoveEntity(Entity('B')))
thread.join(timeout=1)

if thread.is_alive():
raise AssertionError(f'Thread {thread} did not finish after 5 seconds')

assert result is not None
assert len(result.entities) == 1
assert result.score.score() == -1
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .problem_change import *
from .solver import *
from .solver_factory import *
from .solver_manager import *
Expand Down
42 changes: 42 additions & 0 deletions timefold-solver-python-core/src/main/python/api/future.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from ..jpype_type_conversions import PythonBiFunction
from typing import Awaitable, TypeVar, TYPE_CHECKING
from asyncio import Future, get_event_loop, CancelledError

if TYPE_CHECKING:
from java.util.concurrent import (Future as JavaFuture,
CompletableFuture as JavaCompletableFuture)


Result = TypeVar('Result')


def wrap_future(future: 'JavaFuture[Result]') -> Awaitable[Result]:
async def get_result() -> Result:
nonlocal future
return future.get()

return get_result()


def wrap_completable_future(future: 'JavaCompletableFuture[Result]') -> Future[Result]:
loop = get_event_loop()
out = loop.create_future()

def result_handler(result, error):
nonlocal out
if error is not None:
out.set_exception(error)
else:
out.set_result(result)

def cancel_handler(python_future: Future):
nonlocal future
if isinstance(python_future.exception(), CancelledError):
future.cancel(True)

future.handle(PythonBiFunction(result_handler))
out.add_done_callback(cancel_handler)
return out


__all__ = ['wrap_future', 'wrap_completable_future']
Loading
Loading