Skip to content

Commit

Permalink
test,fix: Fix the requests.Session mock in test_eval_run
Browse files Browse the repository at this point in the history
  • Loading branch information
kdestin committed Aug 21, 2024
1 parent 86639dd commit 9bd8a79
Showing 1 changed file with 62 additions and 75 deletions.
137 changes: 62 additions & 75 deletions src/promptflow-evals/tests/evals/unittests/test_eval_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _get_mock_create_resonse(self, status=200):
mock_response = MagicMock()
mock_response.status_code = status
if status != 200:
mock_response.text = "Mock error"
mock_response.text = lambda: "Mock error"
else:
mock_response.json.return_value = {
"run": {"info": {"run_id": str(uuid4()), "experiment_id": str(uuid4()), "run_name": str(uuid4())}}
Expand All @@ -52,17 +52,15 @@ def _get_mock_end_response(self, status=200):
"""Get the mock end run response."""
mock_response = MagicMock()
mock_response.status_code = status
mock_response.text = "Everything good" if status == 200 else "Everything bad"
mock_response.text = lambda: "Everything good" if status == 200 else "Everything bad"
return mock_response

@pytest.mark.parametrize(
"status,should_raise", [("KILLED", False), ("WRONG_STATUS", True), ("FINISHED", False), ("FAILED", False)]
)
def test_end_raises(self, token_mock, status, should_raise, caplog):
"""Test that end run raises exception if incorrect status is set."""
mock_session = MagicMock()
mock_session.request.return_value = self._get_mock_create_resonse()
with patch("promptflow.evals.evaluate._eval_run.requests.Session", return_value=mock_session):
with patch("promptflow.evals._http_utils.HttpPipeline.request", return_value=self._get_mock_create_resonse()):
with EvalRun(run_name=None, **TestEvalRun._MOCK_CREDS) as run:
if should_raise:
with pytest.raises(ValueError) as cm:
Expand All @@ -74,9 +72,7 @@ def test_end_raises(self, token_mock, status, should_raise, caplog):

def test_run_logs_if_terminated(self, token_mock, caplog):
"""Test that run warn user if we are trying to terminate it twice."""
mock_session = MagicMock()
mock_session.request.return_value = self._get_mock_create_resonse()
with patch("promptflow.evals.evaluate._eval_run.requests.Session", return_value=mock_session):
with patch("promptflow.evals._http_utils.HttpPipeline.request", return_value=self._get_mock_create_resonse()):
logger = logging.getLogger(EvalRun.__module__)
# All loggers, having promptflow. prefix will have "promptflow" logger
# as a parent. This logger does not propagate the logs and cannot be
Expand All @@ -98,9 +94,10 @@ def test_run_logs_if_terminated(self, token_mock, caplog):

def test_end_logs_if_fails(self, token_mock, caplog):
"""Test that if the terminal status setting was failed, it is logged."""
mock_session = MagicMock()
mock_session.request.side_effect = [self._get_mock_create_resonse(), self._get_mock_end_response(500)]
with patch("promptflow.evals.evaluate._eval_run.requests.Session", return_value=mock_session):
with patch(
"promptflow.evals._http_utils.HttpPipeline.request",
side_effect=[self._get_mock_create_resonse(), self._get_mock_end_response(500)],
):
logger = logging.getLogger(EvalRun.__module__)
# All loggers, having promptflow. prefix will have "promptflow" logger
# as a parent. This logger does not propagate the logs and cannot be
Expand All @@ -120,12 +117,10 @@ def test_end_logs_if_fails(self, token_mock, caplog):

def test_start_run_fails(self, token_mock, caplog):
"""Test that there are log messges if run was not started."""
mock_session = MagicMock()
mock_response_start = MagicMock()
mock_response_start.status_code = 500
mock_response_start.text = "Mock internal service error."
mock_session.request.return_value = mock_response_start
with patch("promptflow.evals.evaluate._eval_run.requests.Session", return_value=mock_session):
mock_response_start.text = lambda: "Mock internal service error."
with patch("promptflow.evals._http_utils.HttpPipeline.request", return_value=mock_response_start):
logger = logging.getLogger(EvalRun.__module__)
# All loggers, having promptflow. prefix will have "promptflow" logger
# as a parent. This logger does not propagate the logs and cannot be
Expand All @@ -142,7 +137,7 @@ def test_start_run_fails(self, token_mock, caplog):
run._start_run()
assert len(caplog.records) == 1
assert "500" in caplog.records[0].message
assert mock_response_start.text in caplog.records[0].message
assert mock_response_start.text() in caplog.records[0].message
assert "The results will be saved locally" in caplog.records[0].message
caplog.clear()
# Log artifact
Expand All @@ -161,57 +156,47 @@ def test_start_run_fails(self, token_mock, caplog):
assert "Unable to stop run due to Run status=RunStatus.BROKEN." in caplog.records[0].message
caplog.clear()

@patch("promptflow.evals.evaluate._eval_run.requests.Session")
def test_run_name(self, mock_session_cls, token_mock):
def test_run_name(self, token_mock):
"""Test that the run name is the same as ID if name is not given."""
mock_session = MagicMock()
mock_response = self._get_mock_create_resonse()
mock_session.request.return_value = mock_response
mock_session_cls.return_value = mock_session
with EvalRun(
run_name=None,
tracking_uri="www.microsoft.com",
subscription_id="mock",
group_name="mock",
workspace_name="mock",
ml_client=MagicMock(),
) as run:
pass
with patch("promptflow.evals._http_utils.HttpPipeline.request", return_value=mock_response):
with EvalRun(
run_name=None,
tracking_uri="www.microsoft.com",
subscription_id="mock",
group_name="mock",
workspace_name="mock",
ml_client=MagicMock(),
) as run:
pass
assert run.info.run_id == mock_response.json.return_value["run"]["info"]["run_id"]
assert run.info.experiment_id == mock_response.json.return_value["run"]["info"]["experiment_id"]
assert run.info.run_name == mock_response.json.return_value["run"]["info"]["run_name"]

@patch("promptflow.evals.evaluate._eval_run.requests.Session")
def test_run_with_name(self, mock_session_cls, token_mock):
def test_run_with_name(self, token_mock):
"""Test that the run name is not the same as id if it is given."""
mock_response = self._get_mock_create_resonse()
mock_response.json.return_value["run"]["info"]["run_name"] = "test"
mock_session = MagicMock()
mock_session.request.return_value = mock_response
mock_session_cls.return_value = mock_session
with EvalRun(
run_name="test",
tracking_uri="www.microsoft.com",
subscription_id="mock",
group_name="mock",
workspace_name="mock",
ml_client=MagicMock(),
) as run:
pass
with patch("promptflow.evals._http_utils.HttpPipeline.request", return_value=mock_response):
with EvalRun(
run_name="test",
tracking_uri="www.microsoft.com",
subscription_id="mock",
group_name="mock",
workspace_name="mock",
ml_client=MagicMock(),
) as run:
pass
assert run.info.run_id == mock_response.json.return_value["run"]["info"]["run_id"]
assert run.info.experiment_id == mock_response.json.return_value["run"]["info"]["experiment_id"]
assert run.info.run_name == "test"
assert run.info.run_name != run.info.run_id

@patch("promptflow.evals.evaluate._eval_run.requests.Session")
def test_get_urls(self, mock_session_cls, token_mock):
def test_get_urls(self, token_mock):
"""Test getting url-s from eval run."""
mock_response = self._get_mock_create_resonse()
mock_session = MagicMock()
mock_session.request.return_value = mock_response
mock_session_cls.return_value = mock_session
with EvalRun(run_name="test", **TestEvalRun._MOCK_CREDS) as run:
pass
with patch("promptflow.evals._http_utils.HttpPipeline.request", return_value=self._get_mock_create_resonse()):
with EvalRun(run_name="test", **TestEvalRun._MOCK_CREDS) as run:
pass
assert run.get_run_history_uri() == (
"https://region.api.azureml.ms/history/v1.0/subscriptions"
"/000000-0000-0000-0000-0000000/resourceGroups/mock-rg-region"
Expand Down Expand Up @@ -239,19 +224,21 @@ def test_get_urls(self, mock_session_cls, token_mock):
)
def test_log_artifacts_logs_error(self, token_mock, tmp_path, caplog, log_function, expected_str):
"""Test that the error is logged."""
mock_session = MagicMock()
mock_response = MagicMock()
mock_response.status_code = 404
mock_response.text = "Mock not found error."
mock_response.text = lambda: "Mock not found error."
if log_function == "log_artifact":
with open(os.path.join(tmp_path, "test.json"), "w") as fp:
json.dump({"f1": 0.5}, fp)
mock_session.request.side_effect = [
self._get_mock_create_resonse(),
mock_response,
self._get_mock_end_response(),
]
with patch("promptflow.evals.evaluate._eval_run.requests.Session", return_value=mock_session):

with patch(
"promptflow.evals._http_utils.HttpPipeline.request",
side_effect=[
self._get_mock_create_resonse(),
mock_response,
self._get_mock_end_response(),
],
):
logger = logging.getLogger(EvalRun.__module__)
# All loggers, having promptflow. prefix will have "promptflow" logger
# as a parent. This logger does not propagate the logs and cannot be
Expand All @@ -268,7 +255,7 @@ def test_log_artifacts_logs_error(self, token_mock, tmp_path, caplog, log_functi
with patch("promptflow.evals.evaluate._eval_run.BlobServiceClient", return_value=MagicMock()):
fn(**kwargs)
assert len(caplog.records) == 1
assert mock_response.text in caplog.records[0].message
assert mock_response.text() in caplog.records[0].message
assert "404" in caplog.records[0].message
assert expected_str in caplog.records[0].message

Expand All @@ -290,9 +277,7 @@ def test_wrong_artifact_path(
expected_error,
):
"""Test that if artifact path is empty, or dies not exist we are logging the error."""
mock_session = MagicMock()
mock_session.request.return_value = self._get_mock_create_resonse()
with patch("promptflow.evals.evaluate._eval_run.requests.Session", return_value=mock_session):
with patch("promptflow.evals._http_utils.HttpPipeline.request", return_value=self._get_mock_create_resonse()):
with EvalRun(run_name="test", **TestEvalRun._MOCK_CREDS) as run:
logger = logging.getLogger(EvalRun.__module__)
# All loggers, having promptflow. prefix will have "promptflow" logger
Expand Down Expand Up @@ -363,9 +348,9 @@ def test_lifecycle(self, token_mock, status_code, pf_run):
pf_run_mock = MagicMock()
pf_run_mock.name = "mock_pf_run"
pf_run_mock._experiment_name = "mock_pf_experiment"
mock_session = MagicMock()
mock_session.request.return_value = self._get_mock_create_resonse(status_code)
with patch("promptflow.evals.evaluate._eval_run.requests.Session", return_value=mock_session):
with patch(
"promptflow.evals._http_utils.HttpPipeline.request", return_value=self._get_mock_create_resonse(status_code)
):
run = EvalRun(run_name="test", **TestEvalRun._MOCK_CREDS, promptflow_run=pf_run_mock)
assert run.status == RunStatus.NOT_STARTED, f"Get {run.status}, expected {RunStatus.NOT_STARTED}"
run._start_run()
Expand Down Expand Up @@ -400,16 +385,17 @@ def test_write_properties(self, token_mock, caplog, status_code):
"""Test writing properties to the evaluate run."""
mock_write = MagicMock()
mock_write.status_code = status_code
mock_write.text = "Mock error"
mock_session = MagicMock()
mock_session.request.side_effect = [self._get_mock_create_resonse(), mock_write, self._get_mock_end_response()]
with patch("promptflow.evals.evaluate._eval_run.requests.Session", return_value=mock_session):
mock_write.text = lambda: "Mock error"
with patch(
"promptflow.evals._http_utils.HttpPipeline.request",
side_effect=[self._get_mock_create_resonse(), mock_write, self._get_mock_end_response()],
):
with EvalRun(run_name="test", **TestEvalRun._MOCK_CREDS) as run:
run.write_properties_to_run_history({"foo": "bar"})
if status_code != 200:
assert len(caplog.records) == 1
assert "Fail writing properties" in caplog.records[0].message
assert mock_write.text in caplog.records[0].message
assert mock_write.text() in caplog.records[0].message
else:
assert len(caplog.records) == 0

Expand Down Expand Up @@ -461,9 +447,10 @@ def test_logs_if_not_started(self, token_mock, caplog, function_literal, args, e
def test_starting_started_run(self, token_mock, status):
"""Test exception if the run was already started"""
run = EvalRun(run_name=None, **TestEvalRun._MOCK_CREDS)
mock_session = MagicMock()
mock_session.request.return_value = self._get_mock_create_resonse(500 if status == RunStatus.BROKEN else 200)
with patch("promptflow.evals.evaluate._eval_run.requests.Session", return_value=mock_session):
with patch(
"promptflow.evals._http_utils.HttpPipeline.request",
return_value=self._get_mock_create_resonse(500 if status == RunStatus.BROKEN else 200),
):
run._start_run()
if status == RunStatus.TERMINATED:
run._end_run("FINISHED")
Expand Down

0 comments on commit 9bd8a79

Please sign in to comment.