Skip to content

Commit

Permalink
Add deferrable mode in RedshiftPauseClusterOperator (#28850)
Browse files Browse the repository at this point in the history
In this PR, I'm adding `aiobotocore` as an additional dependency until aio-libs/aiobotocore#976 has been resolved.

I'm adding `AwsBaseAsyncHook` a basic async AWS  hook. This will at present support the default `botocore` auth i'e if airflow connection is not provided then auth using ENV and if airflow connection is provided then basic auth with secret-key/access-key-id/profile/token and arn-method. maybe we can support the other auth incrementally depending on the community interest.

Because the dependency making things a little bit complicated so I have created a new test module `deferrable` inside AWS provider tests and I'm keeping the async-related test in this particular module. I have also added a new CI job to test/run the AWS deferable operator tests and I'm ignoring the deferable tests in other CI job test runs.

Add a trigger class which will wait until the Redshift pause request reaches a terminal state i.e paused or fail, We also have retry logic like the sync operator.

This PR donates the following developed RedshiftPauseClusterOperatorAsync` in [astronomer-providers](https://github.com/astronomer/astronomer-providers) repo to apache airflow.

GitOrigin-RevId: cf77c3b96609aa8c260566274d54b06eb38c8100
  • Loading branch information
pankajastro authored and Cloud Composer Team committed Jun 12, 2023
1 parent 83323ae commit 67ab460
Show file tree
Hide file tree
Showing 20 changed files with 821 additions and 18 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,30 @@ jobs:
run: breeze ci fix-ownership
if: always()

tests-aws-async-provider:
timeout-minutes: 50
name: "Pytest for AWS Async Provider"
runs-on: "${{needs.build-info.outputs.runs-on}}"
needs: [build-info, wait-for-ci-images]
if: needs.build-info.outputs.run-tests == 'true'
steps:
- name: Cleanup repo
shell: bash
run: docker run -v "${GITHUB_WORKSPACE}:/workspace" -u 0:0 bash -c "rm -rf /workspace/*"
- name: "Checkout ${{ github.ref }} ( ${{ github.sha }} )"
uses: actions/checkout@v3
with:
persist-credentials: false
- name: "Prepare breeze & CI image"
uses: ./.github/actions/prepare_breeze_and_image
- name: "Run AWS Async Test"
run: "breeze shell \
'pip install aiobotocore>=2.1.1 && pytest /opt/airflow/tests/providers/amazon/aws/deferrable'"
- name: "Post Tests"
uses: ./.github/actions/post_tests
- name: "Fix ownership"
run: breeze ci fix-ownership
if: always()

tests-helm:
timeout-minutes: 80
Expand Down
131 changes: 129 additions & 2 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@

from airflow.compat.functools import cached_property
from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.exceptions import (
AirflowException,
AirflowNotFoundException,
)
from airflow.hooks.base import BaseHook
from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter
Expand All @@ -62,7 +65,6 @@

BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource])


if TYPE_CHECKING:
from airflow.models.connection import Connection # Avoid circular imports.

Expand Down Expand Up @@ -877,3 +879,128 @@ def _parse_s3_config(config_file_name: str, config_format: str | None = "boto",
config_format=config_format,
profile=profile,
)


try:
import aiobotocore.credentials
from aiobotocore.session import AioSession, get_session
except ImportError:
pass


class BaseAsyncSessionFactory(BaseSessionFactory):
"""
Base AWS Session Factory class to handle aiobotocore session creation.
It currently, handles ENV, AWS secret key and STS client method ``assume_role``
provided in Airflow connection
"""

async def get_role_credentials(self) -> dict:
"""Get the role_arn, method credentials from connection details and get the role credentials detail"""
async with self._basic_session.create_client("sts", region_name=self.region_name) as client:
response = await client.assume_role(
RoleArn=self.role_arn,
RoleSessionName=self._strip_invalid_session_name_characters(f"Airflow_{self.conn.conn_id}"),
**self.conn.assume_role_kwargs,
)
return response["Credentials"]

async def _get_refresh_credentials(self) -> dict[str, Any]:
self.log.debug("Refreshing credentials")
assume_role_method = self.conn.assume_role_method
if assume_role_method != "assume_role":
raise NotImplementedError(f"assume_role_method={assume_role_method} not expected")

credentials = await self.get_role_credentials()

expiry_time = credentials["Expiration"].isoformat()
self.log.debug("New credentials expiry_time: %s", expiry_time)
credentials = {
"access_key": credentials.get("AccessKeyId"),
"secret_key": credentials.get("SecretAccessKey"),
"token": credentials.get("SessionToken"),
"expiry_time": expiry_time,
}
return credentials

def _get_session_with_assume_role(self) -> AioSession:

assume_role_method = self.conn.assume_role_method
if assume_role_method != "assume_role":
raise NotImplementedError(f"assume_role_method={assume_role_method} not expected")

credentials = aiobotocore.credentials.AioRefreshableCredentials.create_from_metadata(
metadata=self._get_refresh_credentials(),
refresh_using=self._get_refresh_credentials,
method="sts-assume-role",
)

session = aiobotocore.session.get_session()
session._credentials = credentials
return session

@cached_property
def _basic_session(self) -> AioSession:
"""Cached property with basic aiobotocore.session.AioSession."""
session_kwargs = self.conn.session_kwargs
aws_access_key_id = session_kwargs.get("aws_access_key_id")
aws_secret_access_key = session_kwargs.get("aws_secret_access_key")
aws_session_token = session_kwargs.get("aws_session_token")
region_name = session_kwargs.get("region_name")
profile_name = session_kwargs.get("profile_name")

aio_session = get_session()
if profile_name is not None:
aio_session.set_config_variable("profile", profile_name)
if aws_access_key_id or aws_secret_access_key or aws_session_token:
aio_session.set_credentials(
access_key=aws_access_key_id,
secret_key=aws_secret_access_key,
token=aws_session_token,
)
if region_name is not None:
aio_session.set_config_variable("region", region_name)
return aio_session

def create_session(self) -> AioSession:
"""Create aiobotocore Session from connection and config."""
if not self._conn:
self.log.info("No connection ID provided. Fallback on boto3 credential strategy")
return get_session()
elif not self.role_arn:
return self._basic_session
return self._get_session_with_assume_role()


class AwsBaseAsyncHook(AwsBaseHook):
"""
Interacts with AWS using aiobotocore asynchronously.
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default botocore behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default botocore configuration would be used (and must be
maintained on each worker node).
:param verify: Whether to verify SSL certificates.
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param client_type: boto3.client client_type. Eg 's3', 'emr' etc
:param resource_type: boto3.resource resource_type. Eg 'dynamodb' etc
:param config: Configuration for botocore client.
"""

def get_async_session(self) -> AioSession:
"""Get the underlying aiobotocore.session.AioSession(...)."""
return BaseAsyncSessionFactory(
conn=self.conn_config, region_name=self.region_name, config=self.config
).create_session()

async def get_client_async(self):
"""Get the underlying aiobotocore client using aiobotocore session"""
return self.get_async_session().create_client(
self.client_type,
region_name=self.region_name,
verify=self.verify,
endpoint_url=self.conn_config.endpoint_url,
config=self.config,
)
83 changes: 82 additions & 1 deletion airflow/providers/amazon/aws/hooks/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
# under the License.
from __future__ import annotations

import asyncio
import warnings
from typing import Any, Sequence

import botocore.exceptions
from botocore.exceptions import ClientError

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseAsyncHook, AwsBaseHook


class RedshiftHook(AwsBaseHook):
Expand Down Expand Up @@ -200,3 +202,82 @@ def get_cluster_snapshot_status(self, snapshot_identifier: str, cluster_identifi
return snapshot_status
except self.get_conn().exceptions.ClusterSnapshotNotFoundFault:
return None


class RedshiftAsyncHook(AwsBaseAsyncHook):
"""Interact with AWS Redshift using aiobotocore library"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["client_type"] = "redshift"
super().__init__(*args, **kwargs)

async def cluster_status(self, cluster_identifier: str, delete_operation: bool = False) -> dict[str, Any]:
"""
Connects to the AWS redshift cluster via aiobotocore and get the status
and returns the status of the cluster based on the cluster_identifier passed
:param cluster_identifier: unique identifier of a cluster
:param delete_operation: whether the method has been called as part of delete cluster operation
"""
async with await self.get_client_async() as client:
try:
response = await client.describe_clusters(ClusterIdentifier=cluster_identifier)
cluster_state = (
response["Clusters"][0]["ClusterStatus"] if response and response["Clusters"] else None
)
return {"status": "success", "cluster_state": cluster_state}
except botocore.exceptions.ClientError as error:
if delete_operation and error.response.get("Error", {}).get("Code", "") == "ClusterNotFound":
return {"status": "success", "cluster_state": "cluster_not_found"}
return {"status": "error", "message": str(error)}

async def pause_cluster(self, cluster_identifier: str, poll_interval: float = 5.0) -> dict[str, Any]:
"""
Connects to the AWS redshift cluster via aiobotocore and
pause the cluster based on the cluster_identifier passed
:param cluster_identifier: unique identifier of a cluster
:param poll_interval: polling period in seconds to check for the status
"""
try:
async with await self.get_client_async() as client:
response = await client.pause_cluster(ClusterIdentifier=cluster_identifier)
status = response["Cluster"]["ClusterStatus"] if response and response["Cluster"] else None
if status == "pausing":
flag = asyncio.Event()
while True:
expected_response = await asyncio.create_task(
self.get_cluster_status(cluster_identifier, "paused", flag)
)
await asyncio.sleep(poll_interval)
if flag.is_set():
return expected_response
return {"status": "error", "cluster_state": status}
except botocore.exceptions.ClientError as error:
return {"status": "error", "message": str(error)}

async def get_cluster_status(
self,
cluster_identifier: str,
expected_state: str,
flag: asyncio.Event,
delete_operation: bool = False,
) -> dict[str, Any]:
"""
check for expected Redshift cluster state
:param cluster_identifier: unique identifier of a cluster
:param expected_state: expected_state example("available", "pausing", "paused"")
:param flag: asyncio even flag set true if success and if any error
:param delete_operation: whether the method has been called as part of delete cluster operation
"""
try:
response = await self.cluster_status(cluster_identifier, delete_operation=delete_operation)
if ("cluster_state" in response and response["cluster_state"] == expected_state) or response[
"status"
] == "error":
flag.set()
return response
except botocore.exceptions.ClientError as error:
flag.set()
return {"status": "error", "message": str(error)}
60 changes: 48 additions & 12 deletions airflow/providers/amazon/aws/operators/redshift_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -447,6 +448,7 @@ class RedshiftPauseClusterOperator(BaseOperator):
:param cluster_identifier: id of the AWS Redshift Cluster
:param aws_conn_id: aws connection to use
:param deferrable: Run operator in the deferrable mode. This mode requires an additional aiobotocore>=
"""

template_fields: Sequence[str] = ("cluster_identifier",)
Expand All @@ -458,11 +460,15 @@ def __init__(
*,
cluster_identifier: str,
aws_conn_id: str = "aws_default",
deferrable: bool = False,
poll_interval: int = 10,
**kwargs,
):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
self.deferrable = deferrable
self.poll_interval = poll_interval
# These parameters are added to address an issue with the boto3 API where the API
# prematurely reports the cluster as available to receive requests. This causes the cluster
# to reject initial attempts to pause the cluster despite reporting the correct state.
Expand All @@ -472,18 +478,48 @@ def __init__(
def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)

while self._attempts >= 1:
try:
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
return
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._attempts = self._attempts - 1

if self._attempts > 0:
self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
raise error
if self.deferrable:
self.defer(
timeout=self.execution_timeout,
trigger=RedshiftClusterTrigger(
task_id=self.task_id,
poll_interval=self.poll_interval,
aws_conn_id=self.aws_conn_id,
cluster_identifier=self.cluster_identifier,
attempts=self._attempts,
operation_type="pause_cluster",
),
method_name="execute_complete",
)
else:
while self._attempts >= 1:
try:
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
return
except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
self._attempts = self._attempts - 1

if self._attempts > 0:
self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts)
time.sleep(self._attempt_interval)
else:
raise error

def execute_complete(self, context: Context, event: Any = None) -> None:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
if event:
if "status" in event and event["status"] == "error":
msg = f"{event['status']}: {event['message']}"
raise AirflowException(msg)
elif "status" in event and event["status"] == "success":
self.log.info("%s completed successfully.", self.task_id)
self.log.info("Paused cluster successfully")
else:
raise AirflowException("No event received from trigger")


class RedshiftDeleteClusterOperator(BaseOperator):
Expand Down
17 changes: 17 additions & 0 deletions airflow/providers/amazon/aws/triggers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Loading

0 comments on commit 67ab460

Please sign in to comment.