Skip to content

Commit

Permalink
[experiment] Verify command injection when starting experiments async…
Browse files Browse the repository at this point in the history
…hronously (#3685)

# Description

Added command parameter detection before executing commands to avoid
risky characters.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution
guidelines](https://github.com/microsoft/promptflow/blob/main/CONTRIBUTING.md).**
- [ ] **I confirm that all new dependencies are compatible with the MIT
license.**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.
  • Loading branch information
lalala123123 authored Aug 23, 2024
1 parent cef8ee5 commit f9efeaa
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import os
import platform
import re
import signal
import subprocess
import sys
Expand Down Expand Up @@ -414,6 +415,14 @@ def async_start(self, executable_path=None, nodes=None, from_nodes=None, attempt
:return: Experiment info.
:rtype: ~promptflow.entities.Experiment
"""
def _params_inject_validation(params, param_name):
# Verify that the command is injected in the parameters.
# parameters can only consist of numeric, alphabetic parameters, strikethrough and dash.
pattern = r'^[a-zA-Z0-9 _\-]*$'
for item in params:
if not bool(re.match(pattern, item)):
raise ExperimentValueError(f"Invalid character found in the parameter {params} of {param_name}.")

# Setup file handler
file_handler, index = _set_up_experiment_log_handler(experiment_path=self.experiment._output_dir, index=attempt)
logger.addHandler(file_handler._stream_handler)
Expand All @@ -423,10 +432,13 @@ def async_start(self, executable_path=None, nodes=None, from_nodes=None, attempt
executable_path = executable_path or sys.executable
args = [executable_path, __file__, "start", "--experiment", self.experiment.name]
if nodes:
_params_inject_validation(nodes, "nodes")
args = args + ["--nodes"] + nodes
if from_nodes:
_params_inject_validation(from_nodes, "from-nodes")
args = args + ["--from-nodes"] + from_nodes
if kwargs.get("session"):
_params_inject_validation(kwargs.get("session"), "session")
args = args + ["--session", kwargs.get("session")]
args = args + ["--attempt", str(index)]
# Start an orchestrator process using detach mode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,29 @@ def test_experiment_start_from_nodes(self):
assert len(exp.node_runs["main"]) == 3
assert len(exp.node_runs["echo"]) == 2

@pytest.mark.usefixtures("use_secrets_config_file", "recording_injection", "setup_local_connection")
def test_experiment_start_with_command_injection(self):
template_path = EXP_ROOT / "basic-script-template" / "basic-script.exp.yaml"
# Load template and create experiment
template = load_common(ExperimentTemplate, source=template_path)
experiment = Experiment.from_template(template)
client = PFClient()
exp = client._experiments.create_or_update(experiment)

# Test start experiment with injection command
injection_command = ";bad command;"
with pytest.raises(ExperimentValueError) as error:
client._experiments.start(exp, nodes=[injection_command])
assert "Invalid character found" in str(error.value)

with pytest.raises(ExperimentValueError):
client._experiments.start(exp, from_nodes=[injection_command])
assert "Invalid character found" in str(error.value)

with pytest.raises(ExperimentValueError):
client._experiments.start(exp, session=injection_command)
assert "Invalid character found" in str(error.value)

@pytest.mark.skipif(condition=not pytest.is_live, reason="Injection cannot passed to detach process.")
def test_cancel_experiment(self):
template_path = EXP_ROOT / "command-node-exp-template" / "basic-command.exp.yaml"
Expand Down

0 comments on commit f9efeaa

Please sign in to comment.