From af12168594be177103cc1d6ef75319de15c8a8be Mon Sep 17 00:00:00 2001 From: Brynn Yin <24237253+brynn-code@users.noreply.github.com> Date: Thu, 11 Jan 2024 18:48:50 +0800 Subject: [PATCH] [SDK] Add orchestrator inputs, data entity (#1721) # Description Please add an informative description that covers that changes made by the pull request and link all relevant issues. # All Promptflow Contribution checklist: - [ ] **The pull request does not introduce [breaking changes].** - [ ] **CHANGELOG is updated for new features, bug fixes or other significant changes.** - [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).** - [ ] **Create an issue and link to the pull request to get dedicated review from promptflow team. Learn more: [suggested workflow](../CONTRIBUTING.md#suggested-workflow).** ## General Guidelines and Best Practices - [ ] Title of the pull request is clear and informative. - [ ] There are a small number of commits, each of which have an informative message. This means that previously merged commits do not appear in the history of the PR. For more information on cleaning up the commits in your PR, [see this page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md). ### Testing Guidelines - [ ] Pull request includes test coverage for the included changes. --------- Signed-off-by: Brynn Yin --- .../promptflow/_cli/_pf/_experiment.py | 8 +- src/promptflow/promptflow/_sdk/_errors.py | 18 +++ .../_submitter/experiment_orchestrator.py | 55 ++++++--- .../promptflow/_sdk/entities/_experiment.py | 108 ++++++++++++++++-- .../promptflow/_sdk/schemas/_experiment.py | 24 ++++ .../basic-no-script-template/basic.exp.yaml | 5 + 6 files changed, 185 insertions(+), 33 deletions(-) diff --git a/src/promptflow/promptflow/_cli/_pf/_experiment.py b/src/promptflow/promptflow/_cli/_pf/_experiment.py index 68e71e732d7..f5e76e44b90 100644 --- a/src/promptflow/promptflow/_cli/_pf/_experiment.py +++ b/src/promptflow/promptflow/_cli/_pf/_experiment.py @@ -163,24 +163,24 @@ def create_experiment(args: argparse.Namespace): logger.debug("Creating experiment from template %s", template.name) experiment = Experiment.from_template(template) logger.debug("Creating experiment %s", experiment.name) - exp = _get_pf_client().experiments.create_or_update(experiment) + exp = _get_pf_client()._experiments.create_or_update(experiment) print(json.dumps(exp._to_dict(), indent=4)) @exception_handler("List experiment") def list_experiment(args: argparse.Namespace): list_view_type = get_list_view_type(archived_only=args.archived_only, include_archived=args.include_archived) - results = _get_pf_client().experiments.list(args.max_results, list_view_type=list_view_type) + results = _get_pf_client()._experiments.list(args.max_results, list_view_type=list_view_type) print(json.dumps([result._to_dict() for result in results], indent=4)) @exception_handler("Show experiment") def show_experiment(args: argparse.Namespace): - result = _get_pf_client().experiments.get(args.name) + result = _get_pf_client()._experiments.get(args.name) print(json.dumps(result._to_dict(), indent=4)) @exception_handler("Start experiment") def start_experiment(args: argparse.Namespace): - result = _get_pf_client().experiments.start(args.name) + result = _get_pf_client()._experiments.start(args.name) print(json.dumps(result._to_dict(), indent=4)) diff --git a/src/promptflow/promptflow/_sdk/_errors.py b/src/promptflow/promptflow/_sdk/_errors.py index f03a6f2c577..7be2df46111 100644 --- a/src/promptflow/promptflow/_sdk/_errors.py +++ b/src/promptflow/promptflow/_sdk/_errors.py @@ -129,3 +129,21 @@ class ExperimentNotFoundError(SDKError): """Exception raised if experiment cannot be found.""" pass + + +class ExperimentValidationError(SDKError): + """Exception raised if experiment validation failed.""" + + pass + + +class ExperimentValueError(SDKError): + """Exception raised if experiment validation failed.""" + + pass + + +class ExperimentHasCycle(SDKError): + """Exception raised if experiment validation failed.""" + + pass diff --git a/src/promptflow/promptflow/_sdk/_submitter/experiment_orchestrator.py b/src/promptflow/promptflow/_sdk/_submitter/experiment_orchestrator.py index 663859d8b7e..1e6c767e6bc 100644 --- a/src/promptflow/promptflow/_sdk/_submitter/experiment_orchestrator.py +++ b/src/promptflow/promptflow/_sdk/_submitter/experiment_orchestrator.py @@ -6,13 +6,13 @@ from promptflow._sdk._configuration import Configuration from promptflow._sdk._constants import ExperimentNodeType, ExperimentStatus +from promptflow._sdk._errors import ExperimentHasCycle, ExperimentValueError from promptflow._sdk._submitter import RunSubmitter from promptflow._sdk.entities import Run from promptflow._sdk.entities._experiment import Experiment from promptflow._sdk.operations import RunOperations from promptflow._sdk.operations._experiment_operations import ExperimentOperations from promptflow._utils.logger_utils import LoggerFactory -from promptflow.exceptions import UserErrorException logger = LoggerFactory.get_logger(name=__name__) @@ -44,12 +44,11 @@ def start(self, experiment: Experiment, **kwargs): resolved_nodes = self._ensure_nodes_order(experiment.nodes) # Run nodes - data_dict = {data.get("name", None): data for data in experiment.data} run_dict = {} try: for node in resolved_nodes: - logger.debug(f"Running node {node.name}.") - run = self._run_node(node, experiment, data_dict, run_dict) + logger.info(f"Running node {node.name}.") + run = self._run_node(node, experiment, run_dict) # Update node run to experiment experiment._append_node_run(node.name, run) self.experiment_operations.create_or_update(experiment) @@ -92,23 +91,23 @@ def _prepare_edges(node): referenced_nodes.discard(node.name) break if not action: - raise UserErrorException(f"Experiment has circular dependency {edges!r}") + raise ExperimentHasCycle(f"Experiment has circular dependency {edges!r}") logger.debug(f"Experiment nodes resolved order: {[node.name for node in resolved_nodes]}") return resolved_nodes - def _run_node(self, node, experiment, data_dict, run_dict) -> Run: + def _run_node(self, node, experiment, run_dict) -> Run: if node.type == ExperimentNodeType.FLOW: - return self._run_flow_node(node, experiment, data_dict, run_dict) + return self._run_flow_node(node, experiment, run_dict) elif node.type == ExperimentNodeType.CODE: return self._run_script_node(node, experiment) - raise UserErrorException(f"Unknown experiment node {node.name!r} type {node.type!r}") + raise ExperimentValueError(f"Unknown experiment node {node.name!r} type {node.type!r}") - def _run_flow_node(self, node, experiment, data_dict, run_dict): + def _run_flow_node(self, node, experiment, run_dict): run_output_path = (Path(experiment._output_dir) / "runs" / node.name).resolve().absolute().as_posix() timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") run = ExperimentRun( - experiment_data=data_dict, + experiment=experiment, experiment_runs=run_dict, # Use node name as prefix for run name? name=f"{node.name}_attempt{timestamp}", @@ -132,10 +131,30 @@ def _run_script_node(self, node, experiment): class ExperimentRun(Run): """Experiment run, includes experiment running context, like data, inputs and runs.""" - def __init__(self, experiment_data, experiment_runs, **kwargs): - self.experiment_data = experiment_data + def __init__(self, experiment, experiment_runs, **kwargs): + self.experiment = experiment + self.experiment_data = {data.name: data for data in experiment.data} + self.experiment_inputs = {input.name: input for input in experiment.inputs} self.experiment_runs = experiment_runs super().__init__(**kwargs) + self._resolve_column_mapping() + + def _resolve_column_mapping(self): + """Resolve column mapping with experiment inputs to constant values.""" + logger.info(f"Start resolve node {self.display_name!r} column mapping.") + resolved_mapping = {} + for name, value in self.column_mapping.items(): + if not value.startswith("${inputs."): + resolved_mapping[name] = value + continue + input_name = value.split(".")[1].replace("}", "") + if input_name not in self.experiment_inputs: + raise ExperimentValueError( + f"Node {self.display_name!r} inputs {value!r} related experiment input {input_name!r} not found." + ) + resolved_mapping[name] = self.experiment_inputs[input_name].default + logger.debug(f"Resolved node {self.display_name!r} column mapping {resolved_mapping}.") + self.column_mapping = resolved_mapping class ExperimentRunSubmitter(RunSubmitter): @@ -155,24 +174,26 @@ def _resolve_input_dirs(self, run: ExperimentRun): for value in inputs_mapping.values(): referenced_data, referenced_run = None, None if value.startswith("${data."): - referenced_data = value.split(".")[1] + referenced_data = value.split(".")[1].replace("}", "") elif value.startswith("${"): referenced_run = value.split(".")[0].replace("${", "") if referenced_data: if data_name and data_name != referenced_data: - raise UserErrorException( + raise ExperimentValueError( f"Experiment has multiple data inputs {data_name!r} and {referenced_data!r}" ) data_name = referenced_data if referenced_run: if run_name and run_name != referenced_run: - raise UserErrorException(f"Experiment has multiple run inputs {run_name!r} and {referenced_run!r}") + raise ExperimentValueError( + f"Experiment has multiple run inputs {run_name!r} and {referenced_run!r}" + ) run_name = referenced_run logger.debug(f"Resolve node {run.name} referenced data {data_name!r}, run {run_name!r}.") # Build inputs from experiment data and run result = {} - if data_name in run.experiment_data and run.experiment_data[data_name].get("path"): - result.update({f"data.{data_name}": run.experiment_data[data_name]["path"]}) + if data_name in run.experiment_data and run.experiment_data[data_name].path: + result.update({f"data.{data_name}": run.experiment_data[data_name].path}) if run_name in run.experiment_runs: result.update( { diff --git a/src/promptflow/promptflow/_sdk/entities/_experiment.py b/src/promptflow/promptflow/_sdk/entities/_experiment.py index 8abc8a79b4e..0b7d0d37d46 100644 --- a/src/promptflow/promptflow/_sdk/entities/_experiment.py +++ b/src/promptflow/promptflow/_sdk/entities/_experiment.py @@ -10,6 +10,8 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union +from marshmallow import Schema + from promptflow._sdk._constants import ( BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, @@ -18,22 +20,71 @@ ExperimentNodeType, ExperimentStatus, ) +from promptflow._sdk._errors import ExperimentValidationError, ExperimentValueError from promptflow._sdk._orm.experiment import Experiment as ORMExperiment from promptflow._sdk._submitter import remove_additional_includes from promptflow._sdk._utils import _merge_local_code_and_additional_includes, _sanitize_python_variable_name from promptflow._sdk.entities import Run +from promptflow._sdk.entities._validation import MutableValidationResult, SchemaValidatableMixin from promptflow._sdk.entities._yaml_translatable import YAMLTranslatableMixin from promptflow._sdk.schemas._experiment import ( + ExperimentDataSchema, + ExperimentInputSchema, ExperimentSchema, ExperimentTemplateSchema, FlowNodeSchema, ScriptNodeSchema, ) from promptflow._utils.logger_utils import get_cli_sdk_logger +from promptflow.contracts.tool import ValueType logger = get_cli_sdk_logger() +class ExperimentData(YAMLTranslatableMixin): + def __init__(self, name, path, **kwargs): + self.name = name + self.path = path + + @classmethod + def _get_schema_cls(cls): + return ExperimentDataSchema + + +class ExperimentInput(YAMLTranslatableMixin): + def __init__(self, name, default, type, **kwargs): + self.name = name + self.type, self.default = self._resolve_type_and_default(type, default) + + @classmethod + def _get_schema_cls(cls): + return ExperimentInputSchema + + def _resolve_type_and_default(self, typ, default): + supported_types = [ + ValueType.INT, + ValueType.STRING, + ValueType.DOUBLE, + ValueType.LIST, + ValueType.OBJECT, + ValueType.BOOL, + ] + value_type: ValueType = next((i for i in supported_types if typ.lower() == i.value.lower()), None) + if value_type is None: + raise ExperimentValueError(f"Unknown experiment input type {typ!r}, supported are {supported_types}.") + return value_type.value, value_type.parse(default) + + @classmethod + def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str = None, **kwargs): + # Override this to avoid 'type' got pop out + schema_cls = cls._get_schema_cls() + try: + loaded_data = schema_cls(context=context).load(data, **kwargs) + except Exception as e: + raise Exception(f"Load experiment input failed with {str(e)}. f{(additional_message or '')}.") + return cls(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) + + class FlowNode(YAMLTranslatableMixin): def __init__( self, @@ -109,7 +160,7 @@ def _save_snapshot(self, target): pass -class ExperimentTemplate(YAMLTranslatableMixin): +class ExperimentTemplate(YAMLTranslatableMixin, SchemaValidatableMixin): def __init__(self, nodes, name=None, description=None, data=None, inputs=None, **kwargs): self._base_path = kwargs.get(BASE_PATH_CONTEXT_KEY, Path(".")) self.name = name or self._generate_name() @@ -171,6 +222,30 @@ def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str = No raise Exception(f"Load experiment template failed with {str(e)}. f{(additional_message or '')}.") return cls(base_path=context[BASE_PATH_CONTEXT_KEY], **loaded_data) + @classmethod + def _create_schema_for_validation(cls, context) -> Schema: + return cls._get_schema_cls()(context=context) + + def _default_context(self) -> dict: + return {BASE_PATH_CONTEXT_KEY: self._base_path} + + @classmethod + def _create_validation_error(cls, message: str, no_personal_data_message: str) -> Exception: + return ExperimentValidationError( + message=message, + no_personal_data_message=no_personal_data_message, + ) + + def _customized_validate(self) -> MutableValidationResult: + """Validate the resource with customized logic. + + Override this method to add customized validation logic. + + :return: The customized validation result + :rtype: MutableValidationResult + """ + pass + class Experiment(ExperimentTemplate): def __init__( @@ -178,6 +253,7 @@ def __init__( nodes, name=None, data=None, + inputs=None, status=ExperimentStatus.NOT_STARTED, node_runs=None, properties=None, @@ -192,7 +268,7 @@ def __init__( self.last_end_time = kwargs.get("last_end_time", None) self.is_archived = kwargs.get("is_archived", False) self._output_dir = Path.home() / PROMPT_FLOW_DIR_NAME / PROMPT_FLOW_EXP_DIR_NAME / self.name - super().__init__(nodes, name=self.name, data=data, **kwargs) + super().__init__(nodes, name=self.name, data=data, inputs=inputs, **kwargs) @classmethod def _get_schema_cls(cls): @@ -240,8 +316,8 @@ def _to_orm_object(self): last_start_time=self.last_start_time, last_end_time=self.last_end_time, properties=json.dumps(self.properties), - data=json.dumps(self.data), - inputs=json.dumps(self.inputs), + data=json.dumps([item._to_dict() for item in self.data]), + inputs=json.dumps([input._to_dict() for input in self.inputs]), nodes=json.dumps([node._to_dict() for node in self.nodes]), node_runs=json.dumps(self.node_runs), ) @@ -252,21 +328,28 @@ def _to_orm_object(self): def _from_orm_object(cls, obj: ORMExperiment) -> "Experiment": """Create a experiment object from ORM object.""" nodes = [] + context = {BASE_PATH_CONTEXT_KEY: "./"} for node_dict in json.loads(obj.nodes): if node_dict["type"] == ExperimentNodeType.FLOW: nodes.append( - FlowNode._load_from_dict( - node_dict, context={BASE_PATH_CONTEXT_KEY: "./"}, additional_message="Failed to load node." - ) + FlowNode._load_from_dict(node_dict, context=context, additional_message="Failed to load node.") ) elif node_dict["type"] == ExperimentNodeType.CODE: nodes.append( - ScriptNode._load_from_dict( - node_dict, context={BASE_PATH_CONTEXT_KEY: "./"}, additional_message="Failed to load node." - ) + ScriptNode._load_from_dict(node_dict, context=context, additional_message="Failed to load node.") ) else: raise Exception(f"Unknown node type {node_dict['type']}") + data = [ + ExperimentData._load_from_dict(item, context=context, additional_message="Failed to load experiment data") + for item in json.loads(obj.data) + ] + inputs = [ + ExperimentInput._load_from_dict( + item, context=context, additional_message="Failed to load experiment inputs" + ) + for item in json.loads(obj.inputs) + ] return cls( name=obj.name, @@ -277,8 +360,8 @@ def _from_orm_object(cls, obj: ORMExperiment) -> "Experiment": last_end_time=obj.last_end_time, is_archived=obj.archived, properties=json.loads(obj.properties), - data=json.loads(obj.data), - inputs=json.loads(obj.inputs), + data=data, + inputs=inputs, nodes=nodes, node_runs=json.loads(obj.node_runs), ) @@ -292,6 +375,7 @@ def from_template(cls, template: ExperimentTemplate): name=exp_name, description=template.description, data=copy.deepcopy(template.data), + inputs=copy.deepcopy(template.inputs), nodes=copy.deepcopy(template.nodes), base_path=template._base_path, ) diff --git a/src/promptflow/promptflow/_sdk/schemas/_experiment.py b/src/promptflow/promptflow/_sdk/schemas/_experiment.py index 719d9d09d23..85eb678a464 100644 --- a/src/promptflow/promptflow/_sdk/schemas/_experiment.py +++ b/src/promptflow/promptflow/_sdk/schemas/_experiment.py @@ -72,6 +72,30 @@ def resolve_nodes(self, data, **kwargs): return data + @post_load + def resolve_data_and_inputs(self, data, **kwargs): + from promptflow._sdk.entities._experiment import ExperimentData, ExperimentInput + + def resolve_resource(key, cls): + items = data.get(key, []) + resolved_result = [] + for item in items: + if not isinstance(item, dict): + continue + resolved_result.append( + cls._load_from_dict( + data=item, + context=self.context, + additional_message=f"Failed to load {cls.__name__}", + ) + ) + return resolved_result + + data["data"] = resolve_resource("data", ExperimentData) + data["inputs"] = resolve_resource("inputs", ExperimentInput) + + return data + class ExperimentSchema(ExperimentTemplateSchema): node_runs = fields.Dict(keys=fields.Str(), values=fields.Str()) # TODO: Revisit this diff --git a/src/promptflow/tests/test_configs/experiments/basic-no-script-template/basic.exp.yaml b/src/promptflow/tests/test_configs/experiments/basic-no-script-template/basic.exp.yaml index 97058cacfb9..c5010c1b6b4 100644 --- a/src/promptflow/tests/test_configs/experiments/basic-no-script-template/basic.exp.yaml +++ b/src/promptflow/tests/test_configs/experiments/basic-no-script-template/basic.exp.yaml @@ -6,6 +6,11 @@ data: - name: my_data path: ../../flows/web_classification/data.jsonl +inputs: + - name: my_input + type: string + default: Hello World! + nodes: - name: main type: flow