Skip to content
This repository has been archived by the owner on Jul 17, 2024. It is now read-only.

Commit

Permalink
fix: Generate getters/setters that replace null with None (or vice ve…
Browse files Browse the repository at this point in the history
…rsa) (#29)

- Timefold considers None to be initialized, although it represents
  an uninitialized value

- Instead of annotating the fields directly, we could annotate
  getters/setters instead

- If None is assignable to the field type, the getter is
  `return this.field != None? this.field : null`;
  otherwise it's `return this.field`

- If None is assignable to the field type, the setter is
  `this.field =(value != null)? value : None`;
  otherwise it's `this.field = value`

- Assign the `__class__` field of a new class to itself to
  match CPython.
  • Loading branch information
Christopher-Chianelli authored Apr 8, 2024
1 parent 850ff44 commit d6ad7a1
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import ai.timefold.jpyinterpreter.implementors.JavaEqualsImplementor;
import ai.timefold.jpyinterpreter.implementors.JavaHashCodeImplementor;
import ai.timefold.jpyinterpreter.implementors.JavaInterfaceImplementor;
import ai.timefold.jpyinterpreter.implementors.PythonConstantsImplementor;
import ai.timefold.jpyinterpreter.opcodes.AbstractOpcode;
import ai.timefold.jpyinterpreter.opcodes.Opcode;
import ai.timefold.jpyinterpreter.opcodes.SelfOpcodeWithoutSource;
Expand Down Expand Up @@ -205,14 +206,14 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp
attributeNameToTypeMap.put(attributeName, type);
FieldVisitor fieldVisitor;
String javaFieldTypeDescriptor;
String signature = null;
boolean isJavaType;
if (type.getJavaTypeInternalName().equals(Type.getInternalName(JavaObjectWrapper.class))) {
javaFieldTypeDescriptor = Type.getDescriptor(type.getJavaObjectWrapperType());
fieldVisitor = classWriter.visitField(Modifier.PUBLIC, getJavaFieldName(attributeName), javaFieldTypeDescriptor,
null, null);
isJavaType = true;
} else {
String signature = null;
if (typeHint.genericArgs() != null) {
var signatureWriter = new SignatureWriter();
visitSignature(typeHint, signatureWriter);
Expand All @@ -223,10 +224,12 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp
signature, null);
isJavaType = false;
}
for (var annotation : typeHint.annotationList()) {
annotation.addAnnotationTo(fieldVisitor);
}
fieldVisitor.visitEnd();
createJavaGetterSetter(classWriter, preparedClassInfo,
attributeName,
Type.getType(javaFieldTypeDescriptor),
signature,
typeHint);
FieldDescriptor fieldDescriptor =
new FieldDescriptor(attributeName, getJavaFieldName(attributeName), internalClassName,
javaFieldTypeDescriptor, type, true, isJavaType);
Expand Down Expand Up @@ -761,6 +764,85 @@ private static PythonLikeFunction createConstructor(String classInternalName,
}
}

private static void createJavaGetterSetter(ClassWriter classWriter,
PreparedClassInfo preparedClassInfo,
String attributeName, Type attributeType,
String signature,
TypeHint typeHint) {
createJavaGetter(classWriter, preparedClassInfo, attributeName, attributeType, signature, typeHint);
createJavaSetter(classWriter, preparedClassInfo, attributeName, attributeType, signature, typeHint);
}

