Skip to content

Commit

Permalink
Add general-purpose "notifier" concept to DAGs (#28569)
Browse files Browse the repository at this point in the history
* Add general-purpose "notifier" concept to DAGs

This makes it easy for users to setup notifications for their DAGs using on_*_callbacks
It's extensible and we can add it to more providers. Implemented a SlackNotifier in this
phase.

In the course of this, I extracted a 'Templater' class from AbstractBaseOperator
and have both the Notifier & ABO inherit from it. This is necessary in other to avoid
code repetition.

* Renames and a fixup not to require a call to super in subclasses

* Raise compat exception and add docs

* Ignore import error due to optional provider feature

* fixup! Ignore import error due to optional provider feature

* Apply suggestions from code review

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>

* fixup! Apply suggestions from code review

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
  • Loading branch information
ephraimbuddy and uranusjr authored Jan 4, 2023
1 parent 24af35b commit a7e1cb2
Show file tree
Hide file tree
Showing 18 changed files with 790 additions and 132 deletions.
146 changes: 14 additions & 132 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@
from airflow.exceptions import AirflowException
from airflow.models.expandinput import NotFullyPopulated
from airflow.models.taskmixin import DAGNode
from airflow.template.templater import Templater
from airflow.utils.context import Context
from airflow.utils.helpers import render_template_as_native, render_template_to_string
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.mixins import ResolveMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import skip_locked, with_row_locks
from airflow.utils.state import State, TaskInstanceState
Expand Down Expand Up @@ -76,7 +74,7 @@ class NotMapped(Exception):
"""Raise if a task is neither mapped nor has any parent mapped groups."""


class AbstractOperator(LoggingMixin, DAGNode):
class AbstractOperator(Templater, DAGNode):
"""Common implementation for operators, including unmapped and mapped.
This base class is more about sharing implementations, not defining a common
Expand All @@ -96,10 +94,6 @@ class AbstractOperator(LoggingMixin, DAGNode):

# Defines the operator level extra links.
operator_extra_links: Collection[BaseOperatorLink]
# For derived classes to define which fields will get jinjaified.
template_fields: Collection[str]
# Defines which files extensions to look for in the templated fields.
template_ext: Sequence[str]

owner: str
task_id: str
Expand Down Expand Up @@ -153,48 +147,6 @@ def dag_id(self) -> str:
def node_id(self) -> str:
return self.task_id

def get_template_env(self) -> jinja2.Environment:
"""Fetch a Jinja template environment from the DAG or instantiate empty environment if no DAG."""
# This is imported locally since Jinja2 is heavy and we don't need it
# for most of the functionalities. It is imported by get_template_env()
# though, so we don't need to put this after the 'if dag' check.
from airflow.templates import SandboxedEnvironment

dag = self.get_dag()
if dag:
return dag.get_template_env(force_sandboxed=False)
return SandboxedEnvironment(cache_size=0)

def prepare_template(self) -> None:
"""Hook triggered after the templated fields get replaced by their content.
If you need your operator to alter the content of the file before the
template is rendered, it should override this method to do so.
"""

def resolve_template_files(self) -> None:
"""Getting the content of files for template_field / template_ext."""
if self.template_ext:
for field in self.template_fields:
content = getattr(self, field, None)
if content is None:
continue
elif isinstance(content, str) and any(content.endswith(ext) for ext in self.template_ext):
env = self.get_template_env()
try:
setattr(self, field, env.loader.get_source(env, content)[0]) # type: ignore
except Exception:
self.log.exception("Failed to resolve template field %r", field)
elif isinstance(content, list):
env = self.get_template_env()
for i, item in enumerate(content):
if isinstance(item, str) and any(item.endswith(ext) for ext in self.template_ext):
try:
content[i] = env.loader.get_source(env, item)[0] # type: ignore
except Exception:
self.log.exception("Failed to get source %s", item)
self.prepare_template()

def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
"""Get direct relative IDs to the current task, upstream or downstream."""
if upstream:
Expand Down Expand Up @@ -580,6 +532,17 @@ def render_template_fields(
"""
raise NotImplementedError()

def _render(self, template, context, dag: DAG | None = None):
if dag is None:
dag = self.get_dag()
return super()._render(template, context, dag=dag)

def get_template_env(self, dag: DAG | None = None) -> jinja2.Environment:
"""Get the template environment for rendering templates."""
if dag is None:
dag = self.get_dag()
return super().get_template_env(dag=dag)

@provide_session
def _do_render_template_fields(
self,
Expand All @@ -591,6 +554,7 @@ def _do_render_template_fields(
*,
session: Session = NEW_SESSION,
) -> None:
"""Override the base to use custom error logging."""
for attr_name in template_fields:
try:
value = getattr(parent, attr_name)
Expand Down Expand Up @@ -618,85 +582,3 @@ def _do_render_template_fields(
raise
else:
setattr(parent, attr_name, rendered_content)

def render_template(
self,
content: Any,
context: Context,
jinja_env: jinja2.Environment | None = None,
seen_oids: set[int] | None = None,
) -> Any:
"""Render a templated string.
If *content* is a collection holding multiple templated strings, strings
in the collection will be templated recursively.
:param content: Content to template. Only strings can be templated (may
be inside a collection).
:param context: Dict with values to apply on templated content
:param jinja_env: Jinja environment. Can be provided to avoid
re-creating Jinja environments during recursion.
:param seen_oids: template fields already rendered (to avoid
*RecursionError* on circular dependencies)
:return: Templated content
"""
# "content" is a bad name, but we're stuck to it being public API.
value = content
del content

if seen_oids is not None:
oids = seen_oids
else:
oids = set()

if id(value) in oids:
return value

if not jinja_env:
jinja_env = self.get_template_env()

if isinstance(value, str):
if any(value.endswith(ext) for ext in self.template_ext): # A filepath.
template = jinja_env.get_template(value)
else:
template = jinja_env.from_string(value)
dag = self.get_dag()
if dag and dag.render_template_as_native_obj:
return render_template_as_native(template, context)
return render_template_to_string(template, context)

if isinstance(value, ResolveMixin):
return value.resolve(context)

# Fast path for common built-in collections.
if value.__class__ is tuple:
return tuple(self.render_template(element, context, jinja_env, oids) for element in value)
elif isinstance(value, tuple): # Special case for named tuples.
return value.__class__(*(self.render_template(el, context, jinja_env, oids) for el in value))
elif isinstance(value, list):
return [self.render_template(element, context, jinja_env, oids) for element in value]
elif isinstance(value, dict):
return {k: self.render_template(v, context, jinja_env, oids) for k, v in value.items()}
elif isinstance(value, set):
return {self.render_template(element, context, jinja_env, oids) for element in value}

# More complex collections.
self._render_nested_template_fields(value, context, jinja_env, oids)
return value

def _render_nested_template_fields(
self,
value: Any,
context: Context,
jinja_env: jinja2.Environment,
seen_oids: set[int],
) -> None:
if id(value) in seen_oids:
return
seen_oids.add(id(value))
try:
nested_template_fields = value.template_fields
except AttributeError:
# content has no inner template fields
return
self._do_render_template_fields(value, nested_template_fields, context, jinja_env, seen_oids)
16 changes: 16 additions & 0 deletions airflow/notifications/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
93 changes: 93 additions & 0 deletions airflow/notifications/basenotifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from abc import abstractmethod
from typing import TYPE_CHECKING, Sequence

import jinja2

from airflow.template.templater import Templater
from airflow.utils.context import Context, context_merge

if TYPE_CHECKING:
from airflow import DAG


class BaseNotifier(Templater):
"""BaseNotifier class for sending notifications"""

template_fields: Sequence[str] = ()
template_ext: Sequence[str] = ()

def __init__(self):
super().__init__()
self.resolve_template_files()

def _update_context(self, context: Context) -> Context:
"""
Add additional context to the context
:param context: The airflow context
"""
additional_context = ((f, getattr(self, f)) for f in self.template_fields)
context_merge(context, additional_context)
return context

def _render(self, template, context, dag: DAG | None = None):
dag = dag or context["dag"]
return super()._render(template, context, dag)

def render_template_fields(
self,
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> None:
"""Template all attributes listed in *self.template_fields*.
This mutates the attributes in-place and is irreversible.
:param context: Context dict with values to apply on content.
:param jinja_env: Jinja environment to use for rendering.
"""
dag = context["dag"]
if not jinja_env:
jinja_env = self.get_template_env(dag=dag)
self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())

@abstractmethod
def notify(self, context: Context) -> None:
"""
Sends a notification
:param context: The airflow context
"""
...

def __call__(self, context: Context) -> None:
"""
Send a notification
:param context: The airflow context
"""
context = self._update_context(context)
self.render_template_fields(context)
try:
self.notify(context)
except Exception as e:
self.log.exception("Failed to send notification: %s", e)
1 change: 1 addition & 0 deletions airflow/providers/slack/CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Features
~~~~~~~~

* ``Implements SqlToSlackApiFileOperator (#26374)``
* ``Added SlackNotifier (#28569)``

Bug Fixes
~~~~~~~~~
Expand Down
16 changes: 16 additions & 0 deletions airflow/providers/slack/notifications/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading

0 comments on commit a7e1cb2

Please sign in to comment.