diff --git a/middleware/dynamic_request_logic/get_related_resource_logic.py b/middleware/dynamic_request_logic/get_related_resource_logic.py index c5f42a00..5bfe1664 100644 --- a/middleware/dynamic_request_logic/get_related_resource_logic.py +++ b/middleware/dynamic_request_logic/get_related_resource_logic.py @@ -18,13 +18,13 @@ @dataclass class GetRelatedResourcesParameters: - db_client: DatabaseClient dto: GetByIDBaseDTO db_client_method: callable primary_relation: Relations related_relation: Relations linking_column: str metadata_count_name: str + db_client: DatabaseClient = DatabaseClient() resource_name: str = "resource" diff --git a/middleware/enums.py b/middleware/enums.py index f661b40d..f39bdb8d 100644 --- a/middleware/enums.py +++ b/middleware/enums.py @@ -42,6 +42,11 @@ class AccessTypeEnum(Enum): NO_AUTH = auto() +class OutputFormatEnum(Enum): + JSON = "json" + CSV = "csv" + + class Relations(Enum): """ A list of valid relations for the database diff --git a/middleware/location_logic.py b/middleware/location_logic.py index 1b9105d0..fca1e616 100644 --- a/middleware/location_logic.py +++ b/middleware/location_logic.py @@ -29,8 +29,7 @@ def get_location_id( # In the case of a nonexistent locality, this can be added, # provided the rest of the location is valid - county_id = _get_county_id(db_client, location_info_dict) - + county_id = _get_county_id(location_info_dict) # If this exists, locality does not yet exist in database and should be added. Add and return location id db_client.create_locality( column_value_mappings={ @@ -47,13 +46,13 @@ def _raise_if_not_locality(location_info, location_info_dict): raise InvalidLocationError(f"{location_info_dict} is not a valid location") -def _get_county_id(db_client, location_info_dict) -> int: +def _get_county_id(location_info_dict: dict) -> int: county_dict = { "county_fips": location_info_dict["county_fips"], "state_iso": location_info_dict["state_iso"], "type": LocationType.COUNTY, } - results = db_client._select_from_relation( + results = DatabaseClient()._select_from_relation( relation_name=Relations.LOCATIONS_EXPANDED.value, columns=["county_id"], where_mappings=WhereMapping.from_dict(county_dict), diff --git a/middleware/primary_resource_logic/agencies.py b/middleware/primary_resource_logic/agencies.py index 5d65d7e4..af74d504 100644 --- a/middleware/primary_resource_logic/agencies.py +++ b/middleware/primary_resource_logic/agencies.py @@ -47,7 +47,6 @@ def get_agencies( """ return get_many( middleware_parameters=MiddlewareParameters( - db_client=db_client, access_info=access_info, entry_name="agencies", relation=Relations.AGENCIES_EXPANDED.value, @@ -69,7 +68,6 @@ def get_agency_by_id( ) -> Response: return get_by_id( middleware_parameters=MiddlewareParameters( - db_client=db_client, access_info=access_info, entry_name="agency", relation=Relations.AGENCIES_EXPANDED.value, @@ -151,11 +149,7 @@ def create_agency( ) return post_entry( - middleware_parameters=MiddlewareParameters( - entry_name="agency", - relation=Relations.AGENCIES.value, - db_client_method=DatabaseClient.create_agency, - ), + middleware_parameters=AGENCY_POST_MIDDLEWARE_PARAMETERS, entry=entry_data, pre_insertion_function_with_parameters=pre_insertion_function, check_for_permission=False, @@ -209,7 +203,6 @@ def update_agency( return put_entry( middleware_parameters=MiddlewareParameters( - db_client=db_client, access_info=access_info, entry_name="agency", relation=Relations.AGENCIES.value, @@ -226,7 +219,6 @@ def delete_agency( ) -> Response: return delete_entry( middleware_parameters=MiddlewareParameters( - db_client=db_client, access_info=access_info, entry_name="agency", relation=Relations.AGENCIES.value, diff --git a/middleware/primary_resource_logic/batch_logic.py b/middleware/primary_resource_logic/batch_logic.py index 4fb23029..15a159ec 100644 --- a/middleware/primary_resource_logic/batch_logic.py +++ b/middleware/primary_resource_logic/batch_logic.py @@ -3,6 +3,7 @@ from io import BytesIO from marshmallow import Schema, ValidationError +from werkzeug.datastructures import FileStorage from database_client.database_client import DatabaseClient from middleware.dynamic_request_logic.supporting_classes import ( @@ -25,14 +26,14 @@ from middleware.schema_and_dto_logic.dynamic_logic.dynamic_schema_request_content_population import ( setup_dto_class, ) -from middleware.schema_and_dto_logic.primary_resource_dtos.agencies_dtos import ( - AgenciesPostDTO, -) + from middleware.schema_and_dto_logic.primary_resource_dtos.batch_dtos import ( BatchRequestDTO, ) from csv import DictReader +from middleware.util import bytes_to_text_iter, read_from_csv + def replace_empty_strings_with_none(row: dict): for key, value in row.items(): @@ -48,12 +49,9 @@ def _get_raw_rows_from_csv( return raw_rows -def _get_raw_rows(file: BytesIO): +def _get_raw_rows(file: FileStorage): try: - text_file = (line.decode("utf-8") for line in file) - reader = DictReader(text_file) - rows = list(reader) - return rows + return read_from_csv(file) except Exception as e: FlaskResponseManager.abort( code=HTTPStatus.BAD_REQUEST, message=f"Error reading csv file: {e}" diff --git a/middleware/primary_resource_logic/data_requests.py b/middleware/primary_resource_logic/data_requests.py index 3a609aea..5799aba2 100644 --- a/middleware/primary_resource_logic/data_requests.py +++ b/middleware/primary_resource_logic/data_requests.py @@ -194,7 +194,6 @@ def get_data_requests_wrapper( } return get_many( middleware_parameters=MiddlewareParameters( - db_client=db_client, access_info=access_info, entry_name="data requests", relation=Relations.DATA_REQUESTS_EXPANDED.value, @@ -251,7 +250,6 @@ def delete_data_request_wrapper( """ return delete_entry( middleware_parameters=MiddlewareParameters( - db_client=db_client, access_info=access_info, entry_name="Data request", relation=RELATION, @@ -299,7 +297,6 @@ def update_data_request_wrapper( entry_dict = created_filtered_entry_dict(dto) return put_entry( middleware_parameters=MiddlewareParameters( - db_client=db_client, access_info=access_info, entry_name="Data request", relation=RELATION, @@ -366,7 +363,6 @@ def get_data_request_related_sources(db_client: DatabaseClient, dto: GetByIDBase return get_related_resource( get_related_resources_parameters=GetRelatedResourcesParameters( - db_client=db_client, dto=dto, db_client_method=DatabaseClient.get_data_requests, primary_relation=Relations.DATA_REQUESTS, @@ -383,7 +379,6 @@ def get_data_request_related_locations( ) -> Response: return get_related_resource( get_related_resources_parameters=GetRelatedResourcesParameters( - db_client=db_client, dto=dto, db_client_method=DatabaseClient.get_data_requests, primary_relation=Relations.DATA_REQUESTS, @@ -435,7 +430,6 @@ def create_data_request_related_source( ): post_logic = CreateDataRequestRelatedSourceLogic( middleware_parameters=MiddlewareParameters( - db_client=db_client, access_info=access_info, entry_name="Request-Source association", relation=RELATED_SOURCES_RELATION, @@ -458,7 +452,6 @@ def delete_data_request_related_source( ): return delete_entry( middleware_parameters=MiddlewareParameters( - db_client=db_client, access_info=access_info, entry_name="Request-Source association", relation=RELATED_SOURCES_RELATION, @@ -483,7 +476,6 @@ def create_data_request_related_location( ): post_logic = CreateDataRequestRelatedLocationLogic( middleware_parameters=MiddlewareParameters( - db_client=db_client, access_info=access_info, entry_name="Request-Location association", relation=Relations.LINK_LOCATIONS_DATA_REQUESTS.value, @@ -508,7 +500,6 @@ def delete_data_request_related_location( ): return delete_entry( middleware_parameters=MiddlewareParameters( - db_client=db_client, access_info=access_info, entry_name="Request-Location association", relation=Relations.LINK_LOCATIONS_DATA_REQUESTS.value, diff --git a/middleware/primary_resource_logic/data_sources_logic.py b/middleware/primary_resource_logic/data_sources_logic.py index 014f0e4c..b9411a2e 100644 --- a/middleware/primary_resource_logic/data_sources_logic.py +++ b/middleware/primary_resource_logic/data_sources_logic.py @@ -102,7 +102,6 @@ def data_source_by_id_wrapper( access_info=access_info, relation=Relations.DATA_SOURCES_EXPANDED.value, db_client_method=DatabaseClient.get_data_sources, - db_client=db_client, entry_name="data source", subquery_parameters=SUBQUERY_PARAMS, ), @@ -132,7 +131,6 @@ def delete_data_source_wrapper( access_info=access_info, relation=RELATION, db_client_method=DatabaseClient.delete_data_source, - db_client=db_client, entry_name="data source", ), id_info=IDInfo( @@ -160,7 +158,6 @@ def update_data_source_wrapper( optionally_add_last_approval_editor(entry_data, access_info) return put_entry( middleware_parameters=MiddlewareParameters( - db_client=db_client, entry_name="Data source", relation=RELATION, db_client_method=DatabaseClient.update_data_source, @@ -305,7 +302,6 @@ def create_data_source_related_agency( ) -> Response: post_logic = CreateDataSourceRelatedAgenciesLogic( middleware_parameters=MiddlewareParameters( - db_client=db_client, access_info=access_info, entry_name="Data source-agency association", relation=RELATION, @@ -324,7 +320,6 @@ def delete_data_source_related_agency( ) -> Response: return delete_entry( middleware_parameters=MiddlewareParameters( - db_client=db_client, access_info=access_info, entry_name="Data source-agency association", relation=Relations.LINK_AGENCIES_DATA_SOURCES.value, diff --git a/middleware/primary_resource_logic/search_logic.py b/middleware/primary_resource_logic/search_logic.py index 36d4f5d0..bdeb7ec9 100644 --- a/middleware/primary_resource_logic/search_logic.py +++ b/middleware/primary_resource_logic/search_logic.py @@ -1,7 +1,10 @@ +from csv import DictWriter from http import HTTPStatus +from io import BytesIO, StringIO from typing import Optional -from flask import Response, make_response +from flask import Response, make_response, send_file +from pydantic import BaseModel from database_client.database_client import DatabaseClient from database_client.db_client_dataclasses import WhereMapping @@ -12,12 +15,13 @@ MiddlewareParameters, IDInfo, ) -from middleware.enums import JurisdictionSimplified, Relations +from middleware.enums import JurisdictionSimplified, Relations, OutputFormatEnum from middleware.flask_response_manager import FlaskResponseManager from middleware.schema_and_dto_logic.primary_resource_schemas.search_schemas import ( SearchRequests, ) from middleware.common_response_formatting import message_response +from middleware.util import get_datetime_now, write_to_csv, find_root_directory from utilities.enums import RecordCategories @@ -68,57 +72,99 @@ def format_search_results(search_results: list[dict]) -> dict: response = {"count": 0, "data": {}} + data = response["data"] # Create sub-dictionary for each jurisdiction for jurisdiction in [j.value for j in JurisdictionSimplified]: - response["data"][jurisdiction] = {"count": 0, "results": []} + data[jurisdiction] = {"count": 0, "results": []} for result in search_results: jurisdiction_str = result.get("jurisdiction_type") jurisdiction = get_jurisdiction_type_enum(jurisdiction_str) - response["data"][jurisdiction.value]["count"] += 1 - response["data"][jurisdiction.value]["results"].append(result) + data[jurisdiction.value]["count"] += 1 + data[jurisdiction.value]["results"].append(result) response["count"] += 1 return response +def format_as_csv(ld: list[dict]) -> BytesIO: + string_output = StringIO() + writer = DictWriter(string_output, fieldnames=list(ld[0].keys())) + writer.writeheader() + writer.writerows(ld) + string_output.seek(0) + bytes_output = string_output.getvalue().encode("utf-8") + return BytesIO(bytes_output) + + def search_wrapper( db_client: DatabaseClient, access_info: AccessInfoPrimary, dto: SearchRequests, ) -> Response: + create_search_record(access_info, db_client, dto) + explicit_record_categories = get_explicit_record_categories(dto.record_categories) + search_results = db_client.search_with_location_and_record_type( + state=dto.state, + # Pass modified record categories, which breaks down ALL into individual categories + record_categories=explicit_record_categories, + county=dto.county, + locality=dto.locality, + ) + return send_search_results( + search_results=search_results, + output_format=dto.output_format, + ) + + +def create_search_record(access_info, db_client, dto): location_id = try_getting_location_id_and_raise_error_if_not_found( db_client=db_client, dto=dto, ) - explicit_record_categories = get_explicit_record_categories(dto.record_categories) db_client.create_search_record( user_id=access_info.get_user_id(), location_id=location_id, # Pass originally provided record categories record_categories=dto.record_categories, ) - search_results = db_client.search_with_location_and_record_type( - state=dto.state, - # Pass modified record categories, which breaks down ALL into individual categories - record_categories=explicit_record_categories, - county=dto.county, - locality=dto.locality, - ) + + +def send_search_results(search_results: list[dict], output_format: OutputFormatEnum): + if output_format == OutputFormatEnum.JSON: + return send_as_json(search_results) + elif output_format == OutputFormatEnum.CSV: + return send_as_csv(search_results) + else: + FlaskResponseManager.abort( + message="Invalid output format.", + code=HTTPStatus.BAD_REQUEST, + ) + + +def send_as_json(search_results): formatted_search_results = format_search_results(search_results) return make_response(formatted_search_results, HTTPStatus.OK) +def send_as_csv(search_results): + filename = f"search_results-{get_datetime_now()}.csv" + csv_stream = format_as_csv(ld=search_results) + return send_file( + csv_stream, download_name=filename, mimetype="text/csv", as_attachment=True + ) + + def get_explicit_record_categories( record_categories=list[RecordCategories], ) -> list[RecordCategories]: - if record_categories == [RecordCategories.ALL]: + if RecordCategories.ALL in record_categories: + if len(record_categories) > 1: + FlaskResponseManager.abort( + message="ALL cannot be provided with other record categories.", + code=HTTPStatus.BAD_REQUEST, + ) return [rc for rc in RecordCategories if rc != RecordCategories.ALL] - elif len(record_categories) > 1 and RecordCategories.ALL in record_categories: - FlaskResponseManager.abort( - message="ALL cannot be provided with other record categories.", - code=HTTPStatus.BAD_REQUEST, - ) return record_categories @@ -181,12 +227,25 @@ def make_response(self) -> Response: ) -def create_followed_search( +def get_link_id_and_raise_error_if_not_found( + db_client: DatabaseClient, access_info: AccessInfoPrimary, dto: SearchRequests +): + location_id = try_getting_location_id_and_raise_error_if_not_found( + db_client=db_client, + dto=dto, + ) + return get_user_followed_search_link( + db_client=db_client, + access_info=access_info, + location_id=location_id, + ) + + +def get_location_link_and_raise_error_if_not_found( db_client: DatabaseClient, access_info: AccessInfoPrimary, dto: SearchRequests, -) -> Response: - # Get location id. If not found, not a valid location. Raise error +): location_id = try_getting_location_id_and_raise_error_if_not_found( db_client=db_client, dto=dto, @@ -196,14 +255,30 @@ def create_followed_search( access_info=access_info, location_id=location_id, ) - if link_id is not None: + return LocationLink(link_id=link_id, location_id=location_id) + + +class LocationLink(BaseModel): + link_id: Optional[int] + location_id: int + + +def create_followed_search( + db_client: DatabaseClient, + access_info: AccessInfoPrimary, + dto: SearchRequests, +) -> Response: + # Get location id. If not found, not a valid location. Raise error + location_link = get_location_link_and_raise_error_if_not_found( + db_client=db_client, access_info=access_info, dto=dto + ) + if location_link.link_id is not None: return message_response( message="Location already followed.", ) return post_entry( middleware_parameters=MiddlewareParameters( - db_client=db_client, entry_name="followed search", relation=Relations.LINK_USER_FOLLOWED_LOCATION.value, db_client_method=DatabaseClient.create_followed_search, @@ -211,7 +286,7 @@ def create_followed_search( ), entry={ "user_id": access_info.get_user_id(), - "location_id": location_id, + "location_id": location_link.location_id, }, check_for_permission=False, post_logic_class=FollowedSearchPostLogic, @@ -224,24 +299,17 @@ def delete_followed_search( dto: SearchRequests, ) -> Response: # Get location id. If not found, not a valid location. Raise error - location_id = try_getting_location_id_and_raise_error_if_not_found( - db_client=db_client, - dto=dto, - ) - link_id = get_user_followed_search_link( - db_client=db_client, - access_info=access_info, - location_id=location_id, + location_link = get_location_link_and_raise_error_if_not_found( + db_client=db_client, access_info=access_info, dto=dto ) # Check if search is followed. If not, end early . - if link_id is None: + if location_link.link_id is None: return message_response( message="Location not followed.", ) return delete_entry( middleware_parameters=MiddlewareParameters( - db_client=db_client, entry_name="Followed search", relation=Relations.LINK_USER_FOLLOWED_LOCATION.value, db_client_method=DatabaseClient.delete_followed_search, @@ -249,6 +317,6 @@ def delete_followed_search( ), id_info=IDInfo( id_column_name="id", - id_column_value=link_id, + id_column_value=location_link.link_id, ), ) diff --git a/middleware/schema_and_dto_logic/primary_resource_schemas/search_schemas.py b/middleware/schema_and_dto_logic/primary_resource_schemas/search_schemas.py index da34f316..4e504ecb 100644 --- a/middleware/schema_and_dto_logic/primary_resource_schemas/search_schemas.py +++ b/middleware/schema_and_dto_logic/primary_resource_schemas/search_schemas.py @@ -4,6 +4,7 @@ from marshmallow import Schema, fields, validates_schema, ValidationError from pydantic import BaseModel +from middleware.enums import OutputFormatEnum from middleware.schema_and_dto_logic.schema_helpers import create_get_many_schema from middleware.schema_and_dto_logic.util import get_json_metadata from utilities.common import get_enums_from_string @@ -55,6 +56,17 @@ class SearchRequestSchema(Schema): "location": ParserLocation.QUERY.value, }, ) + output_format = fields.Enum( + required=False, + enum=OutputFormatEnum, + by_value=fields.Str, + load_default=OutputFormatEnum.JSON.value, + metadata={ + "description": "The output format of the search.", + "source": SourceMappingEnum.QUERY_ARGS, + "location": ParserLocation.QUERY.value, + }, + ) @validates_schema def validate_location_info(self, data, **kwargs): @@ -203,3 +215,4 @@ class SearchRequests(BaseModel): record_categories: Optional[list[RecordCategories]] = None county: Optional[str] = None locality: Optional[str] = None + output_format: Optional[OutputFormatEnum] = None diff --git a/middleware/util.py b/middleware/util.py index 052e1ba7..1367a2f6 100644 --- a/middleware/util.py +++ b/middleware/util.py @@ -1,10 +1,14 @@ +import csv import os from dataclasses import is_dataclass, asdict +from datetime import datetime from enum import Enum -from typing import Any, Dict +from io import BytesIO, StringIO +from typing import Any, Dict, TextIO, Generator from dotenv import dotenv_values, find_dotenv from pydantic import BaseModel +from werkzeug.datastructures import FileStorage def get_env_variable(name: str) -> str: @@ -24,6 +28,10 @@ def get_env_variable(name: str) -> str: return value +def get_datetime_now() -> str: + return datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + + def create_web_app_url(endpoint: str) -> str: return f"{get_env_variable('VITE_VUE_APP_BASE_URL')}/{endpoint}" @@ -32,6 +40,32 @@ def get_enum_values(en: type[Enum]) -> list[str]: return [e.value for e in en] +def write_to_csv(file_path: str, data: list[dict[str, Any]], fieldnames: list[str]): + with open(file_path, "w+", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(data) + f.close() + + +def bytes_to_text_iter(file: BytesIO | FileStorage) -> Generator[str, Any, None]: + """ + Convert BytesIO file to text iterator + """ + return (line.decode("utf-8") for line in file) + + +def read_from_csv(file: str | FileStorage | bytes) -> list[dict[str, Any]]: + if isinstance(file, FileStorage): + file = bytes_to_text_iter(file) + elif isinstance(file, str): + file = open(file, "r", newline="", encoding="utf-8") + elif isinstance(file, bytes): + content = file.decode("utf-8") + file = StringIO(content) + return list(csv.DictReader(file)) + + def dict_enums_to_values(d: dict[str, Any]) -> dict[str, Any]: """ Convert enums within a dictionary to their values. @@ -50,7 +84,7 @@ def dataclass_to_filtered_dict(instance: Any) -> Dict[str, Any]: """ if is_dataclass(instance): d = asdict(instance) - if isinstance(instance, BaseModel): + elif isinstance(instance, BaseModel): d = dict(instance) else: raise TypeError( @@ -73,3 +107,38 @@ def update_if_not_none(dict_to_update: Dict[str, Any], secondary_dict: Dict[str, for key, value in secondary_dict.items(): if value is not None: dict_to_update[key] = value + + +def find_root_directory(start_path=None, target_file="app.py"): + """ + Travels upward from the given starting directory (or current working directory) + until it finds the directory containing the specified target file. + + Parameters: + - start_path (str): The directory to start searching from. Defaults to the current working directory. + - target_file (str): The file that identifies the root directory. Defaults to 'app.py'. + + Returns: + - str: The absolute path to the root directory containing the target file. + + Raises: + - FileNotFoundError: If the target file is not found in any parent directory. + """ + current_path = os.path.abspath(start_path or os.getcwd()) + + while True: + if os.path.isfile(os.path.join(current_path, target_file)): + return current_path + + parent_path = os.path.dirname(current_path) + if current_path == parent_path: # Reached the root of the filesystem + break + current_path = parent_path + + raise FileNotFoundError(f"Could not find '{target_file}' in any parent directory.") + + +def get_temp_directory() -> str: + # Go to root directory + + return os.path.join(os.getcwd(), "temp") diff --git a/tests/helper_scripts/helper_classes/RequestValidator.py b/tests/helper_scripts/helper_classes/RequestValidator.py index 4c7682e1..3570d0f7 100644 --- a/tests/helper_scripts/helper_classes/RequestValidator.py +++ b/tests/helper_scripts/helper_classes/RequestValidator.py @@ -11,6 +11,7 @@ from marshmallow import Schema from database_client.enums import SortOrder, RequestStatus +from middleware.enums import OutputFormatEnum from middleware.util import update_if_not_none from resources.endpoint_schema_config import SchemaConfigs from tests.helper_scripts.constants import ( @@ -271,6 +272,7 @@ def search( record_categories: Optional[list[RecordCategories]] = None, county: Optional[str] = None, locality: Optional[str] = None, + format: Optional[OutputFormatEnum] = OutputFormatEnum.JSON, ): endpoint_base = "/search/search-location-and-record-type" query_params = self._get_search_query_params( @@ -279,14 +281,17 @@ def search( record_categories=record_categories, state=state, ) + query_params.update({} if format is None else {"output_format": format.value}) endpoint = add_query_params( url=endpoint_base, params=query_params, ) + kwargs = {"return_json": True if format == OutputFormatEnum.JSON else False} return self.get( endpoint=endpoint, headers=headers, expected_schema=SchemaConfigs.SEARCH_LOCATION_AND_RECORD_TYPE_GET.value.primary_output_schema, + **kwargs, ) @staticmethod diff --git a/tests/helper_scripts/run_and_validate_request.py b/tests/helper_scripts/run_and_validate_request.py index d75d6a04..815c9b23 100644 --- a/tests/helper_scripts/run_and_validate_request.py +++ b/tests/helper_scripts/run_and_validate_request.py @@ -19,6 +19,7 @@ def run_and_validate_request( expected_schema: Optional[Union[Type[Schema], Schema]] = None, query_parameters: Optional[dict] = None, file: Optional[TextIO] = None, + return_json: bool = True, **request_kwargs, ): """ @@ -30,6 +31,8 @@ def run_and_validate_request( :param expected_json_content: The expected json content of the response :param expected_schema: An optional Marshmallow schema to validate the response against :param query_parameters: Query parameters, if any, to add to the endpoint + :param file: The file to send, if any + :param return_json: Whether to return the json content of the response, or the raw response data :param request_kwargs: Additional keyword arguments to add to the request :return: The json content of the response """ @@ -53,6 +56,9 @@ def run_and_validate_request( response = flask_client.open(endpoint, method=http_method, **request_kwargs) check_response_status(response, expected_response_status.value) + if not return_json: + return response.data + # All of our requests should return some json message providing information. assert response.json is not None diff --git a/tests/integration/test_search.py b/tests/integration/test_search.py index acf1cd72..8a9663de 100644 --- a/tests/integration/test_search.py +++ b/tests/integration/test_search.py @@ -1,8 +1,12 @@ +import csv from http import HTTPStatus from typing import Optional from marshmallow import Schema +from database_client.enums import LocationType +from middleware.enums import OutputFormatEnum, JurisdictionSimplified +from middleware.util import bytes_to_text_iter, read_from_csv, get_enum_values from resources.endpoint_schema_config import SchemaConfigs from tests.helper_scripts.helper_classes.TestDataCreatorFlask import ( TestDataCreatorFlask, @@ -19,36 +23,42 @@ run_and_validate_request, http_methods, ) -from tests.conftest import flask_client_with_db, bypass_api_key_required from conftest import test_data_creator_flask, monkeysession from utilities.enums import RecordCategories ENDPOINT_SEARCH_LOCATION_AND_RECORD_TYPE = "/search/search-location-and-record-type" -def test_search_get( - test_data_creator_flask: TestDataCreatorFlask, bypass_api_key_required -): +TEST_STATE = "Pennsylvania" +TEST_COUNTY = "Allegheny" +TEST_LOCALITY = "Pittsburgh" + + +def test_search_get(test_data_creator_flask: TestDataCreatorFlask): tdc = test_data_creator_flask tus = tdc.standard_user() - data = tdc.request_validator.search( - headers=tus.api_authorization_header, - state="Pennsylvania", - county="Allegheny", - locality="Pittsburgh", - record_categories=[RecordCategories.POLICE], - ) + def search(record_format: Optional[OutputFormatEnum] = OutputFormatEnum.JSON): + return tdc.request_validator.search( + headers=tus.api_authorization_header, + state=TEST_STATE, + county=TEST_COUNTY, + locality=TEST_LOCALITY, + record_categories=[RecordCategories.POLICE], + format=record_format, + ) + + json_data = search() - jurisdictions = ["federal", "state", "county", "locality"] + jurisdictions = get_enum_values(JurisdictionSimplified) - assert data["count"] > 0 + assert json_data["count"] > 0 jurisdiction_count = 0 for jurisdiction in jurisdictions: - jurisdiction_count += data["data"][jurisdiction]["count"] + jurisdiction_count += json_data["data"][jurisdiction]["count"] - assert jurisdiction_count == data["count"] + assert jurisdiction_count == json_data["count"] # Check that search shows up in user's recent searches data = run_and_validate_request( @@ -63,15 +73,38 @@ def test_search_get( assert data["data"][0] == { "state_iso": "PA", - "county_name": "Allegheny", - "locality_name": "Pittsburgh", - "location_type": "Locality", - "record_categories": ["Police & Public Interactions"], + "county_name": TEST_COUNTY, + "locality_name": TEST_LOCALITY, + "location_type": LocationType.LOCALITY.value, + "record_categories": [RecordCategories.POLICE.value], } + csv_data = search(record_format=OutputFormatEnum.CSV) + + results = read_from_csv(csv_data) + + assert len(results) == json_data["count"] + + # Flatten json data for comparison + flat_json_data = [] + for jurisdiction in jurisdictions: + if json_data["data"][jurisdiction]["count"] == 0: + continue + for result in json_data["data"][jurisdiction]["results"]: + flat_json_data.append(result) + + # Sort both the flat json data and the csv results for comparison + # Due to differences in how CSV and JSON results are formatted, compare only ids + json_ids = sorted([result["id"] for result in flat_json_data]) + csv_ids = sorted( + [int(result["id"]) for result in results] + ) # CSV ids are formatted as strings + + assert json_ids == csv_ids + def test_search_get_record_categories_all( - test_data_creator_flask: TestDataCreatorFlask, bypass_api_key_required + test_data_creator_flask: TestDataCreatorFlask, ): """ All record categories can be provided in one of two ways: @@ -84,9 +117,9 @@ def test_search_get_record_categories_all( def run_search(record_categories: list[RecordCategories]) -> dict: return tdc.request_validator.search( headers=tus.api_authorization_header, - state="Pennsylvania", - county="Allegheny", - locality="Pittsburgh", + state=TEST_STATE, + county=TEST_COUNTY, + locality=TEST_LOCALITY, record_categories=record_categories if len(record_categories) > 0 else None, ) @@ -113,9 +146,9 @@ def test_search_follow(test_data_creator_flask): tus_1 = tdc.standard_user() location_to_follow = { - "state": "Pennsylvania", - "county": "Allegheny", - "locality": "Pittsburgh", + "state": TEST_STATE, + "county": TEST_COUNTY, + "locality": TEST_LOCALITY, } url_for_following = add_query_params( SEARCH_FOLLOW_BASE_ENDPOINT, location_to_follow @@ -179,8 +212,8 @@ def call_follow_get( # User should try to follow a nonexistent location and be denied tdc.request_validator.follow_search( headers=tus_1.jwt_authorization_header, - state="Pennsylvania", - county="Allegheny", + state=TEST_STATE, + county=TEST_COUNTY, locality="Purtsburgh", expected_response_status=HTTPStatus.BAD_REQUEST, expected_json_content={"message": "Location not found."},