Skip to content

Commit

Permalink
test,fix: Mock azure.core instead of httpx in test_content_safety_rai…
Browse files Browse the repository at this point in the history
…_script
  • Loading branch information
kdestin committed Aug 21, 2024
1 parent 9bd8a79 commit 750bf29
Showing 1 changed file with 152 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import http
import os
import pathlib
from typing import Any, Iterator, MutableMapping, Optional
from unittest.mock import MagicMock, patch

import httpx
import numpy as np
import pytest
from azure.core.exceptions import HttpResponseError
from azure.core.rest import HttpRequest, HttpResponse
from azure.identity import DefaultAzureCredential

from promptflow.evals._common.constants import EvaluationMetrics, HarmSeverityLevel, RAIService
Expand All @@ -26,6 +29,95 @@ def data_file():
return os.path.join(data_path, "evaluate_test_data.jsonl")


class MockHttpResponse(HttpResponse):
"""A mocked implementation of azure.core.rest.HttpResponse."""

def __init__(
self,
status_code: int,
*,
text: Optional[str] = None,
json: Optional[Any] = None,
headers: Optional[MutableMapping[str, str]] = None,
request: Optional[HttpRequest] = None,
content_type: Optional[str] = None,
) -> None:
self._status_code = status_code
self._text = text or ""
self._json = json
self._request = request
self._headers = headers or {}
self._content_type = content_type

def json(self) -> Any:
return self._json

def text(self, encoding: Optional[str] = None) -> str:
return self._text

@property
def status_code(self) -> int:
return self._status_code

@property
def request(self) -> HttpRequest:
return self._request

@property
def reason(self) -> str:
return f"{self.status_code} {http.client.responses[self.status_code]}"

@property
def headers(self) -> MutableMapping[str, str]:
return self._headers

@property
def content_type(self) -> Optional[str]:
return self._content_type

@property
def is_closed(self) -> bool:
return True

@property
def is_stream_consumed(self) -> bool:
return True

@property
def encoding(self) -> Optional[str]:
return None

def raise_for_status(self) -> None:
if self.status_code >= 400:
raise HttpResponseError(response=self)

def close(self) -> None:
pass

def __enter__(self) -> object:
raise NotImplementedError()

def __exit__(self, *args) -> None:
raise NotImplementedError()

@property
def url(self) -> str:
raise NotImplementedError()

@property
def content(self) -> bytes:
raise NotImplementedError()

def read(self) -> bytes:
raise NotImplementedError()

def iter_bytes(self, **kwargs) -> Iterator[bytes]:
raise NotImplementedError()

def iter_raw(self, **kwargs) -> Iterator[bytes]:
raise NotImplementedError()


