Skip to content

Commit

Permalink
[SDK/CLI] Validate flow schema when creating flow to azure ai (#1547)
Browse files Browse the repository at this point in the history
- Validate flow schema when creating flow to azure ai
- Remove flow portal url related logic

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.
  • Loading branch information
0mza987 authored Dec 22, 2023
1 parent df29265 commit 6ff0eeb
Show file tree
Hide file tree
Showing 10 changed files with 392 additions and 403 deletions.
27 changes: 2 additions & 25 deletions src/promptflow/promptflow/_cli/_pf_azure/_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ def add_parser_flow(subparsers):
add_parser_flow_create(flow_subparsers)
add_parser_flow_show(flow_subparsers)
add_parser_flow_list(flow_subparsers)
# add_parser_flow_delete(flow_subparsers)
# add_parser_flow_download(flow_subparsers)
flow_parser.set_defaults(action="flow")


Expand Down Expand Up @@ -188,13 +186,7 @@ def create_flow(args: argparse.Namespace):
"""Create a flow for promptflow."""
pf = _get_azure_pf_client(args.subscription, args.resource_group, args.workspace_name, debug=args.debug)
params = _parse_flow_metadata_args(args.params_override)
pf.flows.create_or_update(
flow=args.flow,
display_name=params.get("display_name", None),
type=params.get("type", None),
description=params.get("description", None),
tags=params.get("tags", None),
)
pf.flows.create_or_update(flow=args.flow, **params)


@exception_handler("Show flow")
Expand All @@ -218,19 +210,6 @@ def list_flows(args: argparse.Namespace):
_output_result_list_with_format(flow_list, args.output)


def download_flow(
source: str,
destination: str,
workspace_name: str,
resource_group: str,
subscription_id: str,
):
"""Download a flow from file share to local."""
flow_operations = _get_flow_operation(subscription_id, resource_group, workspace_name)
flow_operations.download(source, destination)
print(f"Successfully download flow from file share path {source!r} to {destination!r}.")


def _parse_flow_metadata_args(params: List[Dict[str, str]]) -> Dict:
result, tags = {}, {}
if not params:
Expand All @@ -241,8 +220,6 @@ def _parse_flow_metadata_args(params: List[Dict[str, str]]) -> Dict:
tag_key = k.replace("tags.", "")
tags[tag_key] = v
continue
# replace "-" with "_" to handle the usage for both "-" and "_" in the command key
normalized_key = k.replace("-", "_")
result[normalized_key] = v
result[k] = v
result["tags"] = tags
return result
29 changes: 19 additions & 10 deletions src/promptflow/promptflow/_sdk/entities/_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,25 @@
import json
from os import PathLike
from pathlib import Path
from typing import Dict, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import yaml
from marshmallow import Schema

from promptflow._constants import LANGUAGE_KEY, FlowLanguage
from promptflow._sdk._constants import BASE_PATH_CONTEXT_KEY, DEFAULT_ENCODING, FLOW_TOOLS_JSON, PROMPT_FLOW_DIR_NAME
from promptflow._sdk._constants import (
BASE_PATH_CONTEXT_KEY,
DAG_FILE_NAME,
DEFAULT_ENCODING,
FLOW_TOOLS_JSON,
PROMPT_FLOW_DIR_NAME,
)
from promptflow._sdk.entities._connection import _Connection
from promptflow._sdk.entities._validation import SchemaValidatableMixin
from promptflow._utils.flow_utils import resolve_flow_path
from promptflow._utils.logger_utils import get_cli_sdk_logger
from promptflow.exceptions import ErrorTarget, UserErrorException

from ..._utils.flow_utils import resolve_flow_path
from ..._utils.logger_utils import get_cli_sdk_logger
from .._constants import DAG_FILE_NAME
from ._connection import _Connection
from ._validation import SchemaValidatableMixin

logger = get_cli_sdk_logger()


Expand Down Expand Up @@ -204,12 +208,14 @@ class ProtectedFlow(Flow, SchemaValidatableMixin):
def __init__(
self,
code: str,
params_override: Optional[Dict] = None,
**kwargs,
):
super().__init__(code=code, **kwargs)

self._flow_dir, self._dag_file_name = self._get_flow_definition(self.code)
self._executable = None
self._params_override = params_override

@property
def flow_dag_path(self) -> Path:
Expand All @@ -221,7 +227,7 @@ def name(self) -> str:

@property
def display_name(self) -> str:
return self.dag.get("display_name", None)
return self.dag.get("display_name", self.name)

@property
def language(self) -> str:
Expand Down Expand Up @@ -267,7 +273,10 @@ def _create_validation_error(self, message, no_personal_data_message=None):

def _dump_for_validation(self) -> Dict:
# Flow is read-only in control plane, so we always dump the flow from file
return yaml.safe_load(self.flow_dag_path.read_text(encoding=DEFAULT_ENCODING))
data = yaml.safe_load(self.flow_dag_path.read_text(encoding=DEFAULT_ENCODING))
if isinstance(self._params_override, dict):
data.update(self._params_override)
return data

# endregion

Expand Down
10 changes: 9 additions & 1 deletion src/promptflow/promptflow/_sdk/schemas/_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from marshmallow import fields
from marshmallow import fields, validate

from promptflow._sdk._constants import FlowType
from promptflow._sdk.schemas._base import PatchedSchemaMeta, YamlFileSchema
from promptflow._sdk.schemas._fields import NestedField

Expand Down Expand Up @@ -39,3 +40,10 @@ class FlowSchema(YamlFileSchema):
nodes = fields.List(fields.Dict())
node_variants = fields.Dict(keys=fields.Str(), values=fields.Dict())
environment = fields.Dict()

# metadata
type = fields.Str(validate=validate.OneOf(FlowType.get_all_values()))
language = fields.Str()
description = fields.Str()
display_name = fields.Str()
tags = fields.Dict(keys=fields.Str(), values=fields.Str())
1 change: 1 addition & 0 deletions src/promptflow/promptflow/azure/_entities/_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def _from_pf_service(cls, rest_object: FlowDto):
owner=rest_object.owner.as_dict(),
is_archived=rest_object.is_archived,
created_date=rest_object.created_date,
flow_portal_url=rest_object.studio_portal_endpoint,
)

