diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a03d45d8..a001991c 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,2 +1,2 @@ /client/ @joshuagraber -* @mbodeantor +* @josh-chamberlain diff --git a/.github/workflows/bandit.yaml b/.github/workflows/bandit.yaml new file mode 100644 index 00000000..ed64327c --- /dev/null +++ b/.github/workflows/bandit.yaml @@ -0,0 +1,32 @@ +name: Bandit Security Linting + +on: [pull_request] + +jobs: + bandit: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install bandit + + - name: Run Bandit + run: | + bandit -r middleware resources app.py + + - name: Upload Bandit results + uses: actions/upload-artifact@v2 + with: + name: bandit-report + path: bandit_output.txt + diff --git a/.github/workflows/pull.yaml b/.github/workflows/pull.yaml index 4c76ca61..219d9bff 100644 --- a/.github/workflows/pull.yaml +++ b/.github/workflows/pull.yaml @@ -48,7 +48,7 @@ jobs: - name: Test with pytest run: | pip install pytest pytest-cov - pytest app_test.py --doctest-modules --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html + pytest tests/resources/app_test.py --doctest-modules --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html setup_client: defaults: diff --git a/.github/workflows/python_checks.yml b/.github/workflows/python_checks.yml index 4afec55f..5062258a 100644 --- a/.github/workflows/python_checks.yml +++ b/.github/workflows/python_checks.yml @@ -18,4 +18,5 @@ jobs: uses: reviewdog/action-flake8@v3 with: github_token: ${{ secrets.GITHUB_TOKEN }} + flake8_args: --ignore E501,W291 # Does not check for max line exceed or trailing whitespace level: warning \ No newline at end of file diff --git a/.github/workflows/test_api.yml b/.github/workflows/test_api.yml new file mode 100644 index 00000000..db8f60b1 --- /dev/null +++ b/.github/workflows/test_api.yml @@ -0,0 +1,25 @@ +#name: Test API using Pytest +# +#on: +# pull_request: +# +#jobs: +# test_api: +# env: +# SECRET_KEY: ${{ secrets.SECRET_KEY }} +# DEV_DB_CONN_STRING: ${{secrets.DEV_DB_CONN_STRING}} +# name: Test API +# runs-on: ubuntu-latest +# steps: +# - uses: actions/checkout@v4 +# - uses: actions/setup-python@v4 +# with: +# python-version: '3.11' +# - name: Install dependencies +# run: | +# python -m pip install --upgrade pip +# pip install -r requirements.txt +# python -m spacy download en_core_web_sm +# pip install pytest pytest-cov +# - name: Run tests +# run: pytest tests --doctest-modules --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html diff --git a/README.md b/README.md index cc0837c7..28a038fc 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,6 @@ -# data-sources-app +# data-sources-app-v2 + +Development of the next big iteration of the data sources app according to https://github.com/Police-Data-Accessibility-Project/data-sources-app/issues/248 An API and UI for searching, using, and maintaining Data Sources. @@ -102,9 +104,18 @@ npm run dev ## Testing -All unit tests for the API live in the app_test.py file. It is best practice to add tests for any new feature to ensure it is working as expected and that any future code changes do not affect its functionality. All tests will be automatically run when a PR into dev is opened in order to ensure any changes do not break current app functionality. If a test fails, it is a sign that the new code should be checked or possibly that the test needs to be updated. Tests are currently run with pytest and can be run locally with the `pytest` command. +### Location +All unit and integration tests for the API live in the `tests` folder + +It is best practice to add tests for any new feature to ensure it is working as expected and that any future code changes do not affect its functionality. All tests will be automatically run when a PR into dev is opened in order to ensure any changes do not break current app functionality. If a test fails, it is a sign that the new code should be checked or possibly that the test needs to be updated. + + +### How to run tests +Some tests involve interfacing with the development database, which copies the production database's data and schema daily. + +To ensure such tests properly connect to the database, create or amend an `.env` file in the root direct of the project with the environment variable `DEV_DB_CONN_STRING`. Provide as a value a connection string giving you access to the `data_sources_app` user. If you do not have this connection string, DM a database administrator. -Endpoints are structured for simplified testing and debugging. Code for interacting with the database is contained in a function suffixed with "_results" and tested against a local sqlite database instance. Limited rows (stored in the DATA_SOURCES_ROWS and AGENCIES_ROWS variables in app_test_data.py) are inserted into this local instance on setup, you may need to add additional rows to test other functionality fully. +Tests are currently run with pytest and can be run locally with the `pytest` command. Remaining API code is stored in functions suffixed with "_query" tested against static query results stored in app_test_data.py. Tests for hitting the endpoint directly should be included in regular_api_checks.py, makes sure to add the test function name in the list at the bottom so it is included in the Github actions run every 15 minutes. diff --git a/app.py b/app.py index f3a8f341..14b487ee 100644 --- a/app.py +++ b/app.py @@ -25,9 +25,7 @@ def add_resource(api, resource, endpoint, **kwargs): api.add_resource(resource, endpoint, resource_class_kwargs=kwargs) -def create_app() -> Flask: - psycopg2_connection = initialize_psycopg2_connection() - +def create_app(psycopg2_connection) -> Flask: app = Flask(__name__) api = Api(app) CORS(app) @@ -57,5 +55,5 @@ def create_app() -> Flask: if __name__ == "__main__": - app = create_app() + app = create_app(initialize_psycopg2_connection()) app.run(debug=True, host="0.0.0.0") diff --git a/app_test.py b/app_test.py deleted file mode 100644 index def9aead..00000000 --- a/app_test.py +++ /dev/null @@ -1,424 +0,0 @@ -import pytest -import os -from app import create_app -from flask_restful import Api -from middleware.quick_search_query import ( - unaltered_search_query, - quick_search_query, - QUICK_SEARCH_COLUMNS, -) -from middleware.data_source_queries import ( - data_sources_query, - needs_identification_data_sources, - data_source_by_id_query, - data_source_by_id_results, - DATA_SOURCES_APPROVED_COLUMNS, - get_approved_data_sources, - get_data_sources_for_map, -) -from middleware.user_queries import ( - user_post_results, - user_check_email, -) -from middleware.login_queries import ( - login_results, - create_session_token, - token_results, - is_admin, -) -from middleware.archives_queries import ( - archives_get_results, - archives_get_query, - archives_put_broken_as_of_results, - archives_put_last_cached_results, - ARCHIVES_GET_COLUMNS, -) -from middleware.reset_token_queries import ( - check_reset_token, - add_reset_token, - delete_reset_token, -) -from app_test_data import ( - DATA_SOURCES_ROWS, - DATA_SOURCE_QUERY_RESULTS, - QUICK_SEARCH_QUERY_RESULTS, - AGENCIES_ROWS, - DATA_SOURCES_ID_QUERY_RESULTS, - ARCHIVES_GET_QUERY_RESULTS, -) -import datetime -import sqlite3 -import pytest -from resources.ApiKey import ( - ApiKey, -) # Adjust the import according to your project structure -from werkzeug.security import check_password_hash -from unittest.mock import patch, MagicMock - -api_key = os.getenv("VUE_APP_PDAP_API_KEY") -HEADERS = {"Authorization": f"Bearer {api_key}"} -current_datetime = datetime.datetime.now() -DATETIME_STRING = current_datetime.strftime("%Y-%m-%d %H:%M:%S") - - -@pytest.fixture() -def test_app(): - app = create_app() - yield app - - -@pytest.fixture() -def client(test_app): - return test_app.test_client() - - -@pytest.fixture() -def runner(test_app): - return test_app.test_cli_runner() - - -@pytest.fixture() -def test_app_with_mock(): - # Patch the initialize_psycopg2_connection function so it returns a MagicMock - with patch("app.initialize_psycopg2_connection") as mock_init: - mock_connection = MagicMock() - mock_init.return_value = mock_connection - - app = create_app() - # If your app stores the connection in a global or app context, - # you can also directly assign the mock_connection there - - # Provide access to the mock within the app for assertions in tests - app.mock_connection = mock_connection - - yield app - - -@pytest.fixture() -def client_with_mock(test_app_with_mock): - # Use the app with the mocked database connection to get the test client - return test_app_with_mock.test_client() - - -@pytest.fixture() -def runner_with_mock(test_app_with_mock): - # Use the app with the mocked database connection for the test CLI runner - return test_app_with_mock.test_cli_runner() - - -@pytest.fixture -def session(): - connection = sqlite3.connect("file::memory:?cache=shared", uri=True) - db_session = connection.cursor() - with open("do_db_ddl_clean.sql", "r") as f: - sql_file = f.read() - sql_queries = sql_file.split(";") - for query in sql_queries: - db_session.execute(query.replace("\n", "")) - - for row in DATA_SOURCES_ROWS: - # valid_row = {k: v for k, v in row.items() if k in all_columns} - # clean_row = [r if r is not None else "" for r in row] - fully_clean_row = [str(r) for r in row] - fully_clean_row_str = "'" + "', '".join(fully_clean_row) + "'" - db_session.execute(f"insert into data_sources values ({fully_clean_row_str})") - db_session.execute( - "update data_sources set broken_source_url_as_of = null where broken_source_url_as_of = 'NULL'" - ) - - for row in AGENCIES_ROWS: - clean_row = [r if r is not None else "" for r in row] - fully_clean_row = [str(r) for r in clean_row] - fully_clean_row_str = "'" + "', '".join(fully_clean_row) + "'" - db_session.execute(f"insert into agencies values ({fully_clean_row_str})") - - # sql_query_log = f"INSERT INTO quick_search_query_logs (id, search, location, results, result_count, datetime_of_request, created_at) VALUES (1, 'test', 'test', '', 0, '{DATETIME_STRING}', '{DATETIME_STRING}')" - # db_session.execute(sql_query_log) - - yield connection - connection.close() - - -# unit tests -def test_unaltered_search_query(session): - response = unaltered_search_query(session.cursor(), "calls", "chicago") - - assert response - - -def test_data_sources(session): - response = get_approved_data_sources(conn=session) - - assert response - - -def test_needs_identification(session): - response = needs_identification_data_sources(conn=session) - - assert response - - -def test_data_sources_approved(session): - response = get_approved_data_sources(conn=session) - - assert ( - len([d for d in response if "https://joinstatepolice.ny.gov/15-mile-run" in d]) - == 0 - ) - - -def test_data_source_by_id_results(session): - response = data_source_by_id_results( - data_source_id="rec00T2YLS2jU7Tbn", conn=session - ) - - assert response - - -def test_data_source_by_id_approved(session): - response = data_source_by_id_results( - data_source_id="rec013MFNfBnrTpZj", conn=session - ) - - assert not response - - -def test_data_sources(session): - response = get_data_sources_for_map(conn=session) - - assert response - - -def test_user_post_query(session): - curs = session.cursor() - user_post_results(curs, "unit_test", "unit_test") - - email_check = curs.execute( - f"SELECT email FROM users WHERE email = 'unit_test'" - ).fetchone()[0] - - assert email_check == "unit_test" - - -def test_login_query(session): - curs = session.cursor() - user_data = login_results(curs, "test") - - assert user_data["password_digest"] - - -def test_create_session_token_results(session): - curs = session.cursor() - token = create_session_token(curs, 1, "test") - - curs = session.cursor() - new_token = token_results(curs, token) - - assert new_token["email"] - - -def test_is_admin(session): - curs = session.cursor() - admin = is_admin(curs, "mbodenator@gmail.com") - - assert admin - - -def test_not_admin(session): - curs = session.cursor() - admin = is_admin(curs, "test") - - assert not admin - - -def test_user_check_email(session): - curs = session.cursor() - user_data = user_check_email(curs, "test") - print(user_data) - - assert user_data["id"] - - -def test_check_reset_token(session): - curs = session.cursor() - reset_token = check_reset_token(curs, "test") - print(reset_token) - - assert reset_token["id"] - - -def test_add_reset_token(session): - curs = session.cursor() - add_reset_token(curs, "unit_test", "unit_test") - - email_check = curs.execute( - f"SELECT email FROM reset_tokens WHERE email = 'unit_test'" - ).fetchone()[0] - - assert email_check == "unit_test" - - -def test_delete_reset_token(session): - curs = session.cursor() - delete_reset_token(curs, "test", "test") - - email_check = curs.execute( - f"SELECT email FROM reset_tokens WHERE email = 'test'" - ).fetchone() - - assert not email_check - - -def test_archives_get_results(session): - response = archives_get_results(conn=session) - - assert response - - -def test_archives_put_broken_as_of(session): - archives_put_broken_as_of_results( - id="rec00T2YLS2jU7Tbn", - broken_as_of=DATETIME_STRING, - last_cached=DATETIME_STRING, - conn=session, - ) - curs = session.cursor() - broken_check, last_check = curs.execute( - f"SELECT broken_source_url_as_of, last_cached FROM data_sources WHERE airtable_uid = 'rec00T2YLS2jU7Tbn'" - ).fetchone() - - assert broken_check == DATETIME_STRING - assert last_check == DATETIME_STRING - - -def test_archives_put_last_cached(session): - archives_put_last_cached_results( - id="recUGIoPQbJ6laBmr", last_cached=DATETIME_STRING, conn=session - ) - curs = session.cursor() - last_check = curs.execute( - f"SELECT last_cached FROM data_sources WHERE airtable_uid = 'recUGIoPQbJ6laBmr'" - ).fetchone()[0] - - assert last_check == DATETIME_STRING - - -# quick-search -def test_quicksearch_columns(): - response = quick_search_query( - search="", location="", test_query_results=QUICK_SEARCH_QUERY_RESULTS - ) - - assert not set(QUICK_SEARCH_COLUMNS).difference(response["data"][0].keys()) - assert type(response["data"][1]["record_format"]) == list - - -# data-sources -def test_data_sources_columns(): - response = data_sources_query(conn={}, test_query_results=DATA_SOURCE_QUERY_RESULTS) - - assert not set(DATA_SOURCES_APPROVED_COLUMNS).difference(response[0].keys()) - - -def test_data_source_by_id_columns(): - response = data_source_by_id_query("", DATA_SOURCES_ID_QUERY_RESULTS, {}) - - assert not set(DATA_SOURCES_APPROVED_COLUMNS).difference(response.keys()) - - -# user - - -# def test_post_user(client): -# response = client.post( -# "/user", headers=HEADERS, json={"email": "test", "password": "test"} -# ) - -# # with initialize_psycopg2_connection() as psycopg2_connection: -# # cursor = psycopg2_connection.cursor() -# # cursor.execute(f"DELETE FROM users WHERE email = 'test'") -# # psycopg2_connection.commit() - -# assert response.json["data"] == "Successfully added user" - - -# archives -def test_archives_get_columns(): - response = archives_get_query( - test_query_results=ARCHIVES_GET_QUERY_RESULTS, conn={} - ) - - assert not set(ARCHIVES_GET_COLUMNS).difference(response[0].keys()) - - -# def test_put_archives(client): -# current_datetime = datetime.datetime.now() -# datetime_string = current_datetime.strftime("%Y-%m-%d %H:%M:%S") -# response = client.put( -# "/archives", -# headers=HEADERS, -# json=json.dumps( -# { -# "id": "test", -# "last_cached": datetime_string, -# "broken_source_url_as_of": "", -# } -# ), -# ) - -# assert response.json["status"] == "success" - - -# def test_put_archives_brokenasof(client): -# current_datetime = datetime.datetime.now() -# datetime_string = current_datetime.strftime("%Y-%m-%d") -# response = client.put( -# "/archives", -# headers=HEADERS, -# json=json.dumps( -# { -# "id": "test", -# "last_cached": datetime_string, -# "broken_source_url_as_of": datetime_string, -# } -# ), -# ) - -# assert response.json["status"] == "success" - - -# # agencies -# def test_agencies(client): -# response = client.get("/agencies/1", headers=HEADERS) - -# assert len(response.json["data"]) > 0 - - -# def test_agencies_pagination(client): -# response1 = client.get("/agencies/1", headers=HEADERS) -# response2 = client.get("/agencies/2", headers=HEADERS) - -# assert response1 != response2 - -# region Resources - - -def test_get_api_key(client_with_mock, mocker, test_app_with_mock): - mock_request_data = {"email": "user@example.com", "password": "password"} - mock_user_data = {"id": 1, "password_digest": "hashed_password"} - - # Mock login_results function to return mock_user_data - mocker.patch("resources.ApiKey.login_results", return_value=mock_user_data) - # Mock check_password_hash based on the valid_login parameter - mocker.patch("resources.ApiKey.check_password_hash", return_value=True) - - with client_with_mock: - response = client_with_mock.get("/api_key", json=mock_request_data) - json_data = response.get_json() - assert "api_key" in json_data - assert response.status_code == 200 - test_app_with_mock.mock_connection.cursor().execute.assert_called_once() - test_app_with_mock.mock_connection.commit.assert_called_once() - - -# endregion diff --git a/conftest.py b/conftest.py new file mode 100644 index 00000000..99a0a6c0 --- /dev/null +++ b/conftest.py @@ -0,0 +1,44 @@ +import os + +import dotenv +import pytest +from sqlalchemy.orm import sessionmaker, scoped_session + +from middleware.models import db + +from app import create_app + + +# Load environment variables +dotenv.load_dotenv() + + +@pytest.fixture(scope="module") +def test_client(): + app = create_app() + app.config["SQLALCHEMY_DATABASE_URI"] = os.getenv( + "DEV_DB_CONN_STRING" + ) # Connect to pre-existing test database + app.config["TESTING"] = True + + db.init_app(app) + + with app.test_client() as testing_client: + with app.app_context(): + yield testing_client + + +@pytest.fixture +def session(): + connection = db.engine.connect() + transaction = connection.begin() + session = scoped_session(sessionmaker(bind=connection)) + + # Overwrite the db.session with the scoped session + db.session = session + + yield session + + session.close() + transaction.rollback() + connection.close() diff --git a/middleware/access_token_logic.py b/middleware/access_token_logic.py new file mode 100644 index 00000000..ceef6808 --- /dev/null +++ b/middleware/access_token_logic.py @@ -0,0 +1,11 @@ +import datetime +import uuid + + +def insert_access_token(cursor): + token = uuid.uuid4().hex + expiration = datetime.datetime.now() + datetime.timedelta(minutes=5) + cursor.execute( + f"insert into access_tokens (token, expiration_date) values (%s, %s)", + (token, expiration), + ) diff --git a/middleware/archives_queries.py b/middleware/archives_queries.py index 22cc6226..f350468a 100644 --- a/middleware/archives_queries.py +++ b/middleware/archives_queries.py @@ -37,19 +37,15 @@ def archives_get_results(conn: PgConnection) -> list[tuple[Any, ...]]: def archives_get_query( - test_query_results: Optional[List[Dict[str, Any]]] = None, conn: Optional[PgConnection] = None, ) -> List[Dict[str, Any]]: """ - Processes the archives get results, either from the database or a provided set of test results, and converts dates to strings. + Processes the archives get results, either from the database and converts dates to strings. - :param test_query_results: A list of dictionaries representing test query results, if any. :param conn: A psycopg2 connection object to a PostgreSQL database. :return: A list of dictionaries with the query results after processing and date conversion. """ - results = ( - archives_get_results(conn) if not test_query_results else test_query_results - ) + results = archives_get_results(conn) archives_combined_results = [ dict(zip(ARCHIVES_GET_COLUMNS, result)) for result in results ] @@ -72,8 +68,15 @@ def archives_put_broken_as_of_results( :param conn: A psycopg2 connection object to a PostgreSQL database. """ cursor = conn.cursor() - sql_query = "UPDATE data_sources SET url_status = 'broken', broken_source_url_as_of = '{0}', last_cached = '{1}' WHERE airtable_uid = '{2}'" - cursor.execute(sql_query.format(broken_as_of, last_cached, id)) + sql_query = """ + UPDATE data_sources + SET + url_status = 'broken', + broken_source_url_as_of = %s, + last_cached = %s + WHERE airtable_uid = %s + """ + cursor.execute(sql_query, (broken_as_of, last_cached, id)) cursor.close() @@ -88,28 +91,6 @@ def archives_put_last_cached_results( :param conn: A psycopg2 connection object to a PostgreSQL database. """ cursor = conn.cursor() - sql_query = "UPDATE data_sources SET last_cached = '{0}' WHERE airtable_uid = '{1}'" - cursor.execute(sql_query.format(last_cached, id)) + sql_query = "UPDATE data_sources SET last_cached = %s WHERE airtable_uid = %s" + cursor.execute(sql_query, (last_cached, id)) cursor.close() - - -def archives_put_query( - id: str = "", - broken_as_of: str = "", - last_cached: str = "", - conn: Optional[PgConnection] = None, -) -> None: - """ - Updates the data_sources table based on the provided parameters, marking sources as broken or updating the last cached date. - - :param id: The airtable_uid of the data source. - :param broken_as_of: The date when the source was identified as broken, if applicable. - :param last_cached: The last cached date to be updated. - :param conn: A psycopg2 connection object to a PostgreSQL database. - """ - if broken_as_of: - archives_put_broken_as_of_results(id, broken_as_of, last_cached, conn) - else: - archives_put_last_cached_results(id, last_cached, conn) - - conn.commit() diff --git a/middleware/custom_exceptions.py b/middleware/custom_exceptions.py new file mode 100644 index 00000000..ec9f86a0 --- /dev/null +++ b/middleware/custom_exceptions.py @@ -0,0 +1,15 @@ +class UserNotFoundError(Exception): + """Exception raised for errors in the input.""" + + def __init__(self, email, message=""): + if message == "": + message = f"User with email {email} not found" + self.email = email + self.message = message.format(email=self.email) + super().__init__(self.message) + + +class TokenNotFoundError(Exception): + """Raised when the token is not found in the database.""" + + pass diff --git a/middleware/data_source_queries.py b/middleware/data_source_queries.py index 1c411c8f..ff823bce 100644 --- a/middleware/data_source_queries.py +++ b/middleware/data_source_queries.py @@ -1,4 +1,8 @@ from typing import List, Dict, Any, Optional, Tuple, Union + +from flask import make_response, Response +from sqlalchemy.dialects.postgresql import psycopg2 + from utilities.common import convert_dates_to_strings, format_arrays from psycopg2.extensions import connection as PgConnection @@ -46,6 +50,8 @@ "last_cached", ] +DATA_SOURCES_OUTPUT_COLUMNS = DATA_SOURCES_APPROVED_COLUMNS + ["agency_name"] + AGENCY_APPROVED_COLUMNS = [ "homepage_url", "count_data_sources", @@ -72,6 +78,51 @@ "defunct_year", ] +DATA_SOURCES_MAP_COLUMN = [ + "data_source_id", + "name", + "agency_id", + "agency_name", + "state_iso", + "municipality", + "county_name", + "record_type", + "lat", + "lng", +] + + +def get_approved_data_sources_wrapper(conn: PgConnection): + data_source_matches = get_approved_data_sources(conn) + + return make_response( + { + "count": len(data_source_matches), + "data": data_source_matches, + }, + 200, + ) + + +def data_source_by_id_wrapper(arg, conn: PgConnection) -> Response: + data_source_details = data_source_by_id_query(data_source_id=arg, conn=conn) + if data_source_details: + return make_response(data_source_details, 200) + + else: + return make_response({"message": "Data source not found."}, 200) + + +def get_data_sources_for_map_wrapper(conn: PgConnection): + data_source_details = get_data_sources_for_map(conn) + return make_response( + { + "count": len(data_source_details), + "data": data_source_details, + }, + 200, + ) + def data_source_by_id_results( conn: PgConnection, data_source_id: str @@ -108,12 +159,12 @@ def data_source_by_id_results( INNER JOIN agencies ON agency_source_link.agency_described_linked_uid = agencies.airtable_uid WHERE - data_sources.approval_status = 'approved' AND data_sources.airtable_uid = '{1}' + data_sources.approval_status = 'approved' AND data_sources.airtable_uid = %s """.format( - joined_column_names, data_source_id + joined_column_names ) - cursor.execute(sql_query) + cursor.execute(sql_query, (data_source_id,)) result = cursor.fetchone() cursor.close() @@ -122,35 +173,28 @@ def data_source_by_id_results( def data_source_by_id_query( data_source_id: str = "", - test_query_results: Optional[List[Dict[str, Any]]] = None, conn: Optional[PgConnection] = None, ) -> Dict[str, Any]: """ - Processes a request to fetch data source details by ID, either from the database or provided test results. + Processes a request to fetch data source details by ID from the database :param data_source_id: The unique identifier for the data source. - :param test_query_results: A list of dictionaries representing test query results, if provided. :param conn: A psycopg2 connection object to a PostgreSQL database. :return: A dictionary with the data source details after processing. """ - if conn: - result = data_source_by_id_results(conn, data_source_id) - else: - result = test_query_results - - if result: - data_source_and_agency_columns = ( - DATA_SOURCES_APPROVED_COLUMNS + AGENCY_APPROVED_COLUMNS - ) - data_source_and_agency_columns.append("data_source_id") - data_source_and_agency_columns.append("agency_id") - data_source_and_agency_columns.append("agency_name") - data_source_details = dict(zip(data_source_and_agency_columns, result)) - data_source_details = convert_dates_to_strings(data_source_details) - data_source_details = format_arrays(data_source_details) + result = data_source_by_id_results(conn, data_source_id) + if not result: + return [] - else: - data_source_details = [] + data_source_and_agency_columns = ( + DATA_SOURCES_APPROVED_COLUMNS + AGENCY_APPROVED_COLUMNS + ) + data_source_and_agency_columns.append("data_source_id") + data_source_and_agency_columns.append("agency_id") + data_source_and_agency_columns.append("agency_name") + data_source_details = dict(zip(data_source_and_agency_columns, result)) + data_source_details = convert_dates_to_strings(data_source_details) + data_source_details = format_arrays(data_source_details) return data_source_details @@ -189,7 +233,7 @@ def get_approved_data_sources(conn: PgConnection) -> list[tuple[Any, ...]]: results = cursor.fetchall() cursor.close() - return results + return convert_data_source_matches(DATA_SOURCES_OUTPUT_COLUMNS, results) def needs_identification_data_sources(conn) -> list: @@ -213,7 +257,7 @@ def needs_identification_data_sources(conn) -> list: results = cursor.fetchall() cursor.close() - return results + return convert_data_source_matches(DATA_SOURCES_OUTPUT_COLUMNS, results) def get_data_sources_for_map(conn) -> list: @@ -246,55 +290,25 @@ def get_data_sources_for_map(conn) -> list: results = cursor.fetchall() cursor.close() - return results + return convert_data_source_matches(DATA_SOURCES_MAP_COLUMN, results) -def data_sources_query( - conn: Optional[PgConnection] = None, - test_query_results: Optional[List[Dict[str, Any]]] = None, - approval_status: str = "approved", - for_map: bool = False, -) -> List[Dict[str, Any]]: +def convert_data_source_matches( + data_source_output_columns: list[str], results: list[tuple] +) -> dict: """ - Processes and formats a list of approved data sources, with an option to use test query results. - - :param approval_status: The approval status of the data sources to query. - :param conn: Optional psycopg2 connection object to a PostgreSQL database. - :param test_query_results: Optional list of test query results to use instead of querying the database. - :return: A list of dictionaries, each formatted with details of a data source and its associated agency. + Combine a list of output columns with a list of results, + and produce a list of dictionaries where the keys correspond + to the output columns and the values correspond to the results + :param data_source_output_columns: + :param results: + :return: """ - if for_map: - results = get_data_sources_for_map(conn) - elif conn and approval_status == "approved": - results = get_approved_data_sources(conn) - elif conn and not for_map: - results = needs_identification_data_sources(conn) - else: - results = test_query_results - - if not for_map: - data_source_output_columns = DATA_SOURCES_APPROVED_COLUMNS + ["agency_name"] - else: - data_source_output_columns = [ - "data_source_id", - "name", - "agency_id", - "agency_name", - "state_iso", - "municipality", - "county_name", - "record_type", - "lat", - "lng", - ] - data_source_matches = [ dict(zip(data_source_output_columns, result)) for result in results ] data_source_matches_converted = [] - for data_source_match in data_source_matches: data_source_match = convert_dates_to_strings(data_source_match) data_source_matches_converted.append(format_arrays(data_source_match)) - return data_source_matches_converted diff --git a/middleware/initialize_psycopg2_connection.py b/middleware/initialize_psycopg2_connection.py index 6b6e1966..7d465091 100644 --- a/middleware/initialize_psycopg2_connection.py +++ b/middleware/initialize_psycopg2_connection.py @@ -4,9 +4,17 @@ from typing import Union, Dict, List -def initialize_psycopg2_connection() -> ( - Union[PgConnection, Dict[str, Union[int, List]]] -): +class DatabaseInitializationError(Exception): + """ + Custom Exception to be raised when psycopg2 connection initialization fails. + """ + + def __init__(self, message="Failed to initialize psycopg2 connection."): + self.message = message + super().__init__(self.message) + + +def initialize_psycopg2_connection() -> PgConnection: """ Initializes a connection to a PostgreSQL database using psycopg2 with connection parameters obtained from an environment variable. If the connection fails, it returns a default dictionary @@ -27,8 +35,5 @@ def initialize_psycopg2_connection() -> ( keepalives_count=5, ) - except: - print("Error while initializing the DigitalOcean client with psycopg2.") - data_sources = {"count": 0, "data": []} - - return data_sources + except psycopg2.OperationalError as e: + raise DatabaseInitializationError(e) from e diff --git a/middleware/login_queries.py b/middleware/login_queries.py index 77c41d57..679525c9 100644 --- a/middleware/login_queries.py +++ b/middleware/login_queries.py @@ -1,9 +1,14 @@ +from collections import namedtuple +from datetime import datetime as dt + import jwt import os import datetime from typing import Union, Dict from psycopg2.extensions import cursor as PgCursor +from middleware.custom_exceptions import UserNotFoundError, TokenNotFoundError + def login_results(cursor: PgCursor, email: str) -> Dict[str, Union[int, str]]: """ @@ -14,21 +19,19 @@ def login_results(cursor: PgCursor, email: str) -> Dict[str, Union[int, str]]: :return: A dictionary containing user data or an error message. """ cursor.execute( - f"select id, password_digest, api_key from users where email = '{email}'" + f"select id, password_digest, api_key from users where email = %s", (email,) ) results = cursor.fetchall() - if len(results) > 0: - user_data = { - "id": results[0][0], - "password_digest": results[0][1], - "api_key": results[0][2], - } - return user_data - else: - return {"error": "no match"} - - -def is_admin(cursor: PgCursor, email: str) -> Union[bool, Dict[str, str]]: + if len(results) == 0: + raise UserNotFoundError(email) + return { + "id": results[0][0], + "password_digest": results[0][1], + "api_key": results[0][2], + } + + +def is_admin(cursor: PgCursor, email: str) -> bool: """ Checks if a user has an admin role. @@ -36,16 +39,15 @@ def is_admin(cursor: PgCursor, email: str) -> Union[bool, Dict[str, str]]: :param email: User's email. :return: True if user is an admin, False if not, or an error message. """ - cursor.execute(f"select role from users where email = '{email}'") + cursor.execute(f"select role from users where email = %s", (email,)) results = cursor.fetchall() - if len(results) > 0: + try: role = results[0][0] if role == "admin": return True return False - - else: - return {"error": "no match"} + except IndexError: + raise UserNotFoundError(email) def create_session_token(cursor: PgCursor, user_id: int, email: str) -> str: @@ -65,27 +67,30 @@ def create_session_token(cursor: PgCursor, user_id: int, email: str) -> str: } session_token = jwt.encode(payload, os.getenv("SECRET_KEY"), algorithm="HS256") cursor.execute( - f"insert into session_tokens (token, email, expiration_date) values ('{session_token}', '{email}', '{expiration}')" + f"insert into session_tokens (token, email, expiration_date) values (%s, %s, %s)", + (session_token, email, expiration), ) return session_token -def token_results(cursor: PgCursor, token: str) -> Dict[str, Union[int, str]]: +SessionTokenUserData = namedtuple("SessionTokenUserData", ["id", "email"]) + + +def get_session_token_user_data(cursor: PgCursor, token: str) -> SessionTokenUserData: """ Retrieves session token data. :param cursor: A cursor object from a psycopg2 connection. :param token: The session token. - :return: A dictionary containing session token data or an error message. + :return: Session token data or an error message. """ - cursor.execute(f"select id, email from session_tokens where token = '{token}'") + cursor.execute(f"select id, email from session_tokens where token = %s", (token,)) results = cursor.fetchall() - if len(results) > 0: - user_data = { - "id": results[0][0], - "email": results[0][1], - } - return user_data - else: - return {"error": "no match"} + if len(results) == 0: + raise TokenNotFoundError("The specified token was not found.") + return SessionTokenUserData(id=results[0][0], email=results[0][1]) + + +def delete_session_token(cursor, old_token): + cursor.execute(f"delete from session_tokens where token = '{old_token}'") diff --git a/middleware/models.py b/middleware/models.py new file mode 100644 index 00000000..c9b2480e --- /dev/null +++ b/middleware/models.py @@ -0,0 +1,27 @@ +from flask_sqlalchemy import SQLAlchemy +from sqlalchemy import ( + Column, + BigInteger, + text, + Text, + String, +) +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.dialects.postgresql import TIMESTAMP + +db = SQLAlchemy() + +Base = declarative_base() + + +class User(Base): + __tablename__ = "users" + __table_args__ = {"schema": "public"} + + id = Column(BigInteger, primary_key=True, autoincrement=True) + created_at = Column(TIMESTAMP(timezone=True), server_default=text("now()")) + updated_at = Column(TIMESTAMP(timezone=True), server_default=text("now()")) + email = Column(Text, nullable=False, unique=True) + password_digest = Column(Text) + api_key = Column(String) + role = Column(Text) diff --git a/middleware/quick_search_query.py b/middleware/quick_search_query.py index 4584c097..dc09d385 100644 --- a/middleware/quick_search_query.py +++ b/middleware/quick_search_query.py @@ -1,6 +1,11 @@ import spacy import json import datetime + +from flask import make_response, Response +from sqlalchemy.dialects.postgresql import psycopg2 + +from middleware.webhook_logic import post_to_webhook from utilities.common import convert_dates_to_strings, format_arrays from typing import List, Dict, Any, Optional from psycopg2.extensions import connection as PgConnection, cursor as PgCursor @@ -52,7 +57,7 @@ """ -INSERT_LOG_QUERY = "INSERT INTO quick_search_query_logs (search, location, results, result_count, created_at, datetime_of_request) VALUES ('{0}', '{1}', '{2}', '{3}', '{4}', '{4}')" +INSERT_LOG_QUERY = "INSERT INTO quick_search_query_logs (search, location, results, result_count) VALUES ('{0}', '{1}', '{2}', '{3}')" def unaltered_search_query( @@ -105,18 +110,14 @@ def spacy_search_query( def quick_search_query( search: str = "", location: str = "", - test_query_results: Optional[List[Dict[str, Any]]] = None, conn: Optional[PgConnection] = None, - test: bool = False, ) -> Dict[str, Any]: """ Performs a quick search using both unaltered and lemmatized search terms, returning the more fruitful result set. :param search: The search term. :param location: The location term. - :param test_query_results: Predefined results for testing purposes. :param conn: A psycopg2 connection to the database. - :param test: Flag indicating whether the function is being called in a test context. :return: A dictionary with the count of results and the data itself. """ data_sources = {"count": 0, "data": []} @@ -129,16 +130,8 @@ def quick_search_query( if conn: cursor = conn.cursor() - unaltered_results = ( - unaltered_search_query(cursor, search, location) - if not test_query_results - else test_query_results - ) - spacy_results = ( - spacy_search_query(cursor, search, location) - if not test_query_results - else test_query_results - ) + unaltered_results = unaltered_search_query(cursor, search, location) + spacy_results = spacy_search_query(cursor, search, location) # Compare altered search term results with unaltered search term results, return the longer list results = ( @@ -160,18 +153,34 @@ def quick_search_query( "data": data_source_matches_converted, } - if not test_query_results and not test: - current_datetime = datetime.datetime.now() - datetime_string = current_datetime.strftime("%Y-%m-%d %H:%M:%S") - - query_results = json.dumps(data_sources["data"]).replace("'", "") + query_results = json.dumps(data_sources["data"]).replace("'", "") - cursor.execute( - INSERT_LOG_QUERY.format( - search, location, query_results, data_sources["count"], datetime_string - ), - ) - conn.commit() - cursor.close() + cursor.execute( + INSERT_LOG_QUERY.format(search, location, query_results, data_sources["count"]), + ) + conn.commit() + cursor.close() return data_sources + + +def quick_search_query_wrapper(arg1, arg2, conn: PgConnection) -> Response: + try: + data_sources = quick_search_query(search=arg1, location=arg2, conn=conn) + + return make_response(data_sources, 200) + + except Exception as e: + conn.rollback() + user_message = "There was an error during the search operation" + message = { + "content": user_message + + ": " + + str(e) + + "\n" + + f"Search term: {arg1}\n" + + f"Location: {arg2}" + } + post_to_webhook(json.dumps(message)) + + return make_response({"count": 0, "message": user_message}, 500) diff --git a/middleware/reset_token_queries.py b/middleware/reset_token_queries.py index bc0a4762..573ec5af 100644 --- a/middleware/reset_token_queries.py +++ b/middleware/reset_token_queries.py @@ -1,6 +1,8 @@ from psycopg2.extensions import cursor as PgCursor from typing import Dict, Union +from middleware.custom_exceptions import TokenNotFoundError + def check_reset_token(cursor: PgCursor, token: str) -> Dict[str, Union[int, str]]: """ @@ -11,18 +13,16 @@ def check_reset_token(cursor: PgCursor, token: str) -> Dict[str, Union[int, str] :return: A dictionary containing the user's ID, token creation date, and email if the token exists; otherwise, an error message. """ cursor.execute( - f"select id, create_date, email from reset_tokens where token = '{token}'" + f"select id, create_date, email from reset_tokens where token = %s", (token,) ) results = cursor.fetchall() - if len(results) > 0: - user_data = { - "id": results[0][0], - "create_date": results[0][1], - "email": results[0][2], - } - return user_data - else: - return {"error": "no match"} + if len(results) == 0: + raise TokenNotFoundError("The specified token was not found.") + return { + "id": results[0][0], + "create_date": results[0][1], + "email": results[0][2], + } def add_reset_token(cursor: PgCursor, email: str, token: str) -> None: @@ -34,7 +34,7 @@ def add_reset_token(cursor: PgCursor, email: str, token: str) -> None: :param token: The reset token to add. """ cursor.execute( - f"insert into reset_tokens (email, token) values ('{email}', '{token}')" + f"insert into reset_tokens (email, token) values (%s, %s)", (email, token) ) return @@ -49,7 +49,7 @@ def delete_reset_token(cursor: PgCursor, email: str, token: str) -> None: :param token: The reset token to delete. """ cursor.execute( - f"delete from reset_tokens where email = '{email}' and token = '{token}'" + f"delete from reset_tokens where email = %s and token = %s", (email, token) ) return diff --git a/middleware/security.py b/middleware/security.py index 3aa0e430..c57731a9 100644 --- a/middleware/security.py +++ b/middleware/security.py @@ -1,16 +1,33 @@ import functools -from hmac import compare_digest -from flask import request, jsonify +from collections import namedtuple + +from http import HTTPStatus +from flask import request from middleware.initialize_psycopg2_connection import initialize_psycopg2_connection from datetime import datetime as dt from middleware.login_queries import is_admin -import os -from typing import Tuple -from flask.wrappers import Response -from psycopg2.extensions import cursor as PgCursor +from typing import Tuple, Optional + +APIKeyStatus = namedtuple("APIKeyStatus", ["is_valid", "is_expired"]) + + +class NoAPIKeyError(Exception): + pass + + +class ExpiredAPIKeyError(Exception): + pass + +class InvalidAPIKeyError(Exception): + pass -def is_valid(api_key: str, endpoint: str, method: str) -> Tuple[bool, bool]: + +class InvalidRoleError(Exception): + pass + + +def validate_api_key(api_key: str, endpoint: str, method: str): """ Validates the API key and checks if the user has the required role to access a specific endpoint. @@ -19,52 +36,135 @@ def is_valid(api_key: str, endpoint: str, method: str) -> Tuple[bool, bool]: :param method: The HTTP method of the request. :return: A tuple (isValid, isExpired) indicating whether the API key is valid and not expired. """ - if not api_key: - return False, False psycopg2_connection = initialize_psycopg2_connection() cursor = psycopg2_connection.cursor() - cursor.execute(f"select id, api_key, role from users where api_key = '{api_key}'") - results = cursor.fetchall() - if len(results) > 0: - role = results[0][2] + role = get_role(api_key, cursor) + if role: + validate_role(role, endpoint, method) + return - if not results: - cursor.execute( - f"select email, expiration_date from session_tokens where token = '{api_key}'" - ) - results = cursor.fetchall() - if len(results) > 0: - email = results[0][0] - expiration_date = results[0][1] - print(expiration_date, dt.utcnow()) - - if expiration_date < dt.utcnow(): - return False, True - - if is_admin(cursor, email): - role = "admin" - - if not results: - cursor.execute(f"select id, token from access_tokens where token = '{api_key}'") - results = cursor.fetchall() - cursor.execute( - f"delete from access_tokens where expiration_date < '{dt.utcnow()}'" - ) - psycopg2_connection.commit() + session_token_results = get_session_token(api_key, cursor) + if session_token_results: + + if session_token_results.expiration_date < dt.utcnow(): + raise ExpiredAPIKeyError("Session token expired") + + if is_admin(cursor, session_token_results.email): + validate_role(role="admin", endpoint=endpoint, method=method) + return + + if not session_token_results: + delete_expired_access_tokens(cursor, psycopg2_connection) + access_token = get_access_token(api_key, cursor) role = "user" - if not results: - return False, False + if not access_token: + raise InvalidAPIKeyError("API Key not found") - if endpoint in ("datasources", "datasourcebyid") and method in ("PUT", "POST"): - if role != "admin": - return False, False + validate_role(role, endpoint, method) + +def validate_role(role: str, endpoint: str, method: str): # Compare the API key in the user table to the API in the request header and proceed # through the protected route if it's valid. Otherwise, compare_digest will return False # and api_required will send an error message to provide a valid API key - return True, False + if is_admin_only_action(endpoint, method) and role != "admin": + raise InvalidRoleError("You do not have permission to access this endpoint") + + + +def get_role(api_key, cursor): + cursor.execute(f"select id, api_key, role from users where api_key = '{api_key}'") + user_results = cursor.fetchall() + if len(user_results) > 0: + role = user_results[0][2] + if role is None: + return "user" + return role + return None + + +SessionTokenResults = namedtuple("SessionTokenResults", ["email", "expiration_date"]) + + +def get_session_token(api_key, cursor) -> Optional[SessionTokenResults]: + cursor.execute( + f"select email, expiration_date from session_tokens where token = %s", + (api_key,), + ) + session_token_results = cursor.fetchall() + if len(session_token_results) > 0: + return SessionTokenResults( + email=session_token_results[0][0], + expiration_date=session_token_results[0][1], + ) + return None + + +def get_access_token(api_key, cursor): + cursor.execute(f"select id, token from access_tokens where token = %s", (api_key,)) + results = cursor.fetchone() + if results: + return results[1] + return None + + +def delete_expired_access_tokens(cursor, psycopg2_connection): + cursor.execute(f"delete from access_tokens where expiration_date < NOW()") + psycopg2_connection.commit() + + +def is_admin_only_action(endpoint, method): + return endpoint in ("datasources", "datasourcebyid") and method in ("PUT", "POST") + + +class InvalidHeader(Exception): + + def __init__(self, message: str): + super().__init__(message) + + +def validate_header() -> str: + """ + Validates the API key and checks if the user has the required role to access a specific endpoint. + :return: + """ + if not request.headers or "Authorization" not in request.headers: + raise InvalidHeader( + "Please provide an 'Authorization' key in the request header" + ) + + authorization_header = request.headers["Authorization"].split(" ") + if len(authorization_header) < 2 or authorization_header[0] != "Bearer": + raise InvalidHeader( + "Please provide a properly formatted bearer token and API key" + ) + + api_key = authorization_header[1] + if api_key == "undefined": + raise InvalidHeader("Please provide an API key") + return api_key + + +def validate_token() -> Optional[Tuple[dict, int]]: + """ + Validates the API key and checks if the user has the required role to access a specific endpoint. + :return: + """ + try: + api_key = validate_header() + except InvalidHeader as e: + return {"message": str(e)}, HTTPStatus.BAD_REQUEST.value + # Check if API key is correct and valid + try: + validate_api_key(api_key, request.endpoint, request.method) + except ExpiredAPIKeyError as e: + return {"message": str(e)}, HTTPStatus.UNAUTHORIZED.value + except InvalidRoleError as e: + return {"message": str(e)}, HTTPStatus.FORBIDDEN.value + + return None def api_required(func): @@ -77,28 +177,9 @@ def api_required(func): @functools.wraps(func) def decorator(*args, **kwargs): - api_key = None - if request.headers and "Authorization" in request.headers: - authorization_header = request.headers["Authorization"].split(" ") - if len(authorization_header) >= 2 and authorization_header[0] == "Bearer": - api_key = request.headers["Authorization"].split(" ")[1] - if api_key == "undefined": - return {"message": "Please provide an API key"}, 400 - else: - return { - "message": "Please provide a properly formatted bearer token and API key" - }, 400 - else: - return { - "message": "Please provide an 'Authorization' key in the request header" - }, 400 - # Check if API key is correct and valid - valid, expired = is_valid(api_key, request.endpoint, request.method) - if valid: - return func(*args, **kwargs) - else: - if expired: - return {"message": "The provided API key has expired"}, 401 - return {"message": "The provided API key is not valid"}, 403 + validation_error = validate_token() + if validation_error: + return validation_error + return func(*args, **kwargs) return decorator diff --git a/middleware/token_management.py b/middleware/token_management.py new file mode 100644 index 00000000..1e22cd96 --- /dev/null +++ b/middleware/token_management.py @@ -0,0 +1,12 @@ +import datetime +import uuid + + +def insert_new_access_token(cursor): + token = uuid.uuid4().hex + expiration = datetime.datetime.now() + datetime.timedelta(minutes=5) + cursor.execute( + f"insert into access_tokens (token, expiration_date) values (%s, %s)", + (token, expiration), + ) + return token diff --git a/middleware/user_queries.py b/middleware/user_queries.py index be050fe4..388d962a 100644 --- a/middleware/user_queries.py +++ b/middleware/user_queries.py @@ -2,6 +2,8 @@ from psycopg2.extensions import cursor as PgCursor from typing import Dict +from middleware.custom_exceptions import UserNotFoundError + def user_check_email(cursor: PgCursor, email: str) -> Dict[str, str]: """ @@ -11,13 +13,11 @@ def user_check_email(cursor: PgCursor, email: str) -> Dict[str, str]: :param email: The email address to check against the users in the database. :return: A dictionary with the user's ID if found, otherwise an error message. """ - cursor.execute(f"select id from users where email = '{email}'") + cursor.execute(f"select id from users where email = %s", (email,)) results = cursor.fetchall() - if len(results) > 0: - user_data = {"id": results[0][0]} - return user_data - else: - return {"error": "no match"} + if len(results) == 0: + raise UserNotFoundError(email) + return {"id": results[0][0]} def user_post_results(cursor: PgCursor, email: str, password: str) -> None: @@ -30,7 +30,8 @@ def user_post_results(cursor: PgCursor, email: str, password: str) -> None: """ password_digest = generate_password_hash(password) cursor.execute( - f"insert into users (email, password_digest) values ('{email}', '{password_digest}')" + f"insert into users (email, password_digest) values (%s, %s)", + (email, password_digest), ) return diff --git a/middleware/webhook_logic.py b/middleware/webhook_logic.py new file mode 100644 index 00000000..b30a022b --- /dev/null +++ b/middleware/webhook_logic.py @@ -0,0 +1,14 @@ +import json +import os + +import requests + + +def post_to_webhook(data: str): + webhook_url = os.getenv("WEBHOOK_URL") + + requests.post( + webhook_url, + data=data, + headers={"Content-Type": "application/json"}, + ) diff --git a/requirements.txt b/requirements.txt index bd3aca88..7f29cd45 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,7 @@ exceptiongroup==1.1.3 Flask==2.3.2 Flask-Cors==4.0.0 Flask-RESTful==0.3.10 +Flask-SQLAlchemy~=3.1.1 gotrue==1.0.3 gunicorn==21.2.0 h11==0.14.0 @@ -40,7 +41,7 @@ pathy==0.10.2 pluggy==1.2.0 postgrest==0.10.8 preshed==3.0.8 -psycopg2==2.9.7 +psycopg2-binary==2.9.7 py==1.11.0 pycparser==2.21 pydantic==2.2.1 @@ -75,4 +76,4 @@ wasabi==1.1.2 websockets==10.4 Werkzeug==3.0.1 zipp==3.16.2 -pytest-mock~=3.12.0 \ No newline at end of file +pytest-mock~=3.12.0 diff --git a/resources/Agencies.py b/resources/Agencies.py index 9e88f86c..76acbcac 100644 --- a/resources/Agencies.py +++ b/resources/Agencies.py @@ -1,5 +1,5 @@ from middleware.security import api_required -from resources.PsycopgResource import PsycopgResource +from resources.PsycopgResource import PsycopgResource, handle_exceptions from utilities.common import convert_dates_to_strings from typing import Dict, Any @@ -36,6 +36,7 @@ class Agencies(PsycopgResource): """Represents a resource for fetching approved agency data from the database.""" + @handle_exceptions @api_required def get(self, page: str) -> Dict[str, Any]: """ @@ -47,26 +48,18 @@ def get(self, page: str) -> Dict[str, Any]: Returns: - dict: A dictionary containing the count of returned agencies and their data. """ - try: - cursor = self.psycopg2_connection.cursor() - joined_column_names = ", ".join(approved_columns) - offset = (int(page) - 1) * 1000 - cursor.execute( - f"select {joined_column_names} from agencies where approved = 'TRUE' limit 1000 offset {offset}" - ) - results = cursor.fetchall() - agencies_matches = [ - dict(zip(approved_columns, result)) for result in results - ] + cursor = self.psycopg2_connection.cursor() + joined_column_names = ", ".join(approved_columns) + offset = (int(page) - 1) * 1000 + cursor.execute( + f"select {joined_column_names} from agencies where approved = 'TRUE' limit 1000 offset {offset}" + ) + results = cursor.fetchall() + agencies_matches = [dict(zip(approved_columns, result)) for result in results] - for item in agencies_matches: - convert_dates_to_strings(item) + for item in agencies_matches: + convert_dates_to_strings(item) - agencies = {"count": len(agencies_matches), "data": agencies_matches} + agencies = {"count": len(agencies_matches), "data": agencies_matches} - return agencies - - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - return "There has been an error pulling data!" + return agencies diff --git a/resources/ApiKey.py b/resources/ApiKey.py index 520fe416..bd1a044b 100644 --- a/resources/ApiKey.py +++ b/resources/ApiKey.py @@ -4,12 +4,13 @@ import uuid from typing import Dict, Any, Optional -from resources.PsycopgResource import PsycopgResource +from resources.PsycopgResource import PsycopgResource, handle_exceptions class ApiKey(PsycopgResource): """Represents a resource for generating an API key for authenticated users.""" + @handle_exceptions def get(self) -> Optional[Dict[str, Any]]: """ Authenticates a user based on provided credentials and generates an API key. @@ -20,24 +21,18 @@ def get(self) -> Optional[Dict[str, Any]]: Returns: - dict: A dictionary containing the generated API key, or None if an error occurs. """ - try: - data = request.get_json() - email = data.get("email") - password = data.get("password") - cursor = self.psycopg2_connection.cursor() - user_data = login_results(cursor, email) - - if check_password_hash(user_data["password_digest"], password): - api_key = uuid.uuid4().hex - user_id = str(user_data["id"]) - cursor.execute( - "UPDATE users SET api_key = %s WHERE id = %s", (api_key, user_id) - ) - payload = {"api_key": api_key} - self.psycopg2_connection.commit() - return payload - - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - return {"message": str(e)} + data = request.get_json() + email = data.get("email") + password = data.get("password") + cursor = self.psycopg2_connection.cursor() + user_data = login_results(cursor, email) + + if check_password_hash(user_data["password_digest"], password): + api_key = uuid.uuid4().hex + user_id = str(user_data["id"]) + cursor.execute( + "UPDATE users SET api_key = %s WHERE id = %s", (api_key, user_id) + ) + payload = {"api_key": api_key} + self.psycopg2_connection.commit() + return payload diff --git a/resources/Archives.py b/resources/Archives.py index 51201b44..7e11d39d 100644 --- a/resources/Archives.py +++ b/resources/Archives.py @@ -1,11 +1,15 @@ from middleware.security import api_required -from middleware.archives_queries import archives_get_query, archives_put_query +from middleware.archives_queries import ( + archives_get_query, + archives_put_broken_as_of_results, + archives_put_last_cached_results, +) from flask_restful import request import json from typing import Dict, Any -from resources.PsycopgResource import PsycopgResource +from resources.PsycopgResource import PsycopgResource, handle_exceptions class Archives(PsycopgResource): @@ -13,6 +17,7 @@ class Archives(PsycopgResource): A resource for managing archive data, allowing retrieval and update of archived data sources. """ + @handle_exceptions @api_required def get(self) -> Any: """ @@ -23,18 +28,13 @@ def get(self) -> Any: Returns: - Any: The cleaned results of archives combined from the database query, or an error message if an exception occurs. """ - try: - archives_combined_results_clean = archives_get_query( - test_query_results=[], conn=self.psycopg2_connection - ) - - return archives_combined_results_clean + archives_combined_results_clean = archives_get_query( + conn=self.psycopg2_connection + ) - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - return "There has been an error pulling data!" + return archives_combined_results_clean + @handle_exceptions @api_required def put(self) -> Dict[str, str]: """ @@ -45,27 +45,21 @@ def put(self) -> Dict[str, str]: Returns: - dict: A status message indicating success or an error message if an exception occurs. """ - try: - json_data = request.get_json() - data = json.loads(json_data) - id = data["id"] if "id" in data else None - broken_as_of = ( - data["broken_source_url_as_of"] - if "broken_source_url_as_of" in data - else None - ) - last_cached = data["last_cached"] if "last_cached" in data else None + json_data = request.get_json() + data = json.loads(json_data) + id = data["id"] if "id" in data else None + last_cached = data["last_cached"] if "last_cached" in data else None - archives_put_query( + if "broken_source_url_as_of" in data: + archives_put_broken_as_of_results( id=id, - broken_as_of=broken_as_of, + broken_as_of=data["broken_source_url_as_of"], last_cached=last_cached, conn=self.psycopg2_connection, ) + else: + archives_put_last_cached_results(id, last_cached, self.psycopg2_connection) - return {"status": "success"} + self.psycopg2_connection.commit() - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - return {"error": str(e)} + return {"status": "success"} diff --git a/resources/DataSources.py b/resources/DataSources.py index 50fdb48e..230baded 100644 --- a/resources/DataSources.py +++ b/resources/DataSources.py @@ -1,12 +1,17 @@ from flask import request from middleware.security import api_required -from middleware.data_source_queries import data_source_by_id_query, data_sources_query +from middleware.data_source_queries import ( + needs_identification_data_sources, + get_approved_data_sources_wrapper, + data_source_by_id_wrapper, + get_data_sources_for_map_wrapper, +) from datetime import datetime import uuid from typing import Dict, Any, Tuple -from resources.PsycopgResource import PsycopgResource +from resources.PsycopgResource import PsycopgResource, handle_exceptions class DataSourceById(PsycopgResource): @@ -15,6 +20,7 @@ class DataSourceById(PsycopgResource): Provides methods for retrieving and updating data source details. """ + @handle_exceptions @api_required def get(self, data_source_id: str) -> Tuple[Dict[str, Any], int]: """ @@ -26,23 +32,9 @@ def get(self, data_source_id: str) -> Tuple[Dict[str, Any], int]: Returns: - Tuple containing the response message with data source details if found, and the HTTP status code. """ - try: - data_source_details = data_source_by_id_query( - conn=self.psycopg2_connection, data_source_id=data_source_id - ) - if data_source_details: - return { - "message": "Successfully found data source", - "data": data_source_details, - } - - else: - return {"message": "Data source not found."}, 404 - - except Exception as e: - print(str(e)) - return {"message": "There has been an error pulling data!"}, 500 + return data_source_by_id_wrapper(data_source_id, self.psycopg2_connection) + @handle_exceptions @api_required def put(self, data_source_id: str) -> Dict[str, str]: """ @@ -54,43 +46,38 @@ def put(self, data_source_id: str) -> Dict[str, str]: Returns: - A dictionary containing a message about the update operation. """ - try: - data = request.get_json() + data = request.get_json() - restricted_columns = [ - "rejection_note", - "data_source_request", - "approval_status", - "airtable_uid", - "airtable_source_last_modified", - ] + restricted_columns = [ + "rejection_note", + "data_source_request", + "approval_status", + "airtable_uid", + "airtable_source_last_modified", + ] - data_to_update = "" + data_to_update = "" - for key, value in data.items(): - if key not in restricted_columns: - if type(value) == str: - data_to_update += f"{key} = '{value}', " - else: - data_to_update += f"{key} = {value}, " + for key, value in data.items(): + if key not in restricted_columns: + if type(value) == str: + data_to_update += f"{key} = '{value}', " + else: + data_to_update += f"{key} = {value}, " - data_to_update = data_to_update[:-2] + data_to_update = data_to_update[:-2] - cursor = self.psycopg2_connection.cursor() + cursor = self.psycopg2_connection.cursor() - sql_query = f""" - UPDATE data_sources - SET {data_to_update} - WHERE airtable_uid = '{data_source_id}' - """ - - cursor.execute(sql_query) - self.psycopg2_connection.commit() - return {"message": "Data source updated successfully."} + sql_query = f""" + UPDATE data_sources + SET {data_to_update} + WHERE airtable_uid = '{data_source_id}' + """ - except Exception as e: - print(str(e)) - return {"message": "There has been an error updating the data source"}, 500 + cursor.execute(sql_query) + self.psycopg2_connection.commit() + return {"message": "Data source updated successfully."} class DataSources(PsycopgResource): @@ -99,6 +86,7 @@ class DataSources(PsycopgResource): Provides methods for retrieving all data sources and adding new ones. """ + @handle_exceptions @api_required def get(self) -> Dict[str, Any]: """ @@ -107,23 +95,9 @@ def get(self) -> Dict[str, Any]: Returns: - A dictionary containing the count of data sources and their details. """ - try: - data_source_matches = data_sources_query( - self.psycopg2_connection, [], "approved" - ) - - data_sources = { - "count": len(data_source_matches), - "data": data_source_matches, - } - - return data_sources - - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - return {"message": "There has been an error pulling data!"}, 500 + return get_approved_data_sources_wrapper(self.psycopg2_connection) + @handle_exceptions @api_required def post(self) -> Dict[str, str]: """ @@ -132,69 +106,56 @@ def post(self) -> Dict[str, str]: Returns: - A dictionary containing a message about the addition operation. """ - try: - data = request.get_json() - cursor = self.psycopg2_connection.cursor() - - restricted_columns = [ - "rejection_note", - "data_source_request", - "approval_status", - "airtable_uid", - "airtable_source_last_modified", - ] + data = request.get_json() + cursor = self.psycopg2_connection.cursor() - column_names = "" - column_values = "" - for key, value in data.items(): - if key not in restricted_columns: - column_names += f"{key}, " - if type(value) == str: - column_values += f"'{value}', " - else: - column_values += f"{value}, " + restricted_columns = [ + "rejection_note", + "data_source_request", + "approval_status", + "airtable_uid", + "airtable_source_last_modified", + ] - now = datetime.now().strftime("%Y-%m-%d") - airtable_uid = str(uuid.uuid4()) + column_names = "" + column_values = "" + for key, value in data.items(): + if key not in restricted_columns: + column_names += f"{key}, " + if type(value) == str: + column_values += f"'{value}', " + else: + column_values += f"{value}, " - column_names += ( - "approval_status, url_status, data_source_created, airtable_uid" - ) - column_values += f"False, '[\"ok\"]', '{now}', '{airtable_uid}'" + now = datetime.now().strftime("%Y-%m-%d") + airtable_uid = str(uuid.uuid4()) - sql_query = f"INSERT INTO data_sources ({column_names}) VALUES ({column_values}) RETURNING *" + column_names += "approval_status, url_status, data_source_created, airtable_uid" + column_values += f"False, '[\"ok\"]', '{now}', '{airtable_uid}'" - cursor.execute(sql_query) - self.psycopg2_connection.commit() + sql_query = f"INSERT INTO data_sources ({column_names}) VALUES ({column_values}) RETURNING *" - return {"message": "Data source added successfully."} + cursor.execute(sql_query) + self.psycopg2_connection.commit() - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - return {"message": "There has been an error adding the data source"}, 500 + return {"message": "Data source added successfully."} class DataSourcesNeedsIdentification(PsycopgResource): + @handle_exceptions @api_required def get(self): - try: - data_source_matches = data_sources_query( - self.psycopg2_connection, [], "needs_identification" - ) - - data_sources = { - "count": len(data_source_matches), - "data": data_source_matches, - } + data_source_matches = needs_identification_data_sources( + self.psycopg2_connection + ) - return data_sources + data_sources = { + "count": len(data_source_matches), + "data": data_source_matches, + } - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - return {"message": "There has been an error pulling data!"}, 500 + return data_sources class DataSourcesMap(PsycopgResource): @@ -203,6 +164,7 @@ class DataSourcesMap(PsycopgResource): Provides a method for retrieving all data sources. """ + @handle_exceptions @api_required def get(self) -> Dict[str, Any]: """ @@ -211,19 +173,4 @@ def get(self) -> Dict[str, Any]: Returns: - A dictionary containing the count of data sources and their details. """ - try: - data_source_matches = data_sources_query( - self.psycopg2_connection, [], "approved", True - ) - - data_sources = { - "count": len(data_source_matches), - "data": data_source_matches, - } - - return data_sources - - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - return {"message": "There has been an error pulling data!"}, 500 + return get_data_sources_for_map_wrapper(self.psycopg2_connection) diff --git a/resources/Login.py b/resources/Login.py index 92e076c8..b5c2e7f9 100644 --- a/resources/Login.py +++ b/resources/Login.py @@ -1,7 +1,7 @@ from werkzeug.security import check_password_hash from flask import request from middleware.login_queries import login_results, create_session_token -from resources.PsycopgResource import PsycopgResource +from resources.PsycopgResource import PsycopgResource, handle_exceptions class Login(PsycopgResource): @@ -9,6 +9,7 @@ class Login(PsycopgResource): A resource for authenticating users. Allows users to log in using their email and password. """ + @handle_exceptions def post(self): """ Processes the login request. Validates user credentials against the stored hashed password and, @@ -17,27 +18,21 @@ def post(self): Returns: - A dictionary containing a message of success or failure, and the session token if successful. """ - try: - data = request.get_json() - email = data.get("email") - password = data.get("password") - cursor = self.psycopg2_connection.cursor() - - user_data = login_results(cursor, email) - - if "password_digest" in user_data and check_password_hash( - user_data["password_digest"], password - ): - token = create_session_token(cursor, user_data["id"], email) - self.psycopg2_connection.commit() - return { - "message": "Successfully logged in", - "data": token, - } - - return {"message": "Invalid email or password"}, 401 - - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - return {"message": str(e)}, 500 + data = request.get_json() + email = data.get("email") + password = data.get("password") + cursor = self.psycopg2_connection.cursor() + + user_data = login_results(cursor, email) + + if "password_digest" in user_data and check_password_hash( + user_data["password_digest"], password + ): + token = create_session_token(cursor, user_data["id"], email) + self.psycopg2_connection.commit() + return { + "message": "Successfully logged in", + "data": token, + } + + return {"message": "Invalid email or password"}, 401 diff --git a/resources/PsycopgResource.py b/resources/PsycopgResource.py index c6b84803..0282df15 100644 --- a/resources/PsycopgResource.py +++ b/resources/PsycopgResource.py @@ -1,6 +1,47 @@ +import functools +from typing import Callable, Any, Union, Tuple, Dict + +from flask import make_response from flask_restful import Resource +def handle_exceptions( + func: Callable[..., Any] +) -> Callable[..., Union[Any, Tuple[Dict[str, str], int]]]: + """ + A decorator to handle exceptions raised by a function. + + :param func: The function to be decorated. + :return: The decorated function. + + The decorated function handles any exceptions raised + by the original function. If an exception occurs, the + decorator performs a rollback on the psycopg2 connection, + prints the error message, and returns a dictionary with + the error message and an HTTP status code of 500. + + Example usage: + ``` + @handle_exceptions + def my_function(): + # code goes here + ``` + """ + + @functools.wraps(func) + def wrapper( + self, *args: Any, **kwargs: Any + ) -> Union[Any, Tuple[Dict[str, str], int]]: + try: + return func(self, *args, **kwargs) + except Exception as e: + self.psycopg2_connection.rollback() + print(str(e)) + return make_response({"message": str(e)}, 500) + + return wrapper + + class PsycopgResource(Resource): def __init__(self, **kwargs): """ @@ -8,15 +49,3 @@ def __init__(self, **kwargs): - kwargs (dict): Keyword arguments containing 'psycopg2_connection' for database connection. """ self.psycopg2_connection = kwargs["psycopg2_connection"] - - def get(self): - """ - Base implementation of GET. Override in subclasses as needed. - """ - raise NotImplementedError("This method should be overridden by subclasses") - - def post(self): - """ - Base implementation of POST. Override in subclasses as needed. - """ - raise NotImplementedError("This method should be overridden by subclasses") diff --git a/resources/QuickSearch.py b/resources/QuickSearch.py index ee85f0ac..5c071a47 100644 --- a/resources/QuickSearch.py +++ b/resources/QuickSearch.py @@ -1,10 +1,6 @@ from middleware.security import api_required -from middleware.quick_search_query import quick_search_query -import requests -import json -import os -from middleware.initialize_psycopg2_connection import initialize_psycopg2_connection -from flask import request +from middleware.quick_search_query import quick_search_query_wrapper + from typing import Dict, Any from resources.PsycopgResource import PsycopgResource @@ -32,51 +28,4 @@ def get(self, search: str, location: str) -> Dict[str, Any]: Returns: - A dictionary containing a message about the search results and the data found, if any. """ - try: - data = request.get_json() - test = data.get("test_flag") - except: - test = False - - try: - data_sources = quick_search_query( - search, location, [], self.psycopg2_connection, test - ) - - if data_sources["count"] == 0: - self.psycopg2_connection = initialize_psycopg2_connection() - data_sources = quick_search_query( - search, location, [], self.psycopg2_connection - ) - - if data_sources["count"] == 0: - return { - "count": 0, - "message": "No results found. Please considering requesting a new data source.", - }, 404 - - return { - "message": "Results for search successfully retrieved", - "data": data_sources, - } - - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - webhook_url = os.getenv("WEBHOOK_URL") - user_message = "There was an error during the search operation" - message = { - "content": user_message - + ": " - + str(e) - + "\n" - + f"Search term: {search}\n" - + f"Location: {location}" - } - requests.post( - webhook_url, - data=json.dumps(message), - headers={"Content-Type": "application/json"}, - ) - - return {"count": 0, "message": user_message}, 500 + return quick_search_query_wrapper(search, location, self.psycopg2_connection) diff --git a/resources/RefreshSession.py b/resources/RefreshSession.py index df03eb06..a8ea5958 100644 --- a/resources/RefreshSession.py +++ b/resources/RefreshSession.py @@ -1,9 +1,10 @@ from flask import request -from middleware.login_queries import token_results, create_session_token -from datetime import datetime as dt + +from middleware.custom_exceptions import TokenNotFoundError +from middleware.login_queries import get_session_token_user_data, create_session_token, delete_session_token from typing import Dict, Any -from resources.PsycopgResource import PsycopgResource +from resources.PsycopgResource import PsycopgResource, handle_exceptions class RefreshSession(PsycopgResource): @@ -12,6 +13,7 @@ class RefreshSession(PsycopgResource): If the provided session token is valid and not expired, it is replaced with a new one. """ + @handle_exceptions def post(self) -> Dict[str, Any]: """ Processes the session token refresh request. If the provided session token is valid, @@ -20,29 +22,17 @@ def post(self) -> Dict[str, Any]: Returns: - A dictionary containing a message of success or failure, and the new session token if successful. """ + data = request.get_json() + old_token = data.get("session_token") + cursor = self.psycopg2_connection.cursor() try: - data = request.get_json() - old_token = data.get("session_token") - cursor = self.psycopg2_connection.cursor() - user_data = token_results(cursor, old_token) - cursor.execute( - f"delete from session_tokens where token = '{old_token}' and expiration_date < '{dt.utcnow()}'" - ) - self.psycopg2_connection.commit() - - if "id" in user_data: - token = create_session_token( - cursor, user_data["id"], user_data["email"] - ) - self.psycopg2_connection.commit() - return { - "message": "Successfully refreshed session token", - "data": token, - } - + user_data = get_session_token_user_data(cursor, old_token) + except TokenNotFoundError: return {"message": "Invalid session token"}, 403 - - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - return {"message": str(e)}, 500 + delete_session_token(cursor, old_token) + token = create_session_token(cursor, user_data.id, user_data.email) + self.psycopg2_connection.commit() + return { + "message": "Successfully refreshed session token", + "data": token, + } diff --git a/resources/RequestResetPassword.py b/resources/RequestResetPassword.py index 373b6756..46e6c007 100644 --- a/resources/RequestResetPassword.py +++ b/resources/RequestResetPassword.py @@ -6,7 +6,7 @@ import requests from typing import Dict, Any -from resources.PsycopgResource import PsycopgResource +from resources.PsycopgResource import PsycopgResource, handle_exceptions class RequestResetPassword(PsycopgResource): @@ -15,6 +15,7 @@ class RequestResetPassword(PsycopgResource): and sends an email to the user with instructions on how to reset their password. """ + @handle_exceptions def post(self) -> Dict[str, Any]: """ Processes a password reset request. Checks if the user's email exists in the database, @@ -23,34 +24,28 @@ def post(self) -> Dict[str, Any]: Returns: - A dictionary containing a success message and the reset token, or an error message if an exception occurs. """ - try: - data = request.get_json() - email = data.get("email") - cursor = self.psycopg2_connection.cursor() - user_data = user_check_email(cursor, email) - id = user_data["id"] - token = uuid.uuid4().hex - add_reset_token(cursor, email, token) - self.psycopg2_connection.commit() - - body = f"To reset your password, click the following link: {os.getenv('VITE_VUE_APP_BASE_URL')}/reset-password/{token}" - r = requests.post( - "https://api.mailgun.net/v3/mail.pdap.io/messages", - auth=("api", os.getenv("MAILGUN_KEY")), - data={ - "from": "mail@pdap.io", - "to": [email], - "subject": "PDAP Data Sources Reset Password", - "text": body, - }, - ) - - return { - "message": "An email has been sent to your email address with a link to reset your password. It will be valid for 15 minutes.", - "token": token, - } - - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - return {"error": str(e)}, 500 + data = request.get_json() + email = data.get("email") + cursor = self.psycopg2_connection.cursor() + user_data = user_check_email(cursor, email) + id = user_data["id"] + token = uuid.uuid4().hex + add_reset_token(cursor, email, token) + self.psycopg2_connection.commit() + + body = f"To reset your password, click the following link: {os.getenv('VITE_VUE_APP_BASE_URL')}/reset-password/{token}" + r = requests.post( + "https://api.mailgun.net/v3/mail.pdap.io/messages", + auth=("api", os.getenv("MAILGUN_KEY")), + data={ + "from": "mail@pdap.io", + "to": [email], + "subject": "PDAP Data Sources Reset Password", + "text": body, + }, + ) + + return { + "message": "An email has been sent to your email address with a link to reset your password. It will be valid for 15 minutes.", + "token": token, + } diff --git a/resources/ResetPassword.py b/resources/ResetPassword.py index b3c64428..efa68f72 100644 --- a/resources/ResetPassword.py +++ b/resources/ResetPassword.py @@ -7,7 +7,7 @@ from datetime import datetime as dt from typing import Dict, Any -from resources.PsycopgResource import PsycopgResource +from resources.PsycopgResource import PsycopgResource, handle_exceptions class ResetPassword(PsycopgResource): @@ -16,6 +16,7 @@ class ResetPassword(PsycopgResource): If the token is valid and not expired, allows the user to set a new password. """ + @handle_exceptions def post(self) -> Dict[str, Any]: """ Processes a password reset request. Validates the provided reset token and, @@ -24,32 +25,26 @@ def post(self) -> Dict[str, Any]: Returns: - A dictionary containing a message indicating whether the password was successfully updated or an error occurred. """ - try: - data = request.get_json() - token = data.get("token") - password = data.get("password") - cursor = self.psycopg2_connection.cursor() - token_data = check_reset_token(cursor, token) - email = token_data.get("email") - if "create_date" not in token_data: - return {"message": "The submitted token is invalid"}, 400 - - token_create_date = token_data["create_date"] - token_expired = (dt.utcnow() - token_create_date).total_seconds() > 900 - delete_reset_token(cursor, token_data["email"], token) - if token_expired: - return {"message": "The submitted token is invalid"}, 400 - - password_digest = generate_password_hash(password) - cursor = self.psycopg2_connection.cursor() - cursor.execute( - f"update users set password_digest = '{password_digest}' where email = '{email}'" - ) - self.psycopg2_connection.commit() - - return {"message": "Successfully updated password"} - - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - return {"message": str(e)}, 500 + data = request.get_json() + token = data.get("token") + password = data.get("password") + cursor = self.psycopg2_connection.cursor() + token_data = check_reset_token(cursor, token) + email = token_data.get("email") + if "create_date" not in token_data: + return {"message": "The submitted token is invalid"}, 400 + + token_create_date = token_data["create_date"] + token_expired = (dt.utcnow() - token_create_date).total_seconds() > 900 + delete_reset_token(cursor, token_data["email"], token) + if token_expired: + return {"message": "The submitted token is invalid"}, 400 + + password_digest = generate_password_hash(password) + cursor = self.psycopg2_connection.cursor() + cursor.execute( + f"update users set password_digest = '{password_digest}' where email = '{email}'" + ) + self.psycopg2_connection.commit() + + return {"message": "Successfully updated password"} diff --git a/resources/ResetTokenValidation.py b/resources/ResetTokenValidation.py index b174b7a0..6a537823 100644 --- a/resources/ResetTokenValidation.py +++ b/resources/ResetTokenValidation.py @@ -4,29 +4,24 @@ ) from datetime import datetime as dt -from resources.PsycopgResource import PsycopgResource +from resources.PsycopgResource import PsycopgResource, handle_exceptions class ResetTokenValidation(PsycopgResource): + @handle_exceptions def post(self): - try: - data = request.get_json() - token = data.get("token") - cursor = self.psycopg2_connection.cursor() - token_data = check_reset_token(cursor, token) - if "create_date" not in token_data: - return {"message": "The submitted token is invalid"}, 400 + data = request.get_json() + token = data.get("token") + cursor = self.psycopg2_connection.cursor() + token_data = check_reset_token(cursor, token) + if "create_date" not in token_data: + return {"message": "The submitted token is invalid"}, 400 - token_create_date = token_data["create_date"] - token_expired = (dt.utcnow() - token_create_date).total_seconds() > 900 + token_create_date = token_data["create_date"] + token_expired = (dt.utcnow() - token_create_date).total_seconds() > 900 - if token_expired: - return {"message": "The submitted token is invalid"}, 400 + if token_expired: + return {"message": "The submitted token is invalid"}, 400 - return {"message": "Token is valid"} - - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - return {"message": str(e)}, 500 + return {"message": "Token is valid"} diff --git a/resources/SearchTokens.py b/resources/SearchTokens.py index 19090789..43326ef5 100644 --- a/resources/SearchTokens.py +++ b/resources/SearchTokens.py @@ -1,30 +1,35 @@ -from middleware.quick_search_query import quick_search_query +from middleware.access_token_logic import insert_access_token +from middleware.quick_search_query import quick_search_query_wrapper from middleware.data_source_queries import ( - data_source_by_id_query, - data_sources_query, + get_approved_data_sources_wrapper, + data_source_by_id_wrapper, + get_data_sources_for_map_wrapper, ) -from flask import request -import datetime -import uuid +from flask import request, make_response import os -import requests import sys -import json from typing import Dict, Any -from resources.PsycopgResource import PsycopgResource +from resources.PsycopgResource import PsycopgResource, handle_exceptions sys.path.append("..") BASE_URL = os.getenv("VITE_VUE_API_BASE_URL") +class UnknownEndpointError(Exception): + def __init__(self, endpoint): + self.message = f"Unknown endpoint: {endpoint}" + super().__init__(self.message) + + class SearchTokens(PsycopgResource): """ A resource that provides various search functionalities based on the specified endpoint. It supports quick search, data source retrieval by ID, and listing all data sources. """ + @handle_exceptions def get(self) -> Dict[str, Any]: """ Handles GET requests by performing a search operation based on the specified endpoint and arguments. @@ -38,112 +43,24 @@ def get(self) -> Dict[str, Any]: Returns: - A dictionary with the search results or an error message. """ - try: - url_params = request.args - endpoint = url_params.get("endpoint") - arg1 = url_params.get("arg1") - arg2 = url_params.get("arg2") - print(endpoint, arg1, arg2) - data_sources = {"count": 0, "data": []} - if type(self.psycopg2_connection) == dict: - return data_sources - - cursor = self.psycopg2_connection.cursor() - token = uuid.uuid4().hex - expiration = datetime.datetime.now() + datetime.timedelta(minutes=5) - cursor.execute( - f"insert into access_tokens (token, expiration_date) values (%s, %s)", - (token, expiration), - ) - self.psycopg2_connection.commit() - - if endpoint == "quick-search": - try: - data = request.get_json() - test = data.get("test_flag") - except: - test = False - try: - data_sources = quick_search_query( - arg1, arg2, [], self.psycopg2_connection, test - ) - - return data_sources - - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - webhook_url = os.getenv("WEBHOOK_URL") - user_message = "There was an error during the search operation" - message = { - "content": user_message - + ": " - + str(e) - + "\n" - + f"Search term: {arg1}\n" - + f"Location: {arg2}" - } - requests.post( - webhook_url, - data=json.dumps(message), - headers={"Content-Type": "application/json"}, - ) - - return {"count": 0, "message": user_message}, 500 - - elif endpoint == "data-sources": - try: - data_source_matches = data_sources_query(self.psycopg2_connection) - - data_sources = { - "count": len(data_source_matches), - "data": data_source_matches, - } - - return data_sources - - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - return {"message": "There has been an error pulling data!"}, 500 - - elif endpoint == "data-sources-by-id": - try: - data_source_details = data_source_by_id_query( - arg1, [], self.psycopg2_connection - ) - if data_source_details: - return data_source_details - - else: - return {"message": "Data source not found."}, 404 - - except Exception as e: - print(str(e)) - return {"message": "There has been an error pulling data!"}, 500 - - elif endpoint == "data-sources-map": - try: - data_source_details = data_sources_query( - self.psycopg2_connection, [], "approved", True - ) - if data_source_details: - data_sources = { - "count": len(data_source_details), - "data": data_source_details, - } - return data_sources - - else: - return {"message": "There has been an error pulling data!"}, 500 - - except Exception as e: - print(str(e)) - return {"message": "There has been an error pulling data!"}, 500 - else: - return {"message": "Unknown endpoint"}, 500 - - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - return {"message": e}, 500 + url_params = request.args + endpoint = url_params.get("endpoint") + arg1 = url_params.get("arg1") + arg2 = url_params.get("arg2") + + cursor = self.psycopg2_connection.cursor() + insert_access_token(cursor) + self.psycopg2_connection.commit() + + return self.perform_endpoint_logic(arg1, arg2, endpoint) + + def perform_endpoint_logic(self, arg1, arg2, endpoint): + if endpoint == "quick-search": + return quick_search_query_wrapper(arg1, arg2, self.psycopg2_connection) + if endpoint == "data-sources": + return get_approved_data_sources_wrapper(self.psycopg2_connection) + if endpoint == "data-sources-by-id": + return data_source_by_id_wrapper(arg1, self.psycopg2_connection) + if endpoint == "data-sources-map": + return get_data_sources_for_map_wrapper(self.psycopg2_connection) + raise UnknownEndpointError(endpoint) diff --git a/resources/User.py b/resources/User.py index 8eb51ee0..ded38d2e 100644 --- a/resources/User.py +++ b/resources/User.py @@ -4,7 +4,7 @@ from middleware.security import api_required from typing import Dict, Any -from resources.PsycopgResource import PsycopgResource +from resources.PsycopgResource import PsycopgResource, handle_exceptions class User(PsycopgResource): @@ -12,6 +12,7 @@ class User(PsycopgResource): A resource for user management, allowing new users to sign up and existing users to update their passwords. """ + @handle_exceptions def post(self) -> Dict[str, Any]: """ Allows a new user to sign up by providing an email and password. @@ -22,22 +23,17 @@ def post(self) -> Dict[str, Any]: Returns: - A dictionary containing a success message or an error message if the operation fails. """ - try: - data = request.get_json() - email = data.get("email") - password = data.get("password") - cursor = self.psycopg2_connection.cursor() - user_post_results(cursor, email, password) - self.psycopg2_connection.commit() + data = request.get_json() + email = data.get("email") + password = data.get("password") + cursor = self.psycopg2_connection.cursor() + user_post_results(cursor, email, password) + self.psycopg2_connection.commit() - return {"message": "Successfully added user"} - - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - return {"message": e}, 500 + return {"message": "Successfully added user"} # Endpoint for updating a user's password + @handle_exceptions @api_required def put(self) -> Dict[str, Any]: """ @@ -49,20 +45,13 @@ def put(self) -> Dict[str, Any]: Returns: - A dictionary containing a success message or an error message if the operation fails. """ - try: - data = request.get_json() - email = data.get("email") - password = data.get("password") - password_digest = generate_password_hash(password) - cursor = self.psycopg2_connection.cursor() - cursor.execute( - f"update users set password_digest = '{password_digest}' where email = '{email}'" - ) - self.psycopg2_connection.commit() - return {"message": "Successfully updated password"} - - except Exception as e: - self.psycopg2_connection.rollback() - print(str(e)) - return {"message": e}, 500 - return {"message": e}, 500 + data = request.get_json() + email = data.get("email") + password = data.get("password") + password_digest = generate_password_hash(password) + cursor = self.psycopg2_connection.cursor() + cursor.execute( + f"update users set password_digest = '{password_digest}' where email = '{email}'" + ) + self.psycopg2_connection.commit() + return {"message": "Successfully updated password"} diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures.py b/tests/fixtures.py new file mode 100644 index 00000000..77e2c157 --- /dev/null +++ b/tests/fixtures.py @@ -0,0 +1,106 @@ +"""This module contains pytest fixtures employed by middleware tests.""" + +import os +from collections import namedtuple + +import psycopg2 +import pytest +from dotenv import load_dotenv +from flask.testing import FlaskClient + +from app import create_app +from tests.helper_functions import insert_test_agencies_and_sources + + +@pytest.fixture +def dev_db_connection() -> psycopg2.extensions.cursor: + """ + Create reversible connection to dev database. + + Sets up connection to development database + and creates a session that is rolled back after the test completes + to undo any operations performed during the test. + :return: + """ + load_dotenv() + dev_db_connection_string = os.getenv("DEV_DB_CONN_STRING") + connection = psycopg2.connect( + dev_db_connection_string, + keepalives=1, + keepalives_idle=30, + keepalives_interval=10, + keepalives_count=5, + ) + connection.autocommit = False + + yield connection + + # Rollback any changes made during the tests + connection.rollback() + + connection.close() + + +@pytest.fixture +def db_cursor( + dev_db_connection: psycopg2.extensions.connection, +) -> psycopg2.extensions.cursor: + """ + Create cursor for reversible database operations. + + Create a cursor to execute database operations, with savepoint management. + This is to ensure that changes made during the test can be rolled back. + """ + cur = dev_db_connection.cursor() + + # Start a savepoint + cur.execute("SAVEPOINT test_savepoint") + + yield cur + + # Rollback to the savepoint to ignore commits within the test + cur.execute("ROLLBACK TO SAVEPOINT test_savepoint") + cur.close() + + +@pytest.fixture +def connection_with_test_data( + dev_db_connection: psycopg2.extensions.connection, +) -> psycopg2.extensions.connection: + """ + Insert test agencies and sources into test data. + + Will roll back in case of error. + + :param dev_db_connection: + :return: + """ + try: + insert_test_agencies_and_sources(dev_db_connection.cursor()) + except psycopg2.errors.UniqueViolation: + dev_db_connection.rollback() + return dev_db_connection + +ClientWithMockDB = namedtuple("ClientWithMockDB", ["client", "mock_db"]) +@pytest.fixture +def client_with_mock_db(mocker) -> ClientWithMockDB: + """ + Create a client with a mocked database connection + :param mocker: + :return: + """ + mock_db = mocker.MagicMock() + app = create_app(mock_db) + with app.test_client() as client: + yield ClientWithMockDB(client, mock_db) + +@pytest.fixture +def client_with_db(dev_db_connection: psycopg2.extensions.connection): + """ + Creates a client with database connection + :param dev_db_connection: + :return: + """ + app = create_app(dev_db_connection) + with app.test_client() as client: + yield client \ No newline at end of file diff --git a/tests/helper_functions.py b/tests/helper_functions.py new file mode 100644 index 00000000..cc1cd281 --- /dev/null +++ b/tests/helper_functions.py @@ -0,0 +1,331 @@ +"""This module contains helper functions used by middleware pytests.""" + +import uuid +from collections import namedtuple +from datetime import datetime, timedelta +from typing import Optional + +import psycopg2.extensions +from flask.testing import FlaskClient + +TestTokenInsert = namedtuple("TestTokenInsert", ["id", "email", "token"]) +TestUser = namedtuple("TestUser", ["id", "email", "password_hash"]) + + +def insert_test_agencies_and_sources(cursor: psycopg2.extensions.cursor) -> None: + """ + Insert test agencies and sources into database. + + :param cursor: + :return: + """ + cursor.execute( + """ + INSERT INTO + PUBLIC.DATA_SOURCES ( + airtable_uid, + NAME, + DESCRIPTION, + RECORD_TYPE, + SOURCE_URL, + APPROVAL_STATUS, + URL_STATUS + ) + VALUES + ('SOURCE_UID_1','Source 1','Description of src1', + 'Type A','http://src1.com','approved','available'), + ('SOURCE_UID_2','Source 2','Description of src2', + 'Type B','http://src2.com','needs identification','available'), + ('SOURCE_UID_3','Source 3', 'Description of src3', + 'Type C', 'http://src3.com', 'pending', 'available'); + + INSERT INTO public.agencies + (airtable_uid, name, municipality, state_iso, + county_name, count_data_sources, lat, lng) + VALUES + ('Agency_UID_1', 'Agency A', 'City A', + 'CA', 'County X', 3, 30, 20), + ('Agency_UID_2', 'Agency B', 'City B', + 'NY', 'County Y', 2, 40, 50), + ('Agency_UID_3', 'Agency C', 'City C', + 'TX', 'County Z', 1, 90, 60); + + INSERT INTO public.agency_source_link + (airtable_uid, agency_described_linked_uid) + VALUES + ('SOURCE_UID_1', 'Agency_UID_1'), + ('SOURCE_UID_2', 'Agency_UID_2'), + ('SOURCE_UID_3', 'Agency_UID_3'); + """ + ) + + +def get_reset_tokens_for_email( + db_cursor: psycopg2.extensions.cursor, reset_token_insert: TestTokenInsert +) -> tuple: + """ + Get all reset tokens associated with an email. + + :param db_cursor: + :param reset_token_insert: + :return: + """ + db_cursor.execute( + """ + SELECT email from RESET_TOKENS where email = %s + """, + (reset_token_insert.email,), + ) + results = db_cursor.fetchall() + return results + + +def create_reset_token(cursor: psycopg2.extensions.cursor) -> TestTokenInsert: + """ + Create a test user and associated reset token. + + :param cursor: + :return: + """ + user = create_test_user(cursor) + token = uuid.uuid4().hex + cursor.execute( + """ + INSERT INTO reset_tokens(email, token) + VALUES (%s, %s) + RETURNING id + """, + (user.email, token), + ) + id = cursor.fetchone()[0] + return TestTokenInsert(id=id, email=user.email, token=token) + + +def create_test_user( + cursor, + email="", + password_hash="hashed_password_here", + api_key="api_key_here", + role=None, +) -> TestUser: + """ + Create test user and return the id of the test user. + + :param cursor: + :return: user id + """ + if email == "": + email = uuid.uuid4().hex + "@test.com" + cursor.execute( + """ + INSERT INTO users (email, password_digest, api_key, role) + VALUES + (%s, %s, %s, %s) + RETURNING id; + """, + (email, password_hash, api_key, role), + ) + return TestUser( + id=cursor.fetchone()[0], + email=email, + password_hash=password_hash, + ) + + +QuickSearchQueryLogResult = namedtuple( + "QuickSearchQueryLogResult", ["result_count", "updated_at"] +) + + +def get_most_recent_quick_search_query_log( + cursor: psycopg2.extensions.cursor, search: str, location: str +) -> Optional[QuickSearchQueryLogResult]: + """ + Retrieve most recent quick search query log for a search and location. + + :param cursor: The Cursor object of the database connection. + :param search: The search query string. + :param location: The location string. + :return: A QuickSearchQueryLogResult object + containing the result count and updated timestamp. + """ + cursor.execute( + """ + SELECT RESULT_COUNT, CREATED_AT FROM QUICK_SEARCH_QUERY_LOGS WHERE + search = %s AND location = %s ORDER BY CREATED_AT DESC LIMIT 1 + """, + (search, location), + ) + result = cursor.fetchone() + if result is None: + return result + return QuickSearchQueryLogResult(result_count=result[0], updated_at=result[1]) + + +def has_expected_keys(result_keys: list, expected_keys: list) -> bool: + """ + Check that given result includes expected keys. + + :param result: + :param expected_keys: + :return: True if has expected keys, false otherwise + """ + return not set(expected_keys).difference(result_keys) + + +def get_boolean_dictionary(keys: tuple) -> dict: + """ + Creates dictionary of booleans, all set to false. + + :param keys: + :return: dictionary of booleans + """ + d = {} + for key in keys: + d[key] = False + return d + + +UserInfo = namedtuple("UserInfo", ["email", "password"]) + + +def create_test_user_api(client: FlaskClient) -> UserInfo: + """ + Create a test user through calling the /user endpoint via the Flask API + :param client: + :return: + """ + email = str(uuid.uuid4()) + password = str(uuid.uuid4()) + response = client.post( + "/user", + json={"email": email, "password": password}, + ) + assert response.status_code == 200, "User creation not successful" + return UserInfo(email=email, password=password) + + +def login_and_return_session_token( + client_with_db: FlaskClient, user_info: UserInfo +) -> str: + """ + Login as a given user and return the associated session token, + using the /login endpoint of the Flask API + :param client_with_db: + :param user_info: + :return: + """ + response = client_with_db.post( + "/login", + json={"email": user_info.email, "password": user_info.password}, + ) + assert response.status_code == 200, "User login unsuccessful" + session_token = response.json.get("data") + return session_token + + +def get_user_password_digest(cursor: psycopg2.extensions.cursor, user_info): + """ + Get the associated password digest of a user (given their email) from the database + :param cursor: + :param user_info: + :return: + """ + cursor.execute( + """ + SELECT password_digest from users where email = %s + """, + (user_info.email,), + ) + return cursor.fetchone()[0] + + +def request_reset_password_api(client_with_db, mocker, user_info): + """ + Send a request to reset password via a Flask call to the /request-reset-password endpoint + and return the reset token + :param client_with_db: + :param mocker: + :param user_info: + :return: + """ + mocker.patch("resources.RequestResetPassword.requests.post") + response = client_with_db.post( + "/request-reset-password", json={"email": user_info.email} + ) + token = response.json.get("token") + return token + + +def create_api_key(client_with_db, user_info): + """ + Obtain an api key for the given user, via a Flask call to the /api-key endpoint + :param client_with_db: + :param user_info: + :return: + """ + response = client_with_db.get( + "/api_key", json={"email": user_info.email, "password": user_info.password} + ) + assert response.status_code == 200, "API key creation not successful" + api_key = response.json.get("api_key") + return api_key + +def create_api_key_db(cursor, user_id: str): + api_key = uuid.uuid4().hex + cursor.execute( + "UPDATE users SET api_key = %s WHERE id = %s", (api_key, user_id) + ) + return api_key + + +def insert_test_data_source(cursor: psycopg2.extensions.cursor) -> str: + """ + Insert test data source and return id + :param cursor: + :return: randomly generated uuid + """ + test_uid = str(uuid.uuid4()) + cursor.execute( + """ + INSERT INTO + PUBLIC.DATA_SOURCES ( + airtable_uid, + NAME, + DESCRIPTION, + RECORD_TYPE, + SOURCE_URL, + APPROVAL_STATUS, + URL_STATUS + ) + VALUES + (%s,'Example Data Source', 'Example Description', + 'Type A','http://src1.com','approved','available') + """, + (test_uid,), + ) + return test_uid + + +def give_user_admin_role( + connection: psycopg2.extensions.connection, user_info: UserInfo +): + """ + Give the given user an admin role. + :param connection: + :param user_info: + :return: + """ + cursor = connection.cursor() + + cursor.execute( + """ + UPDATE users + SET role = 'admin' + WHERE email = %s + """, + (user_info.email,), + ) + +def check_response_status(response, status_code): + assert response.status_code == status_code, f"Expected status code {status_code}, got {response.status_code}: {response.text}" \ No newline at end of file diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/test_agencies.py b/tests/integration/test_agencies.py new file mode 100644 index 00000000..01fb5797 --- /dev/null +++ b/tests/integration/test_agencies.py @@ -0,0 +1,23 @@ +"""Integration tests for /agencies endpoint""" +import psycopg2 +import pytest +from tests.fixtures import connection_with_test_data, dev_db_connection, client_with_db +from tests.helper_functions import create_test_user_api, create_api_key + + +def test_agencies_get( + client_with_db, dev_db_connection: psycopg2.extensions.connection +): + """ + Test that GET call to /agencies endpoint properly retrieves a nonzero amount of data + """ + + user_info = create_test_user_api(client_with_db) + api_key = create_api_key(client_with_db, user_info) + response = client_with_db.get( + "/agencies/2", + headers={"Authorization": f"Bearer {api_key}"}, + ) + assert response.status_code == 200 + assert len(response.json["data"]) > 0 + diff --git a/tests/integration/test_api_key.py b/tests/integration/test_api_key.py new file mode 100644 index 00000000..58585f4d --- /dev/null +++ b/tests/integration/test_api_key.py @@ -0,0 +1,35 @@ +"""Integration tests for /api_key endpoint""" + +import uuid + +import psycopg2.extensions + +from tests.fixtures import dev_db_connection, client_with_db +from tests.helper_functions import create_test_user_api + + +def test_api_key_get(client_with_db, dev_db_connection: psycopg2.extensions.connection): + """ + Test that GET call to /api_key endpoint successfully creates an API key and aligns it with the user's API key in the database + """ + + user_info = create_test_user_api(client_with_db) + + response = client_with_db.get( + "/api_key", + json={"email": user_info.email, "password": user_info.password}, + ) + assert response.status_code == 200, "API key creation not successful" + + # Check that API key aligned with user + cursor = dev_db_connection.cursor() + cursor.execute( + """ + SELECT api_key from users where email = %s + """, + (user_info.email,), + ) + db_api_key = cursor.fetchone()[0] + assert db_api_key == response.json.get( + "api_key" + ), "API key returned not aligned with user API key in database" diff --git a/tests/integration/test_archives.py b/tests/integration/test_archives.py new file mode 100644 index 00000000..01a407f1 --- /dev/null +++ b/tests/integration/test_archives.py @@ -0,0 +1,71 @@ +"""Integration tests for /archives endpoint""" + +import datetime +import json + +import psycopg2 + +from tests.fixtures import dev_db_connection, client_with_db +from tests.helper_functions import ( + create_test_user_api, + login_and_return_session_token, + get_user_password_digest, + request_reset_password_api, + create_api_key, + insert_test_data_source, +) + + +def test_archives_get( + client_with_db, dev_db_connection: psycopg2.extensions.connection +): + """ + Test that GET call to /archives endpoint successfully retrieves a non-zero amount of data + """ + user_info = create_test_user_api(client_with_db) + api_key = create_api_key(client_with_db, user_info) + response = client_with_db.get( + "/archives", + headers={"Authorization": f"Bearer {api_key}"}, + ) + assert response.status_code == 200, "Archives endpoint returned non-200" + assert len(response.json) > 0, "Endpoint should return more than 0 results" + + +def test_archives_put( + client_with_db, dev_db_connection: psycopg2.extensions.connection +): + """ + Test that PUT call to /archives endpoint successfully updates the data source with last_cached and broken_source_url_as_of fields + """ + user_info = create_test_user_api(client_with_db) + api_key = create_api_key(client_with_db, user_info) + data_source_id = insert_test_data_source(dev_db_connection.cursor()) + last_cached = datetime.date(year=2020, month=3, day=4) + broken_as_of = datetime.date(year=1993, month=11, day=13) + response = client_with_db.put( + "/archives", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + json=json.dumps( + { + "id": data_source_id, + "last_cached": str(last_cached), + "broken_source_url_as_of": str(broken_as_of), + } + ), + ) + assert response.status_code == 200, "Endpoint returned non-200" + + cursor = dev_db_connection.cursor() + cursor.execute( + """ + SELECT last_cached, broken_source_url_as_of FROM data_sources where airtable_uid = %s + """, + (data_source_id,), + ) + row = cursor.fetchone() + assert row[0] == last_cached + assert row[1] == broken_as_of diff --git a/tests/integration/test_data_sources.py b/tests/integration/test_data_sources.py new file mode 100644 index 00000000..1ab5964e --- /dev/null +++ b/tests/integration/test_data_sources.py @@ -0,0 +1,73 @@ +"""Integration tests for /data-sources endpoint""" + +import uuid + +import psycopg2 +import pytest +from tests.fixtures import ( + connection_with_test_data, + dev_db_connection, + connection_with_test_data, + client_with_db, +) +from tests.helper_functions import ( + get_boolean_dictionary, + create_test_user_api, + create_api_key, + give_user_admin_role, +) + + +def test_data_sources_get( + client_with_db, connection_with_test_data: psycopg2.extensions.connection +): + """ + Test that GET call to /data-sources endpoint retrieves data sources and correctly identifies specific sources by name + """ + inserted_data_sources_found = get_boolean_dictionary( + ("Source 1", "Source 2", "Source 3") + ) + user_info = create_test_user_api(client_with_db) + api_key = create_api_key(client_with_db, user_info) + response = client_with_db.get( + "/data-sources", + headers={"Authorization": f"Bearer {api_key}"}, + ) + assert response.status_code == 200 + data = response.get_json()["data"] + for result in data: + name = result["name"] + if name in inserted_data_sources_found: + inserted_data_sources_found[name] = True + assert inserted_data_sources_found["Source 1"] + assert not inserted_data_sources_found["Source 2"] + assert not inserted_data_sources_found["Source 3"] + + +def test_data_sources_post( + client_with_db, dev_db_connection: psycopg2.extensions.connection +): + """ + Test that POST call to /data-sources endpoint successfully creates a new data source with a unique name and verifies its existence in the database + """ + + user_info = create_test_user_api(client_with_db) + give_user_admin_role(dev_db_connection, user_info) + api_key = create_api_key(client_with_db, user_info) + + name = str(uuid.uuid4()) + response = client_with_db.post( + "/data-sources", + json={"name": name}, + headers={"Authorization": f"Bearer {api_key}"}, + ) + assert response.status_code == 200 + cursor = dev_db_connection.cursor() + cursor.execute( + """ + SELECT * from data_sources WHERE name=%s + """, + (name,), + ) + rows = cursor.fetchall() + assert (len(rows)) == 1 diff --git a/tests/integration/test_data_sources_by_id.py b/tests/integration/test_data_sources_by_id.py new file mode 100644 index 00000000..eccc81b3 --- /dev/null +++ b/tests/integration/test_data_sources_by_id.py @@ -0,0 +1,55 @@ +"""Integration tests for /data-sources-by-id endpoint""" + +import uuid +import psycopg2 +from tests.fixtures import connection_with_test_data, dev_db_connection, client_with_db +from tests.helper_functions import ( + create_test_user_api, + create_api_key, + give_user_admin_role, +) + + +def test_data_sources_by_id_get( + client_with_db, connection_with_test_data: psycopg2.extensions.connection +): + """ + Test that GET call to /data-sources-by-id/ endpoint retrieves the data source with the correct homepage URL + """ + + user_info = create_test_user_api(client_with_db) + api_key = create_api_key(client_with_db, user_info) + response = client_with_db.get( + "/data-sources-by-id/SOURCE_UID_1", + headers={"Authorization": f"Bearer {api_key}"}, + ) + assert response.status_code == 200 + assert response.json["source_url"] == "http://src1.com" + + +def test_data_sources_by_id_put( + client_with_db, connection_with_test_data: psycopg2.extensions.connection +): + """ + Test that PUT call to /data-sources-by-id/ endpoint successfully updates the description of the data source and verifies the change in the database + """ + user_info = create_test_user_api(client_with_db) + give_user_admin_role(connection_with_test_data, user_info) + api_key = create_api_key(client_with_db, user_info) + desc = str(uuid.uuid4()) + response = client_with_db.put( + f"/data-sources-by-id/SOURCE_UID_1", + headers={"Authorization": f"Bearer {api_key}"}, + json={"description": desc}, + ) + assert response.status_code == 200 + cursor = connection_with_test_data.cursor() + cursor.execute( + """ + SELECT description + FROM data_sources + WHERE airtable_uid = 'SOURCE_UID_1' + """ + ) + result = cursor.fetchone() + assert result[0] == desc diff --git a/tests/integration/test_data_sources_map.py b/tests/integration/test_data_sources_map.py new file mode 100644 index 00000000..5d1232cf --- /dev/null +++ b/tests/integration/test_data_sources_map.py @@ -0,0 +1,30 @@ +"""Integration tests for /data-sources-map endpoint""" + +import psycopg2 +from tests.fixtures import connection_with_test_data, dev_db_connection, client_with_db +from tests.helper_functions import create_test_user_api, create_api_key + + +def test_data_sources_map_get( + client_with_db, connection_with_test_data: psycopg2.extensions.connection +): + """ + Test that GET call to /data-sources-map endpoint retrieves data sources and verifies the location (latitude and longitude) of a specific source by name + """ + user_info = create_test_user_api(client_with_db) + api_key = create_api_key(client_with_db, user_info) + response = client_with_db.get( + "/data-sources-map", + headers={"Authorization": f"Bearer {api_key}"}, + ) + assert response.status_code == 200 + data = response.json["data"] + found_source = False + for result in data: + name = result["name"] + if name != "Source 1": + continue + found_source = True + assert result["lat"] == 30 + assert result["lng"] == 20 + assert found_source diff --git a/tests/integration/test_data_sources_needs_identification.py b/tests/integration/test_data_sources_needs_identification.py new file mode 100644 index 00000000..d0e97369 --- /dev/null +++ b/tests/integration/test_data_sources_needs_identification.py @@ -0,0 +1,36 @@ +"""Integration tests for /data-sources-needs-identification endpoint""" + +import psycopg2 +from tests.fixtures import connection_with_test_data, dev_db_connection, client_with_db +from tests.helper_functions import ( + get_boolean_dictionary, + create_test_user_api, + create_api_key, +) + + +def test_data_sources_needs_identification( + client_with_db, connection_with_test_data: psycopg2.extensions.connection +): + """ + Test that GET call to /data-sources-needs-identification endpoint retrieves data sources that need identification and correctly identifies specific sources by name + """ + inserted_data_sources_found = get_boolean_dictionary( + ("Source 1", "Source 2", "Source 3") + ) + user_info = create_test_user_api(client_with_db) + api_key = create_api_key(client_with_db, user_info) + response = client_with_db.get( + "/data-sources-needs-identification", + headers={"Authorization": f"Bearer {api_key}"}, + ) + assert response.status_code == 200 + + for result in response.json["data"]: + name = result["name"] + if name in inserted_data_sources_found: + inserted_data_sources_found[name] = True + + assert not inserted_data_sources_found["Source 1"] + assert inserted_data_sources_found["Source 2"] + assert not inserted_data_sources_found["Source 3"] diff --git a/tests/integration/test_login.py b/tests/integration/test_login.py new file mode 100644 index 00000000..f6fb2ec3 --- /dev/null +++ b/tests/integration/test_login.py @@ -0,0 +1,30 @@ +"""Integration tests for /login endpoint""" + +import psycopg2.extensions + +from tests.fixtures import dev_db_connection, client_with_db +from tests.helper_functions import create_test_user_api, login_and_return_session_token + + +def test_login_post(client_with_db, dev_db_connection: psycopg2.extensions.connection): + """ + Test that POST call to /login endpoint successfully logs in a user, creates a session token, and verifies the session token exists only once in the database with the correct email + """ + # Create user + user_info = create_test_user_api(client_with_db) + session_token = login_and_return_session_token(client_with_db, user_info) + + cursor = dev_db_connection.cursor() + cursor.execute( + """ + SELECT email from session_tokens WHERE token = %s + """, + (session_token,), + ) + rows = cursor.fetchall() + assert len(rows) == 1, "Session token should only exist once in database" + + row = rows[0] + assert ( + row[0] == user_info.email + ), "Email in session_tokens table does not match user email" diff --git a/tests/integration/test_quick_search.py b/tests/integration/test_quick_search.py new file mode 100644 index 00000000..42311075 --- /dev/null +++ b/tests/integration/test_quick_search.py @@ -0,0 +1,37 @@ +"""Integration tests for /quick-search//" endpoint""" + +from urllib.parse import quote + +from tests.fixtures import dev_db_connection, client_with_db, connection_with_test_data +from tests.helper_functions import ( + create_test_user_api, + create_api_key, + check_response_status, +) + + +def test_quick_search_get(client_with_db, connection_with_test_data): + """ + Test that GET call to /quick-search// endpoint successfully retrieves a single entry with the correct agency name and airtable UID + """ + + user_info = create_test_user_api(client_with_db) + api_key = create_api_key(client_with_db, user_info) + + search_term = "Source 1" + location = "City A" + + # URL encode the search term and location + encoded_search_term = quote(search_term) + encoded_location = quote(location) + + response = client_with_db.get( + f"/quick-search/{encoded_search_term}/{encoded_location}", + headers={"Authorization": f"Bearer {api_key}"}, + ) + check_response_status(response, 200) + data = response.json.get("data") + assert len(data) == 1, "Quick Search endpoint response should return only one entry" + entry = data[0] + assert entry["agency_name"] == "Agency A" + assert entry["airtable_uid"] == "SOURCE_UID_1" diff --git a/tests/integration/test_refresh_session.py b/tests/integration/test_refresh_session.py new file mode 100644 index 00000000..8d04e329 --- /dev/null +++ b/tests/integration/test_refresh_session.py @@ -0,0 +1,50 @@ +"""Integration tests for /refresh-session endpoint.""" + +import psycopg2.extensions + +from tests.fixtures import dev_db_connection, client_with_db +from tests.helper_functions import create_test_user_api, login_and_return_session_token + + +def test_refresh_session_post( + client_with_db, dev_db_connection: psycopg2.extensions.connection +): + """ + Test that POST call to /refresh-session endpoint successfully generates a new session token, ensures the new token is different from the old one, and verifies the old token is removed while the new token exists in the session tokens table + """ + + test_user = create_test_user_api(client_with_db) + old_session_token = login_and_return_session_token(client_with_db, test_user) + response = client_with_db.post( + "/refresh-session", json={"session_token": old_session_token} + ) + assert response.status_code == 200 + new_session_token = response.json.get("data") + + assert ( + old_session_token != new_session_token + ), "New and old tokens should be different" + + # Check that old_session_token is not in session tokens, and new_session token does + cursor = dev_db_connection.cursor() + cursor.execute( + """ + SELECT * FROM session_tokens where token = %s; + """, + (new_session_token,), + ) + rows = cursor.fetchall() + assert ( + len(rows) == 1 + ), "Only one row should exist for the session token in the session_tokens table" + + cursor.execute( + """ + SELECT * FROM session_tokens where token = %s; + """, + (old_session_token,), + ) + rows = cursor.fetchall() + assert ( + len(rows) == 0 + ), "No row should exist for the old session token in the session_tokens table" diff --git a/tests/integration/test_request_reset_password.py b/tests/integration/test_request_reset_password.py new file mode 100644 index 00000000..f713701f --- /dev/null +++ b/tests/integration/test_request_reset_password.py @@ -0,0 +1,45 @@ +"""Integration tests for /request-reset-password endpoint.""" + +import psycopg2 + +from tests.fixtures import dev_db_connection, client_with_db +from tests.helper_functions import create_test_user_api + + +def test_request_reset_password_post( + client_with_db, dev_db_connection: psycopg2.extensions.connection, mocker +): + """ + Test that POST call to /request-reset-password endpoint successfully initiates a password reset request, sends a single email via Mailgun, and verifies the reset token is correctly associated with the user's email in the database + """ + + user_info = create_test_user_api(client_with_db) + + mock_post = mocker.patch("resources.RequestResetPassword.requests.post") + response = client_with_db.post( + "/request-reset-password", json={"email": user_info.email} + ) + reset_token = response.json.get("token") + assert ( + response.status_code == 200 + ), "Request to Reset Password request was not returned successfully" + assert mock_post.call_count == 1, "request.post should be called only once" + assert ( + mock_post.call_args[0][0] == "https://api.mailgun.net/v3/mail.pdap.io/messages" + ) + + cursor = dev_db_connection.cursor() + cursor.execute( + """ + SELECT email FROM reset_tokens where token = %s + """, + (reset_token,), + ) + rows = cursor.fetchall() + assert ( + len(rows) == 1 + ), "Only one row should have a reset token associated with this email" + email = rows[0][0] + assert ( + email == user_info.email + ), "Email associated with reset token should match the user's email" diff --git a/tests/integration/test_reset_password.py b/tests/integration/test_reset_password.py new file mode 100644 index 00000000..e67e5c21 --- /dev/null +++ b/tests/integration/test_reset_password.py @@ -0,0 +1,38 @@ +"""Integration tests for /reset-password endpoint.""" + +import uuid + +import psycopg2 +from pytest_mock import mocker + +from tests.fixtures import dev_db_connection, client_with_db +from tests.helper_functions import ( + create_test_user_api, + login_and_return_session_token, + get_user_password_digest, + request_reset_password_api, +) + + +def test_reset_password_post( + client_with_db, dev_db_connection: psycopg2.extensions.connection, mocker +): + """ + Test that POST call to /reset-password endpoint successfully resets the user's password, and verifies the new password digest is distinct from the old one in the database + """ + + user_info = create_test_user_api(client_with_db) + cursor = dev_db_connection.cursor() + old_password_digest = get_user_password_digest(cursor, user_info) + + token = request_reset_password_api(client_with_db, mocker, user_info) + new_password = str(uuid.uuid4()) + response = client_with_db.post( + "/reset-password", + json={"email": user_info.email, "token": token, "password": new_password}, + ) + assert response.status_code == 200 + new_password_digest = get_user_password_digest(cursor, user_info) + assert ( + new_password_digest != old_password_digest + ), "Old and new password digests should be distinct" diff --git a/tests/integration/test_reset_token_validation.py b/tests/integration/test_reset_token_validation.py new file mode 100644 index 00000000..c9fd6261 --- /dev/null +++ b/tests/integration/test_reset_token_validation.py @@ -0,0 +1,24 @@ +"""Integration tests for /reset-token-validation endpoint.""" + +from pytest_mock import mocker + +from tests.helper_functions import ( + create_test_user_api, + request_reset_password_api, +) +from tests.fixtures import dev_db_connection, client_with_db + + +def test_reset_token_validation(client_with_db, dev_db_connection, mocker): + """ + Test that POST call to /reset-token-validation endpoint successfully validates the reset token and returns the correct message indicating token validity + """ + user_info = create_test_user_api(client_with_db) + token = request_reset_password_api(client_with_db, mocker, user_info) + response = client_with_db.post("/reset-token-validation", json={"token": token}) + assert ( + response.status_code == 200 + ), "reset-token-validation endpoint call unsuccessful" + assert ( + response.json.get("message") == "Token is valid" + ), "Message does not return 'Token is valid'" diff --git a/tests/integration/test_search_tokens.py b/tests/integration/test_search_tokens.py new file mode 100644 index 00000000..e151b4ff --- /dev/null +++ b/tests/integration/test_search_tokens.py @@ -0,0 +1,33 @@ +"""Integration tests for /search-tokens endpoint.""" + +import psycopg2 +import pytest +from tests.fixtures import connection_with_test_data, dev_db_connection, client_with_db +from tests.helper_functions import ( + create_test_user_api, + create_api_key, + check_response_status, +) + + +def test_search_tokens_get( + client_with_db, connection_with_test_data: psycopg2.extensions.connection +): + """ + Test that GET call to /search-tokens endpoint with specified query parameters successfully retrieves search tokens and verifies the correct entry with agency name and airtable UID + """ + user_info = create_test_user_api(client_with_db) + api_key = create_api_key(client_with_db, user_info) + response = client_with_db.get( + "/search-tokens", + headers={"Authorization": f"Bearer {api_key}"}, + query_string={"endpoint": "quick-search", "arg1": "Source 1", "arg2": "City A"}, + ) + check_response_status(response, 200) + data = response.json.get("data") + assert ( + len(data) == 1 + ), "Quick Search endpoint response should return only one entry" + entry = data[0] + assert entry["agency_name"] == "Agency A" + assert entry["airtable_uid"] == "SOURCE_UID_1" diff --git a/tests/integration/test_user.py b/tests/integration/test_user.py new file mode 100644 index 00000000..2d810d18 --- /dev/null +++ b/tests/integration/test_user.py @@ -0,0 +1,63 @@ +"""Integration tests for /user endpoint.""" + +import uuid + +import psycopg2 + +from tests.fixtures import dev_db_connection, client_with_db +from tests.helper_functions import ( + create_test_user_api, + get_user_password_digest, + create_api_key, +) + + +def test_user_post(client_with_db, dev_db_connection: psycopg2.extensions.connection): + """ + Test that POST call to /user endpoint successfully creates a new user and verifies the user's email and password digest in the database + """ + + user_info = create_test_user_api(client_with_db) + cursor = dev_db_connection.cursor() + cursor.execute( + f"SELECT email, password_digest FROM users WHERE email = %s", (user_info.email,) + ) + rows = cursor.fetchall() + + assert len(rows) == 1, "One row should be returned by user query" + email = rows[0][0] + password_digest = rows[0][1] + assert user_info.email == email, "DB user email and original email do not match" + assert ( + user_info.password != password_digest + ), "DB user password digest should not match password" + + +def test_user_put(client_with_db, dev_db_connection: psycopg2.extensions.connection): + """ + Test that PUT call to /user endpoint successfully updates the user's password and verifies the new password hash is distinct from both the plain new password and the old password hash in the database + """ + + user_info = create_test_user_api(client_with_db) + cursor = dev_db_connection.cursor() + + old_password_hash = get_user_password_digest(cursor, user_info) + + api_key = create_api_key(client_with_db, user_info) + new_password = str(uuid.uuid4()) + + response = client_with_db.put( + "/user", + headers={"Authorization": f"Bearer {api_key}"}, + json={"email": user_info.email, "password": new_password}, + ) + assert response.status_code == 200, "User password update not successful" + + new_password_hash = get_user_password_digest(cursor, user_info) + + assert ( + new_password != new_password_hash + ), "Password and password hash should be distinct after password update" + assert ( + new_password_hash != old_password_hash + ), "Password hashes should be different on update" diff --git a/tests/middleware/__init__.py b/tests/middleware/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/middleware/test_archives_queries.py b/tests/middleware/test_archives_queries.py new file mode 100644 index 00000000..d406963c --- /dev/null +++ b/tests/middleware/test_archives_queries.py @@ -0,0 +1,131 @@ +import datetime +import uuid + +import psycopg2 + +from middleware.archives_queries import ( + archives_get_results, + archives_get_query, + ARCHIVES_GET_COLUMNS, + archives_put_broken_as_of_results, + archives_put_last_cached_results, +) +from tests.helper_functions import ( + has_expected_keys, + insert_test_data_source, +) +from tests.fixtures import ( + dev_db_connection, + db_cursor, + connection_with_test_data, +) + + +def test_archives_get_results( + dev_db_connection: psycopg2.extensions.connection, + db_cursor: psycopg2.extensions.cursor, +) -> None: + """ + :param dev_db_connection: A connection to the development database. + :param db_cursor: A cursor object for executing database queries. + :return: This method does not return anything. + + This method tests the `archives_get_results` method by inserting a + new record into the `data_sources` table in the development database + and verifying that the number of results returned * by `archives_get_results` + increases by 1. + """ + original_results = archives_get_results(dev_db_connection) + db_cursor.execute( + """ + INSERT INTO data_sources(airtable_uid, source_url, name, update_frequency, url_status) + VALUES (%s, %s, %s, %s, %s) + """, + ( + "fake_uid", + "https://www.fake_source_url.com", + "fake_name", + "Annually", + "unbroken", + ), + ) + new_results = archives_get_results(dev_db_connection) + assert len(new_results) == len(original_results) + 1 + + +def test_archives_get_columns( + connection_with_test_data: psycopg2.extensions.connection, +) -> None: + """ + Test the archives_get_columns method, ensuring it properly returns an inserted source + :param connection_with_test_data: A connection object to the database with test data. + :return: None + """ + results = archives_get_query(conn=connection_with_test_data) + assert has_expected_keys(ARCHIVES_GET_COLUMNS, results[0].keys()) + for result in results: + if result["id"] == "SOURCE_UID_1": + return + + +def get_data_sources_archives_info(cursor, test_uid): + cursor.execute( + """ + SELECT URL_STATUS, BROKEN_SOURCE_URL_AS_OF, LAST_CACHED + FROM PUBLIC.DATA_SOURCES + WHERE AIRTABLE_UID = %s + """, + (test_uid,), + ) + row = cursor.fetchone() + return row + + +def test_archives_put_broken_as_of_results( + dev_db_connection: psycopg2.extensions.connection, +) -> None: + cursor = dev_db_connection.cursor() + test_uid = insert_test_data_source(cursor) + + # Check data properly inserted + row = get_data_sources_archives_info(cursor, test_uid) + assert row[0] == "available" + assert row[1] is None + assert row[2] is None + + broken_as_of_date = datetime.datetime.now().strftime("%Y-%m-%d") + last_cached = datetime.datetime.now().strftime("%Y-%m-%d") + + archives_put_broken_as_of_results( + id=test_uid, + broken_as_of=broken_as_of_date, + last_cached=last_cached, + conn=dev_db_connection, + ) + + row = get_data_sources_archives_info(cursor, test_uid) + assert row[0] == "broken" + assert str(row[1]) == broken_as_of_date + assert str(row[2]) == last_cached + + +def test_archives_put_last_cached_results( + dev_db_connection: psycopg2.extensions.connection, +): + cursor = dev_db_connection.cursor() + test_uid = insert_test_data_source(cursor) + + # Check data properly inserted + row = get_data_sources_archives_info(cursor, test_uid) + assert row[0] == "available" + assert row[1] is None + assert row[2] is None + + last_cached = datetime.datetime(year=1999, month=5, day=30).strftime("%Y-%m-%d") + archives_put_last_cached_results( + id=test_uid, last_cached=last_cached, conn=dev_db_connection + ) + row = get_data_sources_archives_info(cursor, test_uid) + assert row[0] == "available" + assert row[1] is None + assert str(row[2]) == last_cached diff --git a/tests/middleware/test_data_source_queries.py b/tests/middleware/test_data_source_queries.py new file mode 100644 index 00000000..201ca518 --- /dev/null +++ b/tests/middleware/test_data_source_queries.py @@ -0,0 +1,191 @@ +from unittest.mock import MagicMock + +import psycopg2 +import pytest + +from middleware import data_source_queries +from middleware.data_source_queries import ( + get_approved_data_sources, + needs_identification_data_sources, + data_source_by_id_results, + data_source_by_id_query, + get_data_sources_for_map, + data_source_by_id_wrapper, +) +from tests.helper_functions import ( + get_boolean_dictionary, +) +from tests.fixtures import connection_with_test_data, dev_db_connection + + +@pytest.fixture +def inserted_data_sources_found(): + """ + A boolean dictionary for identifying if test data sources have been found + :return: boolean dictionary with test data source names as keys, + all values initialized to false + """ + return get_boolean_dictionary(("Source 1", "Source 2", "Source 3")) + + +def test_get_approved_data_sources( + connection_with_test_data: psycopg2.extensions.connection, + inserted_data_sources_found: dict[str, bool], +) -> None: + """ + Test that only one data source -- one set to approved -- is returned by 'get_approved_data_sources + :param connection_with_test_data: + :param inserted_data_sources_found: + :return: + """ + results = get_approved_data_sources(conn=connection_with_test_data) + + for result in results: + name = result["name"] + if name in inserted_data_sources_found: + inserted_data_sources_found[name] = True + + assert inserted_data_sources_found["Source 1"] + assert not inserted_data_sources_found["Source 2"] + assert not inserted_data_sources_found["Source 3"] + + +def test_needs_identification( + connection_with_test_data: psycopg2.extensions.connection, + inserted_data_sources_found: dict[str, bool], +) -> None: + """ + Test only source marked as 'Needs Identification' is returned by 'needs_identification_data_sources' + :param connection_with_test_data: + :param inserted_data_sources_found: + :return: + """ + results = needs_identification_data_sources(conn=connection_with_test_data) + for result in results: + name = result["name"] + if name in inserted_data_sources_found: + inserted_data_sources_found[name] = True + + assert not inserted_data_sources_found["Source 1"] + assert inserted_data_sources_found["Source 2"] + assert not inserted_data_sources_found["Source 3"] + + +def test_data_source_by_id_results( + connection_with_test_data: psycopg2.extensions.connection, +) -> None: + """ + Test that data_source_by_id properly returns data for an inserted data source + -- and does not return one which was not inserted + :param connection_with_test_data: + :return: + """ + # Insert other data sources as well with different id + result = data_source_by_id_results( + data_source_id="SOURCE_UID_1", conn=connection_with_test_data + ) + assert result + # Check that a data source which was not inserted is not pulled + result = data_source_by_id_results( + data_source_id="SOURCE_UID_4", conn=connection_with_test_data + ) + assert not result + + +def test_data_source_by_id_query( + connection_with_test_data: psycopg2.extensions.connection, +) -> None: + """ + Test that data_source_by_id_query properly returns data for an inserted data source + -- and does not return one which was not inserted + :param connection_with_test_data: + :return: + """ + result = data_source_by_id_query( + data_source_id="SOURCE_UID_1", conn=connection_with_test_data + ) + assert result["agency_name"] == "Agency A" + + +def test_get_data_sources_for_map( + connection_with_test_data: psycopg2.extensions.connection, + inserted_data_sources_found: dict[str, bool], +) -> None: + """ + Test that get_data_sources_for_map includes only the expected source + with the expected lat/lng coordinates + :param connection_with_test_data: + :param inserted_data_sources_found: + :return: + """ + results = get_data_sources_for_map(conn=connection_with_test_data) + for result in results: + name = result["name"] + if name == "Source 1": + assert result["lat"] == 30 and result["lng"] == 20 + + if name in inserted_data_sources_found: + inserted_data_sources_found[name] = True + assert inserted_data_sources_found["Source 1"] + assert not inserted_data_sources_found["Source 2"] + assert not inserted_data_sources_found["Source 3"] + + +def test_convert_data_source_matches(): + """ + Convert_data_source_matches should output a list of + dictionaries based on the provided list of columns + and the list of tuples + """ + + # Define Test case Input and Output data + testcases = [ + { + "data_source_output_columns": ["name", "age"], + "results": [("Joe", 20), ("Annie", 30)], + "output": [{"name": "Joe", "age": 20}, {"name": "Annie", "age": 30}], + }, + # You can add more tests here as per requirement. + ] + + # Execute the tests + for testcase in testcases: + assert ( + data_source_queries.convert_data_source_matches( + testcase["data_source_output_columns"], testcase["results"] + ) + == testcase["output"] + ) + + +@pytest.fixture +def mock_make_response(monkeypatch): + mock = MagicMock() + monkeypatch.setattr("middleware.data_source_queries.make_response", mock) + return mock + + +@pytest.fixture +def mock_data_source_by_id_query(monkeypatch): + mock = MagicMock() + monkeypatch.setattr("middleware.data_source_queries.data_source_by_id_query", mock) + return mock + + +def test_data_source_by_id_wrapper_data_found(mock_data_source_by_id_query, mock_make_response): + mock_data_source_by_id_query.return_value = {"agency_name": "Agency A"} + mock_conn = MagicMock() + data_source_by_id_wrapper(arg="SOURCE_UID_1", conn=mock_conn) + mock_data_source_by_id_query.assert_called_with( + data_source_id="SOURCE_UID_1", conn=mock_conn + ) + mock_make_response.assert_called_with({"agency_name": "Agency A"}, 200) + +def test_data_source_by_id_wrapper_data_not_found(mock_data_source_by_id_query, mock_make_response): + mock_data_source_by_id_query.return_value = None + mock_conn = MagicMock() + data_source_by_id_wrapper(arg="SOURCE_UID_1", conn=mock_conn) + mock_data_source_by_id_query.assert_called_with( + data_source_id="SOURCE_UID_1", conn=mock_conn + ) + mock_make_response.assert_called_with({"message": "Data source not found."}, 200) \ No newline at end of file diff --git a/tests/middleware/test_initialize_psycopg2_connection.py b/tests/middleware/test_initialize_psycopg2_connection.py new file mode 100644 index 00000000..9525235b --- /dev/null +++ b/tests/middleware/test_initialize_psycopg2_connection.py @@ -0,0 +1,17 @@ +def test_initialize_psycopg2_connection_success(): + """ + Test that function properly initializes psycopg2 connection + and returns valid connection string, + to be tested by executing a simple select query + :return: + """ + pass + + +def test_initialize_psycopg2_connection_failure(): + """ + Check that function raises DatabaseInitializationError if + psycopg2.OperationalError occurs. + :return: + """ + pass diff --git a/tests/middleware/test_login_queries.py b/tests/middleware/test_login_queries.py new file mode 100644 index 00000000..8c4da62f --- /dev/null +++ b/tests/middleware/test_login_queries.py @@ -0,0 +1,84 @@ +import uuid +from unittest.mock import patch + +import psycopg2 +import pytest + +from middleware.login_queries import ( + login_results, + create_session_token, + get_session_token_user_data, + is_admin, +) +from middleware.custom_exceptions import UserNotFoundError, TokenNotFoundError +from tests.helper_functions import create_test_user +from tests.fixtures import db_cursor, dev_db_connection + + +def test_login_query(db_cursor: psycopg2.extensions.cursor) -> None: + """ + Test the login query by comparing the password digest for a user retrieved from the database + with the password hash of a test user. + + :param db_cursor: The database cursor to execute the query. + :return: None + """ + test_user = create_test_user(db_cursor) + + user_data = login_results(db_cursor, test_user.email) + + assert user_data["password_digest"] == test_user.password_hash + + +def test_login_results_user_not_found(db_cursor: psycopg2.extensions.cursor) -> None: + """UserNotFoundError should be raised if the user does not exist in the database""" + with pytest.raises(UserNotFoundError): + login_results(cursor=db_cursor, email="nonexistent@example.com") + + +def test_create_session_token_results(db_cursor: psycopg2.extensions.cursor) -> None: + """ + Tests the `create_session_token_results` method properly + creates the expected session token in the database, + associated with the proper user. + + :param db_cursor: The psycopg2 database cursor object. + :return: None + """ + test_user = create_test_user(db_cursor) + with patch("os.getenv", return_value="mysecretkey") as mock_getenv: + token = create_session_token(db_cursor, test_user.id, test_user.email) + new_token = get_session_token_user_data(db_cursor, token) + + assert new_token.email == test_user.email + + +def test_is_admin(db_cursor: psycopg2.extensions.cursor) -> None: + """ + Creates and inserts two users, one an admin and the other not + And then checks to see if the `is_admin` properly + identifies both + :param db_cursor: + """ + regular_user = create_test_user(db_cursor) + admin_user = create_test_user( + cursor=db_cursor, email="admin@admin.com", role="admin" + ) + assert is_admin(db_cursor, admin_user.email) + assert not is_admin(db_cursor, regular_user.email) + + +def test_is_admin_raises_user_not_logged_in_error(db_cursor): + """ + Check that when searching for a user by an email that doesn't exist, + the UserNotFoundError is raised + :return: + """ + with pytest.raises(UserNotFoundError): + is_admin(cursor=db_cursor, email=str(uuid.uuid4())) + + +def test_token_results_raises_token_not_found_error(db_cursor): + """token_results() should raise TokenNotFoundError for nonexistent token""" + with pytest.raises(TokenNotFoundError): + get_session_token_user_data(cursor=db_cursor, token=str(uuid.uuid4())) diff --git a/tests/middleware/test_quick_search_query.py b/tests/middleware/test_quick_search_query.py new file mode 100644 index 00000000..71ead947 --- /dev/null +++ b/tests/middleware/test_quick_search_query.py @@ -0,0 +1,154 @@ +import json +from unittest.mock import MagicMock + +import psycopg2 +import pytest + +from middleware.quick_search_query import ( + unaltered_search_query, + quick_search_query, + QUICK_SEARCH_COLUMNS, + quick_search_query_wrapper, +) +from tests.helper_functions import ( + has_expected_keys, + get_most_recent_quick_search_query_log, +) +from tests.fixtures import connection_with_test_data, dev_db_connection + + +def test_unaltered_search_query( + connection_with_test_data: psycopg2.extensions.connection, +) -> None: + """ + :param connection_with_test_data: A connection object that is connected to the test database containing the test data. + :return: None + Test the unaltered_search_query method properly returns only one result + """ + response = unaltered_search_query( + connection_with_test_data.cursor(), search="Source 1", location="City A" + ) + + assert len(response) == 1 + assert response[0][3] == "Type A" # Record Type + + +def test_quick_search_query_logging( + connection_with_test_data: psycopg2.extensions.connection, +) -> None: + """ + Tests that quick_search_query properly creates a log of the query + + :param connection_with_test_data: psycopg2.extensions.connection object representing the connection to the test database. + :return: None + """ + # Get datetime of test + with connection_with_test_data.cursor() as cursor: + cursor.execute("SELECT NOW()") + result = cursor.fetchone() + test_datetime = result[0] + + quick_search_query( + search="Source 1", location="City A", conn=connection_with_test_data + ) + + cursor = connection_with_test_data.cursor() + # Test that query inserted into log + result = get_most_recent_quick_search_query_log(cursor, "Source 1", "City A") + assert result.result_count == 1 + assert result.updated_at >= test_datetime + + +def test_quick_search_query_results( + connection_with_test_data: psycopg2.extensions.connection, +) -> None: + """ + Test the `quick_search_query` method returns expected test data + + :param connection_with_test_data: The connection to the test data database. + :return: None + """ + # TODO: Something about the quick_search_query might be mucking up the savepoints. Address once you fix quick_search's logic issues + results = quick_search_query( + search="Source 1", location="City A", conn=connection_with_test_data + ) + # Test that results include expected keys + assert has_expected_keys(results["data"][0].keys(), QUICK_SEARCH_COLUMNS) + assert len(results["data"]) == 1 + assert results["data"][0]["record_type"] == "Type A" + # "Source 3" was listed as pending and shouldn't show up + results = quick_search_query( + search="Source 3", location="City C", conn=connection_with_test_data + ) + assert len(results["data"]) == 0 + + +def test_quick_search_query_no_results( + connection_with_test_data: psycopg2.extensions.connection, +) -> None: + """ + Test the `quick_search_query` method returns no results when there are no matches + + :param connection_with_test_data: The connection to the test data database. + :return: None + """ + results = quick_search_query( + search="Nonexistent Source", + location="Nonexistent Location", + conn=connection_with_test_data, + ) + assert len(results["data"]) == 0 + + +@pytest.fixture +def mock_quick_search_query(monkeypatch): + mock = MagicMock() + monkeypatch.setattr("middleware.quick_search_query.quick_search_query", mock) + return mock + + +@pytest.fixture +def mock_make_response(monkeypatch): + mock = MagicMock() + monkeypatch.setattr("middleware.quick_search_query.make_response", mock) + return mock + + +@pytest.fixture +def mock_post_to_webhook(monkeypatch): + mock = MagicMock() + monkeypatch.setattr("middleware.quick_search_query.post_to_webhook", mock) + return mock + + +def test_quick_search_query_wrapper_happy_path( + mock_quick_search_query, mock_make_response +): + mock_quick_search_query.return_value = [{"record_type": "Type A"}] + mock_conn = MagicMock() + quick_search_query_wrapper(arg1="Source 1", arg2="City A", conn=mock_conn) + mock_quick_search_query.assert_called_with( + search="Source 1", location="City A", conn=mock_conn + ) + mock_make_response.assert_called_with([{"record_type": "Type A"}], 200) + + +def test_quick_search_query_wrapper_exception( + mock_quick_search_query, mock_make_response, mock_post_to_webhook +): + mock_quick_search_query.side_effect = Exception("Test Exception") + arg1 = "Source 1" + arg2 = "City A" + mock_conn = MagicMock() + quick_search_query_wrapper(arg1=arg1, arg2=arg2, conn=mock_conn) + mock_quick_search_query.assert_called_with( + search=arg1, location=arg2, conn=mock_conn + ) + mock_conn.rollback.assert_called_once() + user_message = "There was an error during the search operation" + mock_post_to_webhook.assert_called_with( + json.dumps({'content': 'There was an error during the search operation: Test Exception\nSearch term: Source 1\nLocation: City A'}) + ) + mock_make_response.assert_called_with( + {"count": 0, "message": user_message}, 500 + ) diff --git a/tests/middleware/test_reset_token_queries.py b/tests/middleware/test_reset_token_queries.py new file mode 100644 index 00000000..ee22a689 --- /dev/null +++ b/tests/middleware/test_reset_token_queries.py @@ -0,0 +1,69 @@ +import uuid + +import psycopg2.extensions +import pytest + +from middleware.custom_exceptions import TokenNotFoundError +from middleware.reset_token_queries import ( + check_reset_token, + add_reset_token, + delete_reset_token, +) +from tests.helper_functions import ( + create_reset_token, + create_test_user, + get_reset_tokens_for_email, +) +from tests.fixtures import db_cursor, dev_db_connection + + +def test_check_reset_token(db_cursor: psycopg2.extensions.cursor) -> None: + """ + Checks if a token existing in the database + is properly returned by check_reset_token + :param db_cursor: + :return: + """ + test_token_insert = create_reset_token(db_cursor) + + user_data = check_reset_token(db_cursor, test_token_insert.token) + assert test_token_insert.id == user_data["id"] + + +def test_check_reset_token_raises_token_not_found_error( + db_cursor: psycopg2.extensions, +) -> None: + with pytest.raises(TokenNotFoundError): + check_reset_token(db_cursor, token=str(uuid.uuid4())) + + +def test_add_reset_token(db_cursor: psycopg2.extensions.cursor) -> None: + """ + Checks if add_reset_token properly inserts a token + for the given email in the database + """ + user = create_test_user(db_cursor) + token = uuid.uuid4().hex + add_reset_token(db_cursor, user.email, token) + db_cursor.execute( + """ + SELECT id, token FROM RESET_TOKENS where email = %s + """, + (user.email,), + ) + results = db_cursor.fetchall() + assert len(results) == 1 + assert results[0][1] == token + + +def test_delete_reset_token(db_cursor: psycopg2.extensions.cursor) -> None: + """ + Checks if token previously inserted is deleted + by the delete_reset_token method + """ + reset_token_insert = create_reset_token(db_cursor) + results = get_reset_tokens_for_email(db_cursor, reset_token_insert) + assert len(results) == 1 + delete_reset_token(db_cursor, reset_token_insert.email, reset_token_insert.token) + results = get_reset_tokens_for_email(db_cursor, reset_token_insert) + assert len(results) == 0 diff --git a/tests/middleware/test_security.py b/tests/middleware/test_security.py new file mode 100644 index 00000000..16b0c86f --- /dev/null +++ b/tests/middleware/test_security.py @@ -0,0 +1,256 @@ +import datetime +import uuid +from http import HTTPStatus +from typing import Callable + +import flask +import pytest +from unittest.mock import patch, MagicMock + +import requests +from flask import Flask + +from middleware import security +from middleware.custom_exceptions import UserNotFoundError +from middleware.login_queries import create_session_token +from middleware.security import ( + validate_api_key, + APIKeyStatus, + api_required, + NoAPIKeyError, + ExpiredAPIKeyError, + InvalidAPIKeyError, + InvalidRoleError, +) +from tests.helper_functions import ( + create_test_user, + UserInfo, + give_user_admin_role, + create_api_key_db, +) +from tests.fixtures import dev_db_connection + + +def test_api_key_exists_in_users_table_with_admin_role(dev_db_connection): + cursor = dev_db_connection.cursor() + test_user = create_test_user(cursor) + give_user_admin_role(dev_db_connection, UserInfo(test_user.email, "")) + api_key = create_api_key_db(cursor, test_user.id) + dev_db_connection.commit() + result = validate_api_key(api_key, "", "") + assert result is None + + +def test_api_key_exists_in_users_table_with_non_admin_role(dev_db_connection): + cursor = dev_db_connection.cursor() + test_user = create_test_user(cursor) + api_key = create_api_key_db(cursor, test_user.id) + dev_db_connection.commit() + result = validate_api_key(api_key, "", "") + assert result is None + + +def test_api_key_not_in_users_table_but_in_session_tokens_table(dev_db_connection): + cursor = dev_db_connection.cursor() + test_user = create_test_user(cursor) + token = create_session_token(cursor, test_user.id, test_user.email) + dev_db_connection.commit() + result = validate_api_key(token, "", "") + assert result is None + + +def test_expired_session_token(dev_db_connection): + cursor = dev_db_connection.cursor() + test_user = create_test_user(cursor) + token = create_session_token(cursor, test_user.id, test_user.email) + cursor.execute( + f"UPDATE session_tokens SET expiration_date = '{datetime.date(year=2020, month=3, day=4)}' WHERE token = '{token}'" + ) + dev_db_connection.commit() + with pytest.raises(ExpiredAPIKeyError): + result = validate_api_key(token, "", "") + + +def test_session_token_with_admin_role(dev_db_connection): + cursor = dev_db_connection.cursor() + test_user = create_test_user(cursor) + give_user_admin_role(dev_db_connection, UserInfo(test_user.email, "")) + token = create_session_token(cursor, test_user.id, test_user.email) + dev_db_connection.commit() + result = validate_api_key(token, "", "") + assert result is None + + +def test_api_key_exists_in_access_tokens_table(dev_db_connection): + cursor = dev_db_connection.cursor() + token = uuid.uuid4().hex + expiration = datetime.datetime(year=2030, month=1, day=1) + cursor.execute( + f"insert into access_tokens (token, expiration_date) values (%s, %s)", + (token, expiration), + ) + dev_db_connection.commit() + result = validate_api_key(token, "", "") + assert result is None + + +def test_api_key_not_exist_in_any_table(dev_db_connection): + token = uuid.uuid4().hex + with pytest.raises(InvalidAPIKeyError) as e: + result = validate_api_key(token, "", "") + assert "API Key not found" in str(e.value) + + +def test_expired_access_token_in_access_tokens_table(dev_db_connection): + cursor = dev_db_connection.cursor() + token = uuid.uuid4().hex + expiration = datetime.datetime(year=1999, month=1, day=1) + cursor.execute( + f"insert into access_tokens (token, expiration_date) values (%s, %s)", + (token, expiration), + ) + dev_db_connection.commit() + with pytest.raises(InvalidAPIKeyError) as e: + result = validate_api_key(token, "", "") + assert "API Key not found" in str(e.value) + + +def test_admin_only_action_with_non_admin_role(dev_db_connection): + cursor = dev_db_connection.cursor() + test_user = create_test_user(cursor) + api_key = create_api_key_db(cursor, test_user.id) + dev_db_connection.commit() + with pytest.raises(InvalidRoleError) as e: + result = validate_api_key(api_key, "datasources", "PUT") + assert "You do not have permission to access this endpoint" in str(e.value) + + +def test_admin_only_action_with_admin_role(dev_db_connection): + cursor = dev_db_connection.cursor() + test_user = create_test_user(cursor) + give_user_admin_role(dev_db_connection, UserInfo(test_user.email, "")) + api_key = create_api_key_db(cursor, test_user.id) + dev_db_connection.commit() + result = validate_api_key(api_key, "datasources", "PUT") + assert result is None + +@pytest.fixture +def app() -> Flask: + app = Flask(__name__) + return app + + +@pytest.fixture +def client(app: Flask): + return app.test_client() + + +@pytest.fixture +def mock_request_headers(monkeypatch): + mock = MagicMock() + monkeypatch.setattr(flask, "request", mock) + return mock + + +@pytest.fixture +def mock_validate_api_key(monkeypatch): + mock = MagicMock() + monkeypatch.setattr(security, "validate_api_key", mock) + return mock + + +@pytest.fixture +def dummy_route(): + @api_required + def _dummy_route(): + return "This is a protected route", HTTPStatus.OK.value + + return _dummy_route + + +def test_api_required_happy_path( + app, client, mock_request_headers, mock_validate_api_key, dummy_route: Callable +): + mock_validate_api_key.return_value = None + with app.test_request_context(headers={"Authorization": "Bearer valid_api_key"}): + response = dummy_route() + assert response == ("This is a protected route", HTTPStatus.OK.value) + + +def test_api_required_api_key_expired( + app, client, mock_request_headers, mock_validate_api_key, dummy_route: Callable +): + mock_validate_api_key.side_effect = ExpiredAPIKeyError("The provided API key has expired") + with app.test_request_context(headers={"Authorization": "Bearer valid_api_key"}): + response = dummy_route() + assert response == ({"message": "The provided API key has expired"}, HTTPStatus.UNAUTHORIZED.value) + + +def test_api_required_expired_api_key( + app, client, mock_request_headers, mock_validate_api_key, dummy_route: Callable +): + mock_validate_api_key.side_effect = ExpiredAPIKeyError("The provided API key has expired") + with app.test_request_context(headers={"Authorization": "Bearer expired_api_key"}): + response = dummy_route() + assert response == ( + {"message": "The provided API key has expired"}, + HTTPStatus.UNAUTHORIZED.value, + ) + + +def test_api_required_no_api_key_in_request_header( + app, client, mock_request_headers, mock_validate_api_key, dummy_route: Callable +): + with app.test_request_context(headers={"Authorization": "Bearer"}): + response = dummy_route() + assert response == ( + {"message": "Please provide a properly formatted bearer token and API key"}, + HTTPStatus.BAD_REQUEST.value, + ) + + +def test_api_required_invalid_role( + app, client, mock_request_headers, mock_validate_api_key, dummy_route: Callable +): + mock_validate_api_key.side_effect = InvalidRoleError( + "You do not have permission to access this endpoint" + ) + with app.test_request_context(headers={"Authorization": "Bearer valid_api_key"}): + response = dummy_route() + assert response == ( + {"message": "You do not have permission to access this endpoint"}, + HTTPStatus.FORBIDDEN.value, + ) + + +def test_api_required_not_authorization_key_in_request_header( + app, client, mock_request_headers, mock_validate_api_key, dummy_route: Callable +): + with app.test_request_context(headers={}): + response = dummy_route() + assert response == ( + {"message": "Please provide an 'Authorization' key in the request header"}, + HTTPStatus.BAD_REQUEST.value, + ) + + +def test_api_required_improperly_formatted_authorization_key( + app, client, mock_request_headers, mock_validate_api_key, dummy_route: Callable +): + with app.test_request_context(headers={"Authorization": "Bearer"}): + response = dummy_route() + assert response == ( + {"message": "Please provide a properly formatted bearer token and API key"}, + HTTPStatus.BAD_REQUEST.value, + ) + + +def test_api_required_undefined_api_key( + app, client, mock_request_headers, mock_validate_api_key, dummy_route: Callable +): + with app.test_request_context(headers={"Authorization": "Bearer undefined"}): + response = dummy_route() + assert response == ( + {"message": "Please provide an API key"}, + HTTPStatus.BAD_REQUEST.value, + ) diff --git a/tests/middleware/test_user_queries.py b/tests/middleware/test_user_queries.py new file mode 100644 index 00000000..4e51540d --- /dev/null +++ b/tests/middleware/test_user_queries.py @@ -0,0 +1,42 @@ +import psycopg2 +import pytest + +from middleware.custom_exceptions import UserNotFoundError +from middleware.user_queries import user_post_results, user_check_email +from tests.helper_functions import create_test_user +from tests.fixtures import db_cursor, dev_db_connection + + +def test_user_post_query(db_cursor: psycopg2.extensions.cursor) -> None: + """ + Test the `user_post_query` method, ensuring it properly returns the expected results + + :param db_cursor: The database cursor. + :return: None. + """ + user_post_results(db_cursor, "unit_test", "unit_test") + + db_cursor.execute(f"SELECT email FROM users WHERE email = 'unit_test'") + email_check = db_cursor.fetchone()[0] + + assert email_check == "unit_test" + + +def test_user_check_email(db_cursor: psycopg2.extensions.cursor) -> None: + """ + Verify the functionality of the `user_check_email` method. + + :param db_cursor: A `psycopg2.extensions.cursor` object representing the database cursor. + :return: None + + """ + user = create_test_user(db_cursor) + user_data = user_check_email(db_cursor, user.email) + assert user_data["id"] == user.id + + +def test_user_check_email_raises_user_not_found_error( + db_cursor: psycopg2.extensions, +) -> None: + with pytest.raises(UserNotFoundError): + user_check_email(db_cursor, "nonexistent@example.com") diff --git a/tests/resources/__init__.py b/tests/resources/__init__.py new file mode 100644 index 00000000..b8dfa95d --- /dev/null +++ b/tests/resources/__init__.py @@ -0,0 +1,4 @@ +# The below line is required to bypass the api_required decorator, +# and must be positioned prior to other imports in order to work. +from unittest.mock import patch, MagicMock +patch("middleware.security.api_required", lambda x: x).start() \ No newline at end of file diff --git a/tests/resources/app_test.py b/tests/resources/app_test.py new file mode 100644 index 00000000..73e24e95 --- /dev/null +++ b/tests/resources/app_test.py @@ -0,0 +1,165 @@ +import os +from app import create_app +from tests.resources.app_test_data import ( + DATA_SOURCES_ROWS, + AGENCIES_ROWS, +) +import datetime +import sqlite3 +import pytest +from unittest.mock import patch, MagicMock + +api_key = os.getenv("VUE_APP_PDAP_API_KEY") +HEADERS = {"Authorization": f"Bearer {api_key}"} +current_datetime = datetime.datetime.now() +DATETIME_STRING = current_datetime.strftime("%Y-%m-%d %H:%M:%S") + + +@pytest.fixture() +def test_app(): + app = create_app() + yield app + + +@pytest.fixture() +def client(test_app): + return test_app.test_client() + + +@pytest.fixture() +def runner(test_app): + return test_app.test_cli_runner() + + +@pytest.fixture() +def test_app_with_mock(mocker): + # Patch the initialize_psycopg2_connection function so it returns a MagicMock + yield create_app(mocker.MagicMock()) + + +@pytest.fixture() +def client_with_mock(test_app_with_mock): + # Use the app with the mocked database connection to get the test client + return test_app_with_mock.test_client() + + +@pytest.fixture() +def runner_with_mock(test_app_with_mock): + # Use the app with the mocked database connection for the test CLI runner + return test_app_with_mock.test_cli_runner() + + +@pytest.fixture +def session(): + connection = sqlite3.connect("file::memory:?cache=shared", uri=True) + db_session = connection.cursor() + with open("do_db_ddl_clean.sql", "r") as f: + sql_file = f.read() + sql_queries = sql_file.split(";") + for query in sql_queries: + db_session.execute(query.replace("\n", "")) + + for row in DATA_SOURCES_ROWS: + # valid_row = {k: v for k, v in row.items() if k in all_columns} + # clean_row = [r if r is not None else "" for r in row] + fully_clean_row = [str(r) for r in row] + fully_clean_row_str = "'" + "', '".join(fully_clean_row) + "'" + db_session.execute(f"insert into data_sources values ({fully_clean_row_str})") + db_session.execute( + "update data_sources set broken_source_url_as_of = null where broken_source_url_as_of = 'NULL'" + ) + + for row in AGENCIES_ROWS: + clean_row = [r if r is not None else "" for r in row] + fully_clean_row = [str(r) for r in clean_row] + fully_clean_row_str = "'" + "', '".join(fully_clean_row) + "'" + db_session.execute(f"insert into agencies values ({fully_clean_row_str})") + + # sql_query_log = f"INSERT INTO quick_search_query_logs (id, search, location, results, result_count, datetime_of_request, created_at) VALUES (1, 'test', 'test', '', 0, '{DATETIME_STRING}', '{DATETIME_STRING}')" + # db_session.execute(sql_query_log) + + yield connection + connection.close() + + +# def test_post_user(client): +# response = client.post( +# "/user", headers=HEADERS, json={"email": "test", "password": "test"} +# ) + +# # with initialize_psycopg2_connection() as psycopg2_connection: +# # cursor = psycopg2_connection.cursor() +# # cursor.execute(f"DELETE FROM users WHERE email = 'test'") +# # psycopg2_connection.commit() + +# assert response.json["data"] == "Successfully added user" + +# def test_put_archives(client): +# current_datetime = datetime.datetime.now() +# datetime_string = current_datetime.strftime("%Y-%m-%d %H:%M:%S") +# response = client.put( +# "/archives", +# headers=HEADERS, +# json=json.dumps( +# { +# "id": "test", +# "last_cached": datetime_string, +# "broken_source_url_as_of": "", +# } +# ), +# ) + +# assert response.json["status"] == "success" + + +# def test_put_archives_brokenasof(client): +# current_datetime = datetime.datetime.now() +# datetime_string = current_datetime.strftime("%Y-%m-%d") +# response = client.put( +# "/archives", +# headers=HEADERS, +# json=json.dumps( +# { +# "id": "test", +# "last_cached": datetime_string, +# "broken_source_url_as_of": datetime_string, +# } +# ), +# ) + +# assert response.json["status"] == "success" + + +# # agencies +# def test_agencies(client): +# response = client.get("/agencies/1", headers=HEADERS) + +# assert len(response.json["data"]) > 0 + + +# def test_agencies_pagination(client): +# response1 = client.get("/agencies/1", headers=HEADERS) +# response2 = client.get("/agencies/2", headers=HEADERS) + +# assert response1 != response2 + +# region Resources + + +def test_get_api_key(client_with_mock, mocker, test_app_with_mock): + mock_request_data = {"email": "user@example.com", "password": "password"} + mock_user_data = {"id": 1, "password_digest": "hashed_password"} + + # Mock login_results function to return mock_user_data + mocker.patch("resources.ApiKey.login_results", return_value=mock_user_data) + # Mock check_password_hash based on the valid_login parameter + mocker.patch("resources.ApiKey.check_password_hash", return_value=True) + + with client_with_mock: + response = client_with_mock.get("/api_key", json=mock_request_data) + json_data = response.get_json() + assert "api_key" in json_data + assert response.status_code == 200 + + +# endregion diff --git a/app_test_data.py b/tests/resources/app_test_data.py similarity index 100% rename from app_test_data.py rename to tests/resources/app_test_data.py diff --git a/do_db_ddl_clean.sql b/tests/resources/do_db_ddl_clean.sql similarity index 100% rename from do_db_ddl_clean.sql rename to tests/resources/do_db_ddl_clean.sql diff --git a/tests/resources/test_DataSources.py b/tests/resources/test_DataSources.py new file mode 100644 index 00000000..bfdc1628 --- /dev/null +++ b/tests/resources/test_DataSources.py @@ -0,0 +1,15 @@ +# The below line is required to bypass the api_required decorator, +# and must be positioned prior to other imports in order to work. +from unittest.mock import patch, MagicMock +patch("middleware.security.api_required", lambda x: x).start() +from tests.fixtures import client_with_mock_db + +def test_put_data_source_by_id( + client_with_mock_db, monkeypatch +): + + monkeypatch.setattr("resources.DataSources.request", MagicMock()) + # mock_request.get_json.return_value = {"name": "Updated Data Source"} + response = client_with_mock_db.client.put("/data-sources-by-id/test_id") + assert response.status_code == 200 + assert response.json == {"message": "Data source updated successfully."} diff --git a/tests/resources/test_RefreshSession.py b/tests/resources/test_RefreshSession.py new file mode 100644 index 00000000..839d9972 --- /dev/null +++ b/tests/resources/test_RefreshSession.py @@ -0,0 +1,134 @@ +from unittest.mock import MagicMock + +import pytest + +from middleware.custom_exceptions import TokenNotFoundError +from middleware.login_queries import SessionTokenUserData +from tests.fixtures import client_with_mock_db +from tests.helper_functions import check_response_status + + +@pytest.fixture +def mock_cursor(client_with_mock_db): + return client_with_mock_db.mock_db.cursor.return_value + + +@pytest.fixture +def mock_get_session_token_user_data(monkeypatch): + mock = MagicMock() + monkeypatch.setattr("resources.RefreshSession.get_session_token_user_data", mock) + return mock + + +@pytest.fixture +def mock_delete_session_token(monkeypatch): + mock = MagicMock() + monkeypatch.setattr("resources.RefreshSession.delete_session_token", mock) + return mock + + +@pytest.fixture +def mock_create_session_token(monkeypatch): + mock = MagicMock() + monkeypatch.setattr("resources.RefreshSession.create_session_token", mock) + return mock + + +def test_post_refresh_session_happy_path( + client_with_mock_db, + mock_cursor, + mock_get_session_token_user_data, + mock_delete_session_token, + mock_create_session_token, +): + test_session_token_user_data = SessionTokenUserData( + id="test_id", email="test_email" + ) + mock_get_session_token_user_data.return_value = test_session_token_user_data + mock_create_session_token.return_value = "new_test_session_token" + + response = client_with_mock_db.client.post( + "/refresh-session", + json={ + "session_token": "old_test_session_token", + }, + ) + check_response_status(response, 200) + assert response.json == { + "message": "Successfully refreshed session token", + "data": "new_test_session_token", + } + mock_get_session_token_user_data.assert_called_once_with( + mock_cursor, "old_test_session_token" + ) + mock_delete_session_token.assert_called_once_with( + mock_cursor, "old_test_session_token" + ) + mock_create_session_token.assert_called_once_with( + mock_cursor, test_session_token_user_data.id, test_session_token_user_data.email + ) + client_with_mock_db.mock_db.commit.assert_called_once() + + +def test_post_refresh_session_token_not_found( + client_with_mock_db, + mock_cursor, + mock_get_session_token_user_data, + mock_delete_session_token, + mock_create_session_token, +): + """ + Test that RefreshSessionPost behaves as expected when the session token is not found + :param client_with_mock_db: + :return: + """ + mock_get_session_token_user_data.side_effect = TokenNotFoundError + response = client_with_mock_db.client.post( + "/refresh-session", + json={ + "session_token": "old_test_session_token", + }, + ) + + check_response_status(response, 403) + assert response.json == { + "message": "Invalid session token", + } + mock_get_session_token_user_data.assert_called_once_with( + mock_cursor, "old_test_session_token" + ) + mock_delete_session_token.assert_not_called() + mock_create_session_token.assert_not_called() + client_with_mock_db.mock_db.commit.assert_not_called() + + +def test_post_refresh_session_unexpected_error( + client_with_mock_db, + mock_cursor, + mock_get_session_token_user_data, + mock_delete_session_token, + mock_create_session_token, +): + """ + Test that RefreshSessionPost behaves as expected when there is an unexpected error + :param client_with_mock_db: + :return: + """ + mock_get_session_token_user_data.side_effect = Exception("An unexpected error occurred") + response = client_with_mock_db.client.post( + "/refresh-session", + json={ + "session_token": "old_test_session_token", + }, + ) + + check_response_status(response, 500) + assert response.json == { + "message": "An unexpected error occurred", + } + mock_get_session_token_user_data.assert_called_once_with( + mock_cursor, "old_test_session_token" + ) + mock_delete_session_token.assert_not_called() + mock_create_session_token.assert_not_called() + client_with_mock_db.mock_db.commit.assert_not_called() diff --git a/tests/resources/test_search_tokens.py b/tests/resources/test_search_tokens.py new file mode 100644 index 00000000..cc5eb4bb --- /dev/null +++ b/tests/resources/test_search_tokens.py @@ -0,0 +1,166 @@ +import unittest.mock +from collections import namedtuple + +import pytest +from flask import Flask + +from resources.SearchTokens import SearchTokens + + +class MockPsycopgConnection: + def cursor(self): + return MockCursor() + + def commit(self): + pass + + def rollback(self): + pass + + +class MockCursor: + def execute(self, query, params=None): + pass + + def fetchall(self): + pass + + +@pytest.fixture +def app(): + app = Flask(__name__) + app.config.update({"TESTING": True}) + return app + + +@pytest.fixture +def client(app): + return app.test_client() + + +@pytest.fixture +def mock_psycopg_connection(): + return MockPsycopgConnection() + + +@pytest.fixture +def search_tokens(mock_psycopg_connection): + return SearchTokens(psycopg2_connection=mock_psycopg_connection) + + +@pytest.fixture +def mock_dependencies(mocker): + mocks = { + "insert_access_token": mocker.patch( + "resources.SearchTokens.insert_access_token", return_value=None + ), + "quick-search": mocker.patch( + "resources.SearchTokens.quick_search_query_wrapper", + return_value={"result": "quick_search"}, + ), + "data-sources": mocker.patch( + "resources.SearchTokens.get_approved_data_sources_wrapper", + return_value={"result": "data_sources"}, + ), + "data-sources-by-id": mocker.patch( + "resources.SearchTokens.data_source_by_id_wrapper", + return_value={"result": "data_source_by_id"}, + ), + "data-sources-map": mocker.patch( + "resources.SearchTokens.get_data_sources_for_map_wrapper", + return_value={"result": "data_sources_map"}, + ), + } + return mocks + + +def perform_test_search_tokens_endpoint( + search_tokens, + mocker, + app, + endpoint, + expected_response, + params=None, + mocked_dependencies: dict[str, unittest.mock.MagicMock] = None, +): + mock_insert_access_token = mocker.patch( + "resources.SearchTokens.insert_access_token" + ) + url = generate_url(endpoint, params) + + with app.test_request_context(url): + response = search_tokens.get() + assert ( + response == expected_response + ), f"{endpoint} endpoint should call {expected_response}, got {response}" + mock_insert_access_token.assert_called_once() + if endpoint in mocked_dependencies: + # Check parameters properly called + mock_dependency = mocked_dependencies[endpoint] + call_args = tuple(params.values()) if params else () + mock_dependency.assert_called_with( + *call_args, search_tokens.psycopg2_connection + ), f"{mock_dependency._mock_name or 'mock'} was not called with the expected parameters" + + +def generate_url(endpoint, params): + url = f"/?endpoint={endpoint}" + if params: + url += "".join([f"&{key}={value}" for key, value in params.items()]) + return url + + +TestCase = namedtuple("TestCase", ["endpoint", "expected_response", "params"]) + +test_cases = [ + TestCase( + "quick-search", {"result": "quick_search"}, {"arg1": "test1", "arg2": "test2"} + ), + TestCase("data-sources", {"result": "data_sources"}, None), + TestCase("data-sources-by-id", {"result": "data_source_by_id"}, {"arg1": "1"}), + TestCase("data-sources-map", {"result": "data_sources_map"}, None), +] + + +@pytest.mark.parametrize("test_case", test_cases) +def test_endpoints(search_tokens, mocker, app, test_case, mock_dependencies): + """ + Perform test for endpoints, ensuring each provided endpoint calls + the appropriate wrapper function with the appropriate arguments + + :param search_tokens: The search tokens to be used for the test. + :param mocker: The mocker object. + :param app: The application object. + :param test_case: The test case object. + :return: None + """ + perform_test_search_tokens_endpoint( + search_tokens, + mocker, + app, + test_case.endpoint, + test_case.expected_response, + test_case.params, + mock_dependencies, + ) + + +def test_search_tokens_unknown_endpoint(app, mocker, search_tokens): + url = generate_url("test_endpoint", {"test_param": "test_value"}) + with app.test_request_context(url): + response = search_tokens.get() + assert response.status_code == 500 + assert response.json == {"message": "Unknown endpoint: test_endpoint"} + + +def test_search_tokens_get_exception(app, mocker, search_tokens): + mocker.patch( + "resources.SearchTokens.insert_access_token", + side_effect=Exception("Test exception"), + ) + + url = generate_url("test_endpoint", {"test_param": "test_value"}) + with app.test_request_context(url): + response = search_tokens.get() + assert response.status_code == 500 + assert response.json == {"message": "Test exception"} diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py new file mode 100644 index 00000000..d6ab52e4 --- /dev/null +++ b/tests/test_endpoints.py @@ -0,0 +1,99 @@ +""" +This module tests the functionality of all endpoints, ensuring that, as designed, they call (or don't call) +the appropriate methods in their supporting classes +""" + +from collections import namedtuple + +import pytest +from unittest.mock import patch + +from flask.testing import FlaskClient +from flask_restful import Resource + +from resources.Agencies import Agencies +from resources.ApiKey import ApiKey +from resources.Archives import Archives +from resources.DataSources import ( + DataSources, + DataSourcesMap, + DataSourcesNeedsIdentification, + DataSourceById, +) +from resources.Login import Login +from resources.QuickSearch import QuickSearch +from resources.RefreshSession import RefreshSession +from resources.RequestResetPassword import RequestResetPassword +from resources.ResetPassword import ResetPassword +from resources.ResetTokenValidation import ResetTokenValidation +from resources.SearchTokens import SearchTokens +from resources.User import User +from tests.fixtures import client_with_mock_db, ClientWithMockDB + +# Define constants for HTTP methods +GET = "get" +POST = "post" +PUT = "put" +DELETE = "delete" + + +def run_endpoint_tests( + client: FlaskClient, endpoint: str, class_type: Resource, allowed_methods: list[str] +): + methods = [GET, POST, PUT, DELETE] + for method in methods: + if method in allowed_methods: + with patch.object( + class_type, method, return_value="Mocked response" + ) as mock_method: + response = getattr(client, method)(endpoint) + assert ( + response.status_code == 200 + ), f"{method.upper()} {endpoint} failed with status code {response.status_code}, expected 200" + mock_method.assert_called_once(), f"{method.upper()} {endpoint} should have called the {method} method on {class_type.__name__}" + else: + response = getattr(client, method)(endpoint) + assert ( + response.status_code == 405 + ), f"{method.upper()} {endpoint} failed with status code {response.status_code}, expected 405" + + +TestParameters = namedtuple("Resource", ["class_type", "endpoint", "allowed_methods"]) +test_parameters = [ + TestParameters(User, "/user", [POST, PUT]), + TestParameters(Login, "/login", [POST]), + TestParameters(RefreshSession, "/refresh-session", [POST]), + TestParameters(ApiKey, "/api_key", [GET]), + TestParameters(RequestResetPassword, "/request-reset-password", [POST]), + TestParameters(ResetPassword, "/reset-password", [POST]), + TestParameters(ResetTokenValidation, "/reset-token-validation", [POST]), + TestParameters(QuickSearch, "/quick-search//", [GET]), + TestParameters(Archives, "/archives", [GET, PUT]), + TestParameters(DataSources, "/data-sources", [GET, POST]), + TestParameters(DataSourcesMap, "/data-sources-map", [GET]), + TestParameters( + DataSourcesNeedsIdentification, "/data-sources-needs-identification", [GET] + ), + TestParameters(DataSourceById, "/data-sources-by-id/", [GET, PUT]), + TestParameters(Agencies, "/agencies/", [GET]), + TestParameters(SearchTokens, "/search-tokens", [GET]), +] + + +@pytest.mark.parametrize("test_parameter", test_parameters) +def test_endpoints(client_with_mock_db: ClientWithMockDB, test_parameter) -> None: + """ + Using the test_parameters list, this tests all endpoints to ensure that + only the appropriate methods can be called from the endpoints + :param client: the client fixture + :param class_type: + :param endpoint: + :param allowed_methods: + :return: + """ + run_endpoint_tests( + client_with_mock_db.client, + test_parameter.endpoint, + test_parameter.class_type, + test_parameter.allowed_methods, + )