diff --git a/CHANGES.md b/CHANGES.md index 57d729ac5f5d..1b0e6ec6f0df 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -77,6 +77,7 @@ * Redis cache support added to RequestResponseIO and Enrichment transform (Python) ([#30307](https://github.com/apache/beam/pull/30307)) * Merged sdks/java/fn-execution and runners/core-construction-java into the main SDK. These artifacts were never meant for users, but noting that they no longer exist. These are steps to bring portability into the core SDK alongside all other core functionality. +* Added Vertex AI Feature Store handler for Enrichment transform (Python) ([#30388](https://github.com/apache/beam/pull/30388)) ## Breaking Changes diff --git a/sdks/python/apache_beam/transforms/enrichment.py b/sdks/python/apache_beam/transforms/enrichment.py index 93344835e930..ddfbba5337fb 100644 --- a/sdks/python/apache_beam/transforms/enrichment.py +++ b/sdks/python/apache_beam/transforms/enrichment.py @@ -161,8 +161,10 @@ def expand(self, throttler=self._throttler) # EnrichmentSourceHandler returns a tuple of (request,response). - return fetched_data | beam.Map( - lambda x: self._join_fn(x[0]._asdict(), x[1]._asdict())) + return ( + fetched_data + | "enrichment_join" >> + beam.Map(lambda x: self._join_fn(x[0]._asdict(), x[1]._asdict()))) def with_redis_cache( self, diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py index 943000a9f6bb..af35f91a42f3 100644 --- a/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/bigtable.py @@ -15,7 +15,6 @@ # limitations under the License. # import logging -from enum import Enum from typing import Any from typing import Dict from typing import Optional @@ -28,30 +27,15 @@ import apache_beam as beam from apache_beam.transforms.enrichment import EnrichmentSourceHandler +from apache_beam.transforms.enrichment_handlers.utils import ExceptionLevel __all__ = [ 'BigTableEnrichmentHandler', - 'ExceptionLevel', ] _LOGGER = logging.getLogger(__name__) -class ExceptionLevel(Enum): - """ExceptionLevel defines the exception level options to either - log a warning, or raise an exception, or do nothing when a BigTable query - returns an empty row. - - Members: - - RAISE: Raise the exception. - - WARN: Log a warning for exception without raising it. - - QUIET: Neither log nor raise the exception. - """ - RAISE = 0 - WARN = 1 - QUIET = 2 - - class BigTableEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]): """A handler for :class:`apache_beam.transforms.enrichment.Enrichment` transform to interact with GCP BigTable. @@ -70,7 +54,7 @@ class BigTableEnrichmentHandler(EnrichmentSourceHandler[beam.Row, beam.Row]): encoding (str): encoding type to convert the string to bytes and vice-versa from BigTable. Default is `utf-8`. exception_level: a `enum.Enum` value from - ``apache_beam.transforms.enrichment_handlers.bigtable.ExceptionLevel`` + ``apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel`` to set the level when an empty row is returned from the BigTable query. Defaults to ``ExceptionLevel.WARN``. include_timestamp (bool): If enabled, the timestamp associated with the diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/utils.py b/sdks/python/apache_beam/transforms/enrichment_handlers/utils.py new file mode 100644 index 000000000000..c61671402576 --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/utils.py @@ -0,0 +1,38 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from enum import Enum + +__all__ = [ + 'ExceptionLevel', +] + + +class ExceptionLevel(Enum): + """Options to set the severity of exceptions. + + You can set the exception level option to either + log a warning, or raise an exception, or do nothing when an empty row + is fetched from the external service. + + Members: + - RAISE: Raise the exception. + - WARN: Log a warning for exception without raising it. + - QUIET: Neither log nor raise the exception. + """ + RAISE = 0 + WARN = 1 + QUIET = 2 diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py new file mode 100644 index 000000000000..b135739ef59c --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store.py @@ -0,0 +1,306 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +from typing import List + +import proto +from google.api_core.exceptions import NotFound +from google.cloud import aiplatform + +import apache_beam as beam +from apache_beam.transforms.enrichment import EnrichmentSourceHandler +from apache_beam.transforms.enrichment_handlers.utils import ExceptionLevel + +__all__ = [ + 'VertexAIFeatureStoreEnrichmentHandler', + 'VertexAIFeatureStoreLegacyEnrichmentHandler' +] + +_LOGGER = logging.getLogger(__name__) + + +def _not_found_err_message( + feature_store_name: str, feature_view_name: str, entity_id: str) -> str: + """returns a string formatted with given parameters""" + return ( + "make sure the Feature Store: %s with Feature View " + "%s has entity_id: %s" % + (feature_store_name, feature_view_name, entity_id)) + + +class VertexAIFeatureStoreEnrichmentHandler(EnrichmentSourceHandler[beam.Row, + beam.Row]): + """Enrichment handler to interact with Vertex AI Feature Store. + + Use this handler with :class:`apache_beam.transforms.enrichment.Enrichment` + transform when the Vertex AI Feature Store is set up for + Bigtable Online serving. + + With the Bigtable Online serving approach, the client fetches all the + available features for an entity-id. The entity-id is extracted from the + `row_key` field in the input `beam.Row` object. To filter the features to + enrich, use the `join_fn` param in + :class:`apache_beam.transforms.enrichment.Enrichment`. + + **NOTE:** The default severity to report exceptions is logging a warning. For + this handler, Vertex AI client returns the same exception + `Requested entity was not found` even though the feature store doesn't + exist. So make sure the feature store instance exists or set + `exception_level` as `ExceptionLevel.RAISE`. + """ + def __init__( + self, + project: str, + location: str, + api_endpoint: str, + feature_store_name: str, + feature_view_name: str, + row_key: str, + *, + exception_level: ExceptionLevel = ExceptionLevel.WARN, + **kwargs, + ): + """Initializes an instance of `VertexAIFeatureStoreEnrichmentHandler`. + + Args: + project (str): The GCP project-id for the Vertex AI Feature Store. + location (str): The region for the Vertex AI Feature Store. + api_endpoint (str): The API endpoint for the Vertex AI Feature Store. + feature_store_name (str): The name of the Vertex AI Feature Store. + feature_view_name (str): The name of the feature view within the + Feature Store. + row_key (str): The row key field name containing the unique id + for the feature values. + exception_level: a `enum.Enum` value from + `apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel` + to set the level when an empty row is returned from the BigTable query. + Defaults to `ExceptionLevel.WARN`. + kwargs: Optional keyword arguments to configure the + `aiplatform.gapic.FeatureOnlineStoreServiceClient`. + """ + self.project = project + self.location = location + self.api_endpoint = api_endpoint + self.feature_store_name = feature_store_name + self.feature_view_name = feature_view_name + self.row_key = row_key + self.exception_level = exception_level + self.kwargs = kwargs if kwargs else {} + if 'client_options' in self.kwargs: + if not self.kwargs['client_options']['api_endpoint']: + self.kwargs['client_options']['api_endpoint'] = self.api_endpoint + elif self.kwargs['client_options']['api_endpoint'] != self.api_endpoint: + raise ValueError( + 'Multiple values received for api_endpoint in ' + 'api_endpoint and client_options parameters.') + else: + self.kwargs['client_options'] = {"api_endpoint": self.api_endpoint} + + def __enter__(self): + """Connect with the Vertex AI Feature Store.""" + self.client = aiplatform.gapic.FeatureOnlineStoreServiceClient( + **self.kwargs) + self.feature_view_path = self.client.feature_view_path( + self.project, + self.location, + self.feature_store_name, + self.feature_view_name) + + def __call__(self, request: beam.Row, *args, **kwargs): + """Fetches feature value for an entity-id from Vertex AI Feature Store. + + Args: + request: the input `beam.Row` to enrich. + """ + try: + entity_id = request._asdict()[self.row_key] + except KeyError: + raise KeyError( + "Enrichment requests to Vertex AI Feature Store should " + "contain a field: %s in the input `beam.Row` to join " + "the input with fetched response. This is used as the " + "`FeatureViewDataKey` to fetch feature values " + "corresponding to this key." % self.row_key) + try: + response = self.client.fetch_feature_values( + request=aiplatform.gapic.FetchFeatureValuesRequest( + data_key=aiplatform.gapic.FeatureViewDataKey(key=entity_id), + feature_view=self.feature_view_path, + data_format=aiplatform.gapic.FeatureViewDataFormat.PROTO_STRUCT, + )) + except NotFound: + if self.exception_level == ExceptionLevel.WARN: + _LOGGER.warning( + _not_found_err_message( + self.feature_store_name, self.feature_view_name, entity_id)) + return request, beam.Row() + elif self.exception_level == ExceptionLevel.RAISE: + raise ValueError( + _not_found_err_message( + self.feature_store_name, self.feature_view_name, entity_id)) + response_dict = dict(response.proto_struct) + return request, beam.Row(**response_dict) + + def __exit__(self, exc_type, exc_val, exc_tb): + """Clean the instantiated Vertex AI client.""" + self.client = None + + def get_cache_key(self, request: beam.Row) -> str: + """Returns a string formatted with unique entity-id for the feature values. + """ + return 'entity_id: %s' % request._asdict()[self.row_key] + + +class VertexAIFeatureStoreLegacyEnrichmentHandler(EnrichmentSourceHandler): + """Enrichment handler to interact with Vertex AI Feature Store (Legacy). + + Use this handler with :class:`apache_beam.transforms.enrichment.Enrichment` + transform for the Vertex AI Feature Store (Legacy). + + By default, it fetches all the features values for an entity-id. The + entity-id is extracted from the `row_key` field in the input `beam.Row` + object.You can specify the features names using `feature_ids` to fetch + specific features. + """ + def __init__( + self, + project: str, + location: str, + api_endpoint: str, + feature_store_id: str, + entity_type_id: str, + feature_ids: List[str], + row_key: str, + *, + exception_level: ExceptionLevel = ExceptionLevel.WARN, + **kwargs, + ): + """Initializes an instance of `VertexAIFeatureStoreLegacyEnrichmentHandler`. + + Args: + project (str): The GCP project for the Vertex AI Feature Store (Legacy). + location (str): The region for the Vertex AI Feature Store (Legacy). + api_endpoint (str): The API endpoint for the + Vertex AI Feature Store (Legacy). + feature_store_id (str): The id of the Vertex AI Feature Store (Legacy). + entity_type_id (str): The entity type of the feature store. + feature_ids (List[str]): A list of feature-ids to fetch + from the Feature Store. + row_key (str): The row key field name containing the entity id + for the feature values. + exception_level: a `enum.Enum` value from + `apache_beam.transforms.enrichment_handlers.utils.ExceptionLevel` + to set the level when an empty row is returned from the BigTable query. + Defaults to `ExceptionLevel.WARN`. + kwargs: Optional keyword arguments to configure the + `aiplatform.gapic.FeaturestoreOnlineServingServiceClient`. + """ + self.project = project + self.location = location + self.api_endpoint = api_endpoint + self.feature_store_id = feature_store_id + self.entity_type_id = entity_type_id + self.feature_ids = feature_ids + self.row_key = row_key + self.exception_level = exception_level + self.kwargs = kwargs if kwargs else {} + if 'client_options' in self.kwargs: + if not self.kwargs['client_options']['api_endpoint']: + self.kwargs['client_options']['api_endpoint'] = self.api_endpoint + elif self.kwargs['client_options']['api_endpoint'] != self.api_endpoint: + raise ValueError( + 'Multiple values received for api_endpoint in ' + 'api_endpoint and client_options parameters.') + else: + self.kwargs['client_options'] = {"api_endpoint": self.api_endpoint} + + def __enter__(self): + """Connect with the Vertex AI Feature Store (Legacy).""" + try: + # checks if feature store exists + _ = aiplatform.Featurestore( + featurestore_name=self.feature_store_id, + project=self.project, + location=self.location, + credentials=self.kwargs.get('credentials'), + ) + self.client = aiplatform.gapic.FeaturestoreOnlineServingServiceClient( + **self.kwargs) + self.entity_type_path = self.client.entity_type_path( + self.project, + self.location, + self.feature_store_id, + self.entity_type_id) + except NotFound: + raise ValueError( + 'Vertex AI Feature Store %s does not exist' % self.feature_store_id) + + def __call__(self, request: beam.Row, *args, **kwargs): + """Fetches feature value for an entity-id from + Vertex AI Feature Store (Legacy). + + Args: + request: the input `beam.Row` to enrich. + """ + try: + entity_id = request._asdict()[self.row_key] + except KeyError: + raise KeyError( + "Enrichment requests to Vertex AI Feature Store should " + "contain a field: %s in the input `beam.Row` to join " + "the input with fetched response. This is used as the " + "`FeatureViewDataKey` to fetch feature values " + "corresponding to this key." % self.row_key) + + try: + selector = aiplatform.gapic.FeatureSelector( + id_matcher=aiplatform.gapic.IdMatcher(ids=self.feature_ids)) + response = self.client.read_feature_values( + request=aiplatform.gapic.ReadFeatureValuesRequest( + entity_type=self.entity_type_path, + entity_id=entity_id, + feature_selector=selector)) + except NotFound: + raise ValueError( + _not_found_err_message( + self.feature_store_id, self.entity_type_id, entity_id)) + + response_dict = {} + proto_to_dict = proto.Message.to_dict(response.entity_view) + for key, msg in zip(response.header.feature_descriptors, + proto_to_dict['data']): + if msg and 'value' in msg: + response_dict[key.id] = list(msg['value'].values())[0] + # skip fetching the metadata + elif self.exception_level == ExceptionLevel.RAISE: + raise ValueError( + _not_found_err_message( + self.feature_store_id, self.entity_type_id, entity_id)) + elif self.exception_level == ExceptionLevel.WARN: + _LOGGER.warning( + _not_found_err_message( + self.feature_store_id, self.entity_type_id, entity_id)) + return request, beam.Row(**response_dict) + + def __exit__(self, exc_type, exc_val, exc_tb): + """Clean the instantiated Vertex AI client.""" + self.client = None + + def get_cache_key(self, request: beam.Row) -> str: + """Returns a string formatted with unique entity-id for the feature values. + """ + return 'entity_id: %s' % request._asdict()[self.row_key] diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_it_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_it_test.py new file mode 100644 index 000000000000..d4224be060e9 --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_it_test.py @@ -0,0 +1,287 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import unittest +from unittest.mock import MagicMock + +import pytest + +import apache_beam as beam +from apache_beam.coders import coders +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import BeamAssertException + +# pylint: disable=ungrouped-imports +try: + from testcontainers.redis import RedisContainer + from apache_beam.transforms.enrichment import Enrichment + from apache_beam.transforms.enrichment_handlers.utils import ExceptionLevel + from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import \ + VertexAIFeatureStoreEnrichmentHandler + from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store import \ + VertexAIFeatureStoreLegacyEnrichmentHandler +except ImportError: + raise unittest.SkipTest( + 'VertexAI Feature Store test dependencies ' + 'are not installed.') + +_LOGGER = logging.getLogger(__name__) + + +class ValidateResponse(beam.DoFn): + """ValidateResponse validates if a PCollection of `beam.Row` + has the required fields.""" + def __init__(self, expected_fields): + self.expected_fields = expected_fields + + def process(self, element: beam.Row, *args, **kwargs): + element_dict = element.as_dict() + if len(self.expected_fields) != len(element_dict.keys()): + raise BeamAssertException( + "Expected %d fields in enriched PCollection:" % + len(self.expected_fields)) + for field in self.expected_fields: + if field not in element_dict: + raise BeamAssertException( + f"Expected to fetch field: {field}" + f"from feature store") + + +@pytest.mark.uses_redis +class TestVertexAIFeatureStoreHandler(unittest.TestCase): + def setUp(self) -> None: + self.project = 'apache-beam-testing' + self.location = 'us-central1' + self.feature_store_name = "the_look_demo_unique" + self.feature_view_name = "registry_product" + self.entity_type_name = "entity_id" + self.api_endpoint = "us-central1-aiplatform.googleapis.com" + self.feature_ids = ['title', 'genres'] + + self._start_container() + + def _start_container(self): + for i in range(3): + try: + self.container = RedisContainer(image='redis:7.2.4') + self.container.start() + self.host = self.container.get_container_host_ip() + self.port = self.container.get_exposed_port(6379) + self.client = self.container.get_client() + break + except Exception as e: + if i == self.retries - 1: + _LOGGER.error('Unable to start redis container for RRIO tests.') + raise e + + def tearDown(self) -> None: + self.container.stop() + self.client = None + + def test_vertex_ai_feature_store_bigtable_serving_enrichment(self): + requests = [ + beam.Row(entity_id="847", name='cardigan jacket'), + beam.Row(entity_id="16050", name='stripe t-shirt'), + ] + expected_fields = [ + 'entity_id', + 'bad_order_count', + 'good_order_count', + 'feature_timestamp', + 'category', + 'cost', + 'brand', + 'retail_price', + 'name' + ] + handler = VertexAIFeatureStoreEnrichmentHandler( + project=self.project, + location=self.location, + api_endpoint=self.api_endpoint, + feature_store_name=self.feature_store_name, + feature_view_name=self.feature_view_name, + row_key=self.entity_type_name, + ) + + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | beam.Create(requests) + | Enrichment(handler) + | beam.ParDo(ValidateResponse(expected_fields))) + + def test_vertex_ai_feature_store_bigtable_serving_enrichment_bad(self): + requests = [ + beam.Row(entity_id="ui", name="fred perry men\'s sharp stripe t-shirt") + ] + handler = VertexAIFeatureStoreEnrichmentHandler( + project=self.project, + location=self.location, + api_endpoint=self.api_endpoint, + feature_store_name=self.feature_store_name, + feature_view_name=self.feature_view_name, + row_key=self.entity_type_name, + exception_level=ExceptionLevel.RAISE, + ) + with self.assertRaises(ValueError): + test_pipeline = beam.Pipeline() + _ = ( + test_pipeline + | "Create" >> beam.Create(requests) + | "Enrich w/ VertexAI" >> Enrichment(handler)) + res = test_pipeline.run() + res.wait_until_finish() + + def test_vertex_ai_legacy_feature_store_enrichment(self): + requests = [ + beam.Row(entity_id="movie_02", title="The Shining"), + beam.Row(entity_id="movie_04", title='The Dark Knight'), + ] + expected_fields = ['entity_id', 'title', 'genres'] + feature_store_id = "movie_prediction_unique" + entity_type_id = "movies" + handler = VertexAIFeatureStoreLegacyEnrichmentHandler( + project=self.project, + location=self.location, + api_endpoint=self.api_endpoint, + feature_store_id=feature_store_id, + entity_type_id=entity_type_id, + feature_ids=self.feature_ids, + row_key=self.entity_type_name, + ) + + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | beam.Create(requests) + | Enrichment(handler) + | beam.ParDo(ValidateResponse(expected_fields))) + + def test_vertex_ai_legacy_feature_store_enrichment_bad(self): + requests = [ + beam.Row(entity_id="12345", title="The Shining"), + ] + feature_store_id = "movie_prediction_unique" + entity_type_id = "movies" + handler = VertexAIFeatureStoreLegacyEnrichmentHandler( + project=self.project, + location=self.location, + api_endpoint=self.api_endpoint, + feature_store_id=feature_store_id, + entity_type_id=entity_type_id, + feature_ids=self.feature_ids, + row_key=self.entity_type_name, + exception_level=ExceptionLevel.RAISE, + ) + + with self.assertRaises(ValueError): + test_pipeline = beam.Pipeline() + _ = ( + test_pipeline + | "Create" >> beam.Create(requests) + | "Enrichment" >> Enrichment(handler)) + res = test_pipeline.run() + res.wait_until_finish() + + def test_vertex_ai_legacy_feature_store_invalid_featurestore(self): + requests = [ + beam.Row(entity_id="movie_02", title="The Shining"), + ] + feature_store_id = "invalid_name" + entity_type_id = "movies" + handler = VertexAIFeatureStoreLegacyEnrichmentHandler( + project=self.project, + location=self.location, + api_endpoint=self.api_endpoint, + feature_store_id=feature_store_id, + entity_type_id=entity_type_id, + feature_ids=self.feature_ids, + row_key=self.entity_type_name, + exception_level=ExceptionLevel.RAISE, + ) + + with self.assertRaises(ValueError): + test_pipeline = beam.Pipeline() + _ = ( + test_pipeline + | "Create" >> beam.Create(requests) + | "Enrichment" >> Enrichment(handler)) + res = test_pipeline.run() + res.wait_until_finish() + + def test_feature_store_enrichment_with_redis(self): + """ + In this test, we run two pipelines back to back. + + In the first pipeline, we run a simple feature store enrichment pipeline + with zero cache records. Therefore, it makes call to the source + and ultimately writes to the cache with a TTL of 300 seconds. + + For the second pipeline, we mock the + `VertexAIFeatureStoreEnrichmentHandler`'s `__call__` method to always + return a `None` response. However, this change won't impact the second + pipeline because the Enrichment transform first checks the cache to fulfill + requests. Since all requests are cached, it will return from there without + making calls to the feature store instance. + """ + expected_fields = ['entity_id', 'title', 'genres'] + requests = [ + beam.Row(entity_id="movie_02", title="The Shining"), + beam.Row(entity_id="movie_04", title="The Dark Knight"), + ] + feature_store_id = "movie_prediction_unique" + entity_type_id = "movies" + handler = VertexAIFeatureStoreLegacyEnrichmentHandler( + project=self.project, + location=self.location, + api_endpoint=self.api_endpoint, + feature_store_id=feature_store_id, + entity_type_id=entity_type_id, + feature_ids=self.feature_ids, + row_key=self.entity_type_name, + ) + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | beam.Create(requests) + | Enrichment(handler).with_redis_cache(self.host, self.port) + | beam.ParDo(ValidateResponse(expected_fields))) + + # manually check cache entry + c = coders.StrUtf8Coder() + for req in requests: + key = handler.get_cache_key(req) + response = self.client.get(c.encode(key)) + if not response: + raise ValueError("No cache entry found for %s" % key) + + actual = VertexAIFeatureStoreLegacyEnrichmentHandler.__call__ + VertexAIFeatureStoreLegacyEnrichmentHandler.__call__ = MagicMock( + return_value=( + beam.Row(entity_id="movie_02", title="The Shining"), beam.Row())) + + with TestPipeline(is_integration_test=True) as test_pipeline: + _ = ( + test_pipeline + | beam.Create(requests) + | Enrichment(handler).with_redis_cache(self.host, self.port) + | beam.ParDo(ValidateResponse(expected_fields))) + VertexAIFeatureStoreLegacyEnrichmentHandler.__call__ = actual + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_test.py b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_test.py new file mode 100644 index 000000000000..352146ecc078 --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_test.py @@ -0,0 +1,58 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +try: + from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store \ + import VertexAIFeatureStoreEnrichmentHandler + from apache_beam.transforms.enrichment_handlers.vertex_ai_feature_store \ + import VertexAIFeatureStoreLegacyEnrichmentHandler +except ImportError: + raise unittest.SkipTest( + 'VertexAI Feature Store test dependencies ' + 'are not installed.') + + +class TestVertexAIFeatureStoreHandlerInit(unittest.TestCase): + def test_raise_error_duplicate_api_endpoint_online_store(self): + with self.assertRaises(ValueError): + _ = VertexAIFeatureStoreEnrichmentHandler( + project='project', + location='location', + api_endpoint='location@google.com', + feature_store_name='feature_store', + feature_view_name='feature_view', + row_key='row_key', + client_options={'api_endpoint': 'region@google.com'}, + ) + + def test_raise_error_duplicate_api_endpoint_legacy_store(self): + with self.assertRaises(ValueError): + _ = VertexAIFeatureStoreLegacyEnrichmentHandler( + project='project', + location='location', + api_endpoint='location@google.com', + feature_store_id='feature_store', + entity_type_id='entity_id', + feature_ids=['feature1', 'feature2'], + row_key='row_key', + client_options={'api_endpoint': 'region@google.com'}, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_tests_requirement.txt b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_tests_requirement.txt new file mode 100644 index 000000000000..cd74683a51c1 --- /dev/null +++ b/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_tests_requirement.txt @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +redis +google-cloud-aiplatform diff --git a/sdks/python/test-suites/direct/common.gradle b/sdks/python/test-suites/direct/common.gradle index b5680c2e1e9a..d0f0ca85bbe3 100644 --- a/sdks/python/test-suites/direct/common.gradle +++ b/sdks/python/test-suites/direct/common.gradle @@ -368,7 +368,7 @@ task transformersInferenceTest { task enrichmentRedisTest { dependsOn 'installGcpTest' dependsOn ':sdks:python:sdist' - def requirementsFile = "${rootDir}/sdks/python/apache_beam/io/requestresponse_tests_requirements.txt" + def requirementsFile = "${rootDir}/sdks/python/apache_beam/transforms/enrichment_handlers/vertex_ai_feature_store_tests_requirement.txt" doFirst { exec { executable 'sh' @@ -381,7 +381,8 @@ task enrichmentRedisTest { "test_opts": testOpts, "suite": "postCommitIT-direct-py${pythonVersionSuffix}", "collect": "uses_redis", - "runner": "TestDirectRunner" + "runner": "TestDirectRunner", + "region": "us-central1", ] def cmdArgs = mapToArgString(argMap) exec {