@pytest.mark.usefixtures("mock_project_scope")
@pytest.mark.unittest
class TestContentSafetyEvaluator:
Expand All @@ -43,26 +135,33 @@ def test_rai_subscript_functions(self):
ensure_service_availability()"""

@pytest.mark.asyncio
@patch("httpx.AsyncClient.get", return_value=httpx.Response(200, json={}))
@patch("promptflow.evals._http_utils.AsyncHttpPipeline.get", return_value=MockHttpResponse(200, json={}))
async def test_ensure_service_availability(self, client_mock):
_ = await ensure_service_availability("dummy_url", "dummy_token")
client_mock.return_value.status_code = 9001
assert client_mock._mock_await_count == 1

@pytest.mark.asyncio
@patch("promptflow.evals._http_utils.AsyncHttpPipeline.get", return_value=MockHttpResponse(9001, json={}))
async def test_ensure_service_availability_service_unavailable(self, client_mock):
with pytest.raises(Exception) as exc_info:
_ = await ensure_service_availability("dummy_url", "dummy_token")
assert "RAI service is not available in this region. Status Code: 9001" in str(exc_info._excinfo[1])
client_mock.return_value.status_code = 200
assert client_mock._mock_await_count == 1

@pytest.mark.asyncio
@patch("promptflow.evals._http_utils.AsyncHttpPipeline.get", return_value=MockHttpResponse(200, json={}))
async def test_ensure_service_availability_exception_capability_unavailable(self, client_mock):
with pytest.raises(Exception) as exc_info:
_ = await ensure_service_availability("dummy_url", "dummy_token", capability="does not exist")
assert "Capability 'does not exist' is not available in this region" in str(exc_info._excinfo[1])
assert client_mock._mock_await_count == 3
assert client_mock._mock_await_count == 1

@pytest.mark.asyncio
@patch(
"httpx.AsyncClient.post",
return_value=httpx.Response(
"promptflow.evals._http_utils.AsyncHttpPipeline.post",
return_value=MockHttpResponse(
202,
json={"location": "this/is/the/dummy-operation-id"},
request=httpx.Request("POST", "test"),
),
)
async def test_submit_request(self, client_mock):
Expand All @@ -74,17 +173,26 @@ async def test_submit_request(self, client_mock):
token="dummy",
)
assert result == "dummy-operation-id"
client_mock.return_value.status_code = 404
with pytest.raises(httpx.HTTPStatusError) as exc_info:

@pytest.mark.asyncio
@patch(
"promptflow.evals._http_utils.AsyncHttpPipeline.post",
return_value=MockHttpResponse(
404,
json={"location": "this/is/the/dummy-operation-id"},
content_type="application/json",
),
)
async def test_submit_request_not_found(self, client_mock):
with pytest.raises(HttpResponseError) as exc_info:
_ = await submit_request(
question="What is the meaning of life",
answer="42",
metric="points",
rai_svc_url="www.notarealurl.com",
token="dummy",
)
assert "Client error '404 Not Found' for url 'test'" in str(exc_info._excinfo[1])
assert client_mock._mock_await_count == 2
assert "Operation returned an invalid status '404 Not Found'" in str(exc_info._excinfo[1])

@pytest.mark.usefixtures("mock_token")
@pytest.mark.usefixtures("mock_expired_token")
Expand All @@ -102,7 +210,10 @@ async def test_fetch_or_reuse_token(self, mock_token, mock_expired_token):
res = await fetch_or_reuse_token(credential=mock_cred, token="not-a-token")
assert res == 100

@patch("httpx.AsyncClient.get", return_value=httpx.Response(200, json={"result": "stuff"}))
@patch(
"promptflow.evals._http_utils.AsyncHttpPipeline.get",
return_value=MockHttpResponse(200, json={"result": "stuff"}),
)
@patch("promptflow.evals._common.constants.RAIService.TIMEOUT", 1)
@patch("promptflow.evals._common.constants.RAIService.SLEEP_TIME", 1.2)
@pytest.mark.usefixtures("mock_token")
Expand All @@ -118,13 +229,21 @@ async def test_fetch_result(self, client_mock, mock_token):
assert client_mock._mock_await_count == 1
assert res["result"] == "stuff"

client_mock.return_value.status_code = 404
@patch(
"promptflow.evals._http_utils.AsyncHttpPipeline.get",
return_value=MockHttpResponse(404, json={"result": "stuff"}),
)
@patch("promptflow.evals._common.constants.RAIService.TIMEOUT", 1)
@patch("promptflow.evals._common.constants.RAIService.SLEEP_TIME", 1.2)
@pytest.mark.usefixtures("mock_token")
@pytest.mark.asyncio
async def test_fetch_result_timeout(self, client_mock, mock_token):
with pytest.raises(TimeoutError) as exc_info:
_ = await fetch_result(
operation_id="op-id", rai_svc_url="www.notarealurl.com", credential=None, token=mock_token
)
# We expect 2 more calls; the initial call, then one more ~2 seconds later.
assert client_mock._mock_await_count == 3
# We expect 2 calls; the initial call, then one more ~2 seconds later.
assert client_mock._mock_await_count == 2
# Don't bother checking exact time beyond seconds, that's never going to be consistent across machines.
assert "Fetching annotation result 2 times out after 1" in str(exc_info._excinfo[1])

Expand Down Expand Up @@ -201,8 +320,8 @@ def test_parse_response(self):

@pytest.mark.asyncio
@patch(
"httpx.AsyncClient.get",
return_value=httpx.Response(200, json={"properties": {"discoveryUrl": "https://www.url.com:123/thePath"}}),
"promptflow.evals._http_utils.AsyncHttpPipeline.get",
return_value=MockHttpResponse(200, json={"properties": {"discoveryUrl": "https://www.url.com:123/thePath"}}),
)
async def test_get_service_discovery_url(self, client_mock):

Expand All @@ -216,16 +335,27 @@ async def test_get_service_discovery_url(self, client_mock):
url = await _get_service_discovery_url(azure_ai_project=azure_ai_project, token=token)
assert url == "https://www.url.com:123"

client_mock.return_value.status_code = 201
@pytest.mark.asyncio
@patch(
"promptflow.evals._http_utils.AsyncHttpPipeline.get",
return_value=MockHttpResponse(201, json={"properties": {"discoveryUrl": "https://www.url.com:123/thePath"}}),
)
async def test_get_service_discovery_url_exception(self, client_mock):
token = "fake-token"
azure_ai_project = {
"subscription_id": "fake-id",
"project_name": "fake-name",
"resource_group_name": "fake-group",
}

with pytest.raises(Exception) as exc_info:
_ = await _get_service_discovery_url(azure_ai_project=azure_ai_project, token=token)
assert "Failed to retrieve the discovery service URL" in str(exc_info._excinfo[1])
assert client_mock._mock_await_count == 2

@pytest.mark.asyncio
@patch(
"httpx.AsyncClient.get",
return_value=httpx.Response(200, json={"properties": {"discoveryUrl": "https://www.url.com:123/thePath"}}),
"promptflow.evals._http_utils.AsyncHttpPipeline.get",
return_value=MockHttpResponse(200, json={"properties": {"discoveryUrl": "https://www.url.com:123/thePath"}}),
)
@patch(
"promptflow.evals._common.rai_service._get_service_discovery_url",
Expand Down

0 comments on commit 750bf29

Please sign in to comment.