@classmethod
Expand Down
132 changes: 60 additions & 72 deletions src/promptflow/promptflow/azure/operations/_flow_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
# pylint: disable=protected-access
import copy
import json
import os
import re
Expand All @@ -24,6 +25,7 @@
from azure.core.exceptions import HttpResponseError

from promptflow._sdk._constants import (
BASE_PATH_CONTEXT_KEY,
CLIENT_FLOW_TYPE_2_SERVICE_FLOW_TYPE,
DAG_FILE_NAME,
FLOW_TOOLS_JSON,
Expand All @@ -35,7 +37,7 @@
)
from promptflow._sdk._errors import FlowOperationError
from promptflow._sdk._telemetry import ActivityType, WorkspaceTelemetryMixin, monitor_operation
from promptflow._sdk._utils import PromptflowIgnoreFile, generate_flow_tools_json
from promptflow._sdk._utils import PromptflowIgnoreFile, generate_flow_tools_json, load_from_dict
from promptflow._sdk._vendor._asset_utils import traverse_directory
from promptflow._utils.logger_utils import get_cli_sdk_logger
from promptflow.azure._constants._flow import DEFAULT_STORAGE
Expand All @@ -44,7 +46,7 @@
from promptflow.azure._restclient.flow_service_caller import FlowServiceCaller
from promptflow.azure.operations._artifact_utilities import _get_datastore_name, get_datastore_info
from promptflow.azure.operations._fileshare_storeage_helper import FlowFileStorageClient
from promptflow.exceptions import SystemErrorException
from promptflow.exceptions import SystemErrorException, UserErrorException

logger = get_cli_sdk_logger()

Expand Down Expand Up @@ -91,44 +93,6 @@ def _index_service_endpoint_url(self):
endpoint = self._service_caller._service_endpoint
return endpoint + "index/v1.0" + self._service_caller._common_azure_url_pattern

def _get_flow_portal_url_from_resource_id(self, flow_resource_id: str):
"""Get the portal url for the run."""
match = self._FLOW_RESOURCE_PATTERN.match(flow_resource_id)
if not match or len(match.groups()) != 2:
logger.warning("Failed to parse flow resource id '%s'", flow_resource_id)
return None
experiment_id, flow_id = match.groups()
return self._get_flow_portal_url(experiment_id, flow_id)

