From 0e507e6f96635fb73afaf1d1031ac9580dc684df Mon Sep 17 00:00:00 2001 From: Max Chis Date: Wed, 11 Dec 2024 10:34:01 -0500 Subject: [PATCH 1/4] Begin work on agencies matching endpoint --- app.py | 2 + database_client/database_client.py | 1 + .../primary_resource_logic/match_logic.py | 130 ++++++++++++++++++ .../primary_resource_logic/search_logic.py | 2 +- .../primary_resource_dtos/match_dtos.py | 14 ++ .../primary_resource_schemas/match_schemas.py | 17 +++ requirements.txt | 3 +- resources/Match.py | 31 +++++ .../complex_test_data_creation_functions.py | 11 ++ .../helper_classes/SimpleTempFile.py | 22 +++ .../helper_classes/TestCSVCreator.py | 20 +++ tests/integration/test_batch.py | 67 +-------- tests/integration/test_match.py | 99 +++++++++++++ utilities/namespace.py | 1 + 14 files changed, 356 insertions(+), 64 deletions(-) create mode 100644 middleware/primary_resource_logic/match_logic.py create mode 100644 middleware/schema_and_dto_logic/primary_resource_dtos/match_dtos.py create mode 100644 middleware/schema_and_dto_logic/primary_resource_schemas/match_schemas.py create mode 100644 resources/Match.py create mode 100644 tests/helper_scripts/helper_classes/SimpleTempFile.py create mode 100644 tests/helper_scripts/helper_classes/TestCSVCreator.py create mode 100644 tests/integration/test_match.py diff --git a/app.py b/app.py index 48382791..92eb14c4 100644 --- a/app.py +++ b/app.py @@ -13,6 +13,7 @@ from resources.LinkToGithub import namespace_link_to_github from resources.LoginWithGithub import namespace_login_with_github from resources.Map import namespace_map +from resources.Match import namespace_match from resources.Notifications import namespace_notifications from resources.OAuth import namespace_oauth from resources.Permissions import namespace_permissions @@ -66,6 +67,7 @@ namespace_map, namespace_signup, namespace_batch, + namespace_match, ] MY_PREFIX = "/api" diff --git a/database_client/database_client.py b/database_client/database_client.py index bfee419a..d4e7c105 100644 --- a/database_client/database_client.py +++ b/database_client/database_client.py @@ -538,6 +538,7 @@ def search_with_location_and_record_type( :param locality: The locality to search for data sources in. If None, all data sources will be searched for. :return: A list of dictionaries. """ + optional_kwargs = {} query = DynamicQueryConstructor.create_search_query( state=state, record_categories=record_categories, diff --git a/middleware/primary_resource_logic/match_logic.py b/middleware/primary_resource_logic/match_logic.py new file mode 100644 index 00000000..6b05cb22 --- /dev/null +++ b/middleware/primary_resource_logic/match_logic.py @@ -0,0 +1,130 @@ +from enum import Enum +from typing import Optional, List + +from pydantic import BaseModel + +from database_client.database_client import DatabaseClient +from database_client.db_client_dataclasses import WhereMapping +from middleware.schema_and_dto_logic.primary_resource_dtos.match_dtos import ( + AgencyMatchOuterDTO, + AgencyMatchInnerDTO, +) +from rapidfuzz import fuzz + + +class AgencyMatchStatus(Enum): + EXACT = "Exact Match" + PARTIAL = "Partial Matches" + LOCATION = "Location Only" + NO_MATCH = "No Match" + + +SIMILARITY_THRESHOLD = 80 + + +def get_agency_match_message(status: AgencyMatchStatus): + match status: + case AgencyMatchStatus.EXACT: + return "Exact match found." + case AgencyMatchStatus.PARTIAL: + return "Partial matches found." + case AgencyMatchStatus.LOCATION: + return "Matches found only on location." + case AgencyMatchStatus.NO_MATCH: + return "No matches found." + + +class AgencyMatchResponse: + + def __init__(self, status: AgencyMatchStatus, agencies: Optional[list] = None): + self.status = status + self.agencies = agencies + self.message = get_agency_match_message(status) + + +def match_agencies(db_client: DatabaseClient, dto: AgencyMatchOuterDTO): + amrs: List[AgencyMatchResponse] = [] + for entry in dto.entries: + amr: AgencyMatchResponse = try_matching_agency(db_client=db_client, dto=entry) + amrs.append(amr) + + +def try_getting_exact_match_agency(dto: AgencyMatchInnerDTO, agencies: list[dict]): + for agency in agencies: + if agency["submitted_name"] == dto.name: + return agency + + +def try_getting_partial_match_agencies(dto: AgencyMatchInnerDTO, agencies: list[dict]): + partial_matches = [] + for agency in agencies: + if fuzz.ratio(dto.name, agency["submitted_name"]) >= SIMILARITY_THRESHOLD: + partial_matches.append(agency) + + return partial_matches + + +def try_matching_agency( + db_client: DatabaseClient, dto: AgencyMatchInnerDTO +) -> AgencyMatchResponse: + + location_id = _get_location_id(db_client, dto) + if location_id is None: + return _no_match_response() + + agencies = _get_agencies(db_client, location_id) + if len(agencies) == 0: + return _no_match_response() + + exact_match_agency = try_getting_exact_match_agency(dto=dto, agencies=agencies) + if exact_match_agency is not None: + return _exact_match_response(exact_match_agency) + + partial_match_agencies = try_getting_partial_match_agencies( + dto=dto, agencies=agencies + ) + if len(partial_match_agencies) > 0: + return _partial_match_response(partial_match_agencies) + + return _location_match_response(agencies) + + +def _location_match_response(agencies): + return AgencyMatchResponse(status=AgencyMatchStatus.LOCATION, agencies=agencies) + + +def _partial_match_response(partial_match_agencies): + return AgencyMatchResponse( + status=AgencyMatchStatus.PARTIAL, agencies=partial_match_agencies + ) + + +def _exact_match_response(exact_match_agency): + return AgencyMatchResponse( + status=AgencyMatchStatus.EXACT, agencies=[exact_match_agency] + ) + + +def _no_match_response(): + return AgencyMatchResponse( + status=AgencyMatchStatus.NO_MATCH, + ) + + +def _get_agencies(db_client, location_id): + return db_client.get_agencies( + columns=["id", "submitted_name"], + where_mappings=WhereMapping.from_dict({"location_id": location_id}), + ) + + +def _get_location_id(db_client, dto): + return db_client.get_location_id( + where_mappings=WhereMapping.from_dict( + { + "state_name": dto.state, + "county_name": dto.county, + "locality_name": dto.locality, + } + ) + ) diff --git a/middleware/primary_resource_logic/search_logic.py b/middleware/primary_resource_logic/search_logic.py index bdeb7ec9..26e77e70 100644 --- a/middleware/primary_resource_logic/search_logic.py +++ b/middleware/primary_resource_logic/search_logic.py @@ -105,9 +105,9 @@ def search_wrapper( create_search_record(access_info, db_client, dto) explicit_record_categories = get_explicit_record_categories(dto.record_categories) search_results = db_client.search_with_location_and_record_type( + record_categories=explicit_record_categories, state=dto.state, # Pass modified record categories, which breaks down ALL into individual categories - record_categories=explicit_record_categories, county=dto.county, locality=dto.locality, ) diff --git a/middleware/schema_and_dto_logic/primary_resource_dtos/match_dtos.py b/middleware/schema_and_dto_logic/primary_resource_dtos/match_dtos.py new file mode 100644 index 00000000..9692cbda --- /dev/null +++ b/middleware/schema_and_dto_logic/primary_resource_dtos/match_dtos.py @@ -0,0 +1,14 @@ +from typing import Optional + +from pydantic import BaseModel + + +class AgencyMatchInnerDTO(BaseModel): + name: str + state: str + county: Optional[str] + locality: Optional[str] + + +class AgencyMatchOuterDTO(BaseModel): + entries: list[AgencyMatchInnerDTO] diff --git a/middleware/schema_and_dto_logic/primary_resource_schemas/match_schemas.py b/middleware/schema_and_dto_logic/primary_resource_schemas/match_schemas.py new file mode 100644 index 00000000..22f6794f --- /dev/null +++ b/middleware/schema_and_dto_logic/primary_resource_schemas/match_schemas.py @@ -0,0 +1,17 @@ +from marshmallow import Schema, fields + +from middleware.schema_and_dto_logic.util import get_json_metadata + + +class AgencyMatchInnerSchema(Schema): + name = fields.String(metadata=get_json_metadata("The name of the agency")) + state = fields.String(metadata=get_json_metadata("The state of the agency")) + county = fields.String(metadata=get_json_metadata("The county of the agency")) + locality = fields.String(metadata=get_json_metadata("The locality of the agency")) + + +class AgencyMatterOuterSchema(Schema): + entries = fields.List( + fields.Nested(AgencyMatchInnerSchema()), + metadata=get_json_metadata("The proposed agencies to find matches for."), + ) diff --git a/requirements.txt b/requirements.txt index b063c953..a9a606f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -70,4 +70,5 @@ PyJWT~=2.9.0 marshmallow~=3.22.0 PyGithub~=2.4.0 dominate~=2.9.1 -pre-commit~=4.0.1 \ No newline at end of file +pre-commit~=4.0.1 +RapidFuzz~=3.10.1 \ No newline at end of file diff --git a/resources/Match.py b/resources/Match.py new file mode 100644 index 00000000..e5d03576 --- /dev/null +++ b/resources/Match.py @@ -0,0 +1,31 @@ +from flask import Response + +from middleware.access_logic import STANDARD_JWT_AUTH_INFO, AccessInfoPrimary +from middleware.decorators import endpoint_info +from middleware.primary_resource_logic.match_logic import match_agencies +from resources.PsycopgResource import PsycopgResource +from resources.endpoint_schema_config import SchemaConfigs +from resources.resource_helpers import ResponseInfo +from utilities.namespace import create_namespace, AppNamespaces + +namespace_match = create_namespace(AppNamespaces.MATCH) + + +@namespace_match.route("/agencies") +class MatchAgencies(PsycopgResource): + + @endpoint_info( + namespace=namespace_match, + auth_info=STANDARD_JWT_AUTH_INFO, + schema_config=SchemaConfigs.MATCH_AGENCIES, + response_info=ResponseInfo( + success_message="Found any possible matches for the search criteria." + ), + description="Returns agencies, if any, that match or partially match the search criteria", + ) + def post(self, access_info: AccessInfoPrimary) -> Response: + return self.run_endpoint( + wrapper_function=match_agencies, + schema_populate_parameters=SchemaConfigs.MATCH_AGENCIES.value.get_schema_populate_parameters(), + access_info=access_info, + ) diff --git a/tests/helper_scripts/complex_test_data_creation_functions.py b/tests/helper_scripts/complex_test_data_creation_functions.py index dfd7907a..100568e1 100644 --- a/tests/helper_scripts/complex_test_data_creation_functions.py +++ b/tests/helper_scripts/complex_test_data_creation_functions.py @@ -158,6 +158,17 @@ def create_test_agency(flask_client: FlaskClient, jwt_authorization_header: dict return TestAgencyInfo(id=json["id"], submitted_name=submitted_name) +def get_sample_location_info(locality_name: Optional[str] = None) -> dict: + if locality_name is None: + locality_name = get_test_name() + return { + "type": "Locality", + "state_iso": "PA", + "county_fips": "42003", + "locality_name": locality_name, + } + + def get_sample_agency_post_parameters( submitted_name, locality_name, diff --git a/tests/helper_scripts/helper_classes/SimpleTempFile.py b/tests/helper_scripts/helper_classes/SimpleTempFile.py new file mode 100644 index 00000000..239f9c6d --- /dev/null +++ b/tests/helper_scripts/helper_classes/SimpleTempFile.py @@ -0,0 +1,22 @@ +import os +import tempfile + + +class SimpleTempFile: + + def __init__(self, suffix: str = ".csv"): + self.suffix = suffix + self.temp_file = None + + def __enter__(self): + self.temp_file = tempfile.NamedTemporaryFile( + mode="w+", encoding="utf-8", suffix=self.suffix, delete=False + ) + return self.temp_file + + def __exit__(self, exc_type, exc_value, traceback): + try: + self.temp_file.close() + os.unlink(self.temp_file.name) + except Exception as e: + print(f"Error cleaning up temporary file {self.temp_file.name}: {e}") diff --git a/tests/helper_scripts/helper_classes/TestCSVCreator.py b/tests/helper_scripts/helper_classes/TestCSVCreator.py new file mode 100644 index 00000000..0731cae7 --- /dev/null +++ b/tests/helper_scripts/helper_classes/TestCSVCreator.py @@ -0,0 +1,20 @@ +from csv import DictWriter +from typing import TextIO + +from marshmallow import Schema + + +class TestCSVCreator: + + def __init__(self, schema: Schema): + self.fields = schema.fields + + def create_csv(self, file: TextIO, rows: list[dict]): + header = list(self.fields.keys()) + writer = DictWriter(file, fieldnames=header) + # Write header row using header + writer.writerow({field: field for field in header}) + + for row in rows: + writer.writerow(row) + file.close() diff --git a/tests/integration/test_batch.py b/tests/integration/test_batch.py index 46b05ad1..a96f0e5e 100644 --- a/tests/integration/test_batch.py +++ b/tests/integration/test_batch.py @@ -1,16 +1,12 @@ -import os from dataclasses import dataclass -from enum import Enum from http import HTTPStatus -from typing import TextIO, Optional, Annotated, Callable +from typing import Optional, Annotated, Callable import pytest from marshmallow import Schema -from csv import DictWriter -import tempfile -from conftest import test_data_creator_flask, monkeysession +from conftest import test_data_creator_flask from database_client.enums import LocationType from middleware.primary_resource_logic.batch_logic import listify_strings from middleware.schema_and_dto_logic.common_response_schemas import MessageSchema @@ -31,47 +27,14 @@ from tests.helper_scripts.helper_classes.SchemaTestDataGenerator import ( generate_test_data_from_schema, ) +from tests.helper_scripts.helper_classes.SimpleTempFile import SimpleTempFile +from tests.helper_scripts.helper_classes.TestCSVCreator import TestCSVCreator from tests.helper_scripts.helper_classes.TestDataCreatorFlask import ( TestDataCreatorFlask, ) -SUFFIX_ARRAY = [".csv", ".csv", ".inv", ".csv"] - -# Helper scripts -# -class ResourceType(Enum): - AGENCY = "agencies" - SOURCE = "data_sources" - REQUEST = "data_requests" - - -def get_endpoint_nomenclature( - resource_type: ResourceType, -) -> str: - """ - Get a string version of the resource, appropriate for endpoints. - """ - return resource_type.value.replace("_", "-") - - -class TestCSVCreator: - - def __init__(self, schema: Schema): - self.fields = schema.fields - - def create_csv(self, file: TextIO, rows: list[dict]): - header = list(self.fields.keys()) - writer = DictWriter(file, fieldnames=header) - # Write header row using header - writer.writerow({field: field for field in header}) - - for row in rows: - writer.writerow(row) - file.close() - - -def stringify_list_of_ints(l: list): +def stringify_list_of_ints(l: list[int]): for i in range(len(l)): l[i] = str(l[i]) return l @@ -149,26 +112,6 @@ def data_sources_put_runner(runner: BatchTestRunner): ) -class SimpleTempFile: - - def __init__(self, suffix: str = ".csv"): - self.suffix = suffix - self.temp_file = None - - def __enter__(self): - self.temp_file = tempfile.NamedTemporaryFile( - mode="w+", encoding="utf-8", suffix=self.suffix, delete=False - ) - return self.temp_file - - def __exit__(self, exc_type, exc_value, traceback): - try: - self.temp_file.close() - os.unlink(self.temp_file.name) - except Exception as e: - print(f"Error cleaning up temporary file {self.temp_file.name}: {e}") - - def generate_agencies_locality_data(): locality_name = get_test_name() return { diff --git a/tests/integration/test_match.py b/tests/integration/test_match.py new file mode 100644 index 00000000..9ed92012 --- /dev/null +++ b/tests/integration/test_match.py @@ -0,0 +1,99 @@ +from dataclasses import dataclass +from typing import Optional + +import pytest +from conftest import test_data_creator_flask, monkeysession +from middleware.primary_resource_logic.match_logic import ( + try_matching_agency, + AgencyMatchStatus, +) +from middleware.schema_and_dto_logic.primary_resource_dtos.match_dtos import ( + AgencyMatchInnerDTO, +) +from tests.helper_scripts.common_test_data import get_test_name +from tests.helper_scripts.complex_test_data_creation_functions import ( + get_sample_location_info, +) +from tests.helper_scripts.helper_classes.TestDataCreatorFlask import ( + TestDataCreatorFlask, +) + + +@dataclass +class TestMatchAgencySetup: + tdc: TestDataCreatorFlask + location_kwargs: dict + agency_name: str + + def additional_agency(self, locality_name: Optional[str] = None): + return self.tdc.agency( + location_info=get_sample_location_info(locality_name=locality_name) + ) + + +@pytest.fixture +def match_agency_setup( + test_data_creator_flask: TestDataCreatorFlask, +) -> TestMatchAgencySetup: + tdc = test_data_creator_flask + tdc.clear_test_data() + locality_name = get_test_name() + locality_id = tdc.locality(locality_name=locality_name) + location_info = get_sample_location_info(locality_name=locality_name) + agency = tdc.agency(location_info=location_info) + return TestMatchAgencySetup( + tdc=tdc, + location_kwargs={ + "state": "Pennsylvania", + "county": "Allegheny", + "locality": locality_name, + }, + agency_name=agency.submitted_name, + ) + + +def test_agency_match_exact_match( + match_agency_setup: TestMatchAgencySetup, +): + mas = match_agency_setup + + dto = AgencyMatchInnerDTO(name=mas.agency_name, **mas.location_kwargs) + amr = try_matching_agency(db_client=mas.tdc.db_client, dto=dto) + assert amr.status == AgencyMatchStatus.EXACT + + +def test_agency_match_partial_match(match_agency_setup: TestMatchAgencySetup): + mas = match_agency_setup + + # Add an additional agency with the same location information but different name + # This should not be picked up. + mas.additional_agency() + modified_agency_name = mas.agency_name + "1" + dto = AgencyMatchInnerDTO(name=modified_agency_name, **mas.location_kwargs) + amr = try_matching_agency(db_client=mas.tdc.db_client, dto=dto) + assert amr.status == AgencyMatchStatus.PARTIAL + assert len(amr.agencies) == 1 + assert mas.agency_name in amr.agencies[0]["submitted_name"] + + +def test_agency_match_location_match(match_agency_setup: TestMatchAgencySetup): + mas = match_agency_setup + mas.additional_agency(locality_name=mas.location_kwargs["locality"]) + + dto = AgencyMatchInnerDTO(name=get_test_name(), **mas.location_kwargs) + amr = try_matching_agency(db_client=mas.tdc.db_client, dto=dto) + assert amr.status == AgencyMatchStatus.LOCATION + assert len(amr.agencies) == 2 + + +def test_agency_match_no_match(match_agency_setup: TestMatchAgencySetup): + mas = match_agency_setup + + dto = AgencyMatchInnerDTO( + name=get_test_name(), + state="New York", + county="New York", + locality=get_test_name(), + ) + amr = try_matching_agency(db_client=mas.tdc.db_client, dto=dto) + assert amr.status == AgencyMatchStatus.NO_MATCH diff --git a/utilities/namespace.py b/utilities/namespace.py index f6eba507..e2450309 100644 --- a/utilities/namespace.py +++ b/utilities/namespace.py @@ -26,6 +26,7 @@ class AppNamespaces(Enum): ) MAP = NamespaceAttributes(path="map", description="Map Namespace") BATCH = NamespaceAttributes(path="batch", description="Batch Namespace") + MATCH = NamespaceAttributes(path="match", description="Match Namespace") def create_namespace( From 066a60afc63a3ddabcd39d5c875c509245f2612d Mon Sep 17 00:00:00 2001 From: Max Chis Date: Wed, 11 Dec 2024 11:10:05 -0500 Subject: [PATCH 2/4] Continue work on agencies matching endpoint --- resources/endpoint_schema_config.py | 3 ++ ...on_test_functions.py => common_asserts.py} | 24 +++++----- .../complex_test_data_creation_functions.py | 42 ------------------ tests/helper_scripts/helper_functions.py | 15 ------- .../run_and_validate_request.py | 4 +- .../simple_result_validators.py | 24 ---------- tests/integration/test_agencies.py | 2 +- tests/integration/test_api_doc_load.py | 4 +- tests/integration/test_batch.py | 2 +- tests/integration/test_data_sources.py | 2 +- tests/integration/test_github_oauth.py | 2 +- tests/integration/test_login.py | 2 +- tests/integration/test_match.py | 44 +++++++++++++++---- tests/resources/test_Search.py | 6 +-- tests/utilities/test_rate_limiter.py | 6 +-- 15 files changed, 65 insertions(+), 117 deletions(-) rename tests/helper_scripts/{common_test_functions.py => common_asserts.py} (72%) delete mode 100644 tests/helper_scripts/simple_result_validators.py diff --git a/resources/endpoint_schema_config.py b/resources/endpoint_schema_config.py index 929cc2ef..999296e9 100644 --- a/resources/endpoint_schema_config.py +++ b/resources/endpoint_schema_config.py @@ -519,5 +519,8 @@ class SchemaConfigs(Enum): input_dto_class=AgenciesPutDTO, primary_output_schema=BatchPutResponseSchema(), ) + # endregion + # region Match + MATCH_AGENCIES = EndpointSchemaConfig() # endregion diff --git a/tests/helper_scripts/common_test_functions.py b/tests/helper_scripts/common_asserts.py similarity index 72% rename from tests/helper_scripts/common_test_functions.py rename to tests/helper_scripts/common_asserts.py index fe018e66..8f94ea83 100644 --- a/tests/helper_scripts/common_test_functions.py +++ b/tests/helper_scripts/common_asserts.py @@ -4,21 +4,12 @@ from http import HTTPStatus +from flask import Response from flask_jwt_extended import decode_token from database_client.constants import PAGE_SIZE from database_client.database_client import DatabaseClient - -from tests.helper_scripts.simple_result_validators import ( - check_response_status, - assert_is_oauth_redirect_link, -) - - -def assert_expected_pre_callback_response(response): - check_response_status(response, HTTPStatus.FOUND) - response_text = response.text - assert_is_oauth_redirect_link(response_text) +from tests.helper_scripts.constants import TEST_RESPONSE def assert_api_key_exists_for_email(db_client: DatabaseClient, email: str, api_key): @@ -50,3 +41,14 @@ def assert_contains_key_value_pairs( assert key in dict_to_check, f"Expected {key} to be in {dict_to_check}" dict_value = dict_to_check[key] assert dict_value == value, f"Expected {key} to be {value}, was {dict_value}" + + +def assert_is_test_response(response): + assert_response_status(response, TEST_RESPONSE.status_code) + assert response.json == TEST_RESPONSE.response + + +def assert_response_status(response: Response, status_code): + assert ( + response.status_code == status_code + ), f"{response.request.base_url}: Expected status code {status_code}, got {response.status_code}: {response.text}" diff --git a/tests/helper_scripts/complex_test_data_creation_functions.py b/tests/helper_scripts/complex_test_data_creation_functions.py index 100568e1..c3ef00b1 100644 --- a/tests/helper_scripts/complex_test_data_creation_functions.py +++ b/tests/helper_scripts/complex_test_data_creation_functions.py @@ -55,28 +55,6 @@ def insert_test_column_permission_data(db_client: DatabaseClient): pass # Already added -def create_agency_entry_for_search_cache(db_client: DatabaseClient) -> str: - """ - Create an entry in `Agencies` guaranteed to appear in the search cache functionality - :param db_client: - :return: - """ - submitted_name = "TEST SEARCH CACHE NAME" - db_client._create_entry_in_table( - table_name="agencies", - column_value_mappings={ - "submitted_name": submitted_name, - "name": submitted_name, - "airtable_uid": uuid.uuid4().hex[:15], - "count_data_sources": 2000, # AKA, an absurdly high number to guarantee it's the first result - "approved": True, - "homepage_url": None, - "jurisdiction_type": JurisdictionType.FEDERAL.value, - }, - ) - return submitted_name - - def create_data_source_entry_for_url_duplicate_checking( db_client: DatabaseClient, ) -> str: @@ -138,26 +116,6 @@ def create_test_data_request( return TestDataRequestInfo(id=json["id"], submission_notes=submission_notes) -def create_test_agency(flask_client: FlaskClient, jwt_authorization_header: dict): - submitted_name = get_test_name() - locality_name = get_test_name() - sample_agency_post_parameters = get_sample_agency_post_parameters( - submitted_name=submitted_name, - locality_name=locality_name, - jurisdiction_type=JurisdictionType.LOCAL, - ) - - json = run_and_validate_request( - flask_client=flask_client, - http_method="post", - endpoint=AGENCIES_BASE_ENDPOINT, - headers=jwt_authorization_header, - json=sample_agency_post_parameters, - ) - - return TestAgencyInfo(id=json["id"], submitted_name=submitted_name) - - def get_sample_location_info(locality_name: Optional[str] = None) -> dict: if locality_name is None: locality_name = get_test_name() diff --git a/tests/helper_scripts/helper_functions.py b/tests/helper_scripts/helper_functions.py index 40649724..3117e47d 100644 --- a/tests/helper_scripts/helper_functions.py +++ b/tests/helper_scripts/helper_functions.py @@ -1,11 +1,9 @@ """This module contains helper functions used by middleware pytests.""" -import uuid from collections import namedtuple from datetime import datetime, timezone, timedelta from typing import Optional from http import HTTPStatus -from unittest.mock import MagicMock from urllib.parse import urlparse, parse_qs, urlencode, urlunparse import psycopg @@ -15,21 +13,13 @@ from database_client.database_client import DatabaseClient from database_client.db_client_dataclasses import WhereMapping -from middleware.custom_dataclasses import ( - GithubUserInfo, - OAuthCallbackInfo, - FlaskSessionCallbackInfo, -) from middleware.enums import ( - CallbackFunctionsEnum, PermissionsEnum, Relations, JurisdictionType, ) from resources.ApiKeyResource import API_KEY_ROUTE from tests.helper_scripts.common_test_data import get_test_name, get_test_email -from tests.helper_scripts.constants import TEST_RESPONSE -from tests.helper_scripts.simple_result_validators import check_response_status from tests.helper_scripts.helper_classes.TestUserSetup import TestUserSetup from tests.helper_scripts.helper_classes.UserInfo import UserInfo @@ -37,11 +27,6 @@ TestUser = namedtuple("TestUser", ["id", "email", "password_hash"]) -def check_is_test_response(response): - check_response_status(response, TEST_RESPONSE.status_code) - assert response.json == TEST_RESPONSE.response - - def create_test_user_db_client(db_client: DatabaseClient) -> UserInfo: email = get_test_email() password = get_test_name() diff --git a/tests/helper_scripts/run_and_validate_request.py b/tests/helper_scripts/run_and_validate_request.py index 815c9b23..7218008c 100644 --- a/tests/helper_scripts/run_and_validate_request.py +++ b/tests/helper_scripts/run_and_validate_request.py @@ -5,7 +5,7 @@ from marshmallow import Schema from tests.helper_scripts.helper_functions import add_query_params -from tests.helper_scripts.simple_result_validators import check_response_status +from tests.helper_scripts.common_asserts import assert_response_status http_methods = Literal["get", "post", "put", "patch", "delete"] @@ -54,7 +54,7 @@ def run_and_validate_request( ) else: response = flask_client.open(endpoint, method=http_method, **request_kwargs) - check_response_status(response, expected_response_status.value) + assert_response_status(response, expected_response_status.value) if not return_json: return response.data diff --git a/tests/helper_scripts/simple_result_validators.py b/tests/helper_scripts/simple_result_validators.py deleted file mode 100644 index 8b8045b0..00000000 --- a/tests/helper_scripts/simple_result_validators.py +++ /dev/null @@ -1,24 +0,0 @@ -from flask import Response - - -def has_expected_keys(result_keys: list, expected_keys: list) -> bool: - """ - Check that given result includes expected keys. - - :param result: - :param expected_keys: - :return: True if has expected keys, false otherwise - """ - return not set(expected_keys).difference(result_keys) - - -def check_response_status(response: Response, status_code): - assert ( - response.status_code == status_code - ), f"{response.request.base_url}: Expected status code {status_code}, got {response.status_code}: {response.text}" - - -def assert_is_oauth_redirect_link(text: str): - assert "https://github.com/login/oauth/authorize?response_type=code" in text, ( - "Expected OAuth authorize link, got: " + text - ) diff --git a/tests/integration/test_agencies.py b/tests/integration/test_agencies.py index d3d85d97..d9b0631e 100644 --- a/tests/integration/test_agencies.py +++ b/tests/integration/test_agencies.py @@ -26,7 +26,7 @@ ) from tests.helper_scripts.constants import AGENCIES_BASE_ENDPOINT -from tests.helper_scripts.common_test_functions import ( +from tests.helper_scripts.common_asserts import ( assert_expected_get_many_result, assert_contains_key_value_pairs, ) diff --git a/tests/integration/test_api_doc_load.py b/tests/integration/test_api_doc_load.py index 0ba067ec..e1b83f3d 100644 --- a/tests/integration/test_api_doc_load.py +++ b/tests/integration/test_api_doc_load.py @@ -2,7 +2,7 @@ from tests.conftest import dev_db_client, flask_client_with_db from tests.helper_scripts.run_and_validate_request import run_and_validate_request -from tests.helper_scripts.simple_result_validators import check_response_status +from tests.helper_scripts.common_asserts import assert_response_status def test_api_doc_load(flask_client_with_db): @@ -15,5 +15,5 @@ def test_api_doc_load(flask_client_with_db): response = flask_client_with_db.open( "/api/swagger.json", method="get", follow_redirects=True ) - check_response_status(response, HTTPStatus.OK) + assert_response_status(response, HTTPStatus.OK) print(response) diff --git a/tests/integration/test_batch.py b/tests/integration/test_batch.py index a96f0e5e..9ab2f80c 100644 --- a/tests/integration/test_batch.py +++ b/tests/integration/test_batch.py @@ -22,7 +22,7 @@ DataSourcesPutBatchRequestSchema, ) from tests.helper_scripts.common_test_data import get_test_name -from tests.helper_scripts.common_test_functions import assert_contains_key_value_pairs +from tests.helper_scripts.common_asserts import assert_contains_key_value_pairs from tests.helper_scripts.helper_classes.RequestValidator import RequestValidator from tests.helper_scripts.helper_classes.SchemaTestDataGenerator import ( generate_test_data_from_schema, diff --git a/tests/integration/test_data_sources.py b/tests/integration/test_data_sources.py index b79b6757..570f2ae3 100644 --- a/tests/integration/test_data_sources.py +++ b/tests/integration/test_data_sources.py @@ -29,7 +29,7 @@ from tests.helper_scripts.helper_classes.TestDataCreatorFlask import ( TestDataCreatorFlask, ) -from tests.helper_scripts.common_test_functions import assert_contains_key_value_pairs +from tests.helper_scripts.common_asserts import assert_contains_key_value_pairs from tests.helper_scripts.run_and_validate_request import run_and_validate_request from tests.helper_scripts.constants import ( DATA_SOURCES_BASE_ENDPOINT, diff --git a/tests/integration/test_github_oauth.py b/tests/integration/test_github_oauth.py index f7503f74..1bbcd465 100644 --- a/tests/integration/test_github_oauth.py +++ b/tests/integration/test_github_oauth.py @@ -15,7 +15,7 @@ from tests.helper_scripts.helper_classes.TestDataCreatorFlask import ( TestDataCreatorFlask, ) -from tests.helper_scripts.common_test_functions import ( +from tests.helper_scripts.common_asserts import ( assert_jwt_token_matches_user_email, ) from tests.helper_scripts.constants import ( diff --git a/tests/integration/test_login.py b/tests/integration/test_login.py index c7625036..73519949 100644 --- a/tests/integration/test_login.py +++ b/tests/integration/test_login.py @@ -5,7 +5,7 @@ from tests.helper_scripts.helper_classes.TestDataCreatorFlask import ( TestDataCreatorFlask, ) -from tests.helper_scripts.common_test_functions import ( +from tests.helper_scripts.common_asserts import ( assert_jwt_token_matches_user_email, ) from conftest import monkeysession, test_data_creator_flask diff --git a/tests/integration/test_match.py b/tests/integration/test_match.py index 9ed92012..07e5f7e9 100644 --- a/tests/integration/test_match.py +++ b/tests/integration/test_match.py @@ -19,6 +19,19 @@ ) +class TestMatchLocationInfo: + def __init__(self, tdc: TestDataCreatorFlask): + self.tdc = tdc + self.locality_name = get_test_name() + self.locality_id = self.tdc.locality(locality_name=self.locality_name) + self.location_info = get_sample_location_info(locality_name=self.locality_name) + self.location_kwargs = { + "state": "Pennsylvania", + "county": "Allegheny", + "locality": self.locality_name, + } + + @dataclass class TestMatchAgencySetup: tdc: TestDataCreatorFlask @@ -37,17 +50,11 @@ def match_agency_setup( ) -> TestMatchAgencySetup: tdc = test_data_creator_flask tdc.clear_test_data() - locality_name = get_test_name() - locality_id = tdc.locality(locality_name=locality_name) - location_info = get_sample_location_info(locality_name=locality_name) - agency = tdc.agency(location_info=location_info) + loc_info: TestMatchLocationInfo = TestMatchLocationInfo(tdc) + agency = tdc.agency(location_info=loc_info.location_info) return TestMatchAgencySetup( tdc=tdc, - location_kwargs={ - "state": "Pennsylvania", - "county": "Allegheny", - "locality": locality_name, - }, + location_kwargs=loc_info.location_kwargs, agency_name=agency.submitted_name, ) @@ -97,3 +104,22 @@ def test_agency_match_no_match(match_agency_setup: TestMatchAgencySetup): ) amr = try_matching_agency(db_client=mas.tdc.db_client, dto=dto) assert amr.status == AgencyMatchStatus.NO_MATCH + + +# region Test Full Integration + + +def test_agency_match_full_integration(test_data_creator_flask: TestDataCreatorFlask): + + location_1_info = get_sample_location_info() + + # Create a csv of possible agencies + # One an exact match + # One a partial match + # One a location match + # One a no match + + # Submit and confirm json received for each. + + +# endregion diff --git a/tests/resources/test_Search.py b/tests/resources/test_Search.py index b11fd4d2..4c97b5a6 100644 --- a/tests/resources/test_Search.py +++ b/tests/resources/test_Search.py @@ -7,9 +7,7 @@ ) from tests.conftest import client_with_mock_db, bypass_authentication_required from tests.helper_scripts.constants import TEST_RESPONSE -from tests.helper_scripts.helper_functions import ( - check_is_test_response, -) +from tests.helper_scripts.common_asserts import assert_is_test_response from utilities.enums import RecordCategories @@ -80,4 +78,4 @@ def test_search_get_parameters( monkeypatch.setattr("resources.Search.search_wrapper", mock_search_wrapper_function) response = client_with_mock_db.client.get(url) - check_is_test_response(response) + assert_is_test_response(response) diff --git a/tests/utilities/test_rate_limiter.py b/tests/utilities/test_rate_limiter.py index c53d8d09..ccd6609f 100644 --- a/tests/utilities/test_rate_limiter.py +++ b/tests/utilities/test_rate_limiter.py @@ -6,7 +6,7 @@ from config import limiter from tests.conftest import client_with_mock_db, bypass_jwt_required from tests.helper_scripts.constants import TEST_RESPONSE -from tests.helper_scripts.simple_result_validators import check_response_status +from tests.helper_scripts.common_asserts import assert_response_status def post_login_request(client_with_mock_db, ip_address="127.0.0.1"): @@ -55,14 +55,14 @@ def test_rate_limiter_explicit_limit( for i in range(5): response = post_login_request(client_with_mock_db) - check_response_status(response, TEST_RESPONSE.status_code) + assert_response_status(response, TEST_RESPONSE.status_code) response = post_login_request(client_with_mock_db) assert "5 per 1 minute" in response.json["message"] # Test that a different IP address still works response = post_login_request(client_with_mock_db, ip_address="237.84.2.178") - check_response_status(response, TEST_RESPONSE.status_code) + assert_response_status(response, TEST_RESPONSE.status_code) def test_rate_limiter_default_limit( From 93236cd457543c60ab7a99dffdaf24e846cfb63c Mon Sep 17 00:00:00 2001 From: maxachis Date: Wed, 11 Dec 2024 13:22:15 -0500 Subject: [PATCH 3/4] Complete draft of agencies match endpoint --- .../primary_resource_logic/match_logic.py | 36 +++++++---- .../permissions_logic.py | 2 +- .../primary_resource_dtos/match_dtos.py | 4 +- .../primary_resource_schemas/match_schemas.py | 26 ++++++-- resources/Match.py | 11 ++-- resources/endpoint_schema_config.py | 9 ++- .../helper_classes/RequestValidator.py | 24 +++++++ .../helper_classes/TestDataCreatorFlask.py | 2 + tests/integration/test_batch.py | 2 +- tests/integration/test_match.py | 63 +++++++++---------- 10 files changed, 117 insertions(+), 62 deletions(-) diff --git a/middleware/primary_resource_logic/match_logic.py b/middleware/primary_resource_logic/match_logic.py index 6b05cb22..6ad62fa3 100644 --- a/middleware/primary_resource_logic/match_logic.py +++ b/middleware/primary_resource_logic/match_logic.py @@ -1,21 +1,23 @@ from enum import Enum from typing import Optional, List -from pydantic import BaseModel +from flask import Response from database_client.database_client import DatabaseClient from database_client.db_client_dataclasses import WhereMapping +from middleware.flask_response_manager import FlaskResponseManager from middleware.schema_and_dto_logic.primary_resource_dtos.match_dtos import ( AgencyMatchOuterDTO, - AgencyMatchInnerDTO, + AgencyMatchDTO, ) from rapidfuzz import fuzz +from middleware.util import update_if_not_none + class AgencyMatchStatus(Enum): EXACT = "Exact Match" PARTIAL = "Partial Matches" - LOCATION = "Location Only" NO_MATCH = "No Match" @@ -28,8 +30,6 @@ def get_agency_match_message(status: AgencyMatchStatus): return "Exact match found." case AgencyMatchStatus.PARTIAL: return "Partial matches found." - case AgencyMatchStatus.LOCATION: - return "Matches found only on location." case AgencyMatchStatus.NO_MATCH: return "No matches found." @@ -49,13 +49,13 @@ def match_agencies(db_client: DatabaseClient, dto: AgencyMatchOuterDTO): amrs.append(amr) -def try_getting_exact_match_agency(dto: AgencyMatchInnerDTO, agencies: list[dict]): +def try_getting_exact_match_agency(dto: AgencyMatchDTO, agencies: list[dict]): for agency in agencies: if agency["submitted_name"] == dto.name: return agency -def try_getting_partial_match_agencies(dto: AgencyMatchInnerDTO, agencies: list[dict]): +def try_getting_partial_match_agencies(dto: AgencyMatchDTO, agencies: list[dict]): partial_matches = [] for agency in agencies: if fuzz.ratio(dto.name, agency["submitted_name"]) >= SIMILARITY_THRESHOLD: @@ -63,9 +63,23 @@ def try_getting_partial_match_agencies(dto: AgencyMatchInnerDTO, agencies: list[ return partial_matches +def format_response(amr: AgencyMatchResponse) -> Response: + data = { + "status": amr.status.value, + "message": amr.message, + } + update_if_not_none(dict_to_update=data, secondary_dict={"agencies": amr.agencies}) + return FlaskResponseManager.make_response( + data=data, + ) + +def match_agency_wrapper(db_client: DatabaseClient, dto: AgencyMatchOuterDTO): + result = try_matching_agency(db_client=db_client, dto=dto) + return format_response(result) + def try_matching_agency( - db_client: DatabaseClient, dto: AgencyMatchInnerDTO + db_client: DatabaseClient, dto: AgencyMatchDTO ) -> AgencyMatchResponse: location_id = _get_location_id(db_client, dto) @@ -86,11 +100,7 @@ def try_matching_agency( if len(partial_match_agencies) > 0: return _partial_match_response(partial_match_agencies) - return _location_match_response(agencies) - - -def _location_match_response(agencies): - return AgencyMatchResponse(status=AgencyMatchStatus.LOCATION, agencies=agencies) + return _no_match_response() def _partial_match_response(partial_match_agencies): diff --git a/middleware/primary_resource_logic/permissions_logic.py b/middleware/primary_resource_logic/permissions_logic.py index dfdcc4a1..23683832 100644 --- a/middleware/primary_resource_logic/permissions_logic.py +++ b/middleware/primary_resource_logic/permissions_logic.py @@ -75,7 +75,7 @@ def __init__(self, db_client: DatabaseClient, user_email: str): try: user_info = db_client.get_user_info(user_email) except UserNotFoundError: - abort(HTTPStatus.NOT_FOUND, "User not found") + abort(HTTPStatus.BAD_REQUEST, "User not found") return self.db_client = db_client self.user_email = user_email diff --git a/middleware/schema_and_dto_logic/primary_resource_dtos/match_dtos.py b/middleware/schema_and_dto_logic/primary_resource_dtos/match_dtos.py index 9692cbda..f809a69f 100644 --- a/middleware/schema_and_dto_logic/primary_resource_dtos/match_dtos.py +++ b/middleware/schema_and_dto_logic/primary_resource_dtos/match_dtos.py @@ -3,7 +3,7 @@ from pydantic import BaseModel -class AgencyMatchInnerDTO(BaseModel): +class AgencyMatchDTO(BaseModel): name: str state: str county: Optional[str] @@ -11,4 +11,4 @@ class AgencyMatchInnerDTO(BaseModel): class AgencyMatchOuterDTO(BaseModel): - entries: list[AgencyMatchInnerDTO] + entries: list[AgencyMatchDTO] diff --git a/middleware/schema_and_dto_logic/primary_resource_schemas/match_schemas.py b/middleware/schema_and_dto_logic/primary_resource_schemas/match_schemas.py index 22f6794f..c0137a58 100644 --- a/middleware/schema_and_dto_logic/primary_resource_schemas/match_schemas.py +++ b/middleware/schema_and_dto_logic/primary_resource_schemas/match_schemas.py @@ -1,17 +1,33 @@ from marshmallow import Schema, fields +from middleware.primary_resource_logic.match_logic import AgencyMatchStatus +from middleware.schema_and_dto_logic.common_response_schemas import MessageSchema from middleware.schema_and_dto_logic.util import get_json_metadata -class AgencyMatchInnerSchema(Schema): +class AgencyMatchSchema(Schema): name = fields.String(metadata=get_json_metadata("The name of the agency")) state = fields.String(metadata=get_json_metadata("The state of the agency")) county = fields.String(metadata=get_json_metadata("The county of the agency")) locality = fields.String(metadata=get_json_metadata("The locality of the agency")) +class MatchAgenciesResultSchema(Schema): + submitted_name = fields.String(metadata=get_json_metadata("The name of the agency")) + id = fields.Integer(metadata=get_json_metadata("The id of the agency")) -class AgencyMatterOuterSchema(Schema): - entries = fields.List( - fields.Nested(AgencyMatchInnerSchema()), - metadata=get_json_metadata("The proposed agencies to find matches for."), +class MatchAgencyResponseSchema(MessageSchema): + + status = fields.Enum( + enum=AgencyMatchStatus, + by_value=fields.Str, + required=True, + metadata=get_json_metadata("The status of the match") + ) + agencies = fields.List( + fields.Nested( + MatchAgenciesResultSchema(), + metadata=get_json_metadata("The list of results, if any") + ), + required=False, + metadata=get_json_metadata("The list of results, if any") ) diff --git a/resources/Match.py b/resources/Match.py index e5d03576..c41a11f1 100644 --- a/resources/Match.py +++ b/resources/Match.py @@ -2,7 +2,7 @@ from middleware.access_logic import STANDARD_JWT_AUTH_INFO, AccessInfoPrimary from middleware.decorators import endpoint_info -from middleware.primary_resource_logic.match_logic import match_agencies +from middleware.primary_resource_logic.match_logic import try_matching_agency, match_agency_wrapper from resources.PsycopgResource import PsycopgResource from resources.endpoint_schema_config import SchemaConfigs from resources.resource_helpers import ResponseInfo @@ -11,13 +11,13 @@ namespace_match = create_namespace(AppNamespaces.MATCH) -@namespace_match.route("/agencies") +@namespace_match.route("/agency") class MatchAgencies(PsycopgResource): @endpoint_info( namespace=namespace_match, auth_info=STANDARD_JWT_AUTH_INFO, - schema_config=SchemaConfigs.MATCH_AGENCIES, + schema_config=SchemaConfigs.MATCH_AGENCY, response_info=ResponseInfo( success_message="Found any possible matches for the search criteria." ), @@ -25,7 +25,6 @@ class MatchAgencies(PsycopgResource): ) def post(self, access_info: AccessInfoPrimary) -> Response: return self.run_endpoint( - wrapper_function=match_agencies, - schema_populate_parameters=SchemaConfigs.MATCH_AGENCIES.value.get_schema_populate_parameters(), - access_info=access_info, + wrapper_function=match_agency_wrapper, + schema_populate_parameters=SchemaConfigs.MATCH_AGENCY.value.get_schema_populate_parameters(), ) diff --git a/resources/endpoint_schema_config.py b/resources/endpoint_schema_config.py index 999296e9..5805fc85 100644 --- a/resources/endpoint_schema_config.py +++ b/resources/endpoint_schema_config.py @@ -15,6 +15,7 @@ from middleware.schema_and_dto_logic.primary_resource_dtos.batch_dtos import ( BatchRequestDTO, ) +from middleware.schema_and_dto_logic.primary_resource_dtos.match_dtos import AgencyMatchDTO from middleware.schema_and_dto_logic.primary_resource_dtos.reset_token_dtos import ( ResetPasswordDTO, ) @@ -31,6 +32,8 @@ DataSourcesPostBatchRequestSchema, DataSourcesPutBatchRequestSchema, ) +from middleware.schema_and_dto_logic.primary_resource_schemas.match_schemas import AgencyMatchSchema, \ + MatchAgencyResponseSchema from middleware.schema_and_dto_logic.primary_resource_schemas.reset_token_schemas import ( ResetPasswordSchema, ) @@ -521,6 +524,10 @@ class SchemaConfigs(Enum): ) # endregion # region Match - MATCH_AGENCIES = EndpointSchemaConfig() + MATCH_AGENCY = EndpointSchemaConfig( + input_schema=AgencyMatchSchema(), + input_dto_class=AgencyMatchDTO, + primary_output_schema=MatchAgencyResponseSchema(), + ) # endregion diff --git a/tests/helper_scripts/helper_classes/RequestValidator.py b/tests/helper_scripts/helper_classes/RequestValidator.py index 3570d0f7..5c0b4bc6 100644 --- a/tests/helper_scripts/helper_classes/RequestValidator.py +++ b/tests/helper_scripts/helper_classes/RequestValidator.py @@ -538,3 +538,27 @@ def get_data_source_by_id(self, headers: dict, id: int): headers=headers, expected_schema=SchemaConfigs.DATA_SOURCES_GET_BY_ID.value.primary_output_schema, ) + + def match_agency( + self, + headers: dict, + name: str, + state: str, + county: str, + locality: str): + data = { + "name": name, + "state": state, + } + update_if_not_none( + dict_to_update=data, + secondary_dict={ + "county": county, + "locality": locality, + }) + return self.post( + endpoint="/api/match/agency", + headers=headers, + json=data, + expected_schema=SchemaConfigs.MATCH_AGENCY.value.primary_output_schema, + ) diff --git a/tests/helper_scripts/helper_classes/TestDataCreatorFlask.py b/tests/helper_scripts/helper_classes/TestDataCreatorFlask.py index f25edda0..c07a844b 100644 --- a/tests/helper_scripts/helper_classes/TestDataCreatorFlask.py +++ b/tests/helper_scripts/helper_classes/TestDataCreatorFlask.py @@ -76,6 +76,8 @@ def data_source(self, location_info: Optional[dict] = None) -> CreatedDataSource def clear_test_data(self): tdc_db = TestDataCreatorDBClient() tdc_db.clear_test_data() + # Recreate admin user + self.admin_tus = create_admin_test_user_setup(self.flask_client) def data_request( self, user_tus: Optional[TestUserSetup] = None diff --git a/tests/integration/test_batch.py b/tests/integration/test_batch.py index 9ab2f80c..9ace80ca 100644 --- a/tests/integration/test_batch.py +++ b/tests/integration/test_batch.py @@ -6,7 +6,7 @@ from marshmallow import Schema -from conftest import test_data_creator_flask +from conftest import test_data_creator_flask, monkeysession from database_client.enums import LocationType from middleware.primary_resource_logic.batch_logic import listify_strings from middleware.schema_and_dto_logic.common_response_schemas import MessageSchema diff --git a/tests/integration/test_match.py b/tests/integration/test_match.py index 07e5f7e9..bfaf0c58 100644 --- a/tests/integration/test_match.py +++ b/tests/integration/test_match.py @@ -8,7 +8,7 @@ AgencyMatchStatus, ) from middleware.schema_and_dto_logic.primary_resource_dtos.match_dtos import ( - AgencyMatchInnerDTO, + AgencyMatchDTO, ) from tests.helper_scripts.common_test_data import get_test_name from tests.helper_scripts.complex_test_data_creation_functions import ( @@ -37,6 +37,7 @@ class TestMatchAgencySetup: tdc: TestDataCreatorFlask location_kwargs: dict agency_name: str + jwt_authorization_header: dict def additional_agency(self, locality_name: Optional[str] = None): return self.tdc.agency( @@ -44,7 +45,7 @@ def additional_agency(self, locality_name: Optional[str] = None): ) -@pytest.fixture +@pytest.fixture() def match_agency_setup( test_data_creator_flask: TestDataCreatorFlask, ) -> TestMatchAgencySetup: @@ -56,6 +57,7 @@ def match_agency_setup( tdc=tdc, location_kwargs=loc_info.location_kwargs, agency_name=agency.submitted_name, + jwt_authorization_header=tdc.get_admin_tus().jwt_authorization_header, ) @@ -64,9 +66,14 @@ def test_agency_match_exact_match( ): mas = match_agency_setup - dto = AgencyMatchInnerDTO(name=mas.agency_name, **mas.location_kwargs) - amr = try_matching_agency(db_client=mas.tdc.db_client, dto=dto) - assert amr.status == AgencyMatchStatus.EXACT + data = mas.tdc.request_validator.match_agency( + headers=mas.jwt_authorization_header, + name=mas.agency_name, + **mas.location_kwargs + ) + + assert data["status"] == AgencyMatchStatus.EXACT.value + assert mas.agency_name in data["agencies"][0]["submitted_name"] def test_agency_match_partial_match(match_agency_setup: TestMatchAgencySetup): @@ -76,50 +83,40 @@ def test_agency_match_partial_match(match_agency_setup: TestMatchAgencySetup): # This should not be picked up. mas.additional_agency() modified_agency_name = mas.agency_name + "1" - dto = AgencyMatchInnerDTO(name=modified_agency_name, **mas.location_kwargs) - amr = try_matching_agency(db_client=mas.tdc.db_client, dto=dto) - assert amr.status == AgencyMatchStatus.PARTIAL - assert len(amr.agencies) == 1 - assert mas.agency_name in amr.agencies[0]["submitted_name"] + data = mas.tdc.request_validator.match_agency( + headers=mas.jwt_authorization_header, + name=modified_agency_name, + **mas.location_kwargs + ) + + assert data["status"] == AgencyMatchStatus.PARTIAL.value + assert len(data["agencies"]) == 1 + assert mas.agency_name in data["agencies"][0]["submitted_name"] def test_agency_match_location_match(match_agency_setup: TestMatchAgencySetup): mas = match_agency_setup mas.additional_agency(locality_name=mas.location_kwargs["locality"]) - dto = AgencyMatchInnerDTO(name=get_test_name(), **mas.location_kwargs) - amr = try_matching_agency(db_client=mas.tdc.db_client, dto=dto) - assert amr.status == AgencyMatchStatus.LOCATION - assert len(amr.agencies) == 2 + data = mas.tdc.request_validator.match_agency( + headers=mas.jwt_authorization_header, + name=get_test_name(), + **mas.location_kwargs + ) + assert data["status"] == AgencyMatchStatus.NO_MATCH.value def test_agency_match_no_match(match_agency_setup: TestMatchAgencySetup): mas = match_agency_setup - - dto = AgencyMatchInnerDTO( + data = mas.tdc.request_validator.match_agency( + headers=mas.jwt_authorization_header, name=get_test_name(), state="New York", county="New York", locality=get_test_name(), ) - amr = try_matching_agency(db_client=mas.tdc.db_client, dto=dto) - assert amr.status == AgencyMatchStatus.NO_MATCH -# region Test Full Integration - - -def test_agency_match_full_integration(test_data_creator_flask: TestDataCreatorFlask): - - location_1_info = get_sample_location_info() - - # Create a csv of possible agencies - # One an exact match - # One a partial match - # One a location match - # One a no match - - # Submit and confirm json received for each. +# region Test Full Integration -# endregion From 528562d9db0f9ade6ecf86e3e3e5c303ce1007e9 Mon Sep 17 00:00:00 2001 From: maxachis Date: Wed, 11 Dec 2024 13:44:07 -0500 Subject: [PATCH 4/4] Fix test_permissions_manager_init_user_not_found --- tests/middleware/test_permissions_logic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/middleware/test_permissions_logic.py b/tests/middleware/test_permissions_logic.py index ec91e933..ef2cfe1d 100644 --- a/tests/middleware/test_permissions_logic.py +++ b/tests/middleware/test_permissions_logic.py @@ -36,7 +36,7 @@ def test_permissions_manager_init_user_not_found(mock): mock.db_client.get_user_info.side_effect = UserNotFoundError("User not found") PermissionsManager(mock.db_client, mock.user_email) mock.db_client.get_user_info.assert_called_once_with(mock.user_email) - mock.abort.assert_called_once_with(HTTPStatus.NOT_FOUND, "User not found") + mock.abort.assert_called_once_with(HTTPStatus.BAD_REQUEST, "User not found") mock.db_client.get_user_permissions.assert_not_called()