private static void createJavaGetter(ClassWriter classWriter, PreparedClassInfo preparedClassInfo, String attributeName,
Type attributeType, String signature, TypeHint typeHint) {
var getterName = "get" + attributeName.substring(0, 1).toUpperCase() + attributeName.substring(1);
if (signature != null) {
signature = "()" + signature;
}
var getterVisitor = classWriter.visitMethod(Modifier.PUBLIC, getterName, Type.getMethodDescriptor(attributeType),
signature, null);
var maxStack = 1;

for (var annotation : typeHint.annotationList()) {
annotation.addAnnotationTo(getterVisitor);
}

getterVisitor.visitCode();
getterVisitor.visitVarInsn(Opcodes.ALOAD, 0);
getterVisitor.visitFieldInsn(Opcodes.GETFIELD, preparedClassInfo.classInternalName,
attributeName, attributeType.getDescriptor());
if (typeHint.type().isInstance(PythonNone.INSTANCE)) {
maxStack = 3;
getterVisitor.visitInsn(Opcodes.DUP);
PythonConstantsImplementor.loadNone(getterVisitor);
Label returnLabel = new Label();
getterVisitor.visitJumpInsn(Opcodes.IF_ACMPNE, returnLabel);
// field is None, so we want Java to see it as null
getterVisitor.visitInsn(Opcodes.POP);
getterVisitor.visitInsn(Opcodes.ACONST_NULL);
getterVisitor.visitLabel(returnLabel);
// If branch is taken, stack is field
// If branch is not taken, stack is null
}
getterVisitor.visitInsn(Opcodes.ARETURN);
getterVisitor.visitMaxs(maxStack, 0);
getterVisitor.visitEnd();
}

private static void createJavaSetter(ClassWriter classWriter, PreparedClassInfo preparedClassInfo, String attributeName,
Type attributeType, String signature, TypeHint typeHint) {
var setterName = "set" + attributeName.substring(0, 1).toUpperCase() + attributeName.substring(1);
if (signature != null) {
signature = "(" + signature + ")V";
}
var setterVisitor = classWriter.visitMethod(Modifier.PUBLIC, setterName, Type.getMethodDescriptor(Type.VOID_TYPE,
attributeType),
signature, null);
var maxStack = 2;
setterVisitor.visitCode();
setterVisitor.visitVarInsn(Opcodes.ALOAD, 0);
setterVisitor.visitVarInsn(Opcodes.ALOAD, 1);
if (typeHint.type().isInstance(PythonNone.INSTANCE)) {
maxStack = 4;
// We want to replace null with None
setterVisitor.visitInsn(Opcodes.DUP);
setterVisitor.visitInsn(Opcodes.ACONST_NULL);
Label setFieldLabel = new Label();
setterVisitor.visitJumpInsn(Opcodes.IF_ACMPNE, setFieldLabel);
// set value is null, so we want Python to see it as None
setterVisitor.visitInsn(Opcodes.POP);
PythonConstantsImplementor.loadNone(setterVisitor);
setterVisitor.visitLabel(setFieldLabel);
// If branch is taken, stack is (non-null instance)
// If branch is not taken, stack is None
}
setterVisitor.visitFieldInsn(Opcodes.PUTFIELD, preparedClassInfo.classInternalName,
attributeName, attributeType.getDescriptor());
setterVisitor.visitInsn(Opcodes.RETURN);
setterVisitor.visitMaxs(maxStack, 0);
setterVisitor.visitEnd();
}