def _get_flow_portal_url_from_index_entity(self, entity: Dict):
"""Enrich the index entity with flow portal url."""
result = None
experiment_id = entity["properties"].get("experimentId", None)
flow_id = entity["properties"].get("flowId", None)

if experiment_id and flow_id:
result = self._get_flow_portal_url(experiment_id, flow_id)
return result

def _get_flow_portal_url(self, experiment_id, flow_id):
"""Get the portal url for the run."""
# TODO[2785705]: Handle the case when endpoint is other clouds
workspace_kind = str(self._workspace._kind).lower()
# default refers to azure machine learning studio
if workspace_kind == "default":
return (
f"https://ml.azure.com/prompts/flow/{experiment_id}/{flow_id}/"
f"details?wsid={self._service_caller._common_azure_url_pattern}"
)
# project refers to azure ai studio
elif workspace_kind == "project":
return (
f"https://ai.azure.com/projectflows/{flow_id}/{experiment_id}/"
f"details/Flow?wsid={self._service_caller._common_azure_url_pattern}"
)
else:
raise FlowOperationError(f"Workspace kind {workspace_kind!r} is not supported for promptflow operations.")

@monitor_operation(activity_name="pfazure.flows.create_or_update", activity_type=ActivityType.PUBLICAPI)
def create_or_update(self, flow: Union[str, Path], display_name=None, type=None, **kwargs) -> Flow:
"""Create a flow to remote from local source.
Expand All @@ -147,7 +111,7 @@ def create_or_update(self, flow: Union[str, Path], display_name=None, type=None,
:type tags: Dict[str, str]
"""
# validate the parameters
azure_flow, flow_display_name, flow_type, kwargs = self._validate_flow_creation_parameters(
azure_flow, flow_display_name, flow_type, kwargs = FlowOperations._validate_flow_creation_parameters(
flow, display_name, type, **kwargs
)
# upload to file share
Expand All @@ -166,59 +130,83 @@ def create_or_update(self, flow: Union[str, Path], display_name=None, type=None,
**kwargs,
)
result_flow = Flow._from_pf_service(rest_flow)
result_flow.flow_portal_url = self._get_flow_portal_url_from_resource_id(rest_flow.flow_resource_id)
flow_dict = result_flow._to_dict()
print(f"Flow created successfully:\n{json.dumps(flow_dict, indent=4)}")

return result_flow

def _validate_flow_creation_parameters(self, source, flow_display_name, flow_type, **kwargs):
@staticmethod
def _validate_flow_creation_parameters(source, flow_display_name, flow_type, **kwargs):
"""Validate the parameters for flow creation operation."""
# validate the source folder
logger.info("Validating flow source.")
if not Path(source, DAG_FILE_NAME).exists():
raise UserErrorException(
f"Flow source must be a directory with flow definition yaml '{DAG_FILE_NAME}'. "
f"Got {Path(source).resolve().as_posix()!r}."
)

# validate flow source with flow schema
logger.info("Validating flow schema.")
flow_dict = FlowOperations._validate_flow_schema(source, flow_display_name, flow_type, **kwargs)

logger.info("Validating flow creation parameters.")
flow = load_flow(source)
# if no flow name specified, use "flow name + timestamp"
flow_display_name = flow_dict.get("display_name", None)
if not flow_display_name:
flow_display_name = f"{flow.display_name}-{datetime.now().strftime('%m-%d-%Y-%H-%M-%S')}"
elif not isinstance(flow_display_name, str):
raise FlowOperationError(
f"Flow name must be a string, got {type(flow_display_name)!r}: {flow_display_name!r}."
)
flow_display_name = f"{Path(source).name}-{datetime.now().strftime('%m-%d-%Y-%H-%M-%S')}"

# if no flow type specified, use default flow type "standard"
supported_flow_types = FlowType.get_all_values()
flow_type = flow_dict.get("type", None)
if not flow_type:
flow_type = FlowType.STANDARD
elif flow_type not in supported_flow_types:
raise FlowOperationError(
f"Flow type {flow_type!r} is not supported, supported types are {supported_flow_types}"
)

# check description type
description = kwargs.get("description", None) or flow.description

# update description and tags to be the final value
description = flow_dict.get("description", None)
if isinstance(description, str):
kwargs["description"] = description
elif description is not None:
raise FlowOperationError(f"Description must be a string, got {type(description)!r}: {description!r}.")

# check if the tags type is Dict[str, str]
tags = kwargs.get("tags", None) or flow.tags
if isinstance(tags, dict) and all(
isinstance(key, str) and isinstance(value, str) for key, value in tags.items()
):

tags = flow_dict.get("tags", None)
if tags:
kwargs["tags"] = tags
elif tags is not None:
raise FlowOperationError(
f"Tags type must be 'Dict[str, str]', got non-dict or non-string key/value in tags: {tags}."
)

return flow, flow_display_name, flow_type, kwargs

@staticmethod
def _validate_flow_schema(source, display_name=None, type=None, **kwargs):
"""Validate the flow schema."""
from marshmallow import ValidationError

from promptflow._sdk.entities._flow import ProtectedFlow
from promptflow._sdk.schemas._flow import FlowSchema

params_override = copy.deepcopy(kwargs)
if display_name:
params_override["display_name"] = display_name
if type:
params_override["type"] = type

flow_entity = ProtectedFlow.load(source=source, params_override=params_override)
flow_dict = flow_entity._dump_for_validation()
try:
load_from_dict(
schema=FlowSchema,
data=flow_dict,
context={BASE_PATH_CONTEXT_KEY: Path(source)},
)
except ValidationError as e:
raise UserErrorException(f"Failed to validate flow schema due to: {str(e)}") from e

return flow_dict

def _resolve_flow_code_and_upload_to_file_share(
self, flow: Flow, flow_display_name: str, ignore_tools_json=False
) -> str:
ops = OperationOrchestrator(self._all_operations, self._operation_scope, self._operation_config)
file_share_flow_path = ""

logger.info("Building flow code.")
with flow._build_code() as code:
if code is None:
raise FlowOperationError("Failed to build flow code.")
Expand All @@ -238,25 +226,28 @@ def _resolve_flow_code_and_upload_to_file_share(
datastore_name = _get_datastore_name(datastore_name=DEFAULT_STORAGE)
datastore_operation = ops._code_assets._datastore_operation
datastore_info = get_datastore_info(datastore_operation, datastore_name)

logger.debug("Creating storage client for uploading flow to file share.")
storage_client = FlowFileStorageClient(
credential=datastore_info["credential"],
file_share_name=datastore_info["container_name"],
account_url=datastore_info["account_url"],
azure_cred=datastore_operation._credential,
)
logger.debug("Created storage client for uploading flow to file share.")

# set storage client to flow operation, can be used in test case
self._storage_client = storage_client

# check if the file share directory exists
logger.debug("Checking if the file share directory exists.")
if storage_client._check_file_share_directory_exist(flow_display_name):
raise FlowOperationError(
f"Remote flow folder {flow_display_name!r} already exists under "
f"'{storage_client.file_share_prefix}'. Please change the flow folder name and try again."
)

try:
logger.info("Uploading flow directory to file share.")
storage_client.upload_dir(
source=code.path,
dest=flow_display_name,
Expand Down Expand Up @@ -316,7 +307,6 @@ def get(self, name: str) -> Flow:
raise FlowOperationError(f"Failed to get flow {name!r} due to: {str(e)}.") from e

flow = Flow._from_pf_service(rest_flow)
flow.flow_portal_url = self._get_flow_portal_url_from_resource_id(rest_flow.flow_resource_id)
return flow

@monitor_operation(activity_name="pfazure.flows.list", activity_type=ActivityType.PUBLICAPI)
Expand Down Expand Up @@ -407,8 +397,6 @@ def list(
flow_instances = []
for entity in flow_entities:
flow = Flow._from_index_service(entity)
# add flow portal url
flow.flow_portal_url = self._get_flow_portal_url_from_index_entity(entity)
flow_instances.append(flow)

return flow_instances
Expand Down
Loading

0 comments on commit 6ff0eeb

Please sign in to comment.