diff --git a/airflow/serialization/serde.py b/airflow/serialization/serde.py index c512116cde702..dd4913e5670cc 100644 --- a/airflow/serialization/serde.py +++ b/airflow/serialization/serde.py @@ -19,6 +19,7 @@ import dataclasses import enum +import functools import logging import re import sys @@ -30,6 +31,7 @@ import airflow.serialization.serializers from airflow.configuration import conf +from airflow.stats import Stats from airflow.utils.module_loading import import_string, iter_namespace, qualname log = logging.getLogger(__name__) @@ -58,7 +60,6 @@ _primitives = (int, bool, float, str) _builtin_collections = (frozenset, list, set, tuple) # dict is treated specially. -_patterns: list[re.Pattern] = [] def encode(cls: str, version: int, data: T) -> dict[str, str | int | T]: @@ -253,7 +254,7 @@ def _convert(old: dict) -> dict: def _match(classname: str) -> bool: - return any(p.match(classname) is not None for p in _patterns) + return any(p.match(classname) is not None for p in _get_patterns()) def _stringify(classname: str, version: int, value: T | None) -> str: @@ -275,33 +276,32 @@ def _register(): _serializers.clear() _deserializers.clear() - for _, name, _ in iter_namespace(airflow.serialization.serializers): - name = import_module(name) - for s in getattr(name, "serializers", list()): - if not isinstance(s, str): - s = qualname(s) - if s in _serializers and _serializers[s] != name: - raise AttributeError(f"duplicate {s} for serialization in {name} and {_serializers[s]}") - log.debug("registering %s for serialization") - _serializers[s] = name - for d in getattr(name, "deserializers", list()): - if not isinstance(d, str): - d = qualname(d) - if d in _deserializers and _deserializers[d] != name: - raise AttributeError(f"duplicate {d} for deserialization in {name} and {_serializers[d]}") - log.debug("registering %s for deserialization", d) - _deserializers[d] = name - _extra_allowed.add(d) - - -def _compile_patterns(): + with Stats.timer("serde.load_serializers") as timer: + for _, name, _ in iter_namespace(airflow.serialization.serializers): + name = import_module(name) + for s in getattr(name, "serializers", list()): + if not isinstance(s, str): + s = qualname(s) + if s in _serializers and _serializers[s] != name: + raise AttributeError(f"duplicate {s} for serialization in {name} and {_serializers[s]}") + log.debug("registering %s for serialization", s) + _serializers[s] = name + for d in getattr(name, "deserializers", list()): + if not isinstance(d, str): + d = qualname(d) + if d in _deserializers and _deserializers[d] != name: + raise AttributeError(f"duplicate {d} for deserialization in {name} and {_serializers[d]}") + log.debug("registering %s for deserialization", d) + _deserializers[d] = name + _extra_allowed.add(d) + + log.info("loading serializers took %.3f seconds", timer.duration) + + +@functools.lru_cache(maxsize=None) +def _get_patterns() -> list[re.Pattern]: patterns = conf.get("core", "allowed_deserialization_classes").split() - - _patterns.clear() # ensure to reinit - for p in patterns: - p = re.sub(r"(\w)\.", r"\1\..", p) - _patterns.append(re.compile(p)) + return [re.compile(re.sub(r"(\w)\.", r"\1\..", p)) for p in patterns] _register() -_compile_patterns() diff --git a/airflow/serialization/serializers/bignum.py b/airflow/serialization/serializers/bignum.py index 649c83dbd1dbe..769e78491e9e4 100644 --- a/airflow/serialization/serializers/bignum.py +++ b/airflow/serialization/serializers/bignum.py @@ -17,22 +17,25 @@ # under the License. from __future__ import annotations -from decimal import Decimal from typing import TYPE_CHECKING from airflow.utils.module_loading import qualname if TYPE_CHECKING: + import decimal + from airflow.serialization.serde import U -serializers = [Decimal] +serializers = ["decimal.Decimal"] deserializers = serializers __version__ = 1 def serialize(o: object) -> tuple[U, str, int, bool]: + from decimal import Decimal + if not isinstance(o, Decimal): return "", "", 0, False name = qualname(o) @@ -44,7 +47,9 @@ def serialize(o: object) -> tuple[U, str, int, bool]: return float(o), name, __version__, True -def deserialize(classname: str, version: int, data: object) -> Decimal: +def deserialize(classname: str, version: int, data: object) -> decimal.Decimal: + from decimal import Decimal + if version > __version__: raise TypeError(f"serialized {version} of {classname} > {__version__}") diff --git a/airflow/serialization/serializers/datetime.py b/airflow/serialization/serializers/datetime.py index b400258838a0d..bdb9a6cb6c5a6 100644 --- a/airflow/serialization/serializers/datetime.py +++ b/airflow/serialization/serializers/datetime.py @@ -17,21 +17,19 @@ # under the License. from __future__ import annotations -from datetime import date, datetime, timedelta from typing import TYPE_CHECKING -from pendulum import DateTime -from pendulum.tz import timezone - from airflow.utils.module_loading import qualname from airflow.utils.timezone import convert_to_utc, is_naive if TYPE_CHECKING: + import datetime + from airflow.serialization.serde import U __version__ = 1 -serializers = [date, datetime, timedelta, DateTime] +serializers = ["datetime.date", "datetime.datetime", "datetime.timedelta", "pendulum.datetime.DateTime"] deserializers = serializers TIMESTAMP = "timestamp" @@ -39,7 +37,9 @@ def serialize(o: object) -> tuple[U, str, int, bool]: - if isinstance(o, DateTime) or isinstance(o, datetime): + from datetime import date, datetime, timedelta + + if isinstance(o, datetime): qn = qualname(o) if is_naive(o): o = convert_to_utc(o) @@ -57,17 +57,22 @@ def serialize(o: object) -> tuple[U, str, int, bool]: return "", "", 0, False -def deserialize(classname: str, version: int, data: dict | str) -> datetime | timedelta | date: - if classname == qualname(datetime) and isinstance(data, dict): - return datetime.fromtimestamp(float(data[TIMESTAMP]), tz=timezone(data[TIMEZONE])) +def deserialize(classname: str, version: int, data: dict | str) -> datetime.date | datetime.timedelta: + import datetime + + from pendulum import DateTime + from pendulum.tz import timezone + + if classname == qualname(datetime.datetime) and isinstance(data, dict): + return datetime.datetime.fromtimestamp(float(data[TIMESTAMP]), tz=timezone(data[TIMEZONE])) if classname == qualname(DateTime) and isinstance(data, dict): return DateTime.fromtimestamp(float(data[TIMESTAMP]), tz=timezone(data[TIMEZONE])) - if classname == qualname(timedelta) and isinstance(data, (str, float)): - return timedelta(seconds=float(data)) + if classname == qualname(datetime.timedelta) and isinstance(data, (str, float)): + return datetime.timedelta(seconds=float(data)) - if classname == qualname(date) and isinstance(data, str): - return date.fromisoformat(data) + if classname == qualname(datetime.date) and isinstance(data, str): + return datetime.date.fromisoformat(data) raise TypeError(f"unknown date/time format {classname}") diff --git a/airflow/serialization/serializers/kubernetes.py b/airflow/serialization/serializers/kubernetes.py index a8f0c0f333de5..0ed9c96f71860 100644 --- a/airflow/serialization/serializers/kubernetes.py +++ b/airflow/serialization/serializers/kubernetes.py @@ -22,19 +22,15 @@ from airflow.utils.module_loading import qualname -serializers = [] - -try: - from kubernetes.client import models as k8s - - serializers = [k8s.v1_pod.V1Pod, k8s.V1ResourceRequirements] -except ImportError: - k8s = None +# lazy loading for performance reasons +serializers = [ + "kubernetes.client.models.v1_resource_requirements.V1ResourceRequirements", + "kubernetes.client.models.v1_pod.V1Pod", +] if TYPE_CHECKING: from airflow.serialization.serde import U - __version__ = 1 deserializers: list[type[object]] = [] @@ -42,6 +38,8 @@ def serialize(o: object) -> tuple[U, str, int, bool]: + from kubernetes.client import models as k8s + if not k8s: return "", "", 0, False diff --git a/airflow/serialization/serializers/numpy.py b/airflow/serialization/serializers/numpy.py index 4dea70aa13742..603f4df44a44c 100644 --- a/airflow/serialization/serializers/numpy.py +++ b/airflow/serialization/serializers/numpy.py @@ -19,47 +19,36 @@ from typing import TYPE_CHECKING, Any -from airflow.utils.module_loading import qualname - -serializers = [] - -try: - import numpy as np - - serializers = [ - np.int_, - np.intc, - np.intp, - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, - np.bool_, - np.float_, - np.float16, - np.float64, - np.complex_, - np.complex64, - np.complex128, - ] -except ImportError: - np = None # type: ignore - +from airflow.utils.module_loading import import_string, qualname + +# lazy loading for performance reasons +serializers = [ + "numpy.int8", + "numpy.int16", + "numpy.int32", + "numpy.int64", + "numpy.uint8", + "numpy.uint16", + "numpy.uint32", + "numpy.uint64", + "numpy.bool_", + "numpy.float64", + "numpy.float16", + "numpy.complex128", + "numpy.complex64", +] if TYPE_CHECKING: from airflow.serialization.serde import U -deserializers: list = serializers -_deserializers: dict[str, type[object]] = {qualname(x): x for x in deserializers} +deserializers = serializers __version__ = 1 def serialize(o: object) -> tuple[U, str, int, bool]: + import numpy as np + if np is None: return "", "", 0, False @@ -97,8 +86,7 @@ def deserialize(classname: str, version: int, data: str) -> Any: if version > __version__: raise TypeError("serialized version is newer than class version") - f = _deserializers.get(classname, None) - if callable(f): - return f(data) # type: ignore [call-arg] + if classname not in deserializers: + raise TypeError(f"unsupported {classname} found for numpy deserialization") - raise TypeError(f"unsupported {classname} found for numpy deserialization") + return import_string(classname)(data) diff --git a/airflow/serialization/serializers/timezone.py b/airflow/serialization/serializers/timezone.py index 0a0bcc222c611..b55b51610b41b 100644 --- a/airflow/serialization/serializers/timezone.py +++ b/airflow/serialization/serializers/timezone.py @@ -19,16 +19,15 @@ from typing import TYPE_CHECKING -import pendulum -from pendulum.tz.timezone import FixedTimezone, Timezone - from airflow.utils.module_loading import qualname if TYPE_CHECKING: + from pendulum.tz.timezone import Timezone + from airflow.serialization.serde import U -serializers = [FixedTimezone, Timezone] +serializers = ["pendulum.tz.timezone.FixedTimezone", "pendulum.tz.timezone.Timezone"] deserializers = serializers __version__ = 1 @@ -44,6 +43,8 @@ def serialize(o: object) -> tuple[U, str, int, bool]: 0 without the special case), but passing 0 into ``pendulum.timezone`` does not give us UTC (but ``+00:00``). """ + from pendulum.tz.timezone import FixedTimezone, Timezone + name = qualname(o) if isinstance(o, FixedTimezone): if o.offset == 0: @@ -57,6 +58,8 @@ def serialize(o: object) -> tuple[U, str, int, bool]: def deserialize(classname: str, version: int, data: object) -> Timezone: + from pendulum.tz import fixed_timezone, timezone + if not isinstance(data, (str, int)): raise TypeError(f"{data} is not of type int or str but of {type(data)}") @@ -64,6 +67,6 @@ def deserialize(classname: str, version: int, data: object) -> Timezone: raise TypeError(f"serialized {version} of {classname} > {__version__}") if isinstance(data, int): - return pendulum.tz.fixed_timezone(data) + return fixed_timezone(data) - return pendulum.tz.timezone(data) + return timezone(data) diff --git a/tests/serialization/test_serde.py b/tests/serialization/test_serde.py index 475525cf104b9..03aea4f7b8697 100644 --- a/tests/serialization/test_serde.py +++ b/tests/serialization/test_serde.py @@ -19,6 +19,7 @@ import datetime import enum from dataclasses import dataclass +from importlib import import_module from typing import ClassVar import attr @@ -30,15 +31,20 @@ DATA, SCHEMA_ID, VERSION, - _compile_patterns, + _get_patterns, _match, deserialize, serialize, ) -from airflow.utils.module_loading import qualname +from airflow.utils.module_loading import import_string, iter_namespace, qualname from tests.test_utils.config import conf_vars +@pytest.fixture() +def recalculate_patterns(): + _get_patterns.cache_clear() + + class Z: __version__: ClassVar[int] = 1 @@ -74,12 +80,8 @@ class W: x: int +@pytest.mark.usefixtures("recalculate_patterns") class TestSerDe: - @pytest.fixture(autouse=True) - def ensure_clean_allow_list(self): - _compile_patterns() - yield - def test_ser_primitives(self): i = 10 e = serialize(i) @@ -173,8 +175,8 @@ def test_serder_dataclass(self): ("core", "allowed_deserialization_classes"): "airflow[.].*", } ) + @pytest.mark.usefixtures("recalculate_patterns") def test_allow_list_for_imports(self): - _compile_patterns() i = Z(10) e = serialize(i) with pytest.raises(ImportError) as ex: @@ -187,8 +189,8 @@ def test_allow_list_for_imports(self): ("core", "allowed_deserialization_classes"): "tests.*", } ) + @pytest.mark.usefixtures("recalculate_patterns") def test_allow_list_replace(self): - _compile_patterns() assert _match("tests.airflow.deep") assert _match("testsfault") is False @@ -232,3 +234,17 @@ def test_encode_dataset(self): dataset = Dataset("mytest://dataset") obj = deserialize(serialize(dataset)) assert dataset.uri == obj.uri + + def test_serializers_importable_and_str(self): + """test if all distributed serializers are lazy loading and can be imported""" + import airflow.serialization.serializers + + for _, name, _ in iter_namespace(airflow.serialization.serializers): + mod = import_module(name) + for s in getattr(mod, "serializers", list()): + if not isinstance(s, str): + raise TypeError(f"{s} is not of type str. This is required for lazy loading") + try: + import_string(s) + except ImportError: + raise AttributeError(f"{s} cannot be imported (located in {name})")