private static void addAnnotationsToMethod(PythonCompiledFunction function, MethodVisitor methodVisitor) {
var returnTypeHint = function.typeAnnotations.get("return");
if (returnTypeHint != null) {
Expand Down Expand Up @@ -956,15 +1038,9 @@ public static void createGetAttribute(ClassWriter classWriter, String classInter
methodVisitor.visitVarInsn(Opcodes.ALOAD, 0);
var type = fieldToType.get(field);
if (type.getJavaTypeInternalName().equals(Type.getInternalName(JavaObjectWrapper.class))) {
Class<?> fieldType = type.getJavaObjectWrapperType();
methodVisitor.visitFieldInsn(Opcodes.GETFIELD, classInternalName, getJavaFieldName(field),
Type.getDescriptor(fieldType));
methodVisitor.visitTypeInsn(Opcodes.NEW, Type.getInternalName(JavaObjectWrapper.class));
methodVisitor.visitInsn(Opcodes.DUP_X1);
methodVisitor.visitInsn(Opcodes.DUP_X1);
methodVisitor.visitInsn(Opcodes.POP);
methodVisitor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(JavaObjectWrapper.class),
"<init>", Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(Object.class)), false);
Type.getDescriptor(type.getJavaObjectWrapperType()));
getWrappedJavaObject(methodVisitor);
} else {
methodVisitor.visitFieldInsn(Opcodes.GETFIELD, classInternalName, getJavaFieldName(field),
'L' + type.getJavaTypeInternalName() + ';');
Expand All @@ -984,6 +1060,15 @@ public static void createGetAttribute(ClassWriter classWriter, String classInter
methodVisitor.visitEnd();
}

private static void getWrappedJavaObject(MethodVisitor methodVisitor) {
methodVisitor.visitTypeInsn(Opcodes.NEW, Type.getInternalName(JavaObjectWrapper.class));
methodVisitor.visitInsn(Opcodes.DUP_X1);
methodVisitor.visitInsn(Opcodes.DUP_X1);
methodVisitor.visitInsn(Opcodes.POP);
methodVisitor.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(JavaObjectWrapper.class),
"<init>", Type.getMethodDescriptor(Type.VOID_TYPE, Type.getType(Object.class)), false);
}

public static void createSetAttribute(ClassWriter classWriter, String classInternalName, String superInternalName,
Collection<String> instanceAttributes,
Map<String, PythonLikeType> fieldToType) {
Expand All @@ -1008,10 +1093,7 @@ public static void createSetAttribute(ClassWriter classWriter, String classInter
String typeDescriptor = type.getJavaTypeDescriptor();
if (type.getJavaTypeInternalName().equals(Type.getInternalName(JavaObjectWrapper.class))) {
// Need to unwrap the object
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(JavaObjectWrapper.class));
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(JavaObjectWrapper.class),
"getWrappedObject", Type.getMethodDescriptor(Type.getType(Object.class)), false);
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getType(type.getJavaObjectWrapperType()).getInternalName());
getUnwrappedJavaObject(methodVisitor, type);
typeDescriptor = Type.getDescriptor(type.getJavaObjectWrapperType());
} else {
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, type.getJavaTypeInternalName());
Expand All @@ -1035,6 +1117,13 @@ public static void createSetAttribute(ClassWriter classWriter, String classInter
methodVisitor.visitEnd();
}

private static void getUnwrappedJavaObject(MethodVisitor methodVisitor, PythonLikeType type) {
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(JavaObjectWrapper.class));
methodVisitor.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(JavaObjectWrapper.class),
"getWrappedObject", Type.getMethodDescriptor(Type.getType(Object.class)), false);
methodVisitor.visitTypeInsn(Opcodes.CHECKCAST, Type.getType(type.getJavaObjectWrapperType()).getInternalName());
}

