From d6ad7a164dfaba540435a6ee95386570b79bae15 Mon Sep 17 00:00:00 2001 From: Christopher Chianelli Date: Mon, 8 Apr 2024 08:34:47 -0400 Subject: [PATCH] fix: Generate getters/setters that replace null with None (or vice versa) (#29) - Timefold considers None to be initialized, although it represents an uninitialized value - Instead of annotating the fields directly, we could annotate getters/setters instead - If None is assignable to the field type, the getter is `return this.field != None? this.field : null`; otherwise it's `return this.field` - If None is assignable to the field type, the setter is `this.field =(value != null)? value : None`; otherwise it's `this.field = value` - Assign the `__class__` field of a new class to itself to match CPython. --- .../jpyinterpreter/PythonClassTranslator.java | 121 +++++++++++++++--- .../jpyinterpreter/types/PythonLikeType.java | 4 +- jpyinterpreter/tests/test_classes.py | 2 +- tests/test_custom_shadow_variables.py | 26 ++-- tests/test_incremental_score_calculator.py | 43 +++---- 5 files changed, 133 insertions(+), 63 deletions(-) diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java index 5d08569..b0608cb 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java @@ -23,6 +23,7 @@ import ai.timefold.jpyinterpreter.implementors.JavaEqualsImplementor; import ai.timefold.jpyinterpreter.implementors.JavaHashCodeImplementor; import ai.timefold.jpyinterpreter.implementors.JavaInterfaceImplementor; +import ai.timefold.jpyinterpreter.implementors.PythonConstantsImplementor; import ai.timefold.jpyinterpreter.opcodes.AbstractOpcode; import ai.timefold.jpyinterpreter.opcodes.Opcode; import ai.timefold.jpyinterpreter.opcodes.SelfOpcodeWithoutSource; @@ -205,6 +206,7 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp attributeNameToTypeMap.put(attributeName, type); FieldVisitor fieldVisitor; String javaFieldTypeDescriptor; + String signature = null; boolean isJavaType; if (type.getJavaTypeInternalName().equals(Type.getInternalName(JavaObjectWrapper.class))) { javaFieldTypeDescriptor = Type.getDescriptor(type.getJavaObjectWrapperType()); @@ -212,7 +214,6 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp null, null); isJavaType = true; } else { - String signature = null; if (typeHint.genericArgs() != null) { var signatureWriter = new SignatureWriter(); visitSignature(typeHint, signatureWriter); @@ -223,10 +224,12 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp signature, null); isJavaType = false; } - for (var annotation : typeHint.annotationList()) { - annotation.addAnnotationTo(fieldVisitor); - } fieldVisitor.visitEnd(); + createJavaGetterSetter(classWriter, preparedClassInfo, + attributeName, + Type.getType(javaFieldTypeDescriptor), + signature, + typeHint); FieldDescriptor fieldDescriptor = new FieldDescriptor(attributeName, getJavaFieldName(attributeName), internalClassName, javaFieldTypeDescriptor, type, true, isJavaType); @@ -761,6 +764,85 @@ private static PythonLikeFunction createConstructor(String classInternalName, } } + private static void createJavaGetterSetter(ClassWriter classWriter, + PreparedClassInfo preparedClassInfo, + String attributeName, Type attributeType, + String signature, + TypeHint typeHint) { + createJavaGetter(classWriter, preparedClassInfo, attributeName, attributeType, signature, typeHint); + createJavaSetter(classWriter, preparedClassInfo, attributeName, attributeType, signature, typeHint); + } + + private static void createJavaGetter(ClassWriter classWriter, PreparedClassInfo preparedClassInfo, String attributeName, + Type attributeType, String signature, TypeHint typeHint) { + var getterName = "get" + attributeName.substring(0, 1).toUpperCase() + attributeName.substring(1); + if (signature != null) { + signature = "()" + signature; + } + var getterVisitor = classWriter.visitMethod(Modifier.PUBLIC, getterName, Type.getMethodDescriptor(attributeType), + signature, null); + var maxStack = 1; + + for (var annotation : typeHint.annotationList()) { + annotation.addAnnotationTo(getterVisitor); + } + + getterVisitor.visitCode(); + getterVisitor.visitVarInsn(Opcodes.ALOAD, 0); + getterVisitor.visitFieldInsn(Opcodes.GETFIELD, preparedClassInfo.classInternalName, + attributeName, attributeType.getDescriptor()); + if (typeHint.type().isInstance(PythonNone.INSTANCE)) { + maxStack = 3; + getterVisitor.visitInsn(Opcodes.DUP); + PythonConstantsImplementor.loadNone(getterVisitor); + Label returnLabel = new Label(); + getterVisitor.visitJumpInsn(Opcodes.IF_ACMPNE, returnLabel); + // field is None, so we want Java to see it as null + getterVisitor.visitInsn(Opcodes.POP); + getterVisitor.visitInsn(Opcodes.ACONST_NULL); + getterVisitor.visitLabel(returnLabel); + // If branch is taken, stack is field + // If branch is not taken, stack is null + } + getterVisitor.visitInsn(Opcodes.ARETURN); + getterVisitor.visitMaxs(maxStack, 0); + getterVisitor.visitEnd(); + } + + private static void createJavaSetter(ClassWriter classWriter, PreparedClassInfo preparedClassInfo, String attributeName, + Type attributeType, String signature, TypeHint typeHint) { + var setterName = "set" + attributeName.substring(0, 1).toUpperCase() + attributeName.substring(1); + if (signature != null) { + signature = "(" + signature + ")V"; + } + var setterVisitor = classWriter.visitMethod(Modifier.PUBLIC, setterName, Type.getMethodDescriptor(Type.VOID_TYPE, + attributeType), + signature, null); + var maxStack = 2; + setterVisitor.visitCode(); + setterVisitor.visitVarInsn(Opcodes.ALOAD, 0); + setterVisitor.visitVarInsn(Opcodes.ALOAD, 1); + if (typeHint.type().isInstance(PythonNone.INSTANCE)) { + maxStack = 4; + // We want to replace null with None + setterVisitor.visitInsn(Opcodes.DUP); + setterVisitor.visitInsn(Opcodes.ACONST_NULL); + Label setFieldLabel = new Label(); + setterVisitor.visitJumpInsn(Opcodes.IF_ACMPNE, setFieldLabel); + // set value is null, so we want Python to see it as None + setterVisitor.visitInsn(Opcodes.POP); + PythonConstantsImplementor.loadNone(setterVisitor); + setterVisitor.visitLabel(setFieldLabel); + // If branch is taken, stack is (non-null instance) + // If branch is not taken, stack is None + } + setterVisitor.visitFieldInsn(Opcodes.PUTFIELD, preparedClassInfo.classInternalName, + attributeName, attributeType.getDescriptor()); + setterVisitor.visitInsn(Opcodes.RETURN); + setterVisitor.visitMaxs(maxStack, 0); + setterVisitor.visitEnd(); + } + private static void addAnnotationsToMethod(PythonCompiledFunction function, MethodVisitor methodVisitor) { var returnTypeHint = function.typeAnnotations.get("return"); if (returnTypeHint != null) { @@ -956,15 +1038,9 @@ public static void createGetAttribute(ClassWriter classWriter, String classInter methodVisitor.visitVarInsn(Opcodes.ALOAD, 0); var type = fieldToType.get(field); if (type.getJavaTypeInternalName().equals(Type.getInternalName(JavaObjectWrapper.class))) { - Class fieldType = type.getJavaObjectWrapperType(); methodVisitor.visitFieldInsn(Opcodes.GETFIELD, classInternalName, getJavaFieldName(field), - Type.getDescriptor(fieldType)); - methodVisitor.visitTypeInsn(Opcodes.NEW, Type.getInternalName(JavaObjectWrapper.class)); - methodVisitor.visitInsn(Opcodes.DUP_X1); - methodVisitor.visitInsn(Opcodes.DUP_X1); - methodVisitor.visitInsn(Opcodes.POP); - methodVisitor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(JavaObjectWrapper.class), - "", Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(Object.class)), false); + Type.getDescriptor(type.getJavaObjectWrapperType())); + getWrappedJavaObject(methodVisitor); } else { methodVisitor.visitFieldInsn(Opcodes.GETFIELD, classInternalName, getJavaFieldName(field), 'L' + type.getJavaTypeInternalName() + ';'); @@ -984,6 +1060,15 @@ public static void createGetAttribute(ClassWriter classWriter, String classInter methodVisitor.visitEnd(); } + private static void getWrappedJavaObject(MethodVisitor methodVisitor) { + methodVisitor.visitTypeInsn(Opcodes.NEW, Type.getInternalName(JavaObjectWrapper.class)); + methodVisitor.visitInsn(Opcodes.DUP_X1); + methodVisitor.visitInsn(Opcodes.DUP_X1); + methodVisitor.visitInsn(Opcodes.POP); + methodVisitor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(JavaObjectWrapper.class), + "", Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(Object.class)), false); + } + public static void createSetAttribute(ClassWriter classWriter, String classInternalName, String superInternalName, Collection instanceAttributes, Map fieldToType) { @@ -1008,10 +1093,7 @@ public static void createSetAttribute(ClassWriter classWriter, String classInter String typeDescriptor = type.getJavaTypeDescriptor(); if (type.getJavaTypeInternalName().equals(Type.getInternalName(JavaObjectWrapper.class))) { // Need to unwrap the object - methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(JavaObjectWrapper.class)); - methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(JavaObjectWrapper.class), - "getWrappedObject", Type.getMethodDescriptor(Type.getType(Object.class)), false); - methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getType(type.getJavaObjectWrapperType()).getInternalName()); + getUnwrappedJavaObject(methodVisitor, type); typeDescriptor = Type.getDescriptor(type.getJavaObjectWrapperType()); } else { methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type.getJavaTypeInternalName()); @@ -1035,6 +1117,13 @@ public static void createSetAttribute(ClassWriter classWriter, String classInter methodVisitor.visitEnd(); } + private static void getUnwrappedJavaObject(MethodVisitor methodVisitor, PythonLikeType type) { + methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(JavaObjectWrapper.class)); + methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(JavaObjectWrapper.class), + "getWrappedObject", Type.getMethodDescriptor(Type.getType(Object.class)), false); + methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getType(type.getJavaObjectWrapperType()).getInternalName()); + } + public static void createDeleteAttribute(ClassWriter classWriter, String classInternalName, String superInternalName, Collection instanceAttributes, Map fieldToType) { diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonLikeType.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonLikeType.java index 34179bb..5f94bd4 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonLikeType.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/types/PythonLikeType.java @@ -107,7 +107,9 @@ private PythonLikeType(String typeName, String internalName) { } public static PythonLikeType getTypeForNewClass(String typeName, String internalName) { - return new PythonLikeType(typeName, internalName); + var out = new PythonLikeType(typeName, internalName); + out.__dir__.put("__class__", out); + return out; } public void initializeNewType(List superClassTypes) { diff --git a/jpyinterpreter/tests/test_classes.py b/jpyinterpreter/tests/test_classes.py index f620819..ec963f3 100644 --- a/jpyinterpreter/tests/test_classes.py +++ b/jpyinterpreter/tests/test_classes.py @@ -902,7 +902,7 @@ def my_method(self) -> Annotated[str, 'extra', JavaAnnotation(Deprecated, { assert annotations[0].forRemoval() assert annotations[0].since() == '0.0.0' - annotations = translated_class.getField('my_field').getAnnotations() + annotations = translated_class.getMethod('getMy_field').getAnnotations() assert len(annotations) == 2 assert isinstance(annotations[0], Deprecated) assert annotations[0].forRemoval() diff --git a/tests/test_custom_shadow_variables.py b/tests/test_custom_shadow_variables.py index 9761944..b092ab2 100644 --- a/tests/test_custom_shadow_variables.py +++ b/tests/test_custom_shadow_variables.py @@ -48,11 +48,8 @@ class MyPlanningEntity: def my_constraints(constraint_factory: timefold.solver.constraint.ConstraintFactory): return [ constraint_factory.for_each(MyPlanningEntity) - .filter(lambda entity: entity.value is None) - .penalize('Unassigned value', timefold.solver.score.HardSoftScore.ONE_HARD), - constraint_factory.for_each(MyPlanningEntity) - .filter(lambda entity: entity.value is not None and entity.value * 2 == entity.value_squared) - .reward('Double value is value squared', timefold.solver.score.HardSoftScore.ONE_SOFT) + .filter(lambda entity: entity.value * 2 == entity.value_squared) + .reward('Double value is value squared', timefold.solver.score.SimpleScore.ONE) ] @timefold.solver.planning_solution @@ -60,7 +57,7 @@ def my_constraints(constraint_factory: timefold.solver.constraint.ConstraintFact class MySolution: entity_list: Annotated[List[MyPlanningEntity], timefold.solver.PlanningEntityCollectionProperty] value_list: Annotated[List[int], timefold.solver.ValueRangeProvider] - score: Annotated[timefold.solver.score.HardSoftScore, timefold.solver.PlanningScore] = field(default=None) + score: Annotated[timefold.solver.score.SimpleScore, timefold.solver.PlanningScore] = field(default=None) solver_config = timefold.solver.config.SolverConfig( solution_class=MySolution, @@ -69,7 +66,7 @@ class MySolution: constraint_provider_function=my_constraints ), termination_config=timefold.solver.config.TerminationConfig( - best_score_limit='0hard/1soft' + best_score_limit='1' ) ) @@ -77,8 +74,7 @@ class MySolution: solver = solver_factory.build_solver() problem = MySolution([MyPlanningEntity()], [1, 2, 3]) solution: MySolution = solver.solve(problem) - assert solution.score.hard_score() == 0 - assert solution.score.soft_score() == 1 + assert solution.score.score() == 1 assert solution.entity_list[0].value == 2 assert solution.entity_list[0].value_squared == 4 @@ -127,12 +123,9 @@ class MyPlanningEntity: @timefold.solver.constraint_provider def my_constraints(constraint_factory: timefold.solver.constraint.ConstraintFactory): return [ - constraint_factory.for_each(MyPlanningEntity) - .filter(lambda entity: entity.value is None) - .penalize('Unassigned value', timefold.solver.score.HardSoftScore.ONE_HARD), constraint_factory.for_each(MyPlanningEntity) .filter(lambda entity: entity.twice_value == entity.value_squared) - .reward('Double value is value squared', timefold.solver.score.HardSoftScore.ONE_SOFT) + .reward('Double value is value squared', timefold.solver.score.SimpleScore.ONE) ] @timefold.solver.planning_solution @@ -140,7 +133,7 @@ def my_constraints(constraint_factory: timefold.solver.constraint.ConstraintFact class MySolution: entity_list: Annotated[List[MyPlanningEntity], PlanningEntityCollectionProperty] value_list: Annotated[List[int], ValueRangeProvider] - score: Annotated[timefold.solver.score.HardSoftScore, PlanningScore] = field(default=None) + score: Annotated[timefold.solver.score.SimpleScore, PlanningScore] = field(default=None) solver_config = timefold.solver.config.SolverConfig( solution_class=MySolution, @@ -149,7 +142,7 @@ class MySolution: constraint_provider_function=my_constraints ), termination_config=timefold.solver.config.TerminationConfig( - best_score_limit='0hard/1soft' + best_score_limit='1' ) ) @@ -157,7 +150,6 @@ class MySolution: solver = solver_factory.build_solver() problem = MySolution([MyPlanningEntity()], [1, 2, 3]) solution: MySolution = solver.solve(problem) - assert solution.score.hard_score() == 0 - assert solution.score.soft_score() == 1 + assert solution.score.score() == 1 assert solution.entity_list[0].value == 2 assert solution.entity_list[0].value_squared == 4 diff --git a/tests/test_incremental_score_calculator.py b/tests/test_incremental_score_calculator.py index 57e686a..abd0897 100644 --- a/tests/test_incremental_score_calculator.py +++ b/tests/test_incremental_score_calculator.py @@ -43,12 +43,12 @@ class Solution: queen_list: Annotated[List[Queen], timefold.solver.PlanningEntityCollectionProperty] column_list: List[int] row_list: Annotated[List[int], timefold.solver.ValueRangeProvider] - score: Annotated[timefold.solver.score.HardSoftScore, timefold.solver.PlanningScore] = field(default=None) + score: Annotated[timefold.solver.score.SimpleScore, timefold.solver.PlanningScore] = field(default=None) + def test_constraint_match_disabled_incremental_score_calculator(): @timefold.solver.incremental_score_calculator class IncrementalScoreCalculator: - uninit_score: int # Workaround, since None is considered initialized score: int row_index_map: dict ascending_diagonal_index_map: dict @@ -67,7 +67,6 @@ def resetWorkingSolution(self, working_solution: Solution): self.ascending_diagonal_index_map[n - 1 + i] = list() self.descending_diagonal_index_map[-i] = list() self.score = 0 - self.uninit_score = 0 for queen in working_solution.queen_list: self.insert(queen) @@ -101,8 +100,6 @@ def insert(self, queen: Queen): descending_diagonal_index_list = self.descending_diagonal_index_map[queen.getDescendingDiagonalIndex()] self.score -= len(descending_diagonal_index_list) descending_diagonal_index_list.append(queen) - else: - self.uninit_score -= 1 def retract(self, queen: Queen): if queen.row is not None: @@ -116,11 +113,9 @@ def retract(self, queen: Queen): descending_diagonal_index_list = self.descending_diagonal_index_map[queen.getDescendingDiagonalIndex()] descending_diagonal_index_list.remove(queen) self.score += len(descending_diagonal_index_list) - else: - self.uninit_score += 1 def calculateScore(self) -> timefold.solver.score.HardSoftScore: - return timefold.solver.score.HardSoftScore.of(self.uninit_score, self.score) + return timefold.solver.score.SimpleScore.of(self.score) solver_config = timefold.solver.config.SolverConfig( solution_class=Solution, @@ -129,7 +124,7 @@ def calculateScore(self) -> timefold.solver.score.HardSoftScore: incremental_score_calculator_class=IncrementalScoreCalculator ), termination_config=timefold.solver.config.TerminationConfig( - best_score_limit='0hard/0soft' + best_score_limit='0' ) ) problem: Solution = Solution(4, @@ -138,8 +133,7 @@ def calculateScore(self) -> timefold.solver.score.HardSoftScore: [0, 1, 2, 3]) solver = timefold.solver.SolverFactory.create(solver_config).build_solver() solution = solver.solve(problem) - assert solution.score.hard_score() == 0 - assert solution.score.soft_score() == 0 + assert solution.score.score() == 0 for i in range(4): for j in range(i + 1, 4): left_queen = solution.queen_list[i] @@ -157,7 +151,6 @@ def test_constraint_match_enabled_incremental_score_calculator(): @timefold.solver.incremental_score_calculator class IncrementalScoreCalculator: score: int - uninit_score: int row_index_map: dict ascending_diagonal_index_map: dict descending_diagonal_index_map: dict @@ -175,7 +168,6 @@ def resetWorkingSolution(self, working_solution: Solution, constraint_match_enab self.ascending_diagonal_index_map[n - 1 + i] = list() self.descending_diagonal_index_map[-i] = list() self.score = 0 - self.uninit_score = 0 for queen in working_solution.queen_list: self.insert(queen) @@ -210,8 +202,6 @@ def insert(self, queen: Queen): descending_diagonal_index_list = self.descending_diagonal_index_map[queen.getDescendingDiagonalIndex()] self.score -= len(descending_diagonal_index_list) descending_diagonal_index_list.append(queen) - else: - self.uninit_score -= 1 def retract(self, queen: Queen): row = queen.row @@ -226,11 +216,9 @@ def retract(self, queen: Queen): descending_diagonal_index_list = self.descending_diagonal_index_map[queen.getDescendingDiagonalIndex()] descending_diagonal_index_list.remove(queen) self.score += len(descending_diagonal_index_list) - else: - self.uninit_score += 1 def calculateScore(self) -> timefold.solver.score.HardSoftScore: - return timefold.solver.score.HardSoftScore.of(self.uninit_score, self.score) + return timefold.solver.score.SimpleScore.of(self.score) def getConstraintMatchTotals(self): row_conflict_constraint_match_total = timefold.solver.constraint.DefaultConstraintMatchTotal( @@ -276,7 +264,7 @@ def getIndictmentMap(self): incremental_score_calculator_class=IncrementalScoreCalculator ), termination_config=timefold.solver.config.TerminationConfig( - best_score_limit='0hard/0soft' + best_score_limit='0' ) ) problem: Solution = Solution(4, @@ -286,8 +274,7 @@ def getIndictmentMap(self): solver_factory = timefold.solver.SolverFactory.create(solver_config) solver = solver_factory.build_solver() solution = solver.solve(problem) - assert solution.score.hard_score() == 0 - assert solution.score.soft_score() == 0 + assert solution.score.score() == 0 for i in range(4): for j in range(i + 1, 4): left_queen = solution.queen_list[i] @@ -302,23 +289,23 @@ def getIndictmentMap(self): row_conflict = constraint_match_total_map.get('NQueens/Row Conflict') ascending_diagonal_conflict = constraint_match_total_map.get('NQueens/Ascending Diagonal Conflict') descending_diagonal_conflict = constraint_match_total_map.get('NQueens/Descending Diagonal Conflict') - assert row_conflict.score().soft_score() == 0 - assert ascending_diagonal_conflict.score().soft_score() == 0 - assert descending_diagonal_conflict.score().soft_score() == 0 + assert row_conflict.score().score() == 0 + assert ascending_diagonal_conflict.score().score() == 0 + assert descending_diagonal_conflict.score().score() == 0 bad_solution = Solution(4, [Queen('A', 0, 0), Queen('B', 1, 1), Queen('C', 2, 0), Queen('D', 3, 1)], [0, 1, 2, 3], [0, 1, 2, 3]) score_explanation = score_manager.explain(bad_solution) - assert score_explanation.get_score().soft_score() == -5 + assert score_explanation.get_score().score() == -5 constraint_match_total_map = score_explanation.getConstraintMatchTotalMap() row_conflict = constraint_match_total_map.get('NQueens/Row Conflict') ascending_diagonal_conflict = constraint_match_total_map.get('NQueens/Ascending Diagonal Conflict') descending_diagonal_conflict = constraint_match_total_map.get('NQueens/Descending Diagonal Conflict') - assert row_conflict.score().soft_score() == -2 # (A, C), (B, D) - assert ascending_diagonal_conflict.score().soft_score() == -1 # (B, C) - assert descending_diagonal_conflict.score().soft_score() == -2 # (A, B), (C, D) + assert row_conflict.score().score() == -2 # (A, C), (B, D) + assert ascending_diagonal_conflict.score().score() == -1 # (B, C) + assert descending_diagonal_conflict.score().score() == -2 # (A, B), (C, D) indictment_map = score_explanation.getIndictmentMap() assert indictment_map.get(bad_solution.queen_list[0]).getConstraintMatchCount() == 2 assert indictment_map.get(bad_solution.queen_list[1]).getConstraintMatchCount() == 3