Skip to content
This repository has been archived by the owner on Jul 17, 2024. It is now read-only.

Commit

Permalink
chore: Implement IncrementalScoreCalculator using classes instead of …
Browse files Browse the repository at this point in the history
…decorators

- Since there can only be one function signature in Python, and Java
  allows many, it might be the case that the top function signature
  in Python does not match its parent's function signature. Since
  the interface calls the parent's function signature, the wrong
  method would be called. To prevent this, we need to look up
  the 'canonical' method of the type, which is conveniently stored
  as an attribute on the type.

- Fix a bug in function __get__ descriptor; in particular, when called
  on a type, it should return the unbounded function instead of binding
  the function to the type.

- Make the ABC check less strict. In particular, only collections.abc
  and Protocol are banned, since collections.abc contain classes that
  should be Protocols but are instead ABC, and Protocols only define
  the structure and do not play a part in type hierarchy.
  • Loading branch information
Christopher-Chianelli committed Apr 25, 2024
1 parent aa61c30 commit 119c144
Show file tree
Hide file tree
Showing 6 changed files with 239 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
import ai.timefold.jpyinterpreter.types.BoundPythonLikeFunction;
import ai.timefold.jpyinterpreter.types.PythonLikeFunction;
import ai.timefold.jpyinterpreter.types.PythonLikeType;
import ai.timefold.jpyinterpreter.types.PythonNone;