public static void createDeleteAttribute(ClassWriter classWriter, String classInternalName, String superInternalName,
Collection<String> instanceAttributes,
Map<String, PythonLikeType> fieldToType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ private PythonLikeType(String typeName, String internalName) {
}

public static PythonLikeType getTypeForNewClass(String typeName, String internalName) {
return new PythonLikeType(typeName, internalName);
var out = new PythonLikeType(typeName, internalName);
out.__dir__.put("__class__", out);
return out;
}

public void initializeNewType(List<PythonLikeType> superClassTypes) {
Expand Down
2 changes: 1 addition & 1 deletion jpyinterpreter/tests/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ def my_method(self) -> Annotated[str, 'extra', JavaAnnotation(Deprecated, {
assert annotations[0].forRemoval()
assert annotations[0].since() == '0.0.0'

annotations = translated_class.getField('my_field').getAnnotations()
annotations = translated_class.getMethod('getMy_field').getAnnotations()
assert len(annotations) == 2
assert isinstance(annotations[0], Deprecated)
assert annotations[0].forRemoval()
Expand Down
26 changes: 9 additions & 17 deletions tests/test_custom_shadow_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,16 @@ class MyPlanningEntity:
def my_constraints(constraint_factory: timefold.solver.constraint.ConstraintFactory):
return [
constraint_factory.for_each(MyPlanningEntity)
.filter(lambda entity: entity.value is None)
.penalize('Unassigned value', timefold.solver.score.HardSoftScore.ONE_HARD),
constraint_factory.for_each(MyPlanningEntity)
.filter(lambda entity: entity.value is not None and entity.value * 2 == entity.value_squared)
.reward('Double value is value squared', timefold.solver.score.HardSoftScore.ONE_SOFT)
.filter(lambda entity: entity.value * 2 == entity.value_squared)
.reward('Double value is value squared', timefold.solver.score.SimpleScore.ONE)
]

@timefold.solver.planning_solution
@dataclass
class MySolution:
entity_list: Annotated[List[MyPlanningEntity], timefold.solver.PlanningEntityCollectionProperty]
value_list: Annotated[List[int], timefold.solver.ValueRangeProvider]
score: Annotated[timefold.solver.score.HardSoftScore, timefold.solver.PlanningScore] = field(default=None)
score: Annotated[timefold.solver.score.SimpleScore, timefold.solver.PlanningScore] = field(default=None)

solver_config = timefold.solver.config.SolverConfig(
solution_class=MySolution,
Expand All @@ -69,16 +66,15 @@ class MySolution:
constraint_provider_function=my_constraints
),
termination_config=timefold.solver.config.TerminationConfig(
best_score_limit='0hard/1soft'
best_score_limit='1'
)
)

solver_factory = timefold.solver.SolverFactory.create(solver_config)
solver = solver_factory.build_solver()
problem = MySolution([MyPlanningEntity()], [1, 2, 3])
solution: MySolution = solver.solve(problem)
assert solution.score.hard_score() == 0
assert solution.score.soft_score() == 1
assert solution.score.score() == 1
assert solution.entity_list[0].value == 2
assert solution.entity_list[0].value_squared == 4

Expand Down Expand Up @@ -127,20 +123,17 @@ class MyPlanningEntity:
@timefold.solver.constraint_provider
def my_constraints(constraint_factory: timefold.solver.constraint.ConstraintFactory):
return [
constraint_factory.for_each(MyPlanningEntity)
.filter(lambda entity: entity.value is None)
.penalize('Unassigned value', timefold.solver.score.HardSoftScore.ONE_HARD),
constraint_factory.for_each(MyPlanningEntity)
.filter(lambda entity: entity.twice_value == entity.value_squared)
.reward('Double value is value squared', timefold.solver.score.HardSoftScore.ONE_SOFT)
.reward('Double value is value squared', timefold.solver.score.SimpleScore.ONE)
]

@timefold.solver.planning_solution
@dataclass
class MySolution:
entity_list: Annotated[List[MyPlanningEntity], PlanningEntityCollectionProperty]
value_list: Annotated[List[int], ValueRangeProvider]
score: Annotated[timefold.solver.score.HardSoftScore, PlanningScore] = field(default=None)
score: Annotated[timefold.solver.score.SimpleScore, PlanningScore] = field(default=None)

solver_config = timefold.solver.config.SolverConfig(
solution_class=MySolution,
Expand All @@ -149,15 +142,14 @@ class MySolution:
constraint_provider_function=my_constraints
),
termination_config=timefold.solver.config.TerminationConfig(
best_score_limit='0hard/1soft'
best_score_limit='1'
)
)

solver_factory = timefold.solver.SolverFactory.create(solver_config)
solver = solver_factory.build_solver()
problem = MySolution([MyPlanningEntity()], [1, 2, 3])
solution: MySolution = solver.solve(problem)
assert solution.score.hard_score() == 0
assert solution.score.soft_score() == 1
assert solution.score.score() == 1
assert solution.entity_list[0].value == 2
assert solution.entity_list[0].value_squared == 4
Loading

0 comments on commit d6ad7a1

Please sign in to comment.