Skip to content

Commit

Permalink
Lazy load serialization modules (apache#30094)
Browse files Browse the repository at this point in the history
Currently all the serde modules are loaded on startup
which are notriously slow, even if they are never used.

This changes the behavior to lazy load and
adds functionality to the serde module to report how much time
serialization classes take to load and caches the regexp pattern
for matching against the allow list.

---------

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
  • Loading branch information
bolkedebruin and uranusjr authored Mar 17, 2023
1 parent 9a417a5 commit b749f7f
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 104 deletions.
56 changes: 28 additions & 28 deletions airflow/serialization/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import dataclasses
import enum
import functools
import logging
import re
import sys
Expand All @@ -30,6 +31,7 @@

import airflow.serialization.serializers
from airflow.configuration import conf
from airflow.stats import Stats
from airflow.utils.module_loading import import_string, iter_namespace, qualname

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -58,7 +60,6 @@

_primitives = (int, bool, float, str)
_builtin_collections = (frozenset, list, set, tuple) # dict is treated specially.
_patterns: list[re.Pattern] = []


def encode(cls: str, version: int, data: T) -> dict[str, str | int | T]:
Expand Down Expand Up @@ -253,7 +254,7 @@ def _convert(old: dict) -> dict:


def _match(classname: str) -> bool:
return any(p.match(classname) is not None for p in _patterns)
return any(p.match(classname) is not None for p in _get_patterns())


def _stringify(classname: str, version: int, value: T | None) -> str:
Expand All @@ -275,33 +276,32 @@ def _register():
_serializers.clear()
_deserializers.clear()

for _, name, _ in iter_namespace(airflow.serialization.serializers):
name = import_module(name)
for s in getattr(name, "serializers", list()):
if not isinstance(s, str):
s = qualname(s)
if s in _serializers and _serializers[s] != name:
raise AttributeError(f"duplicate {s} for serialization in {name} and {_serializers[s]}")
log.debug("registering %s for serialization")
_serializers[s] = name
for d in getattr(name, "deserializers", list()):
if not isinstance(d, str):
d = qualname(d)
if d in _deserializers and _deserializers[d] != name:
raise AttributeError(f"duplicate {d} for deserialization in {name} and {_serializers[d]}")
log.debug("registering %s for deserialization", d)
_deserializers[d] = name
_extra_allowed.add(d)


def _compile_patterns():
with Stats.timer("serde.load_serializers") as timer:
for _, name, _ in iter_namespace(airflow.serialization.serializers):
name = import_module(name)
for s in getattr(name, "serializers", list()):
if not isinstance(s, str):
s = qualname(s)
if s in _serializers and _serializers[s] != name:
raise AttributeError(f"duplicate {s} for serialization in {name} and {_serializers[s]}")
log.debug("registering %s for serialization", s)
_serializers[s] = name
for d in getattr(name, "deserializers", list()):
if not isinstance(d, str):
d = qualname(d)
if d in _deserializers and _deserializers[d] != name:
raise AttributeError(f"duplicate {d} for deserialization in {name} and {_serializers[d]}")
log.debug("registering %s for deserialization", d)
_deserializers[d] = name
_extra_allowed.add(d)

log.info("loading serializers took %.3f seconds", timer.duration)


@functools.lru_cache(maxsize=None)
def _get_patterns() -> list[re.Pattern]:
patterns = conf.get("core", "allowed_deserialization_classes").split()

_patterns.clear() # ensure to reinit
for p in patterns:
p = re.sub(r"(\w)\.", r"\1\..", p)
_patterns.append(re.compile(p))
return [re.compile(re.sub(r"(\w)\.", r"\1\..", p)) for p in patterns]


_register()
_compile_patterns()
11 changes: 8 additions & 3 deletions airflow/serialization/serializers/bignum.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,25 @@
# under the License.
from __future__ import annotations

from decimal import Decimal
from typing import TYPE_CHECKING

from airflow.utils.module_loading import qualname

if TYPE_CHECKING:
import decimal

from airflow.serialization.serde import U


serializers = [Decimal]
serializers = ["decimal.Decimal"]
deserializers = serializers

__version__ = 1


def serialize(o: object) -> tuple[U, str, int, bool]:
from decimal import Decimal

if not isinstance(o, Decimal):
return "", "", 0, False
name = qualname(o)
Expand All @@ -44,7 +47,9 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return float(o), name, __version__, True


def deserialize(classname: str, version: int, data: object) -> Decimal:
def deserialize(classname: str, version: int, data: object) -> decimal.Decimal:
from decimal import Decimal

if version > __version__:
raise TypeError(f"serialized {version} of {classname} > {__version__}")

Expand Down
31 changes: 18 additions & 13 deletions airflow/serialization/serializers/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,29 @@
# under the License.
from __future__ import annotations

from datetime import date, datetime, timedelta
from typing import TYPE_CHECKING

from pendulum import DateTime
from pendulum.tz import timezone

from airflow.utils.module_loading import qualname
from airflow.utils.timezone import convert_to_utc, is_naive

if TYPE_CHECKING:
import datetime

from airflow.serialization.serde import U

__version__ = 1

serializers = [date, datetime, timedelta, DateTime]
serializers = ["datetime.date", "datetime.datetime", "datetime.timedelta", "pendulum.datetime.DateTime"]
deserializers = serializers

TIMESTAMP = "timestamp"
TIMEZONE = "tz"


def serialize(o: object) -> tuple[U, str, int, bool]:
if isinstance(o, DateTime) or isinstance(o, datetime):
from datetime import date, datetime, timedelta

if isinstance(o, datetime):
qn = qualname(o)
if is_naive(o):
o = convert_to_utc(o)
Expand All @@ -57,17 +57,22 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
return "", "", 0, False


def deserialize(classname: str, version: int, data: dict | str) -> datetime | timedelta | date:
if classname == qualname(datetime) and isinstance(data, dict):
return datetime.fromtimestamp(float(data[TIMESTAMP]), tz=timezone(data[TIMEZONE]))
def deserialize(classname: str, version: int, data: dict | str) -> datetime.date | datetime.timedelta:
import datetime

from pendulum import DateTime
from pendulum.tz import timezone

if classname == qualname(datetime.datetime) and isinstance(data, dict):
return datetime.datetime.fromtimestamp(float(data[TIMESTAMP]), tz=timezone(data[TIMEZONE]))

if classname == qualname(DateTime) and isinstance(data, dict):
return DateTime.fromtimestamp(float(data[TIMESTAMP]), tz=timezone(data[TIMEZONE]))

if classname == qualname(timedelta) and isinstance(data, (str, float)):
return timedelta(seconds=float(data))
if classname == qualname(datetime.timedelta) and isinstance(data, (str, float)):
return datetime.timedelta(seconds=float(data))

if classname == qualname(date) and isinstance(data, str):
return date.fromisoformat(data)
if classname == qualname(datetime.date) and isinstance(data, str):
return datetime.date.fromisoformat(data)

raise TypeError(f"unknown date/time format {classname}")
16 changes: 7 additions & 9 deletions airflow/serialization/serializers/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,24 @@

from airflow.utils.module_loading import qualname

serializers = []

try:
from kubernetes.client import models as k8s

serializers = [k8s.v1_pod.V1Pod, k8s.V1ResourceRequirements]
except ImportError:
k8s = None
# lazy loading for performance reasons
serializers = [
"kubernetes.client.models.v1_resource_requirements.V1ResourceRequirements",
"kubernetes.client.models.v1_pod.V1Pod",
]

if TYPE_CHECKING:
from airflow.serialization.serde import U


__version__ = 1

deserializers: list[type[object]] = []
log = logging.getLogger(__name__)


def serialize(o: object) -> tuple[U, str, int, bool]:
from kubernetes.client import models as k8s

if not k8s:
return "", "", 0, False

Expand Down
60 changes: 24 additions & 36 deletions airflow/serialization/serializers/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,47 +19,36 @@

from typing import TYPE_CHECKING, Any

from airflow.utils.module_loading import qualname

serializers = []

try:
import numpy as np

serializers = [
np.int_,
np.intc,
np.intp,
np.int8,
np.int16,
np.int32,
np.int64,
np.uint8,
np.uint16,
np.uint32,
np.uint64,
np.bool_,
np.float_,
np.float16,
np.float64,
np.complex_,
np.complex64,
np.complex128,
]
except ImportError:
np = None # type: ignore

from airflow.utils.module_loading import import_string, qualname

# lazy loading for performance reasons
serializers = [
"numpy.int8",
"numpy.int16",
"numpy.int32",
"numpy.int64",
"numpy.uint8",
"numpy.uint16",
"numpy.uint32",
"numpy.uint64",
"numpy.bool_",
"numpy.float64",
"numpy.float16",
"numpy.complex128",
"numpy.complex64",
]

if TYPE_CHECKING:
from airflow.serialization.serde import U

deserializers: list = serializers
_deserializers: dict[str, type[object]] = {qualname(x): x for x in deserializers}
deserializers = serializers

__version__ = 1


def serialize(o: object) -> tuple[U, str, int, bool]:
import numpy as np

if np is None:
return "", "", 0, False

Expand Down Expand Up @@ -97,8 +86,7 @@ def deserialize(classname: str, version: int, data: str) -> Any:
if version > __version__:
raise TypeError("serialized version is newer than class version")

f = _deserializers.get(classname, None)
if callable(f):
return f(data) # type: ignore [call-arg]
if classname not in deserializers:
raise TypeError(f"unsupported {classname} found for numpy deserialization")

raise TypeError(f"unsupported {classname} found for numpy deserialization")
return import_string(classname)(data)
15 changes: 9 additions & 6 deletions airflow/serialization/serializers/timezone.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,15 @@

from typing import TYPE_CHECKING

import pendulum
from pendulum.tz.timezone import FixedTimezone, Timezone

from airflow.utils.module_loading import qualname

if TYPE_CHECKING:
from pendulum.tz.timezone import Timezone

from airflow.serialization.serde import U


serializers = [FixedTimezone, Timezone]
serializers = ["pendulum.tz.timezone.FixedTimezone", "pendulum.tz.timezone.Timezone"]
deserializers = serializers

__version__ = 1
Expand All @@ -44,6 +43,8 @@ def serialize(o: object) -> tuple[U, str, int, bool]:
0 without the special case), but passing 0 into ``pendulum.timezone`` does
not give us UTC (but ``+00:00``).
"""
from pendulum.tz.timezone import FixedTimezone, Timezone

name = qualname(o)
if isinstance(o, FixedTimezone):
if o.offset == 0:
Expand All @@ -57,13 +58,15 @@ def serialize(o: object) -> tuple[U, str, int, bool]:


def deserialize(classname: str, version: int, data: object) -> Timezone:
from pendulum.tz import fixed_timezone, timezone

if not isinstance(data, (str, int)):
raise TypeError(f"{data} is not of type int or str but of {type(data)}")

if version > __version__:
raise TypeError(f"serialized {version} of {classname} > {__version__}")

if isinstance(data, int):
return pendulum.tz.fixed_timezone(data)
return fixed_timezone(data)

return pendulum.tz.timezone(data)
return timezone(data)
Loading

0 comments on commit b749f7f

Please sign in to comment.