public class FunctionBuiltinOperations {
public static PythonLikeObject bindFunctionToInstance(final PythonLikeFunction function, final PythonLikeObject instance,
final PythonLikeType type) {
if (instance == PythonNone.INSTANCE) {
return function;
}
return new BoundPythonLikeFunction(instance, function);
}

Expand Down
20 changes: 14 additions & 6 deletions jpyinterpreter/src/main/python/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import inspect
import sys
import abc
from typing import Protocol

from jpype import JInt, JBoolean, JProxy, JClass, JArray


Expand Down Expand Up @@ -505,6 +507,7 @@ def force_update_type(python_type, java_type):


def translate_python_class_to_java_class(python_class):
import collections.abc as collections_abc
from .annotations import erase_generic_args, convert_java_annotation, copy_type_annotations
from .conversions import (
init_type_to_compiled_java_class, is_banned_module, is_c_native, convert_to_java_python_like_object
Expand All @@ -523,16 +526,21 @@ def translate_python_class_to_java_class(python_class):
if raw_type in type_to_compiled_java_class:
return type_to_compiled_java_class[raw_type]

if python_class == abc.ABC or inspect.isabstract(python_class): # TODO: Implement a class for interfaces?
if Protocol in python_class.__bases__:
python_class_java_type = BuiltinTypes.BASE_TYPE
type_to_compiled_java_class[python_class] = python_class_java_type
return python_class_java_type

if hasattr(python_class, '__module__') and python_class.__module__ is not None and \
is_banned_module(python_class.__module__):
python_class_java_type = CPythonType.getType(JProxy(OpaquePythonReference, inst=python_class, convert=True))
type_to_compiled_java_class[python_class] = python_class_java_type
return python_class_java_type
if hasattr(python_class, '__module__') and python_class.__module__ is not None:
if python_class.__module__ == collections_abc.Collection.__module__:
python_class_java_type = BuiltinTypes.BASE_TYPE
type_to_compiled_java_class[python_class] = python_class_java_type
return python_class_java_type

if is_banned_module(python_class.__module__):
python_class_java_type = CPythonType.getType(JProxy(OpaquePythonReference, inst=python_class, convert=True))
type_to_compiled_java_class[python_class] = python_class_java_type
return python_class_java_type

if isinstance(python_class, JArray):
python_class_java_type = CPythonType.getType(JProxy(OpaquePythonReference, inst=python_class, convert=True))
Expand Down
99 changes: 48 additions & 51 deletions tests/test_incremental_score_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@ class Queen:
column: int
row: Annotated[Optional[int], PlanningVariable] = field(default=None)

def getColumnIndex(self):
def get_column_index(self):
return self.column

def getRowIndex(self):
def get_row_index(self):
if self.row is None:
return -1
return self.row

def getAscendingDiagonalIndex(self):
return self.getColumnIndex() + self.getRowIndex()
def get_ascending_diagonal_index(self):
return self.get_column_index() + self.get_row_index()

def getDescendingDiagonalIndex(self):
return self.getColumnIndex() - self.getRowIndex()
def get_descending_diagonal_index(self):
return self.get_column_index() - self.get_row_index()

def __eq__(self, other):
return self.code == other.code
Expand All @@ -48,14 +48,13 @@ class Solution:


def test_constraint_match_disabled_incremental_score_calculator():
@incremental_score_calculator
class IncrementalScoreCalculator:
class NQueensIncrementalScoreCalculator(IncrementalScoreCalculator):
score: int
row_index_map: dict
ascending_diagonal_index_map: dict
descending_diagonal_index_map: dict

def resetWorkingSolution(self, working_solution: Solution):
def reset_working_solution(self, working_solution: Solution):
n = working_solution.n
self.row_index_map = dict()
self.ascending_diagonal_index_map = dict()
Expand All @@ -71,22 +70,22 @@ def resetWorkingSolution(self, working_solution: Solution):
for queen in working_solution.queen_list:
self.insert(queen)

def beforeEntityAdded(self, entity: any):
def before_entity_added(self, entity: any):
pass

def afterEntityAdded(self, entity: any):
def after_entity_added(self, entity: any):
self.insert(entity)

def beforeVariableChanged(self, entity: any, variableName: str):
def before_variable_changed(self, entity: any, variableName: str):
self.retract(entity)

def afterVariableChanged(self, entity: any, variableName: str):
def after_variable_changed(self, entity: any, variableName: str):
self.insert(entity)

def beforeEntityRemoved(self, entity: any):
def before_entity_removed(self, entity: any):
self.retract(entity)

def afterEntityRemoved(self, entity: any):
def after_entity_removed(self, entity: any):
pass

def insert(self, queen: Queen):
Expand All @@ -95,10 +94,10 @@ def insert(self, queen: Queen):
row_index_list = self.row_index_map[row_index]
self.score -= len(row_index_list)
row_index_list.append(queen)
ascending_diagonal_index_list = self.ascending_diagonal_index_map[queen.getAscendingDiagonalIndex()]
ascending_diagonal_index_list = self.ascending_diagonal_index_map[queen.get_ascending_diagonal_index()]
self.score -= len(ascending_diagonal_index_list)
ascending_diagonal_index_list.append(queen)
descending_diagonal_index_list = self.descending_diagonal_index_map[queen.getDescendingDiagonalIndex()]
descending_diagonal_index_list = self.descending_diagonal_index_map[queen.get_descending_diagonal_index()]
self.score -= len(descending_diagonal_index_list)
descending_diagonal_index_list.append(queen)

Expand All @@ -108,21 +107,21 @@ def retract(self, queen: Queen):
row_index_list = self.row_index_map[row_index]
row_index_list.remove(queen)
self.score += len(row_index_list)
ascending_diagonal_index_list = self.ascending_diagonal_index_map[queen.getAscendingDiagonalIndex()]
ascending_diagonal_index_list = self.ascending_diagonal_index_map[queen.get_ascending_diagonal_index()]
ascending_diagonal_index_list.remove(queen)
self.score += len(ascending_diagonal_index_list)
descending_diagonal_index_list = self.descending_diagonal_index_map[queen.getDescendingDiagonalIndex()]
descending_diagonal_index_list = self.descending_diagonal_index_map[queen.get_descending_diagonal_index()]
descending_diagonal_index_list.remove(queen)
self.score += len(descending_diagonal_index_list)

def calculateScore(self) -> HardSoftScore:
def calculate_score(self) -> HardSoftScore:
return SimpleScore.of(self.score)

solver_config = SolverConfig(
solution_class=Solution,
entity_class_list=[Queen],
score_director_factory_config=ScoreDirectorFactoryConfig(
incremental_score_calculator_class=IncrementalScoreCalculator
incremental_score_calculator_class=NQueensIncrementalScoreCalculator
),
termination_config=TerminationConfig(
best_score_limit='0'
Expand All @@ -141,22 +140,22 @@ def calculateScore(self) -> HardSoftScore:
right_queen = solution.queen_list[j]
assert left_queen.row is not None and right_queen.row is not None
assert left_queen.row != right_queen.row
assert left_queen.getAscendingDiagonalIndex() != right_queen.getAscendingDiagonalIndex()
assert left_queen.getDescendingDiagonalIndex() != right_queen.getDescendingDiagonalIndex()
assert left_queen.get_ascending_diagonal_index() != right_queen.get_ascending_diagonal_index()
assert left_queen.get_descending_diagonal_index() != right_queen.get_descending_diagonal_index()


@pytest.mark.skip(reason="Special case where you want to convert all items of the list before returning."
"Doing this for all conversions would be expensive."
"This feature is not that important, so skipping for now.")
def test_constraint_match_enabled_incremental_score_calculator():
@incremental_score_calculator
class IncrementalScoreCalculator:
class NQueensIncrementalScoreCalculator(ConstraintMatchAwareIncrementalScoreCalculator):
score: int
row_index_map: dict
ascending_diagonal_index_map: dict
descending_diagonal_index_map: dict

def resetWorkingSolution(self, working_solution: Solution, constraint_match_enabled=False):
def reset_working_solution(self, working_solution: Solution, constraint_match_enabled=False):
n = working_solution.n
self.row_index_map = dict()
self.ascending_diagonal_index_map = dict()
Expand All @@ -172,22 +171,22 @@ def resetWorkingSolution(self, working_solution: Solution, constraint_match_enab
for queen in working_solution.queen_list:
self.insert(queen)

def beforeEntityAdded(self, entity: any):
def before_entity_added(self, entity: any):
pass

def afterEntityAdded(self, entity: any):
def after_entity_added(self, entity: any):
self.insert(entity)

def beforeVariableChanged(self, entity: any, variableName: str):
def before_variable_changed(self, entity: any, variableName: str):
self.retract(entity)

def afterVariableChanged(self, entity: any, variableName: str):
def after_variable_changed(self, entity: any, variableName: str):
self.insert(entity)

def beforeEntityRemoved(self, entity: any):
def before_entity_removed(self, entity: any):
self.retract(entity)

def afterEntityRemoved(self, entity: any):
def after_entity_removed(self, entity: any):
pass

def insert(self, queen: Queen):
Expand All @@ -197,10 +196,10 @@ def insert(self, queen: Queen):
row_index_list = self.row_index_map[row_index]
self.score -= len(row_index_list)
row_index_list.append(queen)
ascending_diagonal_index_list = self.ascending_diagonal_index_map[queen.getAscendingDiagonalIndex()]
ascending_diagonal_index_list = self.ascending_diagonal_index_map[queen.get_ascending_diagonal_index()]
self.score -= len(ascending_diagonal_index_list)
ascending_diagonal_index_list.append(queen)
descending_diagonal_index_list = self.descending_diagonal_index_map[queen.getDescendingDiagonalIndex()]
descending_diagonal_index_list = self.descending_diagonal_index_map[queen.get_descending_diagonal_index()]
self.score -= len(descending_diagonal_index_list)
descending_diagonal_index_list.append(queen)

Expand All @@ -211,17 +210,17 @@ def retract(self, queen: Queen):
row_index_list = self.row_index_map[row_index]
row_index_list.remove(queen)
self.score += len(row_index_list)
ascending_diagonal_index_list = self.ascending_diagonal_index_map[queen.getAscendingDiagonalIndex()]
ascending_diagonal_index_list = self.ascending_diagonal_index_map[queen.get_ascending_diagonal_index()]
ascending_diagonal_index_list.remove(queen)
self.score += len(ascending_diagonal_index_list)
descending_diagonal_index_list = self.descending_diagonal_index_map[queen.getDescendingDiagonalIndex()]
descending_diagonal_index_list = self.descending_diagonal_index_map[queen.get_descending_diagonal_index()]
descending_diagonal_index_list.remove(queen)
self.score += len(descending_diagonal_index_list)

def calculateScore(self) -> HardSoftScore:
def calculate_score(self) -> HardSoftScore:
return SimpleScore.of(self.score)

def getConstraintMatchTotals(self):
def get_constraint_match_totals(self):
row_conflict_constraint_match_total = DefaultConstraintMatchTotal(
'NQueens',
'Row Conflict',
Expand Down Expand Up @@ -255,14 +254,14 @@ def getConstraintMatchTotals(self):
descending_diagonal_constraint_match_total
]

def getIndictmentMap(self):
def get_indictment_map(self):
return None

solver_config = SolverConfig(
solution_class=Solution,
entity_class_list=[Queen],
score_director_factory_config=ScoreDirectorFactoryConfig(
incremental_score_calculator_class=IncrementalScoreCalculator
incremental_score_calculator_class=NQueensIncrementalScoreCalculator
),
termination_config=TerminationConfig(
best_score_limit='0'
Expand All @@ -282,8 +281,8 @@ def getIndictmentMap(self):
right_queen = solution.queen_list[j]
assert left_queen.row is not None and right_queen.row is not None
assert left_queen.row != right_queen.row
assert left_queen.getAscendingDiagonalIndex() != right_queen.getAscendingDiagonalIndex()
assert left_queen.getDescendingDiagonalIndex() != right_queen.getDescendingDiagonalIndex()
assert left_queen.get_ascending_diagonal_index() != right_queen.get_ascending_diagonal_index()
assert left_queen.get_descending_diagonal_index() != right_queen.get_descending_diagonal_index()

score_manager = SolutionManager.create(solver_factory)
constraint_match_total_map = score_manager.explain(solution).constraint_match_total_map
Expand Down Expand Up @@ -315,21 +314,19 @@ def getIndictmentMap(self):


def test_error_message_for_missing_methods():
with pytest.raises(ValueError, match=(
f"The following required methods are missing from @incremental_score_calculator class "
f".*IncrementalScoreCalculatorMissingMethods.*: "
f"\\['resetWorkingSolution', 'beforeEntityRemoved', 'afterEntityRemoved', 'calculateScore'\\]"
)):
with pytest.raises(TypeError): # Exact error message from ABC changes between versions
@incremental_score_calculator
class IncrementalScoreCalculatorMissingMethods:
def beforeEntityAdded(self, entity: any):
class IncrementalScoreCalculatorMissingMethods(IncrementalScoreCalculator):
def before_entity_added(self, entity):
pass

def afterEntityAdded(self, entity: any):
def after_entity_added(self, entity):
pass

def beforeVariableChanged(self, entity: any, variableName: str):
def before_variable_changed(self, entity, variable_name: str):
pass

def afterVariableChanged(self, entity: any, variableName: str):
def after_variable_changed(self, entity, variable_name: str):
pass

score_calculator = IncrementalScoreCalculatorMissingMethods()
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from ._solution_manager import *
from ._score_director import *
from ._variable_listener import *
from ._incremental_score_calculator import *
Loading

0 comments on commit 119c144

Please sign in to comment.