From 6282186d63fdcd0dc448b3fe752d8eaf77965381 Mon Sep 17 00:00:00 2001 From: Christopher Chianelli Date: Wed, 13 Mar 2024 20:15:04 -0400 Subject: [PATCH] feat: Allow jpyinterpreter to create annotations on classes, fields and methods - Fields annotations are created using Annotated in type hints - Method annotations are created using Annotated in a method's return type hints - Class annotations are created using a decorator which stores the annotation in a dict - Changed values in type hint dictionary from PythonLikeType to TypeHint - TypeHint is a record containing the type and any Java annotations - Java annotations are stored as instances of AnnotationMetadata, a record containing the annotation type and value of its attributes --- create-stubs.py | 4 +- .../jpyinterpreter/AnnotationMetadata.java | 74 ++++++ ...ythonBytecodeToJavaBytecodeTranslator.java | 6 +- .../jpyinterpreter/PythonClassTranslator.java | 59 ++++- .../jpyinterpreter/PythonCompiledClass.java | 10 +- .../PythonCompiledFunction.java | 15 +- .../ai/timefold/jpyinterpreter/TypeHint.java | 12 + jpyinterpreter/src/main/python/__init__.py | 14 +- .../python_to_java_bytecode_translator.py | 226 +++++++++++++++--- .../PythonClassTranslatorTest.java | 9 +- .../util/PythonFunctionBuilder.java | 4 +- jpyinterpreter/tests/test_classes.py | 46 ++++ 12 files changed, 416 insertions(+), 63 deletions(-) create mode 100644 jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/AnnotationMetadata.java create mode 100644 jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/TypeHint.java diff --git a/create-stubs.py b/create-stubs.py index 3bf090a7..888d7d22 100644 --- a/create-stubs.py +++ b/create-stubs.py @@ -9,10 +9,12 @@ import jpype.imports # noqa import ai.timefold.solver.core.api # noqa import ai.timefold.solver.core.config # noqa +import ai.timefold.jpyinterpreter # noqa import java.lang # noqa import java.time # noqa import java.util # noqa -stubgenj.generateJavaStubs([java.lang, java.time, java.util, ai.timefold.solver.core.api, ai.timefold.solver.core.config], +stubgenj.generateJavaStubs([java.lang, java.time, java.util, ai.timefold.solver.core.api, + ai.timefold.solver.core.config, ai.timefold.jpyinterpreter], useStubsSuffix=True) diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/AnnotationMetadata.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/AnnotationMetadata.java new file mode 100644 index 00000000..240d7ff2 --- /dev/null +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/AnnotationMetadata.java @@ -0,0 +1,74 @@ +package ai.timefold.jpyinterpreter; + +import java.lang.annotation.Annotation; +import java.lang.reflect.Array; +import java.util.Map; + +import org.objectweb.asm.AnnotationVisitor; +import org.objectweb.asm.ClassVisitor; +import org.objectweb.asm.FieldVisitor; +import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.Type; + +public record AnnotationMetadata(Class annotationType, Map annotationValueMap) { + public void addAnnotationTo(ClassVisitor classVisitor) { + visitAnnotation(classVisitor.visitAnnotation(Type.getDescriptor(annotationType), true)); + } + + public void addAnnotationTo(FieldVisitor fieldVisitor) { + visitAnnotation(fieldVisitor.visitAnnotation(Type.getDescriptor(annotationType), true)); + } + + public void addAnnotationTo(MethodVisitor methodVisitor) { + visitAnnotation(methodVisitor.visitAnnotation(Type.getDescriptor(annotationType), true)); + } + + private void visitAnnotation(AnnotationVisitor annotationVisitor) { + for (var entry : annotationValueMap.entrySet()) { + var annotationAttributeName = entry.getKey(); + var annotationAttributeValue = entry.getValue(); + + visitAnnotationAttribute(annotationVisitor, annotationAttributeName, annotationAttributeValue); + } + annotationVisitor.visitEnd(); + } + + private void visitAnnotationAttribute(AnnotationVisitor annotationVisitor, String attributeName, Object attributeValue) { + if (attributeValue instanceof Number + || attributeValue instanceof Boolean + || attributeValue instanceof Character + || attributeValue instanceof String) { + annotationVisitor.visit(attributeName, attributeValue); + return; + } + + if (attributeValue instanceof Class clazz) { + annotationVisitor.visit(attributeName, Type.getType(clazz)); + return; + } + + if (attributeValue instanceof AnnotationMetadata annotationMetadata) { + annotationMetadata.visitAnnotation( + annotationVisitor.visitAnnotation(attributeName, Type.getDescriptor(annotationMetadata.annotationType))); + return; + } + + if (attributeValue instanceof Enum enumValue) { + annotationVisitor.visitEnum(attributeName, Type.getDescriptor(enumValue.getClass()), + enumValue.name()); + return; + } + + if (attributeValue.getClass().isArray()) { + var arrayAnnotationVisitor = annotationVisitor.visitArray(attributeName); + var arrayLength = Array.getLength(attributeValue); + for (int i = 0; i < arrayLength; i++) { + visitAnnotationAttribute(arrayAnnotationVisitor, attributeName, Array.get(attributeValue, i)); + } + arrayAnnotationVisitor.visitEnd(); + return; + } + throw new IllegalArgumentException("Annotation of type %s has an illegal value %s for attribute %s." + .formatted(annotationType, attributeValue, attributeName)); + } +} diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonBytecodeToJavaBytecodeTranslator.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonBytecodeToJavaBytecodeTranslator.java index 8472ca05..dc535308 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonBytecodeToJavaBytecodeTranslator.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonBytecodeToJavaBytecodeTranslator.java @@ -146,7 +146,7 @@ public static T translatePythonBytecode(PythonCompiledFunction pythonCompile PythonLikeTuple annotationTuple = pythonCompiledFunction.typeAnnotations.entrySet() .stream() .map(entry -> PythonLikeTuple.fromList(List.of(PythonString.valueOf(entry.getKey()), - entry.getValue() != null ? entry.getValue() : BuiltinTypes.BASE_TYPE))) + entry.getValue() != null ? entry.getValue().type() : BuiltinTypes.BASE_TYPE))) .collect(Collectors.toCollection(PythonLikeTuple::new)); return FunctionImplementor.createInstance(pythonCompiledFunction.defaultPositionalArguments, pythonCompiledFunction.defaultKeywordArguments, @@ -161,7 +161,7 @@ public static T translatePythonBytecode(PythonCompiledFunction pythonCompile translatePythonBytecodeToClass(pythonCompiledFunction, javaFunctionalInterfaceType, genericTypeArgumentList); PythonLikeTuple annotationTuple = pythonCompiledFunction.typeAnnotations.entrySet() .stream() - .map(entry -> PythonLikeTuple.fromList(List.of(PythonString.valueOf(entry.getKey()), entry.getValue()))) + .map(entry -> PythonLikeTuple.fromList(List.of(PythonString.valueOf(entry.getKey()), entry.getValue().type()))) .collect(Collectors.toCollection(PythonLikeTuple::new)); return FunctionImplementor.createInstance(pythonCompiledFunction.defaultPositionalArguments, pythonCompiledFunction.defaultKeywordArguments, @@ -216,7 +216,7 @@ public static T translatePythonBytecodeToInstance(PythonCompiledFunction pyt Class compiledClass = translatePythonBytecodeToClass(pythonCompiledFunction, methodDescriptor, isVirtual); PythonLikeTuple annotationTuple = pythonCompiledFunction.typeAnnotations.entrySet() .stream() - .map(entry -> PythonLikeTuple.fromList(List.of(PythonString.valueOf(entry.getKey()), entry.getValue()))) + .map(entry -> PythonLikeTuple.fromList(List.of(PythonString.valueOf(entry.getKey()), entry.getValue().type()))) .collect(Collectors.toCollection(PythonLikeTuple::new)); return FunctionImplementor.createInstance(pythonCompiledFunction.defaultPositionalArguments, pythonCompiledFunction.defaultKeywordArguments, diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java index b88329d0..bd90598c 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonClassTranslator.java @@ -56,6 +56,8 @@ public class PythonClassTranslator { // $ is illegal in variables/methods in Python public static String TYPE_FIELD_NAME = "$TYPE"; public static String CPYTHON_TYPE_FIELD_NAME = "$CPYTHON_TYPE"; + private static String JAVA_FIELD_PREFIX = "$field$"; + private static String JAVA_METHOD_PREFIX = "$method$"; public static PythonLikeType translatePythonClass(PythonCompiledClass pythonCompiledClass) { String maybeClassName = @@ -144,6 +146,10 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp classWriter.visit(Opcodes.V11, Modifier.PUBLIC, internalClassName, null, superClassType.getJavaTypeInternalName(), interfaces); + for (var annotation : pythonCompiledClass.annotations) { + annotation.addAnnotationTo(classWriter); + } + pythonCompiledClass.staticAttributeNameToObject.forEach(pythonLikeType::$setAttribute); classWriter.visitField(Modifier.PUBLIC | Modifier.STATIC, TYPE_FIELD_NAME, Type.getDescriptor(PythonLikeType.class), @@ -157,15 +163,28 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp pythonLikeType.$setAttribute(staticAttributeEntry.getKey(), staticAttributeEntry.getValue()); } + for (var attributeName : pythonCompiledClass.typeAnnotations.keySet()) { + if (pythonLikeType.$getAttributeOrNull(attributeName) == null) { + instanceAttributeSet.add(attributeName); + } + } + Map attributeNameToTypeMap = new HashMap<>(); for (String attributeName : instanceAttributeSet) { - PythonLikeType type = pythonCompiledClass.typeAnnotations.getOrDefault(attributeName, BuiltinTypes.BASE_TYPE); + var typeHint = pythonCompiledClass.typeAnnotations.getOrDefault(attributeName, + TypeHint.withoutAnnotations(BuiltinTypes.BASE_TYPE)); + PythonLikeType type = typeHint.type(); if (type == null) { // null might be in __annotations__ type = BuiltinTypes.BASE_TYPE; } String javaFieldTypeDescriptor = 'L' + type.getJavaTypeInternalName() + ';'; attributeNameToTypeMap.put(attributeName, type); - classWriter.visitField(Modifier.PUBLIC, getJavaFieldName(attributeName), javaFieldTypeDescriptor, null, null); + var fieldVisitor = classWriter.visitField(Modifier.PUBLIC, getJavaFieldName(attributeName), javaFieldTypeDescriptor, + null, null); + for (var annotation : typeHint.annotationList()) { + annotation.addAnnotationTo(fieldVisitor); + } + fieldVisitor.visitEnd(); FieldDescriptor fieldDescriptor = new FieldDescriptor(attributeName, getJavaFieldName(attributeName), internalClassName, javaFieldTypeDescriptor, type, true); @@ -264,7 +283,7 @@ public static PythonLikeType translatePythonClass(PythonCompiledClass pythonComp pythonLikeType.$setAttribute("__module__", PythonString.valueOf(pythonCompiledClass.module)); PythonLikeDict annotations = new PythonLikeDict(); - pythonCompiledClass.typeAnnotations.forEach((name, type) -> annotations.put(PythonString.valueOf(name), type)); + pythonCompiledClass.typeAnnotations.forEach((name, type) -> annotations.put(PythonString.valueOf(name), type.type())); pythonLikeType.$setAttribute("__annotations__", annotations); PythonLikeTuple mro = new PythonLikeTuple(); @@ -347,19 +366,19 @@ public static void setSelfStaticInstances(PythonCompiledClass pythonCompiledClas } public static String getJavaFieldName(String pythonFieldName) { - return "$field$" + pythonFieldName; + return JAVA_FIELD_PREFIX + pythonFieldName; } public static String getPythonFieldName(String javaFieldName) { - return javaFieldName.substring("$field$".length()); + return javaFieldName.substring(JAVA_FIELD_PREFIX.length()); } public static String getJavaMethodName(String pythonMethodName) { - return "$method$" + pythonMethodName; + return JAVA_METHOD_PREFIX + pythonMethodName; } public static String getPythonMethodName(String javaMethodName) { - return javaMethodName.substring("$method$".length()); + return javaMethodName.substring(JAVA_METHOD_PREFIX.length()); } private static Class createBytecodeForMethodAndSetOnClass(String className, PythonLikeType pythonLikeType, @@ -673,6 +692,15 @@ private static PythonLikeFunction createConstructor(String classInternalName, } } + private static void addAnnotationsToMethod(PythonCompiledFunction function, MethodVisitor methodVisitor) { + var returnTypeHint = function.typeAnnotations.get("return"); + if (returnTypeHint != null) { + for (var annotation : returnTypeHint.annotationList()) { + annotation.addAnnotationTo(methodVisitor); + } + } + } + private static void createInstanceMethod(PythonLikeType pythonLikeType, ClassWriter classWriter, String internalClassName, String methodName, PythonCompiledFunction function) { InterfaceDeclaration interfaceDeclaration = getInterfaceForInstancePythonFunction(internalClassName, function); @@ -693,7 +721,8 @@ private static void createInstanceMethod(PythonLikeType pythonLikeType, ClassWri MethodVisitor methodVisitor = classWriter.visitMethod(Modifier.PUBLIC, javaMethodName, javaMethodDescriptor, null, null); - createMethodBody(internalClassName, javaMethodName, javaParameterTypes, interfaceDeclaration.methodDescriptor, function, + createInstanceOrStaticMethodBody(internalClassName, javaMethodName, javaParameterTypes, + interfaceDeclaration.methodDescriptor, function, interfaceDeclaration.interfaceName, interfaceDescriptor, methodVisitor); pythonLikeType.addMethod(methodName, @@ -722,7 +751,9 @@ private static void createStaticMethod(PythonLikeType pythonLikeType, ClassWrite for (int i = 0; i < function.totalArgCount(); i++) { javaParameterTypes[i] = Type.getType('L' + parameterPythonTypeList.get(i).getJavaTypeInternalName() + ';'); } - createMethodBody(internalClassName, javaMethodName, javaParameterTypes, interfaceDeclaration.methodDescriptor, function, + + createInstanceOrStaticMethodBody(internalClassName, javaMethodName, javaParameterTypes, + interfaceDeclaration.methodDescriptor, function, interfaceDeclaration.interfaceName, interfaceDescriptor, methodVisitor); pythonLikeType.addMethod(methodName, @@ -746,8 +777,10 @@ private static void createClassMethod(PythonLikeType pythonLikeType, ClassWriter classWriter.visitMethod(Modifier.PUBLIC | Modifier.STATIC, javaMethodName, javaMethodDescriptor, null, null); for (int i = 0; i < function.getParameterTypes().size(); i++) { - methodVisitor.visitParameter("parameter" + i, 0); + methodVisitor.visitParameter(function.co_varnames.get(i), 0); } + + addAnnotationsToMethod(function, methodVisitor); methodVisitor.visitCode(); methodVisitor.visitFieldInsn(Opcodes.GETSTATIC, internalClassName, javaMethodName, interfaceDescriptor); @@ -775,13 +808,15 @@ private static void createClassMethod(PythonLikeType pythonLikeType, ClassWriter parameterTypes)); } - private static void createMethodBody(String internalClassName, String javaMethodName, Type[] javaParameterTypes, + private static void createInstanceOrStaticMethodBody(String internalClassName, String javaMethodName, + Type[] javaParameterTypes, String methodDescriptorString, PythonCompiledFunction function, String interfaceInternalName, String interfaceDescriptor, MethodVisitor methodVisitor) { for (int i = 0; i < javaParameterTypes.length; i++) { - methodVisitor.visitParameter("parameter" + i, 0); + methodVisitor.visitParameter(function.co_varnames.get(i), 0); } + addAnnotationsToMethod(function, methodVisitor); methodVisitor.visitCode(); methodVisitor.visitFieldInsn(Opcodes.GETSTATIC, internalClassName, javaMethodName, interfaceDescriptor); diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledClass.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledClass.java index 7d46d04b..44c0f102 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledClass.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledClass.java @@ -21,10 +21,15 @@ public class PythonCompiledClass { public String className; + /** + * The annotations on the type + */ + public List annotations; + /** * Type annotations for fields */ - public Map typeAnnotations; + public Map typeAnnotations; /** * The binary type of this PythonCompiledClass; @@ -47,6 +52,9 @@ public class PythonCompiledClass { */ public Map staticAttributeNameToClassInstance; + public PythonCompiledClass() { + } + public String getGeneratedClassBaseName() { if (module == null || module.isEmpty()) { return JavaIdentifierUtils.sanitizeClassName((qualifiedName != null) ? qualifiedName : "PythonClass"); diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledFunction.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledFunction.java index 24459daf..18bb4f04 100644 --- a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledFunction.java +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/PythonCompiledFunction.java @@ -46,7 +46,7 @@ public class PythonCompiledFunction { * Type annotations for the parameters and return. * (return is stored under the "return" key). */ - public Map typeAnnotations; + public Map typeAnnotations; /** * Default positional arguments @@ -155,9 +155,10 @@ public List getParameterTypes() { for (int i = 0; i < totalArgCount(); i++) { String parameterName = co_varnames.get(i); - PythonLikeType parameterType = typeAnnotations.get(parameterName); - if (parameterType == null) { // map may have nulls - parameterType = defaultType; + var parameterTypeHint = typeAnnotations.get(parameterName); + PythonLikeType parameterType = defaultType; + if (parameterTypeHint != null) { + parameterType = parameterTypeHint.type(); } out.add(parameterType); } @@ -165,7 +166,11 @@ public List getParameterTypes() { } public Optional getReturnType() { - return Optional.ofNullable(typeAnnotations.get("return")); + var returnTypeHint = typeAnnotations.get("return"); + if (returnTypeHint == null) { + return Optional.empty(); + } + return Optional.of(returnTypeHint.type()); } public String getAsmMethodDescriptorString() { diff --git a/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/TypeHint.java b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/TypeHint.java new file mode 100644 index 00000000..ae676308 --- /dev/null +++ b/jpyinterpreter/src/main/java/ai/timefold/jpyinterpreter/TypeHint.java @@ -0,0 +1,12 @@ +package ai.timefold.jpyinterpreter; + +import java.util.List; + +import ai.timefold.jpyinterpreter.types.PythonLikeType; + +public record TypeHint(PythonLikeType type, List annotationList) { + public static TypeHint withoutAnnotations(PythonLikeType type) { + return new TypeHint(type, List.of()); + } + +} diff --git a/jpyinterpreter/src/main/python/__init__.py b/jpyinterpreter/src/main/python/__init__.py index 9cbcaae1..4dbcad9e 100644 --- a/jpyinterpreter/src/main/python/__init__.py +++ b/jpyinterpreter/src/main/python/__init__.py @@ -2,8 +2,12 @@ This module acts as an interface to the Python bytecode to Java bytecode interpreter """ from .jvm_setup import init, set_class_output_directory -from .python_to_java_bytecode_translator import translate_python_bytecode_to_java_bytecode, \ - translate_python_class_to_java_class, convert_to_java_python_like_object, force_update_type, \ - get_java_type_for_python_type, unwrap_python_like_object, as_java, as_untyped_java, as_typed_java, is_c_native, \ - is_current_python_version_supported, check_current_python_version_supported, is_python_version_supported, \ - _force_as_java_generator +from .python_to_java_bytecode_translator import (JavaAnnotation, add_class_annotation, + translate_python_bytecode_to_java_bytecode, + translate_python_class_to_java_class, + convert_to_java_python_like_object, force_update_type, + get_java_type_for_python_type, unwrap_python_like_object, as_java, + as_untyped_java, as_typed_java, is_c_native, + is_current_python_version_supported, + check_current_python_version_supported, is_python_version_supported, + _force_as_java_generator) diff --git a/jpyinterpreter/src/main/python/python_to_java_bytecode_translator.py b/jpyinterpreter/src/main/python/python_to_java_bytecode_translator.py index aff06036..3006b20e 100644 --- a/jpyinterpreter/src/main/python/python_to_java_bytecode_translator.py +++ b/jpyinterpreter/src/main/python/python_to_java_bytecode_translator.py @@ -4,7 +4,10 @@ import inspect import sys import abc -from typing import Union +from dataclasses import dataclass +from types import FunctionType +from typing import TypeVar, Any, List, Tuple, Dict, Union, Annotated, Type, Callable, \ + get_origin, get_args, get_type_hints from jpype import JInt, JLong, JDouble, JBoolean, JProxy, JClass, JArray @@ -14,10 +17,36 @@ global_dict_to_instance = dict() global_dict_to_key_set = dict() type_to_compiled_java_class = dict() +type_to_annotations = dict() + function_interface_pair_to_instance = dict() function_interface_pair_to_class = dict() +@dataclass +class JavaAnnotation: + annotation_type: JClass + annotation_values: Dict[str, Any] + + +T = TypeVar('T') + + +def add_class_annotation(annotation_type, /, **annotation_values: Any) -> Callable[[Type[T]], Type[T]]: + def decorator(_cls: Type[T]) -> Type[T]: + global type_to_compiled_java_class + global type_to_annotations + if _cls in type_to_compiled_java_class: + raise RuntimeError('Cannot add an annotation after a class been compiled.') + annotations = type_to_annotations.get(_cls, []) + annotation = JavaAnnotation(annotation_type, annotation_values) + annotations.append(annotation) + type_to_annotations[_cls] = annotations + return _cls + + return decorator + + def is_python_version_supported(python_version): python_version_major_minor = python_version[0:2] return MINIMUM_SUPPORTED_PYTHON_VERSION <= python_version_major_minor <= MAXIMUM_SUPPORTED_PYTHON_VERSION @@ -662,39 +691,178 @@ def get_default_args(func): } -def copy_type_annotations(annotations_dict, default_args, vargs_name, kwargs_name): - from java.util import HashMap +def copy_type_annotations(hinted_object, default_args, vargs_name, kwargs_name): + from java.util import List, HashMap from java.lang import Class as JavaClass + from ai.timefold.jpyinterpreter import TypeHint from ai.timefold.jpyinterpreter.types.wrappers import OpaquePythonReference, JavaObjectWrapper, CPythonType # noqa global type_to_compiled_java_class out = HashMap() - if annotations_dict is None or not isinstance(annotations_dict, dict): - return out + type_hints = get_type_hints(hinted_object, include_extras=True) - for name, value in annotations_dict.items(): + for name, type_hint in type_hints.items(): if not isinstance(name, str): continue if name == vargs_name: - out.put(name, type_to_compiled_java_class[tuple]) + out.put(name, TypeHint.withoutAnnotations(type_to_compiled_java_class[tuple])) continue if name == kwargs_name: - out.put(name, type_to_compiled_java_class[dict]) + out.put(name, TypeHint.withoutAnnotations(type_to_compiled_java_class[dict])) continue + hint_type = type_hint + hint_annotations = List.of() + if get_origin(type_hint) is Annotated: + hint_type = get_args(type_hint)[0] + hint_annotations = get_java_annotations(type_hint.__metadata__) + if name in default_args: - value = Union[value, type(default_args[name])] - if value in type_to_compiled_java_class: - out.put(name, type_to_compiled_java_class[value]) - elif isinstance(value, (JClass, JavaClass)): - java_type = JavaObjectWrapper.getPythonTypeForClass(value) - type_to_compiled_java_class[value] = java_type - out.put(name, java_type) - elif isinstance(value, (type, str)): - out.put(name, get_java_type_for_python_type(value)) + hint_type = Union[hint_type, type(default_args[name])] + + java_type = None + if hint_type in type_to_compiled_java_class: + java_type = type_to_compiled_java_class[hint_type] + elif isinstance(hint_type, (JClass, JavaClass)): + java_type = JavaObjectWrapper.getPythonTypeForClass(hint_type) + type_to_compiled_java_class[hint_type] = java_type + elif isinstance(hint_type, (type, str)): + java_type = get_java_type_for_python_type(hint_type) + + if java_type is not None: + out.put(name, TypeHint(java_type, hint_annotations)) + return out + + +def get_java_annotations(annotated_metadata: List[Any]): + from java.util import ArrayList + out = ArrayList() + for metadata in annotated_metadata: + if not isinstance(metadata, JavaAnnotation): + continue + out.add(convert_java_annotation(metadata)) + return out + + +def convert_java_annotation(java_annotation: JavaAnnotation): + from java.util import HashMap + from ai.timefold.jpyinterpreter import AnnotationMetadata + annotation_values = HashMap() + for attribute_name, attribute_value in java_annotation.annotation_values.items(): + annotation_method = java_annotation.annotation_type.class_.getDeclaredMethod(attribute_name) + attribute_type = annotation_method.getReturnType() + java_attribute_value = convert_annotation_value(java_annotation.annotation_type, attribute_type, + attribute_name, attribute_value) + annotation_values.put(attribute_name, java_attribute_value) + return AnnotationMetadata(java_annotation.annotation_type.class_, annotation_values) + + +def convert_annotation_value(annotation_type: JClass, attribute_type: JClass, attribute_name: str, attribute_value: Any): + from jpype import JBoolean, JByte, JChar, JShort, JInt, JLong, JFloat, JDouble, JString, JArray + # See 9.6.1 of the Java spec for possible element values of annotations + if attribute_type == JClass('boolean').class_: + return JBoolean(attribute_value) + elif attribute_type == JClass('byte').class_: + return JByte(attribute_value) + elif attribute_type == JClass('char').class_: + return JChar(attribute_value) + elif attribute_type == JClass('short').class_: + return JShort(attribute_value) + elif attribute_type == JClass('int').class_: + return JInt(attribute_value) + elif attribute_type == JClass('long').class_: + return JLong(attribute_value) + elif attribute_type == JClass('float').class_: + return JFloat(attribute_value) + elif attribute_type == JClass('double').class_: + return JDouble(attribute_value) + elif attribute_type == JClass('java.lang.String').class_: + return JString(attribute_value) + elif attribute_type == JClass('java.lang.Class').class_: + if isinstance(attribute_value, type): + return get_java_type_for_python_type(attribute_type) + elif isinstance(attribute_value, FunctionType): + generic_type = annotation_type.class_.getDeclaredMethod(attribute_name).getGenericReturnType() + function_type_and_generic_args = resolve_java_function_type_as_tuple(generic_type) + return translate_python_bytecode_to_java_bytecode(attribute_value, *function_type_and_generic_args) + else: + raise ValueError(f'Illegal value for {attribute_name} in annotation {annotation_type}: {attribute_value}') + elif attribute_type.isEnum(): + return attribute_value + elif attribute_type.isArray(): + dimensions = get_dimensions(attribute_type) + component_type = get_component_type(attribute_type) + return JArray(component_type, dims=dimensions)(convert_annotation_array_elements(annotation_type, + component_type.class_, + attribute_name, + attribute_value)) + elif JClass('java.lang.Annotation').class_.isAssignableFrom(attribute_type): + if not isinstance(attribute_value, JavaAnnotation): + raise ValueError(f'Illegal value for {attribute_name} in annotation {annotation_type}: {attribute_value}') + return convert_java_annotation(attribute_value) + else: + raise ValueError(f'Illegal type for annotation element {attribute_type} for element named ' + f'{attribute_name} on annotation type {annotation_type}.') + + +def resolve_java_function_type_as_tuple(function_class) -> Tuple[JClass]: + from java.lang.reflect import ParameterizedType, WildcardType + if isinstance(function_class, WildcardType): + return resolve_java_type_as_tuple(function_class.getUpperBounds()[0]) + elif isinstance(function_class, ParameterizedType): + return resolve_java_type_as_tuple(function_class.getActualTypeArguments()[0]) + else: + raise ValueError(f'Unable to determine interface for type {function_class}') + +def resolve_java_type_as_tuple(generic_type) -> Tuple[JClass]: + from java.lang.reflect import ParameterizedType, WildcardType + if isinstance(generic_type, WildcardType): + return (*map(resolve_java_type_as_tuple, generic_type.getUpperBounds()),) + elif isinstance(generic_type, ParameterizedType): + return resolve_raw_types(generic_type.getRawType(), *generic_type.getActualTypeArguments()) + elif isinstance(generic_type, JClass): + return (generic_type,) + else: + raise ValueError(f'Unable to determine interface for type {generic_type}') + + +def resolve_raw_types(*type_arguments) -> Tuple[JClass]: + return (*map(resolve_raw_type, type_arguments),) + + +def resolve_raw_type(type_argument) -> JClass: + from java.lang.reflect import ParameterizedType + if isinstance(type_argument, ParameterizedType): + return resolve_raw_type(type_argument.getRawType()) + elif isinstance(type_argument, JClass): + return type_argument + else: + raise ValueError(f'Unable to determine raw type for type {type_argument}') + + +def convert_annotation_array_elements(annotation_type: JClass, component_type: JClass, attribute_name: str, + array_elements: List) -> List: + out = [] + for item in array_elements: + if isinstance(item, (list, tuple)): + out.append(convert_annotation_array_elements(annotation_type, component_type, attribute_name, item)) + else: + out.append(convert_annotation_value(annotation_type, component_type, attribute_name, item)) return out +def get_dimensions(array_type: JClass) -> int: + if array_type.getComponentType() is None: + return 0 + return get_dimensions(array_type.getComponentType()) + 1 + + +def get_component_type(array_type: JClass) -> JClass: + if not array_type.getComponentType().isArray(): + return JClass(array_type.getComponentType().getCanonicalName()) + return get_component_type(array_type.getComponentType()) + + def copy_constants(constants_iterable): from java.util import ArrayList from ai.timefold.jpyinterpreter import CPythonBackedPythonInterpreter @@ -834,7 +1002,7 @@ def get_function_bytecode_object(python_function): python_compiled_function.co_kwonlyargcount = python_function.__code__.co_kwonlyargcount python_compiled_function.closure = copy_closure(python_function.__closure__) python_compiled_function.globalsMap = copy_globals(python_function.__globals__, python_function.__code__.co_names) - python_compiled_function.typeAnnotations = copy_type_annotations(python_function.__annotations__, + python_compiled_function.typeAnnotations = copy_type_annotations(python_function, get_default_args(python_function), inspect.getfullargspec(python_function).varargs, inspect.getfullargspec(python_function).varkw) @@ -1065,7 +1233,7 @@ def erase_generic_args(python_type): def translate_python_class_to_java_class(python_class): from java.lang import Class as JavaClass from java.util import ArrayList, HashMap - from ai.timefold.jpyinterpreter import PythonCompiledClass, PythonClassTranslator, CPythonBackedPythonInterpreter # noqa + from ai.timefold.jpyinterpreter import AnnotationMetadata, PythonCompiledClass, PythonClassTranslator, CPythonBackedPythonInterpreter # noqa from ai.timefold.jpyinterpreter.types import BuiltinTypes from ai.timefold.jpyinterpreter.types.wrappers import JavaObjectWrapper, OpaquePythonReference, CPythonType # noqa @@ -1179,24 +1347,20 @@ def translate_python_class_to_java_class(python_class): static_attributes_map.put(attribute[0], convert_to_java_python_like_object(attribute[1])) - python_compiled_class = PythonCompiledClass() + python_compiled_class.annotations = ArrayList() + for annotation in type_to_annotations.get(python_class, []): + python_compiled_class.annotations.add(convert_java_annotation(annotation)) + python_compiled_class.binaryType = CPythonType.getType(JProxy(OpaquePythonReference, inst=python_class, convert=True)) python_compiled_class.module = python_class.__module__ python_compiled_class.qualifiedName = python_class.__qualname__ python_compiled_class.className = python_class.__name__ - if hasattr(python_class, '__annotations__'): - python_compiled_class.typeAnnotations = copy_type_annotations(python_class.__annotations__, - dict(), - None, - None) - else: - python_compiled_class.typeAnnotations = copy_type_annotations(None, - dict(), - None, - None) - + python_compiled_class.typeAnnotations = copy_type_annotations(python_class, + dict(), + None, + None) python_compiled_class.superclassList = superclass_list python_compiled_class.instanceFunctionNameToPythonBytecode = instance_method_map python_compiled_class.staticFunctionNameToPythonBytecode = static_method_map diff --git a/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/PythonClassTranslatorTest.java b/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/PythonClassTranslatorTest.java index a844d32d..f6b94d94 100644 --- a/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/PythonClassTranslatorTest.java +++ b/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/PythonClassTranslatorTest.java @@ -41,11 +41,12 @@ public void testPythonClassTranslation() throws ClassNotFoundException, NoSuchMe .op(ControlOpDescriptor.RETURN_VALUE) .build(); + compiledClass.annotations = List.of(); compiledClass.className = "MyClass"; compiledClass.superclassList = List.of(BuiltinTypes.BASE_TYPE); compiledClass.staticAttributeNameToObject = Map.of("type_variable", new PythonString("type_value")); compiledClass.staticAttributeNameToClassInstance = Map.of(); - compiledClass.typeAnnotations = Map.of("age", BuiltinTypes.INT_TYPE); + compiledClass.typeAnnotations = Map.of("age", TypeHint.withoutAnnotations(BuiltinTypes.INT_TYPE)); compiledClass.instanceFunctionNameToPythonBytecode = Map.of("__init__", initFunction, "get_age", ageFunction); compiledClass.staticFunctionNameToPythonBytecode = Map.of("hello_world", helloWorldFunction); @@ -93,11 +94,12 @@ public void testPythonClassComparable() throws ClassNotFoundException { PythonCompiledFunction comparisonFunction = getCompareFunction.apply(compareOp); PythonCompiledClass compiledClass = new PythonCompiledClass(); + compiledClass.annotations = List.of(); compiledClass.className = "MyClass"; compiledClass.superclassList = List.of(BuiltinTypes.BASE_TYPE); compiledClass.staticAttributeNameToObject = Map.of(); compiledClass.staticAttributeNameToClassInstance = Map.of(); - compiledClass.typeAnnotations = Map.of("key", BuiltinTypes.INT_TYPE); + compiledClass.typeAnnotations = Map.of("key", TypeHint.withoutAnnotations(BuiltinTypes.INT_TYPE)); compiledClass.instanceFunctionNameToPythonBytecode = Map.of("__init__", initFunction, compareOp.dunderMethod, comparisonFunction); compiledClass.staticFunctionNameToPythonBytecode = Map.of(); @@ -161,11 +163,12 @@ public void testPythonClassEqualsAndHashCode() throws ClassNotFoundException { .build(); PythonCompiledClass compiledClass = new PythonCompiledClass(); + compiledClass.annotations = List.of(); compiledClass.className = "MyClass"; compiledClass.superclassList = List.of(BuiltinTypes.BASE_TYPE); compiledClass.staticAttributeNameToObject = Map.of(); compiledClass.staticAttributeNameToClassInstance = Map.of(); - compiledClass.typeAnnotations = Map.of("key", BuiltinTypes.INT_TYPE); + compiledClass.typeAnnotations = Map.of("key", TypeHint.withoutAnnotations(BuiltinTypes.INT_TYPE)); compiledClass.instanceFunctionNameToPythonBytecode = Map.of("__init__", initFunction, "__eq__", equalsFunction, "__hash__", hashFunction); diff --git a/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/util/PythonFunctionBuilder.java b/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/util/PythonFunctionBuilder.java index d84be9fe..c78d1322 100644 --- a/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/util/PythonFunctionBuilder.java +++ b/jpyinterpreter/src/test/java/ai/timefold/jpyinterpreter/util/PythonFunctionBuilder.java @@ -14,6 +14,7 @@ import ai.timefold.jpyinterpreter.PythonExceptionTable; import ai.timefold.jpyinterpreter.PythonLikeObject; import ai.timefold.jpyinterpreter.PythonVersion; +import ai.timefold.jpyinterpreter.TypeHint; import ai.timefold.jpyinterpreter.implementors.JavaPythonTypeConversionImplementor; import ai.timefold.jpyinterpreter.opcodes.descriptor.CollectionOpDescriptor; import ai.timefold.jpyinterpreter.opcodes.descriptor.ControlOpDescriptor; @@ -26,7 +27,6 @@ import ai.timefold.jpyinterpreter.opcodes.descriptor.OpcodeDescriptor; import ai.timefold.jpyinterpreter.opcodes.descriptor.StackOpDescriptor; import ai.timefold.jpyinterpreter.opcodes.descriptor.VariableOpDescriptor; -import ai.timefold.jpyinterpreter.types.PythonLikeType; /** * A builder for Python bytecode. @@ -65,7 +65,7 @@ public class PythonFunctionBuilder { Map globalsMap = new HashMap<>(); - Map typeAnnotations = new HashMap<>(); + Map typeAnnotations = new HashMap<>(); int co_argcount = 0; int co_kwonlyargcount = 0; diff --git a/jpyinterpreter/tests/test_classes.py b/jpyinterpreter/tests/test_classes.py index 7b13f1f3..51b90188 100644 --- a/jpyinterpreter/tests/test_classes.py +++ b/jpyinterpreter/tests/test_classes.py @@ -869,3 +869,49 @@ def my_function_instance(a, b, c): verifier.verify(1, 2, 3, expected_result=16) verifier.verify(2, 4, 6, expected_result=22) verifier.verify(1, 1, 1, expected_result=13) + + +def test_class_annotations(): + from typing import Annotated + from java.lang import Deprecated + from java.lang.annotation import Target, ElementType + from jpyinterpreter import add_class_annotation, JavaAnnotation, translate_python_class_to_java_class + + @add_class_annotation(Deprecated, + forRemoval=True, + since='0.0.0') + class A: + my_field: Annotated[int, JavaAnnotation(Deprecated, { + 'forRemoval': True, + 'since': '1.0.0' + }), 'extra metadata', + JavaAnnotation(Target, { + 'value': [ElementType.CONSTRUCTOR, ElementType.METHOD] + })] + + def my_method(self) -> Annotated[str, 'extra', JavaAnnotation(Deprecated, { + 'forRemoval': False, + 'since': '2.0.0' + })]: + return 'hello world' + + translated_class = translate_python_class_to_java_class(A).getJavaClass() + annotations = translated_class.getAnnotations() + assert len(annotations) == 1 + assert isinstance(annotations[0], Deprecated) + assert annotations[0].forRemoval() + assert annotations[0].since() == '0.0.0' + + annotations = translated_class.getField('$field$my_field').getAnnotations() + assert len(annotations) == 2 + assert isinstance(annotations[0], Deprecated) + assert annotations[0].forRemoval() + assert annotations[0].since() == '1.0.0' + assert isinstance(annotations[1], Target) + assert list(annotations[1].value()) == [ElementType.CONSTRUCTOR, ElementType.METHOD] + + annotations = translated_class.getMethod('$method$my_method').getAnnotations() + assert len(annotations) == 1 + assert isinstance(annotations[0], Deprecated) + assert annotations[0].forRemoval() is False + assert annotations[0].since() == '2.0.0'