Skip to content

Commit

Permalink
Support driver and executor volumes in Airflow operator
Browse files Browse the repository at this point in the history
  • Loading branch information
hussein-awala committed Feb 18, 2024
1 parent 39ce270 commit 7773688
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
13 changes: 13 additions & 0 deletions spark_on_k8s/airflow/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Literal

import jinja2
from kubernetes import client as k8s

from airflow.utils.context import Context
from spark_on_k8s.client import ExecutorInstances, PodResources
Expand Down Expand Up @@ -61,6 +62,9 @@ class SparkOnK8SOperator(BaseOperator):
executor_instances (ExecutorInstances, optional): Executor instances. Defaults to None.
secret_values (dict[str, str], optional): Dictionary of secret values to pass to the application
as environment variables. Defaults to None.
volumes: List of volumes to mount to the driver and/or executors.
driver_volume_mounts: List of volume mounts to mount to the driver.
executor_volume_mounts: List of volume mounts to mount to the executors.
kubernetes_conn_id (str, optional): Kubernetes connection ID. Defaults to
"kubernetes_default".
poll_interval (int, optional): Poll interval for checking the Spark application status.
Expand Down Expand Up @@ -109,6 +113,9 @@ def __init__(
executor_resources: PodResources | None = None,
executor_instances: ExecutorInstances | None = None,
secret_values: dict[str, str] | None = None,
volumes: list[k8s.V1Volume] | None = None,
driver_volume_mounts: list[k8s.V1VolumeMount] | None = None,
executor_volume_mounts: list[k8s.V1VolumeMount] | None = None,
kubernetes_conn_id: str = "kubernetes_default",
poll_interval: int = 10,
deferrable: bool = False,
Expand All @@ -131,6 +138,9 @@ def __init__(
self.executor_resources = executor_resources
self.executor_instances = executor_instances
self.secret_values = secret_values
self.volumes = volumes
self.driver_volume_mounts = driver_volume_mounts
self.executor_volume_mounts = executor_volume_mounts
self.kubernetes_conn_id = kubernetes_conn_id
self.poll_interval = poll_interval
self.deferrable = deferrable
Expand Down Expand Up @@ -217,6 +227,9 @@ def execute(self, context):
executor_resources=self.executor_resources,
executor_instances=self.executor_instances,
secret_values=self.secret_values,
volumes=self.volumes,
driver_volume_mounts=self.driver_volume_mounts,
executor_volume_mounts=self.executor_volume_mounts,
)
if self.app_waiter == "no_wait":
return
Expand Down
6 changes: 3 additions & 3 deletions spark_on_k8s/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,11 @@ def submit_app(
driver_env_vars_from_secrets = Configuration.SPARK_ON_K8S_DRIVER_ENV_VARS_FROM_SECRET
if driver_env_vars_from_secrets:
env_from_secrets.extend(driver_env_vars_from_secrets)
if volumes is NOTSET:
if volumes is NOTSET or volumes is None:
volumes = []
if driver_volume_mounts is NOTSET:
if driver_volume_mounts is NOTSET or driver_volume_mounts is None:
driver_volume_mounts = []
if executor_volume_mounts is NOTSET:
if executor_volume_mounts is NOTSET or executor_volume_mounts is None:
executor_volume_mounts = []

spark_conf = spark_conf or {}
Expand Down
6 changes: 6 additions & 0 deletions tests/airflow/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def test_execute(self, mock_submit_app):
spark_conf=None,
class_name=None,
secret_values=None,
volumes=None,
driver_volume_mounts=None,
executor_volume_mounts=None,
)

@mock.patch("spark_on_k8s.client.SparkOnK8S.submit_app")
Expand Down Expand Up @@ -125,4 +128,7 @@ def test_rendering_templates(self, mock_submit_app):
"KEY2": "value from connection",
},
class_name=None,
volumes=None,
driver_volume_mounts=None,
executor_volume_mounts=None,
)

0 comments on commit 7773688

Please sign in to comment.