diff --git a/.env.example b/.env.example index ccc2421..814943a 100644 --- a/.env.example +++ b/.env.example @@ -11,7 +11,7 @@ export DATABASE_URL="postgresql://user:password@localhost:5432/database_name" export SECRET_KEY="secret" # Secret key for signing tokens # Origins to be allowed by the backend -export ORIGINS="https://beta.ansari.chat,http://beta.ansari.chat,https://ansari.chat,http://ansari.chat,https://hajiansari.ai,http://hajiansari.ai,https://ansari.endeavorpal.com" +export ORIGINS="https://beta.ansari.chat,http://beta.ansari.chat,https://ansari.chat,http://ansari.chat,https://hajiansari.ai,http://hajiansari.ai,https://ansari.endeavorpal.com,https://web.whatsapp.com" # Vectara search engine configuration export PGPASSWORD="" # Password for PostgreSQL database @@ -25,15 +25,41 @@ export template_dir="." # Directory path for templates # Related to WhatsApp Business and Meta (leave empty if you're not planning to use WhatsApp) # Source 1: https://www.youtube.com/watch?v=KP6_BUw3i0U +# Watch Until 32:25 # Source 2: https://glitch.com/edit/#!/insidious-tartan-alvarezsaurus # Source 3: https://developers.facebook.com/blog/post/2022/10/24/sending-messages-with-whatsapp-in-your-python-applications/#u_0_39_8q + +# Moreover, if want to test whatsapp's webhook locally, you can use zrok on a reserved URL with a zrok "share token" +# obtained by contacting its current holder: https://github.com/OdyAsh (source 1, 2 below) +# Alternatively, you can change the webhook url all together (source 3, 4 below) +# Check these sources for more details: +# Source 1: https://dev.to/odyash/quickly-share-your-app-with-zrok-4ihp +# Source 2: https://openziti.discourse.group/t/how-do-i-use-a-reserved-share-on-different-devices/2379/2 +# Source 3: https://youtu.be/KP6_BUw3i0U?t=1294 +# (@21:33 and 25:30, however they use glitch instead of zrok, so you'll just need to change the webhook url to your zrok url) +# Source 4 (where you can change callback url, given that your facebook account gets access by the app's admins): +# https://developers.facebook.com/apps/871020755148175/whatsapp-business/wa-settings/ +# Note 1: Obviously, that `871...175` is the testing app's public id, so if this link still doesn't work even after you gain access, +# then the admins most probably created a new test app instance +# Note 2: If an unexpected 3rd party discovers the ZROK_SHARE_TOKEN, +# a new one will have to be generated, then added to Meta's callback URL of the *testing* app +# (Noting that the *production* app's callback URL will be different anyway, so the 3rd party won't be able to access that app) +# (but we still don't want random calls to be made to our testing app, so that's why we'll still have to change an exposed token :]) + export WHATSAPP_RECIPIENT_WAID="<>" export WHATSAPP_API_VERSION="<>" export WHATSAPP_BUSINESS_PHONE_NUMBER_ID="<>" export WHATSAPP_ACCESS_TOKEN_FROM_SYS_USER="<" -export WHATSAPP_VERIFY_TOKEN_FOR_WEBHOOK="<>" +export WHATSAPP_VERIFY_TOKEN_FOR_WEBHOOK="<>" +export ZROK_SHARE_TOKEN="<>" # Related to internal code logic # Leave the values below when locally debugging the application +# In production, don't add them to environment variables, or add them as "INFO"/"False" respectively export LOGGING_LEVEL="DEBUG" -export DEBUG_MODE="True" \ No newline at end of file +export DEBUG_MODE="True" + +# To get rid of .py[cod] files (This should key should NOT be set in production!) +# This is only to de-clutter your local development environment +# Details: https://docs.python-guide.org/writing/gotchas/#disabling-bytecode-pyc-files +PYTHONDONTWRITEBYTECODE=1 diff --git a/data/mawsuah/strip_tashkeel.py b/data/mawsuah/strip_tashkeel.py index d4fa643..90586e4 100644 --- a/data/mawsuah/strip_tashkeel.py +++ b/data/mawsuah/strip_tashkeel.py @@ -6,7 +6,7 @@ from ansari.ansari_logger import get_logger -logger = get_logger(__name__) +logger = get_logger() def strip_tashkeel_from_doc(input_file, output_file): diff --git a/docs/structure_of_api_responses/ansari_website_api_structure_of_a_request_sent_using_zrok.json b/docs/structure_of_api_responses/ansari_website_api_structure_of_a_request_sent_using_zrok.json new file mode 100644 index 0000000..07d129f --- /dev/null +++ b/docs/structure_of_api_responses/ansari_website_api_structure_of_a_request_sent_using_zrok.json @@ -0,0 +1,76 @@ +{ + "scope": { + "type": "http", + "asgi": { + "version": "3.0", + "spec_version": "2.4" + }, + "http_version": "1.1", + "server": ["127.0.0.1", 8000], // When running locally + "client": ["127.0.0.1", 11563], // The port here changes dynamically + "scheme": "http", + "method": "POST", + "root_path": "", + "path": "/api/v2/users/login", + "raw_path": "/api/v2/users/login", + "query_string": "", + "headers": [ + ["host", "localhost:8000"], + ["connection", "keep-alive"], + ["content-length", "83"], + ["sec-ch-ua-platform", "\"Windows\""], + ["user-agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"], + ["x-mobile-ansari", "ANSARI"], + ["sec-ch-ua", "\"Google Chrome\";v=\"131\", \"Chromium\";v=\"131\", \"Not_A Brand\";v=\"24\""], + ["content-type", "application/json"], + ["sec-ch-ua-mobile", "?0"], + ["accept", "*/*"], + ["origin", "http://localhost:3000"], + ["sec-fetch-site", "same-site"], + ["sec-fetch-mode", "cors"], + ["sec-fetch-dest", "empty"], + ["referer", "http://localhost:3000/"], + ["accept-encoding", "gzip, deflate, br, zstd"], + ["accept-language", "en-GB,en;q=0.9,ar-EG;q=0.8,ar;q=0.7,en-US;q=0.6"] + ], + "state": {}, + "app": "", + "starlette.exception_handlers": { + "": "", + "": "", + "": "", + "": "" + }, + "router": "", + "endpoint": "", + "path_params": {}, + "route": { + "path": "/api/v2/users/login", + "name": "login_user", + "methods": ["POST"] + } + }, + "_receive": "", + "_send": ".wrapped_app..sender>", + "_stream_consumed": true, + "_is_disconnected": false, + "_form": null, + // _body's value (and other strings) were actually binary strings (i.e., start with b'...') + "_body": "{\"email\":\"guest_<>@endeavorpal.com\",\"password\":\"<>\",\"guest\":true}", + // this is what actually gets returned when accessing headers property (e.g., `request.headers`) + // Check Starlette's implementation (which FastAPI uses) for details: + // https://github.com/encode/starlette/blob/b68a142a356ede730083347f254e1eae8b5c803e/starlette/requests.py#L12 + "_headers": { + "host": "localhost:8000", + "connection": "...", + "...": "..." + // I.e., the value of the `_headers` key is a dictionary of the headers already mentioned above + }, + "_json": { + "email": "guest_<>@endeavorpal.com", + "...": ["..."] + // I.e., the value of the `_json` key is simply the dictionary equivalent of `_body`'s string value + }, + "_query_params": "", + "_cookies": {} +} \ No newline at end of file diff --git a/docs/structure_of_api_responses/meta_whatsapp_api_structure_of_a_request_sent_using_zrok.json b/docs/structure_of_api_responses/meta_whatsapp_api_structure_of_a_request_sent_using_zrok.json new file mode 100644 index 0000000..5a7d1af --- /dev/null +++ b/docs/structure_of_api_responses/meta_whatsapp_api_structure_of_a_request_sent_using_zrok.json @@ -0,0 +1,73 @@ +{ + "scope": { + "type": "http", + "asgi": { + "version": "3.0", + "spec_version": "2.4" + }, + "http_version": "1.1", + "server": ["127.0.0.1", 8000], // When running locally + "client": ["<>", 0], + "scheme": "https", + "method": "POST", + "root_path": "", + "path": "/whatsapp/v1", + "raw_path": "/whatsapp/v1", + "query_string": "", + "headers": [ + ["host", "YOUR_ZROK_SHARE_TOKEN.share.zrok.io"], + ["user-agent", "facebookexternalua"], + ["content-length", "545"], + ["accept", "*/*"], + ["accept-encoding", "deflate, gzip"], + ["content-type", "application/json"], + ["x-amzn-trace-id", "Root=1-674b2035-0f0a8ab27075asce3324dcdb"], // trace value here is fake + ["x-forwarded-for", "173.REST.OF.IP, <>"], + ["x-forwarded-port", "443"], + ["x-forwarded-proto", "https"], + ["x-hub-signature", "sha1=8a3e35da6fb5dfaaf5aaa46c8d059d519e18112d"], // sha1 hash here is fake + ["x-hub-signature-256", "sha256=51d62480d40ffd0f48d1cde1ea47656452fd65b5ac29077fe3c6b4e68d74c827"], // sha256 here is fake + ["x-proxy", "zrok"] + ], + "state": {}, + "app": "", + "starlette.exception_handlers": { + "": "", + "": "", + "": "", + "": "" + }, + "router": "", + "endpoint": "", + "path_params": {}, + "route": { + "path": "/whatsapp/v1", + "name": "main_webhook", + "methods": ["POST"] + } + }, + "_receive": "", + "_send": ".wrapped_app..sender>", + "_stream_consumed": true, + "_is_disconnected": false, + "_form": null, + "_query_params": "", + // this is what actually gets returned when accessing headers property (e.g., `request.headers`) + // Check Starlette's implementation (which FastAPI uses) for details: + // https://github.com/encode/starlette/blob/b68a142a356ede730083347f254e1eae8b5c803e/starlette/requests.py#L125 + "_headers": { + "host": "...", + "user-agent": "...", + "...": "..." + // I.e., the value of the `_headers` key is a dictionary of the headers already mentioned above + }, + "_cookies": {}, + // _body's value (and other strings) were actually binary strings (i.e., start with b'...') + // Also, it contains content mentioned in other `meta_whatsapp_*.json` files + "_body": "{\"object\":\"whatsapp_business_account\", ...}", + "_json": { + "object": "whatsapp_business_account", + "...": ["..."] + // I.e., the value of the `_json` key is simply the dictionary equivalent of `_body`'s string value + } +} \ No newline at end of file diff --git a/docs/structure_of_api_responses/meta_whatsapp_structure_of_a_user_incoming_msg.json b/docs/structure_of_api_responses/meta_whatsapp_api_structure_of_a_user_incoming_msg.json similarity index 98% rename from docs/structure_of_api_responses/meta_whatsapp_structure_of_a_user_incoming_msg.json rename to docs/structure_of_api_responses/meta_whatsapp_api_structure_of_a_user_incoming_msg.json index 7aac800..ddc5e66 100644 --- a/docs/structure_of_api_responses/meta_whatsapp_structure_of_a_user_incoming_msg.json +++ b/docs/structure_of_api_responses/meta_whatsapp_api_structure_of_a_user_incoming_msg.json @@ -9,7 +9,7 @@ "messaging_product": "whatsapp", "metadata": { "display_phone_number": "< 15555555555)>>", - "phone_number_id": "<>" + "phone_number_id": "<>" }, "contacts": [ { diff --git a/requirements.txt b/requirements.txt index 0786598..98417dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ jinja2 # should be >= 2 in order for "from langfuse.decorators" to work langfuse>=2.0.0 litellm +loguru openai pandas psycopg2-binary diff --git a/setup_database.py b/setup_database.py index 87695a8..ecf7545 100644 --- a/setup_database.py +++ b/setup_database.py @@ -5,7 +5,7 @@ from ansari.ansari_logger import get_logger from ansari.config import get_settings -logger = get_logger(__name__) +logger = get_logger() def import_sql_files(directory, db_url): diff --git a/sql/01_create_tables.sql b/sql/01_create_tables.sql index 0860dcb..574d4cf 100644 --- a/sql/01_create_tables.sql +++ b/sql/01_create_tables.sql @@ -1,15 +1,16 @@ CREATE TABLE users ( id SERIAL PRIMARY KEY, - email VARCHAR(100) UNIQUE, -- can be null if it is a guest account - password_hash VARCHAR(255), -- can be null if it is a guest account - first_name VARCHAR(50), -- can be null if it is a guest account - last_name VARCHAR(50), -- can be null if it is a guest account + email VARCHAR(100) UNIQUE, -- Can be null if it is a guest account + password_hash VARCHAR(255), -- Can be null if it is a guest account + first_name VARCHAR(50), -- Can be null if it is a guest account + last_name VARCHAR(50), -- Can be null if it is a guest account preferred_language VARCHAR(10) DEFAULT 'en', is_guest BOOLEAN NOT NULL DEFAULT FALSE, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); + CREATE TABLE preferences ( id SERIAL PRIMARY KEY, user_id INTEGER NOT NULL, @@ -35,7 +36,7 @@ CREATE TABLE messages ( user_id INTEGER NOT NULL, thread_id INTEGER NOT NULL, role TEXT NOT NULL, - -- TODO (odyash): check if "function" can be renamed to "tool" like the rest of the codebase or not + -- #TODO (odyash): check if "function" can be renamed to "tool" like the rest of the codebase or not function_name TEXT, content TEXT NOT NULL, timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, @@ -44,5 +45,3 @@ CREATE TABLE messages ( FOREIGN KEY (user_id) REFERENCES users(id), FOREIGN KEY (thread_id) REFERENCES threads(id) ); - - diff --git a/src/ansari/agents/ansari.py b/src/ansari/agents/ansari.py index 278d47b..4d315f6 100644 --- a/src/ansari/agents/ansari.py +++ b/src/ansari/agents/ansari.py @@ -14,7 +14,8 @@ from ansari.tools.search_vectara import SearchVectara from ansari.util.prompt_mgr import PromptMgr -logger = get_logger(__name__ + ".Ansari") +# previous logger name: __name__ + ".Ansari" +logger = get_logger() class Ansari: diff --git a/src/ansari/agents/ansari_workflow.py b/src/ansari/agents/ansari_workflow.py index 9b36133..755660b 100644 --- a/src/ansari/agents/ansari_workflow.py +++ b/src/ansari/agents/ansari_workflow.py @@ -17,7 +17,8 @@ else: logging_level = logging.INFO -logger = get_logger(__name__ + ".AnsariWorkflow", logging_level) +# previous logger name: __name__ + ".AnsariWorkflow" +logger = get_logger(logging_level) class AnsariWorkflow: diff --git a/src/ansari/ansari_db.py b/src/ansari/ansari_db.py index e852f8c..de7b12d 100644 --- a/src/ansari/ansari_db.py +++ b/src/ansari/ansari_db.py @@ -1,7 +1,10 @@ +import inspect import json import logging +import re from contextlib import contextmanager from datetime import datetime, timedelta, timezone +from typing import Literal, Optional, Union import bcrypt import jwt @@ -13,7 +16,7 @@ from ansari.ansari_logger import get_logger from ansari.config import Settings, get_settings -logger = get_logger(__name__) +logger = get_logger() class MessageLogger: @@ -109,13 +112,89 @@ def _get_token_from_request(self, request: Request) -> str: detail="Authorization header is malformed", ) + def _execute_query( + self, + query: Union[str, list[str]], + params: Union[tuple, list[tuple]], + which_fetch: Union[Literal["one", "all"], list[Literal["one", "all"]]] = "", + commit_after: Literal["each", "all"] = "each", + ) -> list[Optional[any]]: + """ + Executes one or more SQL queries with the provided parameters and fetch types. + + Args: + query (Union[str, List[str]]): A single SQL query string or a list of SQL query strings. + params (Union[tuple, List[tuple]]): A single tuple of parameters or a list of tuples of parameters. + which_fetch (Union[Literal["one", "all"], List[Literal["one", "all"]]]): + A single fetch type or a list of fetch types. Each fetch type can be: + - "one": Fetch one row. + - "all": Fetch all rows. + - Any other value: Do not fetch any rows. + commit_after (Literal["each", "all"]): Whether to commit the transaction after each query is executed, + or only after all of them are executed. + + Returns: + List[Optional[Any]]: + - When single or multiple queries are executed: + - Returns a list of results, where each "result" is: + - A single result if which_fetch is "one". + - A list of results if which_fetch is "all". + - Else, returns None. + + Note: the word "result" means a row in the DB, + which could be a tuple if more than 1 column is selected in the query. + + Raises: + ValueError: If an invalid fetch type is provided. + """ + # If query is a single string, we assume that params and which_fetch are also non-list values + if isinstance(query, str): + query = [query] + params = [params] + which_fetch = [which_fetch] + # else, we assume that params and which_fetch are lists of the same length + # and do a list-conversion just in case they are strings + else: + if isinstance(params, str): + params = [params] * len(query) + if isinstance(which_fetch, str): + which_fetch = [which_fetch] * len(query) + + caller_function_name = inspect.stack()[1].function + logger.debug(f"Function {caller_function_name}() \nis running queries: \n{query} \nwith params: \n{params}") + + results = [] + with self.get_connection() as conn: + with conn.cursor() as cur: + for q, p, wf in zip(query, params, which_fetch): + cur.execute(q, p) + result = None + if wf.lower() == "one": + result = cur.fetchone() + elif wf.lower() == "all": + result = cur.fetchall() + + # Remove possible SQL comments at the start of the q variable + q = re.sub(r"^\s*--.*\n|^\s*---.*\n", "", q, flags=re.MULTILINE) + + if not q.strip().lower().startswith("select") and commit_after.lower() == "each": + conn.commit() + + results.append(result) + + if commit_after.lower() == "all": + conn.commit() + + # Return a list when 1 or more queries are executed \ + # (or a list of a single None if it was a non-fetch query) + return results + def _validate_token_in_db(self, user_id: str, token: str, table: str) -> bool: try: - with self.get_connection() as conn: - with conn.cursor() as cur: - select_cmd = f"SELECT user_id FROM {table} WHERE user_id = %s AND token = %s;" - cur.execute(select_cmd, (user_id, token)) - return cur.fetchone() is not None + select_cmd = f"SELECT user_id FROM {table} WHERE user_id = %s AND token = %s;" + # Note: the "[0]" is added here because `select_cmd` is not a list + result = self._execute_query(select_cmd, (user_id, token), "one")[0] + return result is not None except Exception: logger.exception("Database error during token validation") raise HTTPException(status_code=500, detail="Internal server error") @@ -160,178 +239,134 @@ def validate_reset_token(self, token: str) -> dict[str, str]: def register(self, email, first_name, last_name, password_hash): try: - with self.get_connection() as conn: - with conn.cursor() as cur: - insert_cmd = """INSERT INTO users (email, password_hash, first_name, last_name) values (%s, %s, %s, %s);""" - cur.execute( - insert_cmd, - (email, password_hash, first_name, last_name), - ) - conn.commit() - return {"status": "success"} + insert_cmd = """INSERT INTO users (email, password_hash, first_name, last_name) values (%s, %s, %s, %s);""" + self._execute_query(insert_cmd, (email, password_hash, first_name, last_name)) + return {"status": "success"} except Exception as e: logger.warning(f"Error is {e}") return {"status": "failure", "error": str(e)} def account_exists(self, email): try: - with self.get_connection() as conn: - with conn.cursor() as cur: - select_cmd = """SELECT id FROM users WHERE email = %s;""" - cur.execute(select_cmd, (email,)) - result = cur.fetchone() - return result is not None + select_cmd = """SELECT id FROM users WHERE email = %s;""" + result = self._execute_query(select_cmd, (email,), "one")[0] + return result is not None except Exception as e: logger.warning(f"Error is {e}") return False def save_access_token(self, user_id, token): try: - with self.get_connection() as conn: - with conn.cursor() as cur: - insert_cmd = "INSERT INTO access_tokens (user_id, token) " + "VALUES (%s, %s) RETURNING id;" - cur.execute(insert_cmd, (user_id, token)) - inserted_id = cur.fetchone()[0] - conn.commit() - return { - "status": "success", - "token": token, - "token_db_id": inserted_id, - } + insert_cmd = "INSERT INTO access_tokens (user_id, token) VALUES (%s, %s) RETURNING id;" + result = self._execute_query(insert_cmd, (user_id, token), "one")[0] + inserted_id = result[0] if result else None + return { + "status": "success", + "token": token, + "token_db_id": inserted_id, + } except Exception as e: logger.warning(f"Error is {e}") return {"status": "failure", "error": str(e)} def save_refresh_token(self, user_id, token, access_token_id): try: - with self.get_connection() as conn: - with conn.cursor() as cur: - insert_cmd = "INSERT INTO refresh_tokens (user_id, token, access_token_id) " + "VALUES (%s, %s, %s);" - cur.execute(insert_cmd, (user_id, token, access_token_id)) - conn.commit() - return {"status": "success", "token": token} + insert_cmd = "INSERT INTO refresh_tokens (user_id, token, access_token_id) VALUES (%s, %s, %s);" + self._execute_query(insert_cmd, (user_id, token, access_token_id)) + return {"status": "success", "token": token} except Exception as e: logger.warning(f"Error is {e}") return {"status": "failure", "error": str(e)} def save_reset_token(self, user_id, token): try: - with self.get_connection() as conn: - with conn.cursor() as cur: - insert_cmd = ( - "INSERT INTO reset_tokens (user_id, token) " - + "VALUES (%s, %s) ON CONFLICT (user_id) DO UPDATE SET token = %s;" - ) - cur.execute(insert_cmd, (user_id, token, token)) - conn.commit() - return {"status": "success", "token": token} + insert_cmd = ( + "INSERT INTO reset_tokens (user_id, token) " + + "VALUES (%s, %s) ON CONFLICT (user_id) DO UPDATE SET token = %s;" + ) + self._execute_query(insert_cmd, (user_id, token, token)) + return {"status": "success", "token": token} except Exception as e: logger.warning(f"Error is {e}") return {"status": "failure", "error": str(e)} def retrieve_user_info(self, email): try: - with self.get_connection() as conn: - with conn.cursor() as cur: - select_cmd = "SELECT id, password_hash, first_name, last_name FROM users WHERE email = %s;" - cur.execute(select_cmd, (email,)) - result = cur.fetchone() - user_id = result[0] - existing_hash = result[1] - first_name = result[2] - last_name = result[3] - return user_id, existing_hash, first_name, last_name + select_cmd = "SELECT id, password_hash, first_name, last_name FROM users WHERE email = %s;" + result = self._execute_query(select_cmd, (email,), "one")[0] + if result: + user_id, existing_hash, first_name, last_name = result + return user_id, existing_hash, first_name, last_name + return None, None, None, None except Exception as e: logger.warning(f"Error is {e}") return None, None, None, None def add_feedback(self, user_id, thread_id, message_id, feedback_class, comment): try: - with self.get_connection() as conn: - with conn.cursor() as cur: - insert_cmd = ( - "INSERT INTO feedback (user_id, thread_id, message_id, class, comment)" - + " VALUES (%s, %s, %s, %s, %s);" - ) - cur.execute( - insert_cmd, - (user_id, thread_id, message_id, feedback_class, comment), - ) - conn.commit() - return {"status": "success"} + insert_cmd = ( + "INSERT INTO feedback (user_id, thread_id, message_id, class, comment)" + " VALUES (%s, %s, %s, %s, %s);" + ) + self._execute_query(insert_cmd, (user_id, thread_id, message_id, feedback_class, comment)) + return {"status": "success"} except Exception as e: logger.warning(f"Error is {e}") return {"status": "failure", "error": str(e)} def create_thread(self, user_id): try: - with self.get_connection() as conn: - with conn.cursor() as cur: - insert_cmd = """INSERT INTO threads (user_id) values (%s) RETURNING id;""" - cur.execute(insert_cmd, (user_id,)) - inserted_id = cur.fetchone()[0] - conn.commit() - return {"status": "success", "thread_id": inserted_id} - + insert_cmd = """INSERT INTO threads (user_id) values (%s) RETURNING id;""" + result = self._execute_query(insert_cmd, (user_id,), "one")[0] + inserted_id = result[0] if result else None + return {"status": "success", "thread_id": inserted_id} except Exception as e: logger.warning(f"Error is {e}") return {"status": "failure", "error": str(e)} def get_all_threads(self, user_id): try: - with self.get_connection() as conn: - with conn.cursor() as cur: - select_cmd = """SELECT id, name, updated_at FROM threads WHERE user_id = %s;""" - cur.execute(select_cmd, (user_id,)) - result = cur.fetchall() - return [{"thread_id": x[0], "thread_name": x[1], "updated_at": x[2]} for x in result] + select_cmd = """SELECT id, name, updated_at FROM threads WHERE user_id = %s;""" + result = self._execute_query(select_cmd, (user_id,), "all")[0] + return [{"thread_id": x[0], "thread_name": x[1], "updated_at": x[2]} for x in result] if result else [] except Exception as e: logger.warning(f"Error is {e}") return [] def set_thread_name(self, thread_id, user_id, thread_name): try: - with self.get_connection() as conn: - with conn.cursor() as cur: - insert_cmd = ( - "INSERT INTO threads (id, user_id, name) " - + "VALUES (%s, %s, %s) ON CONFLICT (id) DO UPDATE SET name = %s;" - ) - cur.execute( - insert_cmd, - ( - thread_id, - user_id, - thread_name[: get_settings().MAX_THREAD_NAME_LENGTH], - thread_name[: get_settings().MAX_THREAD_NAME_LENGTH], - ), - ) - conn.commit() - return {"status": "success"} + insert_cmd = ( + "INSERT INTO threads (id, user_id, name) " + "VALUES (%s, %s, %s) ON CONFLICT (id) DO UPDATE SET name = %s;" + ) + self._execute_query( + insert_cmd, + ( + thread_id, + user_id, + thread_name[: get_settings().MAX_THREAD_NAME_LENGTH], + thread_name[: get_settings().MAX_THREAD_NAME_LENGTH], + ), + ) + return {"status": "success"} except Exception as e: logger.warning(f"Error is {e}") return {"status": "failure", "error": str(e)} def append_message(self, user_id, thread_id, role, content, tool_name=None): try: - with self.get_connection() as conn: - with conn.cursor() as cur: - insert_cmd = ( - # TODO (odyash): check if "function" can be renamed to - # "tool" like the rest of the codebase or not - "INSERT INTO messages (thread_id, user_id, role, content, function_name) " - + "VALUES (%s, %s, %s, %s, %s);" - ) - cur.execute( - insert_cmd, - (thread_id, user_id, role, content, tool_name), - ) - # Appending a message should update the thread's updated_at field. - update_cmd = "UPDATE threads SET updated_at = now() " - "WHERE id = %s AND user_id = %s;" - cur.execute(update_cmd, (thread_id, user_id)) - conn.commit() - return {"status": "success"} + # TODO (odyash): check if "function" can be renamed to + # "tool" like the rest of the codebase or not + insert_cmd = ( + "INSERT INTO messages (thread_id, user_id, role, content, function_name) " + "VALUES (%s, %s, %s, %s, %s);" + ) + params_1 = (thread_id, user_id, role, content, tool_name) + + # Appending a message should update the thread's updated_at field. + update_cmd = "UPDATE threads SET updated_at = now() WHERE id = %s AND user_id = %s;" + params_2 = (thread_id, user_id) + + self._execute_query([insert_cmd, update_cmd], [params_1, params_2], commit_after="all") + + return {"status": "success"} except Exception as e: logger.warning(f"Error is {e}") return {"status": "failure", "error": str(e)} @@ -342,34 +377,27 @@ def get_thread(self, thread_id, user_id): tool messages are not included. """ try: - with self.get_connection() as conn: - with conn.cursor() as cur: - select_cmd = ( - "SELECT id, role, content FROM messages " - + "WHERE thread_id = %s AND user_id = %s ORDER BY updated_at;" - ) - cur.execute(select_cmd, (thread_id, user_id)) - result = cur.fetchall() - select_cmd = "SELECT name FROM threads WHERE id = %s AND user_id = %s;" - cur.execute(select_cmd, (thread_id, user_id)) - if cur.rowcount == 0: - raise HTTPException( - status_code=401, - detail="Incorrect user_id or thread_id.", - ) - thread_name = cur.fetchone()[0] - retval = { - "thread_name": thread_name, - "messages": [ - self.convert_message(x) - for x in result - if x[1] - # TODO (odyash): check if "function" can be renamed to "tool" - # like the rest of the codebase or not - != "function" - ], - } - return retval + select_cmd_1 = ( + "SELECT id, role, content FROM messages " + "WHERE thread_id = %s AND user_id = %s ORDER BY updated_at;" + ) + select_cmd_2 = "SELECT name FROM threads WHERE id = %s AND user_id = %s;" + params = (thread_id, user_id) + + # Note: we don't add "[0]" here since the first arg. below is a list + result, thread_name_result = self._execute_query([select_cmd_1, select_cmd_2], [params, params], ["all", "one"]) + + if not thread_name_result: + raise HTTPException( + status_code=401, + detail="Incorrect user_id or thread_id.", + ) + thread_name = thread_name_result[0] + # TODO (odyash): check if "function" can be renamed to "tool" + retval = { + "thread_name": thread_name, + "messages": [self.convert_message(x) for x in result if x[1] != "function"], + } + return retval except Exception as e: logger.warning(f"Error is {e}") return {} @@ -380,29 +408,28 @@ def get_thread_llm(self, thread_id, user_id): """ try: # We need to check user_id to make sure that the user has access to the thread. - with self.get_connection() as conn: - with conn.cursor() as cur: - select_cmd = ( - # TODO (odyash): check if "function" can be renamed to "tool" like the rest of the codebase or not - "SELECT role, content, function_name FROM messages " - + "WHERE thread_id = %s AND user_id = %s ORDER BY timestamp;" - ) - cur.execute(select_cmd, (thread_id, user_id)) - result = cur.fetchall() - select_cmd = """SELECT name FROM threads WHERE id = %s AND user_id = %s;""" - cur.execute(select_cmd, (thread_id, user_id)) - if cur.rowcount == 0: - raise HTTPException( - status_code=401, - detail="Incorrect user_id or thread_id.", - ) - thread_name = cur.fetchone()[0] - # Now convert into the standard format - retval = { - "thread_name": thread_name, - "messages": [self.convert_message_llm(x) for x in result], - } - return retval + # TODO (odyash): check if "function" can be renamed to "tool" like the rest of the codebase or not + select_cmd_1 = ( + "SELECT role, content, function_name FROM messages " + + "WHERE thread_id = %s AND user_id = %s ORDER BY timestamp;" + ) + select_cmd_2 = """SELECT name FROM threads WHERE id = %s AND user_id = %s;""" + params = (thread_id, user_id) + + result, thread_name_result = self._execute_query([select_cmd_1, select_cmd_2], [params, params], ["all", "one"]) + + if not thread_name_result: + raise HTTPException( + status_code=401, + detail="Incorrect user_id or thread_id.", + ) + thread_name = thread_name_result[0] + # Now convert into the standard format + retval = { + "thread_name": thread_name, + "messages": [self.convert_message_llm(x) for x in result], + } + return retval except Exception as e: logger.warning(f"Error is {e}") return {} @@ -415,18 +442,13 @@ def snapshot_thread(self, thread_id, user_id): try: # First we retrieve the thread. thread = self.get_thread(thread_id, user_id) - logger.info(f"!!!!!! !!!! Thread is {json.dumps(thread)}") + logger.info(f"Thread is {json.dumps(thread)}") # Now we create a new thread - with self.get_connection() as conn: - with conn.cursor() as cur: - insert_cmd = """INSERT INTO share (content) values (%s) RETURNING id;""" - thread_as_json = json.dumps(thread) - cur.execute(insert_cmd, (thread_as_json,)) - result = cur.fetchone()[0] - logger.info(f"Result is {result}") - """Commit Changes to Database""" - conn.commit() - return result + insert_cmd = """INSERT INTO share (content) values (%s) RETURNING id;""" + thread_as_json = json.dumps(thread) + result = self._execute_query(insert_cmd, (thread_as_json,), "one")[0] + logger.info(f"Result is {result}") + return result[0] if result else None except Exception as e: logger.warning(f"Error is {e}") return {"status": "failure", "error": str(e)} @@ -434,14 +456,12 @@ def snapshot_thread(self, thread_id, user_id): def get_snapshot(self, share_uuid): """Retrieve a snapshot of a thread.""" try: - with self.get_connection() as conn: - with conn.cursor() as cur: - select_cmd = """SELECT content FROM share WHERE id = %s;""" - cur.execute(select_cmd, (share_uuid,)) - result = cur.fetchone()[0] - """Deserialize json string""" - result = json.loads(result) - return result + select_cmd = """SELECT content FROM share WHERE id = %s;""" + result = self._execute_query(select_cmd, (share_uuid,), "one")[0] + if result: + # Deserialize json string + return json.loads(result[0]) + return {} except Exception as e: logger.warning(f"Error is {e}") return {} @@ -450,15 +470,11 @@ def delete_thread(self, thread_id, user_id): try: # We need to ensure that the user_id has access to the thread. # We must delete the messages associated with the thread first. - with self.get_connection() as conn: - with conn.cursor() as cur: - delete_cmd = """DELETE FROM messages WHERE thread_id = %s and user_id = %s;""" - cur.execute(delete_cmd, (thread_id, user_id)) - conn.commit() - delete_cmd = """DELETE FROM threads WHERE id = %s AND user_id = %s;""" - cur.execute(delete_cmd, (thread_id, user_id)) - conn.commit() - return {"status": "success"} + delete_cmd_1 = """DELETE FROM messages WHERE thread_id = %s and user_id = %s;""" + delete_cmd_2 = """DELETE FROM threads WHERE id = %s AND user_id = %s;""" + params = (thread_id, user_id) + self._execute_query([delete_cmd_1, delete_cmd_2], [params, params]) + return {"status": "success"} except Exception as e: logger.warning(f"Error is {e}") return {"status": "failure", "error": str(e)} @@ -476,83 +492,64 @@ def delete_access_refresh_tokens_pair(self, refresh_token): """ try: - with self.get_connection() as conn: - with conn.cursor() as cur: - # Retrieve the associated access_token_id - select_cmd = """SELECT access_token_id FROM refresh_tokens WHERE token = %s;""" - cur.execute(select_cmd, (refresh_token,)) - result = cur.fetchone() - if result is None: - raise HTTPException( - status_code=401, - detail="Couldn't find refresh_token in the database.", - ) - access_token_id = result[0] - - # Delete the access token; the refresh token will auto-delete via its foreign key constraint. - delete_cmd = """DELETE FROM access_tokens WHERE id = %s;""" - cur.execute(delete_cmd, (access_token_id,)) - conn.commit() - return {"status": "success"} + # Retrieve the associated access_token_id + select_cmd = """SELECT access_token_id FROM refresh_tokens WHERE token = %s;""" + result = self._execute_query(select_cmd, (refresh_token,), "one")[0] + if result is None: + raise HTTPException( + status_code=401, + detail="Couldn't find refresh_token in the database.", + ) + access_token_id = result[0] + + # Delete the access token; the refresh token will auto-delete via its foreign key constraint. + delete_cmd = """DELETE FROM access_tokens WHERE id = %s;""" + self._execute_query(delete_cmd, (access_token_id,)) + return {"status": "success"} except psycopg2.Error as e: logging.critical(f"Error: {e}") raise HTTPException(status_code=500, detail="Database error") def delete_access_token(self, user_id, token): try: - with self.get_connection() as conn: - with conn.cursor() as cur: - delete_cmd = """DELETE FROM access_tokens WHERE user_id = %s AND token = %s;""" - cur.execute(delete_cmd, (user_id, token)) - conn.commit() - return {"status": "success"} + delete_cmd = """DELETE FROM access_tokens WHERE user_id = %s AND token = %s;""" + self._execute_query(delete_cmd, (user_id, token)) + return {"status": "success"} except Exception as e: logger.warning(f"Error is {e}") return {"status": "failure", "error": str(e)} def logout(self, user_id, token): try: - with self.get_connection() as conn: - with conn.cursor() as cur: - for db_table in ["access_tokens", "refresh_tokens"]: - delete_cmd = f"""DELETE FROM {db_table} WHERE user_id = %s AND token = %s;""" - cur.execute(delete_cmd, (user_id, token)) - conn.commit() - return {"status": "success"} + for db_table in ["access_tokens", "refresh_tokens"]: + delete_cmd = f"""DELETE FROM {db_table} WHERE user_id = %s AND token = %s;""" + self._execute_query(delete_cmd, (user_id, token)) + return {"status": "success"} except Exception as e: logger.warning(f"Error is {e}") return {"status": "failure", "error": str(e)} def set_pref(self, user_id, key, value): - with self.get_connection() as conn: - with conn.cursor() as cur: - insert_cmd = ( - "INSERT INTO preferences (user_id, pref_key, pref_value) " - + "VALUES (%s, %s, %s) ON CONFLICT (user_id, pref_key) DO UPDATE SET pref_value = %s;" - ) - cur.execute(insert_cmd, (user_id, key, value, value)) - conn.commit() - return {"status": "success"} + insert_cmd = ( + "INSERT INTO preferences (user_id, pref_key, pref_value) " + + "VALUES (%s, %s, %s) ON CONFLICT (user_id, pref_key) DO UPDATE SET pref_value = %s;" + ) + self._execute_query(insert_cmd, (user_id, key, value, value)) + return {"status": "success"} def get_prefs(self, user_id): - with self.get_connection() as conn: - with conn.cursor() as cur: - select_cmd = """SELECT pref_key, pref_value FROM preferences WHERE user_id = %s;""" - cur.execute(select_cmd, (user_id,)) - result = cur.fetchall() - retval = {} - for x in result: - retval[x[0]] = x[1] - return retval + select_cmd = """SELECT pref_key, pref_value FROM preferences WHERE user_id = %s;""" + result = self._execute_query(select_cmd, (user_id,), "all")[0] + retval = {} + for x in result: + retval[x[0]] = x[1] + return retval def update_password(self, user_id, new_password_hash): try: - with self.get_connection() as conn: - with conn.cursor() as cur: - update_cmd = """UPDATE users SET password_hash = %s WHERE id = %s;""" - cur.execute(update_cmd, (new_password_hash, user_id)) - conn.commit() - return {"status": "success"} + update_cmd = """UPDATE users SET password_hash = %s WHERE id = %s;""" + self._execute_query(update_cmd, (new_password_hash, user_id)) + return {"status": "success"} except Exception as e: logger.warning(f"Error is {e}") return {"status": "failure", "error": str(e)} @@ -572,16 +569,11 @@ def store_quran_answer( question: str, ansari_answer: str, ): - with self.get_connection() as conn: - with conn.cursor() as cur: - cur.execute( - """ - INSERT INTO quran_answers (surah, ayah, question, ansari_answer, review_result, final_answer) - VALUES (%s, %s, %s, %s, 'pending', NULL) - """, - (surah, ayah, question, ansari_answer), - ) - conn.commit() + insert_cmd = """ + INSERT INTO quran_answers (surah, ayah, question, ansari_answer, review_result, final_answer) + VALUES (%s, %s, %s, %s, 'pending', NULL) + """ + self._execute_query(insert_cmd, (surah, ayah, question, ansari_answer)) def get_quran_answer( self, @@ -601,20 +593,17 @@ def get_quran_answer( """ try: - with self.get_connection() as conn: - with conn.cursor() as cur: - select_cmd = """ - SELECT ansari_answer - FROM quran_answers - WHERE surah = %s AND ayah = %s AND question = %s - ORDER BY created_at DESC, id DESC - LIMIT 1; - """ - cur.execute(select_cmd, (surah, ayah, question)) - result = cur.fetchone() - if result: - return result[0] - return None + select_cmd = """ + SELECT ansari_answer + FROM quran_answers + WHERE surah = %s AND ayah = %s AND question = %s + ORDER BY created_at DESC, id DESC + LIMIT 1; + """ + result = self._execute_query(select_cmd, (surah, ayah, question), "one")[0] + if result: + return result[0] + return None except Exception as e: logger.error(f"Error retrieving Quran answer: {e!s}") return None diff --git a/src/ansari/ansari_logger.py b/src/ansari/ansari_logger.py index 32b2153..7e6a957 100644 --- a/src/ansari/ansari_logger.py +++ b/src/ansari/ansari_logger.py @@ -1,34 +1,48 @@ -import logging +import copy +import os +import sys + +from loguru import logger +from loguru._logger import Logger from ansari.config import get_settings +# Using loguru for logging, check below resources for reasons/details: +# https://nikhilakki.in/loguru-logging-in-python-made-fun-and-easy#heading-why-use-loguru-over-the-std-logging-module +# https://loguru.readthedocs.io/en/stable/resources/migration.html +# https://loguru.readthedocs.io/en/stable/resources/recipes.html#creating-independent-loggers-with-separate-set-of-handlers + def get_logger( - caller_file_name: str, - logging_level=None, - debug_mode=None, -) -> logging.Logger: + logging_level: str = None, +) -> Logger: """Creates and returns a logger instance for the specified caller file. Args: caller_file_name (str): The name of the file requesting the logger. logging_level (Optional[str]): The logging level to be set for the logger. If None, it defaults to the LOGGING_LEVEL from settings. - debug_mode (Optional[bool]): If True, adds a console handler to the logger. - If None, it defaults to the DEBUG_MODE from settings. Returns: - logging.Logger: Configured logger instance. + logger: Configured logger instance. """ - logger = logging.getLogger(caller_file_name) if logging_level is None: logging_level = get_settings().LOGGING_LEVEL.upper() - logger.setLevel(logging_level) - if debug_mode is not False and get_settings().DEBUG_MODE: - console_handler = logging.StreamHandler() - console_handler.setLevel(logging_level) - logger.addHandler(console_handler) + log_format = ( + "{time:YYYY-MM-DD HH:mm:ss.SSSS} | " + + "{level} | " + + "{name}:{function}:{line} | " + + "{message}" + ) + + logger.remove() + cur_logger = copy.deepcopy(logger) + + # In colorize, If None, the choice is automatically made based on the sink being a tty or not. + cur_logger.add( + sys.stdout, level=logging_level, format=log_format, enqueue=True, colorize=os.getenv("GITHUB_ACTIONS", None) + ) - return logger + return cur_logger diff --git a/src/ansari/app/main_api.py b/src/ansari/app/main_api.py index 0ff7495..dd91a12 100644 --- a/src/ansari/app/main_api.py +++ b/src/ansari/app/main_api.py @@ -9,7 +9,6 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from jinja2 import Environment, FileSystemLoader -from jwt import PyJWTError from langfuse.decorators import langfuse_context, observe from pydantic import BaseModel from sendgrid import SendGridAPIClient @@ -22,8 +21,9 @@ from ansari.app.main_whatsapp import router as whatsapp_router from ansari.config import Settings, get_settings from ansari.presenters.api_presenter import ApiPresenter +from ansari.util.fastapi_helpers import validate_cors -logger = get_logger(__name__) +logger = get_logger() # Register the UUID type globally @@ -32,21 +32,37 @@ app = FastAPI() -def main(): - add_app_middleware() - - def add_app_middleware(): + origins = get_settings().ORIGINS + + # This if condition only runs in local development + if get_settings().DEBUG_MODE: + # Change "3000" to the port of your frontend server (3000 is the default there) + local_origin = "http://localhost:3000" + origins.append(local_origin) + zrok_origin = "https://" + get_settings().ZROK_SHARE_TOKEN.get_secret_value() + ".share.zrok.io" + origins.append(zrok_origin) + # If we don't do this, we'll get a "400 Bad Request" error when trying to access the API from the local frontend + logger.debug( + f"Added {local_origin} and zrok's origin to the list of allowed origins for debugging purposes " + + "(assuming local frontend's port is 3000)..." + ) + app.add_middleware( CORSMiddleware, - allow_origins=get_settings().ORIGINS, + allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) +def main(): + add_app_middleware() + + main() + db = AnsariDB(get_settings()) ansari = Ansari(get_settings()) @@ -60,33 +76,28 @@ def add_app_middleware(): if __name__ == "__main__" and get_settings().DEBUG_MODE: # Programatically start a Uvicorn server while debugging (development) for easier control/accessibility - # Note: if you instead run + # Note 1: if you instead run # uvicorn main_api:app --host YOUR_HOST --port YOUR_PORT # in the terminal, then this block will be ignored + # Note 2: you have to use zrok to test whatsapp's webhook locally, + # Check the resources at `.env.example` file for more details + # Run commands: + # `zrok enable SECRET_TOKEN_GENERATED_BY_ZROK_FOR_YOUR_DEVICE` (should be run only once) + # `zrok reserve public localhost:8000 -n ZROK_SHARE_TOKEN` (should be run only once) + # (if error occurs, contact odyash on GitHub) + # `zrok share reserved ZROK_SHARE_TOKEN` import uvicorn filename_without_extension = os.path.splitext(os.path.basename(__file__))[0] uvicorn.run( f"{filename_without_extension}:app", - host="127.0.0.1", + host="localhost", port=8000, reload=True, + log_level="debug", ) -def validate_cors(request: Request, settings: Settings = Depends(get_settings)) -> bool: - try: - logger.info(f"Raw request is {request.headers}") - origin = request.headers.get("origin", "") - mobile = request.headers.get("x-mobile-ansari", "") - if origin and origin in settings.ORIGINS or mobile == "ANSARI": - logger.debug("CORS OK") - return True - raise HTTPException(status_code=502, detail="Not Allowed Origin") - except PyJWTError: - raise HTTPException(status_code=403, detail="Could not validate credentials") - - class RegisterRequest(BaseModel): email: str password: str @@ -101,6 +112,9 @@ async def register_user(req: RegisterRequest, cors_ok: bool = Depends(validate_c Returns 200 on success. Returns 400 if the password is too weak. Will include suggestions for a stronger password. """ + if not cors_ok: + raise HTTPException(status_code=403, detail="CORS not permitted") + password_hash = db.hash_password(req.password) logger.info( f"Received request to create account: {req.email} {password_hash} {req.first_name} {req.last_name}", @@ -136,52 +150,56 @@ async def login_user( Returns a token on success. Returns 403 if the password is incorrect or the user doesn't exist. """ - if db.account_exists(req.email): - user_id, existing_hash, first_name, last_name = db.retrieve_user_info(req.email) - if db.check_password(req.password, existing_hash): - # Generate a token and return it - try: - access_token = db.generate_token( - user_id, - token_type="access", - expiry_hours=settings.ACCESS_TOKEN_EXPIRY_HOURS, - ) - refresh_token = db.generate_token( - user_id, - token_type="refresh", - expiry_hours=settings.REFRESH_TOKEN_EXPIRY_HOURS, - ) - access_token_insert_result = db.save_access_token(user_id, access_token) - if access_token_insert_result["status"] != "success": - raise HTTPException( - status_code=500, - detail="Couldn't save access token", - ) - refresh_token_insert_result = db.save_refresh_token( - user_id, - refresh_token, - access_token_insert_result["token_db_id"], - ) - if refresh_token_insert_result["status"] != "success": - raise HTTPException( - status_code=500, - detail="Couldn't save refresh token", - ) - return { - "status": "success", - "access_token": access_token, - "refresh_token": refresh_token, - "first_name": first_name, - "last_name": last_name, - } - except psycopg2.Error as e: - logger.critical(f"Error: {e}") - raise HTTPException(status_code=500, detail="Database error") - else: - raise HTTPException(status_code=403, detail="Invalid username or password") - else: + if not cors_ok: + raise HTTPException(status_code=403, detail="CORS not permitted") + + if not db.account_exists(req.email): raise HTTPException(status_code=403, detail="Invalid username or password") + user_id, existing_hash, first_name, last_name = db.retrieve_user_info(req.email) + + if not db.check_password(req.password, existing_hash): + raise HTTPException(status_code=403, detail="Invalid username or password") + + # Generate a token and return it + try: + access_token = db.generate_token( + user_id, + token_type="access", + expiry_hours=settings.ACCESS_TOKEN_EXPIRY_HOURS, + ) + refresh_token = db.generate_token( + user_id, + token_type="refresh", + expiry_hours=settings.REFRESH_TOKEN_EXPIRY_HOURS, + ) + access_token_insert_result = db.save_access_token(user_id, access_token) + if access_token_insert_result["status"] != "success": + raise HTTPException( + status_code=500, + detail="Couldn't save access token", + ) + refresh_token_insert_result = db.save_refresh_token( + user_id, + refresh_token, + access_token_insert_result["token_db_id"], + ) + if refresh_token_insert_result["status"] != "success": + raise HTTPException( + status_code=500, + detail="Couldn't save refresh token", + ) + return { + "status": "success", + "access_token": access_token, + "refresh_token": refresh_token, + "first_name": first_name, + "last_name": last_name, + } + except psycopg2.Error as e: + logger.critical(f"Error: {e}") + raise HTTPException(status_code=500, detail="Database error") + @app.post("/api/v2/users/refresh_token") async def refresh_token( @@ -201,69 +219,69 @@ async def refresh_token( - 500 if there is an internal server error during token generation or saving. """ - if cors_ok: - old_refresh_token = request.headers.get("Authorization", "").split(" ")[1] - token_params = db.decode_token(old_refresh_token) - - lock_key = f"lock:{token_params['user_id']}" - with Lock(cache, lock_key, expire=3): - # Check cache for existing token pair - cached_tokens = cache.get(old_refresh_token) - if cached_tokens: - return {"status": "success", **cached_tokens} - - # If no cached tokens, proceed to validate and generate new tokens - try: - # Validate the refresh token and delete the old token pair - db.delete_access_refresh_tokens_pair(old_refresh_token) - - # Generate new tokens - new_access_token = db.generate_token( - token_params["user_id"], - token_type="access", - expiry_hours=settings.ACCESS_TOKEN_EXPIRY_HOURS, - ) - new_refresh_token = db.generate_token( - token_params["user_id"], - token_type="refresh", - expiry_hours=settings.REFRESH_TOKEN_EXPIRY_HOURS, - ) + if not cors_ok: + raise HTTPException(status_code=403, detail="CORS not permitted") + + old_refresh_token = request.headers.get("Authorization", "").split(" ")[1] + token_params = db.decode_token(old_refresh_token) - # Save the new access token to the database - access_token_insert_result = db.save_access_token( - token_params["user_id"], - new_access_token, + lock_key = f"lock:{token_params['user_id']}" + with Lock(cache, lock_key, expire=3): + # Check cache for existing token pair + cached_tokens = cache.get(old_refresh_token) + if cached_tokens: + return {"status": "success", **cached_tokens} + + # If no cached tokens, proceed to validate and generate new tokens + try: + # Validate the refresh token and delete the old token pair + db.delete_access_refresh_tokens_pair(old_refresh_token) + + # Generate new tokens + new_access_token = db.generate_token( + token_params["user_id"], + token_type="access", + expiry_hours=settings.ACCESS_TOKEN_EXPIRY_HOURS, + ) + new_refresh_token = db.generate_token( + token_params["user_id"], + token_type="refresh", + expiry_hours=settings.REFRESH_TOKEN_EXPIRY_HOURS, + ) + + # Save the new access token to the database + access_token_insert_result = db.save_access_token( + token_params["user_id"], + new_access_token, + ) + if access_token_insert_result["status"] != "success": + raise HTTPException( + status_code=500, + detail="Couldn't save access token", ) - if access_token_insert_result["status"] != "success": - raise HTTPException( - status_code=500, - detail="Couldn't save access token", - ) - - # Save the new refresh token to the database - refresh_token_insert_result = db.save_refresh_token( - token_params["user_id"], - new_refresh_token, - access_token_insert_result["token_db_id"], + + # Save the new refresh token to the database + refresh_token_insert_result = db.save_refresh_token( + token_params["user_id"], + new_refresh_token, + access_token_insert_result["token_db_id"], + ) + if refresh_token_insert_result["status"] != "success": + raise HTTPException( + status_code=500, + detail="Couldn't save refresh token", ) - if refresh_token_insert_result["status"] != "success": - raise HTTPException( - status_code=500, - detail="Couldn't save refresh token", - ) - - # Cache the new tokens with a short expiry (3 seconds) - new_tokens = { - "access_token": new_access_token, - "refresh_token": new_refresh_token, - } - cache.set(old_refresh_token, new_tokens, expire=3) - return {"status": "success", **new_tokens} - except psycopg2.Error as e: - logger.critical(f"Error: {e}") - raise HTTPException(status_code=500, detail="Database error") - else: - raise HTTPException(status_code=403, detail="Invalid origins") + + # Cache the new tokens with a short expiry (3 seconds) + new_tokens = { + "access_token": new_access_token, + "refresh_token": new_refresh_token, + } + cache.set(old_refresh_token, new_tokens, expire=3) + return {"status": "success", **new_tokens} + except psycopg2.Error as e: + logger.critical(f"Error: {e}") + raise HTTPException(status_code=500, detail="Database error") @app.post("/api/v2/users/logout") @@ -276,15 +294,16 @@ async def logout_user( Deletes all tokens. Returns 403 if the password is incorrect or the user doesn't exist. """ - if cors_ok and token_params: - try: - token = request.headers.get("Authorization", "").split(" ")[1] - db.logout(token_params["user_id"], token) - return {"status": "success"} - except psycopg2.Error as e: - logger.critical(f"Error: {e}") - raise HTTPException(status_code=500, detail="Database error") - raise HTTPException(status_code=403, detail="Invalid username or password") + if not (cors_ok and token_params): + raise HTTPException(status_code=403, detail="Invalid username or password") + + try: + token = request.headers.get("Authorization", "").split(" ")[1] + db.logout(token_params["user_id"], token) + return {"status": "success"} + except psycopg2.Error as e: + logger.critical(f"Error: {e}") + raise HTTPException(status_code=500, detail="Database error") class FeedbackRequest(BaseModel): @@ -300,24 +319,24 @@ async def add_feedback( cors_ok: bool = Depends(validate_cors), token_params: dict = Depends(db.validate_token), ): - if cors_ok and token_params: - logger.info(f"Token_params is {token_params}") - # Now create a thread and return the thread_id - try: - db.add_feedback( - token_params["user_id"], - req.thread_id, - req.message_id, - req.feedback_class, - req.comment, - ) - return {"status": "success"} - except psycopg2.Error as e: - logger.critical(f"Error: {e}") - raise HTTPException(status_code=500, detail="Database error") - else: + if not (cors_ok and token_params): raise HTTPException(status_code=403, detail="CORS not permitted") + logger.info(f"Token_params is {token_params}") + # Now create a thread and return the thread_id + try: + db.add_feedback( + token_params["user_id"], + req.thread_id, + req.message_id, + req.feedback_class, + req.comment, + ) + return {"status": "success"} + except psycopg2.Error as e: + logger.critical(f"Error: {e}") + raise HTTPException(status_code=500, detail="Database error") + @app.post("/api/v2/threads") async def create_thread( @@ -325,19 +344,19 @@ async def create_thread( cors_ok: bool = Depends(validate_cors), token_params: dict = Depends(db.validate_token), ): - if cors_ok and token_params: - logger.info(f"Token_params is {token_params}") - # Now create a thread and return the thread_id - try: - thread_id = db.create_thread(token_params["user_id"]) - print(f"Created thread {thread_id}") - return thread_id - except psycopg2.Error as e: - logger.critical(f"Error: {e}") - raise HTTPException(status_code=500, detail="Database error") - else: + if not (cors_ok and token_params): raise HTTPException(status_code=403, detail="CORS not permitted") + logger.info(f"Token_params is {token_params}") + # Now create a thread and return the thread_id + try: + thread_id = db.create_thread(token_params["user_id"]) + print(f"Created thread {thread_id}") + return thread_id + except psycopg2.Error as e: + logger.critical(f"Error: {e}") + raise HTTPException(status_code=500, detail="Database error") + @app.get("/api/v2/threads") async def get_all_threads( @@ -346,18 +365,18 @@ async def get_all_threads( token_params: dict = Depends(db.validate_token), ): """Retrieve all threads for the user whose id is included in the token.""" - if cors_ok and token_params: - logger.info(f"Token_params is {token_params}") - # Now create a thread and return the thread_id - try: - threads = db.get_all_threads(token_params["user_id"]) - return threads - except psycopg2.Error as e: - logger.critical(f"Error: {e}") - raise HTTPException(status_code=500, detail="Database error") - else: + if not (cors_ok and token_params): raise HTTPException(status_code=403, detail="CORS not permitted") + logger.info(f"Token_params is {token_params}") + # Now create a thread and return the thread_id + try: + threads = db.get_all_threads(token_params["user_id"]) + return threads + except psycopg2.Error as e: + logger.critical(f"Error: {e}") + raise HTTPException(status_code=500, detail="Database error") + class AddMessageRequest(BaseModel): role: str @@ -376,43 +395,43 @@ def add_message( """Adds a message to a thread. If the message is the first message in the thread, we set the name of the thread to the content of the message. """ - if cors_ok and token_params: - logger.info(f"Token_params is {token_params}") + if not (cors_ok and token_params): + raise HTTPException(status_code=403, detail="CORS not permitted") - try: - db.append_message(token_params["user_id"], thread_id, req.role, req.content) - # Now actually use Ansari. - history = db.get_thread_llm(thread_id, token_params["user_id"]) - if history["thread_name"] is None and len(history["messages"]) > 1: - db.set_thread_name( - thread_id, - token_params["user_id"], - history["messages"][0]["content"], - ) - print(f"Added thread {thread_id}") - - langfuse_context.update_current_trace( - session_id=str(thread_id), - user_id=token_params["user_id"], - tags=["debug"], - metadata={ - "db_host": settings.DATABASE_URL.hosts()[0]["host"], - }, - ) - return presenter.complete( - history, - message_logger=MessageLogger( - db, - token_params["user_id"], - thread_id, - langfuse_context.get_current_trace_id(), - ), + logger.info(f"Token_params is {token_params}") + + try: + db.append_message(token_params["user_id"], thread_id, req.role, req.content) + # Now actually use Ansari. + history = db.get_thread_llm(thread_id, token_params["user_id"]) + if history["thread_name"] is None and len(history["messages"]) > 1: + db.set_thread_name( + thread_id, + token_params["user_id"], + history["messages"][0]["content"], ) - except psycopg2.Error as e: - logger.critical(f"Error: {e}") - raise HTTPException(status_code=500, detail="Database error") - else: - raise HTTPException(status_code=403, detail="CORS not permitted") + print(f"Added thread {thread_id}") + + langfuse_context.update_current_trace( + session_id=str(thread_id), + user_id=token_params["user_id"], + tags=["debug"], + metadata={ + "db_host": settings.DATABASE_URL.hosts()[0]["host"], + }, + ) + return presenter.complete( + history, + message_logger=MessageLogger( + db, + token_params["user_id"], + thread_id, + langfuse_context.get_current_trace_id(), + ), + ) + except psycopg2.Error as e: + logger.critical(f"Error: {e}") + raise HTTPException(status_code=500, detail="Database error") @app.post("/api/v2/share/{thread_id}") @@ -422,19 +441,19 @@ def share_thread( token_params: dict = Depends(db.validate_token), ): """Take a snapshot of a thread at this time and make it shareable.""" - if cors_ok and token_params: - logger.info(f"Token_params is {token_params}") - # TODO(mwk): check that the user_id in the token matches the - # user_id associated with the thread_id. - try: - share_uuid = db.snapshot_thread(thread_id, token_params["user_id"]) - return {"status": "success", "share_uuid": share_uuid} - except psycopg2.Error as e: - logger.critical(f"Error: {e}") - raise HTTPException(status_code=500, detail="Database error") - else: + if not (cors_ok and token_params): raise HTTPException(status_code=403, detail="CORS not permitted") + logger.info(f"Token_params is {token_params}") + # TODO(mwk): check that the user_id in the token matches the + # user_id associated with the thread_id. + try: + share_uuid = db.snapshot_thread(thread_id, token_params["user_id"]) + return {"status": "success", "share_uuid": share_uuid} + except psycopg2.Error as e: + logger.critical(f"Error: {e}") + raise HTTPException(status_code=500, detail="Database error") + @app.get("/api/v2/share/{share_uuid_str}") def get_snapshot( @@ -442,7 +461,9 @@ def get_snapshot( cors_ok: bool = Depends(validate_cors), ): """Take a snapshot of a thread at this time and make it shareable.""" - # Note that unlike the other endpoints, we don't need to check the token here. + if not cors_ok: + raise HTTPException(status_code=403, detail="CORS not permitted") + logger.info(f"Incoming share_uuid is {share_uuid_str}") share_uuid = uuid.UUID(share_uuid_str) try: @@ -459,21 +480,21 @@ async def get_thread( cors_ok: bool = Depends(validate_cors), token_params: dict = Depends(db.validate_token), ): - if cors_ok and token_params: - logger.info(f"Token_params is {token_params}") - # TODO(mwk): check that the user_id in the token matches the - # user_id associated with the thread_id. - try: - messages = db.get_thread(thread_id, token_params["user_id"]) - if messages: # return only if the thread exists. else raise 404 - return messages - raise HTTPException(status_code=404, detail="Thread not found") - except psycopg2.Error as e: - logger.critical(f"Error: {e}") - raise HTTPException(status_code=500, detail="Database error") - else: + if not (cors_ok and token_params): raise HTTPException(status_code=403, detail="CORS not permitted") + logger.info(f"Token_params is {token_params}") + # TODO(mwk): check that the user_id in the token matches the + # user_id associated with the thread_id. + try: + messages = db.get_thread(thread_id, token_params["user_id"]) + if messages: # return only if the thread exists. else raise 404 + return messages + raise HTTPException(status_code=404, detail="Thread not found") + except psycopg2.Error as e: + logger.critical(f"Error: {e}") + raise HTTPException(status_code=500, detail="Database error") + @app.delete("/api/v2/threads/{thread_id}") async def delete_thread( @@ -481,18 +502,18 @@ async def delete_thread( cors_ok: bool = Depends(validate_cors), token_params: dict = Depends(db.validate_token), ): - if cors_ok and token_params: - logger.info(f"Token_params is {token_params}") - # TODO(mwk): check that the user_id in the token matches the - # user_id associated with the thread_id. - try: - return db.delete_thread(thread_id, token_params["user_id"]) - except psycopg2.Error as e: - logger.critical(f"Error: {e}") - raise HTTPException(status_code=500, detail="Database error") - else: + if not (cors_ok and token_params): raise HTTPException(status_code=403, detail="CORS not permitted") + logger.info(f"Token_params is {token_params}") + # TODO(mwk): check that the user_id in the token matches the + # user_id associated with the thread_id. + try: + return db.delete_thread(thread_id, token_params["user_id"]) + except psycopg2.Error as e: + logger.critical(f"Error: {e}") + raise HTTPException(status_code=500, detail="Database error") + class ThreadNameRequest(BaseModel): name: str @@ -505,19 +526,19 @@ async def set_thread_name( cors_ok: bool = Depends(validate_cors), token_params: dict = Depends(db.validate_token), ): - if cors_ok and token_params: - logger.info(f"Token_params is {token_params}") - # TODO(mwk): check that the user_id in the token matches the - # user_id associated with the thread_id. - try: - messages = db.set_thread_name(thread_id, token_params["user_id"], req.name) - return messages - except psycopg2.Error as e: - logger.critical(f"Error: {e}") - raise HTTPException(status_code=500, detail="Database error") - else: + if not (cors_ok and token_params): raise HTTPException(status_code=403, detail="CORS not permitted") + logger.info(f"Token_params is {token_params}") + # TODO(mwk): check that the user_id in the token matches the + # user_id associated with the thread_id. + try: + messages = db.set_thread_name(thread_id, token_params["user_id"], req.name) + return messages + except psycopg2.Error as e: + logger.critical(f"Error: {e}") + raise HTTPException(status_code=500, detail="Database error") + class SetPrefRequest(BaseModel): key: str @@ -530,37 +551,35 @@ async def set_pref( cors_ok: bool = Depends(validate_cors), token_params: dict = Depends(db.validate_token), ): - if cors_ok and token_params: - logger.info(f"Token_params is {token_params}") - # Now create a thread and return the thread_id - try: - db.set_pref(token_params["user_id"], req.key, req.value) - - except psycopg2.Error as e: - logger.critical(f"Error: {e}") - raise HTTPException(status_code=500, detail="Database error") - else: + if not (cors_ok and token_params): raise HTTPException(status_code=403, detail="CORS not permitted") + logger.info(f"Token_params is {token_params}") + # Now create a thread and return the thread_id + try: + db.set_pref(token_params["user_id"], req.key, req.value) + except psycopg2.Error as e: + logger.critical(f"Error: {e}") + raise HTTPException(status_code=500, detail="Database error") + @app.get("/api/v2/preferences") async def get_prefs( cors_ok: bool = Depends(validate_cors), token_params: dict = Depends(db.validate_token), ): - if cors_ok and token_params: - logger.info(f"Token_params is {token_params}") - # Now create a thread and return the thread_id - try: - prefs = db.get_prefs(token_params["user_id"]) - return prefs - - except psycopg2.Error as e: - logger.critical(f"Error: {e}") - raise HTTPException(status_code=500, detail="Database error") - else: + if not (cors_ok and token_params): raise HTTPException(status_code=403, detail="CORS not permitted") + logger.info(f"Token_params is {token_params}") + # Now create a thread and return the thread_id + try: + prefs = db.get_prefs(token_params["user_id"]) + return prefs + except psycopg2.Error as e: + logger.critical(f"Error: {e}") + raise HTTPException(status_code=500, detail="Database error") + class ResetPasswordRequest(BaseModel): email: str @@ -572,40 +591,40 @@ async def request_password_reset( cors_ok: bool = Depends(validate_cors), settings: Settings = Depends(get_settings), ): - if cors_ok: - logger.info(f"Request received to reset {req.email}") - if db.account_exists(req.email): - user_id, _, _, _ = db.retrieve_user_info(req.email) - reset_token = db.generate_token(user_id, "reset") - db.save_reset_token(user_id, reset_token) - # shall we also revoke login and refresh tokens? - tenv = Environment(loader=FileSystemLoader(settings.template_dir)) - template = tenv.get_template("password_reset.html") - rendered_template = template.render(reset_token=reset_token) - message = Mail( - from_email="feedback@ansari.chat", - to_emails=f"{req.email}", - subject="Ansari Password Reset", - html_content=rendered_template, - ) + if not cors_ok: + raise HTTPException(status_code=403, detail="CORS not permitted") - try: - if settings.SENDGRID_API_KEY: - sg = SendGridAPIClient(settings.SENDGRID_API_KEY) - response = sg.send(message) - logger.debug(response.status_code) - logger.debug(response.body) - logger.debug(response.headers) - else: - logger.warning("No sendgrid key") - logger.info(f"Would have sent: {message}") - except Exception as e: - print(e.message) - # Even if the email doesn't exist, we return success. - # So this can't be used to work out who is on our system. - return {"status": "success"} + logger.info(f"Request received to reset {req.email}") + if db.account_exists(req.email): + user_id, _, _, _ = db.retrieve_user_info(req.email) + reset_token = db.generate_token(user_id, "reset") + db.save_reset_token(user_id, reset_token) + # shall we also revoke login and refresh tokens? + tenv = Environment(loader=FileSystemLoader(settings.template_dir)) + template = tenv.get_template("password_reset.html") + rendered_template = template.render(reset_token=reset_token) + message = Mail( + from_email="feedback@ansari.chat", + to_emails=f"{req.email}", + subject="Ansari Password Reset", + html_content=rendered_template, + ) - raise HTTPException(status_code=403, detail="CORS note permitted.") + try: + if settings.SENDGRID_API_KEY: + sg = SendGridAPIClient(settings.SENDGRID_API_KEY) + response = sg.send(message) + logger.debug(response.status_code) + logger.debug(response.body) + logger.debug(response.headers) + else: + logger.warning("No sendgrid key") + logger.info(f"Would have sent: {message}") + except Exception as e: + print(e.message) + # Even if the email doesn't exist, we return success. + # So this can't be used to work out who is on our system. + return {"status": "success"} @app.post("/api/v2/update_password") @@ -615,23 +634,23 @@ async def update_password( password: str = None, ): """Update the user's password if you have a valid token""" - if cors_ok and token_params: - logger.info(f"Token_params is {token_params}") - try: - password_hash = db.hash_password(password) - passwd_quality = zxcvbn(password) - if passwd_quality["score"] < 2: - raise HTTPException( - status_code=400, - detail="Password is too weak. Suggestions: " + ",".join(passwd_quality["feedback"]["suggestions"]), - ) - db.update_password(token_params["email"], password_hash) - except psycopg2.Error as e: - logger.critical(f"Error: {e}") - raise HTTPException(status_code=500, detail="Database error") - else: + if not (cors_ok and token_params): raise HTTPException(status_code=403, detail="Invalid username or password") + logger.info(f"Token_params is {token_params}") + try: + password_hash = db.hash_password(password) + passwd_quality = zxcvbn(password) + if passwd_quality["score"] < 2: + raise HTTPException( + status_code=400, + detail="Password is too weak. Suggestions: " + ",".join(passwd_quality["feedback"]["suggestions"]), + ) + db.update_password(token_params["email"], password_hash) + except psycopg2.Error as e: + logger.critical(f"Error: {e}") + raise HTTPException(status_code=500, detail="Database error") + class PasswordReset(BaseModel): reset_token: str @@ -642,24 +661,24 @@ class PasswordReset(BaseModel): async def reset_password(req: PasswordReset, cors_ok: bool = Depends(validate_cors)): """Resets the user's password if you have a reset token.""" token_params = db.validate_reset_token(req.reset_token) - if cors_ok: - logger.info(f"Token_params is {token_params}") - try: - password_hash = db.hash_password(req.new_password) - passwd_quality = zxcvbn(req.new_password) - if passwd_quality["score"] < 2: - raise HTTPException( - status_code=400, - detail="Password is too weak. Suggestions: " + ",".join(passwd_quality["feedback"]["suggestions"]), - ) - db.update_password(token_params["user_id"], password_hash) - return {"status": "success"} - except psycopg2.Error as e: - logger.critical(f"Error: {e}") - raise HTTPException(status_code=500, detail="Database error") - else: + if not cors_ok: raise HTTPException(status_code=403, detail="Invalid username or password") + logger.info(f"Token_params is {token_params}") + try: + password_hash = db.hash_password(req.new_password) + passwd_quality = zxcvbn(req.new_password) + if passwd_quality["score"] < 2: + raise HTTPException( + status_code=400, + detail="Password is too weak. Suggestions: " + ",".join(passwd_quality["feedback"]["suggestions"]), + ) + db.update_password(token_params["user_id"], password_hash) + return {"status": "success"} + except psycopg2.Error as e: + logger.critical(f"Error: {e}") + raise HTTPException(status_code=500, detail="Database error") + @app.post("/api/v1/complete") async def complete(request: Request, cors_ok: bool = Depends(validate_cors)): @@ -672,12 +691,13 @@ async def complete(request: Request, cors_ok: bool = Depends(validate_cors)): It returns a stream of tokens (a token is a part of a word). """ - logger.info(f"Raw request is {request.headers}") - if cors_ok: - body = await request.json() - logger.info(f"Request received > {body}.") - return presenter.complete(body) - raise HTTPException(status_code=403, detail="CORS not permitted") + if not cors_ok: + raise HTTPException(status_code=403, detail="CORS not permitted") + + logger.debug(f"Raw request is {request.headers}") + body = await request.json() + logger.info(f"Request received > {body}.") + return presenter.complete(body) class AyahQuestionRequest(BaseModel): @@ -691,9 +711,13 @@ class AyahQuestionRequest(BaseModel): @app.post("/api/v2/ayah") async def answer_ayah_question( req: AyahQuestionRequest, + cors_ok: bool = Depends(validate_cors), settings: Settings = Depends(get_settings), db: AnsariDB = Depends(lambda: AnsariDB(get_settings())), ): + if not cors_ok: + raise HTTPException(status_code=403, detail="CORS not permitted") + if req.apikey != settings.QURAN_DOT_COM_API_KEY.get_secret_value(): raise HTTPException(status_code=401, detail="Unauthorized") diff --git a/src/ansari/app/main_whatsapp.py b/src/ansari/app/main_whatsapp.py index dbd634b..920259c 100644 --- a/src/ansari/app/main_whatsapp.py +++ b/src/ansari/app/main_whatsapp.py @@ -1,12 +1,13 @@ -from fastapi import APIRouter, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import HTMLResponse from ansari.agents import Ansari from ansari.ansari_logger import get_logger from ansari.config import get_settings from ansari.presenters.whatsapp_presenter import WhatsAppPresenter +from ansari.util.fastapi_helpers import validate_cors -logger = get_logger(__name__) +logger = get_logger() # Create a router in order to make the FastAPI functions here an extension of the main FastAPI app router = APIRouter() @@ -31,7 +32,7 @@ @router.get("/whatsapp/v1") -async def verification_webhook(request: Request) -> str | None: +async def verification_webhook(request: Request, cors_ok: bool = Depends(validate_cors)) -> str | None: """Handles the WhatsApp webhook verification request. Args: @@ -41,6 +42,9 @@ async def verification_webhook(request: Request) -> str | None: Optional[str]: The challenge string if verification is successful, otherwise raises an HTTPException. """ + if not cors_ok: + raise HTTPException(status_code=403, detail="CORS not permitted") + mode = request.query_params.get("hub.mode") verify_token = request.query_params.get("hub.verify_token") challenge = request.query_params.get("hub.challenge") @@ -57,7 +61,7 @@ async def verification_webhook(request: Request) -> str | None: @router.post("/whatsapp/v1") -async def main_webhook(request: Request) -> None: +async def main_webhook(request: Request, cors_ok: bool = Depends(validate_cors)) -> None: """Handles the incoming WhatsApp webhook message. Args: @@ -67,18 +71,23 @@ async def main_webhook(request: Request) -> None: None """ + if not cors_ok: + raise HTTPException(status_code=403, detail="CORS not permitted") + # Wait for the incoming webhook message to be received as JSON data = await request.json() + # # Logging the origin (host) of the incoming webhook message + # logger.debug(f"ORIGIN of the incoming webhook message: {json.dumps(request, indent=4)}") + # Terminate if incoming webhook message is empty/invalid/msg-status-update(sent,delivered,read) - result = await presenter.extract_relevant_whatsapp_message_details(data) - if isinstance(result, str): - if "error" in result.lower(): - presenter.send_whatsapp_message( - "There's a problem with the server. Kindly send again later...", - ) - return + try: + result = await presenter.extract_relevant_whatsapp_message_details(data) + except Exception: return + else: + if isinstance(result, str): + return # Get relevant info from Meta's API ( @@ -88,6 +97,7 @@ async def main_webhook(request: Request) -> None: incoming_msg_body, ) = result + # TODO (odyash, later): Add support for location type messages if incoming_msg_type != "text": msg_type = incoming_msg_type + "s" if not incoming_msg_type.endswith("s") else incoming_msg_type msg_type = msg_type.replace("unsupporteds", "this media type") @@ -97,15 +107,17 @@ async def main_webhook(request: Request) -> None: ) return + incoming_msg_text = incoming_msg_body["body"] + # Send acknowledgment message if get_settings().DEBUG_MODE: await presenter.send_whatsapp_message( from_whatsapp_number, - f"Ack: {incoming_msg_body}", + f"Ack: {incoming_msg_text}", ) # Actual code to process the incoming message using Ansari agent then reply to the sender await presenter.process_and_reply_to_whatsapp_sender( from_whatsapp_number, - incoming_msg_body, + incoming_msg_text, ) diff --git a/src/ansari/config.py b/src/ansari/config.py index 615cdc2..94ee070 100644 --- a/src/ansari/config.py +++ b/src/ansari/config.py @@ -1,4 +1,3 @@ -import logging from functools import lru_cache from pathlib import Path from typing import Literal @@ -7,7 +6,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict # Can't use get_logger() here due to circular import -logger = logging.getLogger(__name__) +# logger = get_logger() class Settings(BaseSettings): @@ -53,6 +52,7 @@ def get_resource_path(filename): ACCESS_TOKEN_EXPIRY_HOURS: int = Field(default=2) REFRESH_TOKEN_EXPIRY_HOURS: int = Field(default=24 * 90) + # Will later contain ZROK_SHARE_TOKEN and localhost origins ORIGINS: str | list[str] = Field( default=["https://ansari.chat", "http://ansari.chat"], ) @@ -108,8 +108,7 @@ def get_resource_path(filename): { "name": "query", "type": "string", - "description": "The topic to search for in Tafsir Ibn Kathir. " - "You will translate this query into English.", + "description": "The topic to search for in Tafsir Ibn Kathir. " "You will translate this query into English.", }, ], ) @@ -127,7 +126,7 @@ def get_resource_path(filename): WHATSAPP_TEST_BUSINESS_PHONE_NUMBER_ID: SecretStr | None = Field(default=None) WHATSAPP_ACCESS_TOKEN_FROM_SYS_USER: SecretStr | None = Field(default=None) WHATSAPP_VERIFY_TOKEN_FOR_WEBHOOK: SecretStr | None = Field(default=None) - + ZROK_SHARE_TOKEN: SecretStr = Field(default="") template_dir: DirectoryPath = Field(default=get_resource_path("templates")) diskcache_dir: str = Field(default="diskcache_dir") diff --git a/src/ansari/presenters/whatsapp_presenter.py b/src/ansari/presenters/whatsapp_presenter.py index 7d64e7a..4eea2a9 100644 --- a/src/ansari/presenters/whatsapp_presenter.py +++ b/src/ansari/presenters/whatsapp_presenter.py @@ -2,26 +2,10 @@ from typing import Any import httpx -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware from ansari.ansari_logger import get_logger -from ansari.config import get_settings -logger = get_logger(__name__) - - -# Initialize FastAPI app -app = FastAPI() - -# Add CORS middleware -app.add_middleware( - CORSMiddleware, - allow_origins=get_settings().ORIGINS, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +logger = get_logger() class WhatsAppPresenter: @@ -91,21 +75,31 @@ async def extract_relevant_whatsapp_message_details( and (entry := body.get("entry", [])) and (changes := entry[0].get("changes", [])) and (value := changes[0].get("value", {})) - and (messages := value.get("messages", [])) - and (incoming_msg := messages[0]) ): + error_msg = f"Invalid received payload from WhatsApp user and/or problem with Meta's API :\n{body}" logger.error( - f"Invalid received payload from WhatsApp user and/or problem with Meta's API :\n{body}", + error_msg, ) - return "error" + raise Exception(error_msg) + if "statuses" in value: status = value["statuses"]["status"] timestamp = value["statuses"]["timestamp"] + # This log isn't important if we don't want to track when an Ansari's replied message is + # delivered to or read by the recipient logger.debug( f"WhatsApp status update received:\n({status} at {timestamp}.)", ) return "status update" + elif "messages" not in value: + error_msg = f"Unsupported message type received from WhatsApp user:\n{body}" + logger.error( + error_msg, + ) + raise Exception(error_msg) + logger.info(f"Received payload from WhatsApp user:\n{body}") + incoming_msg = value["messages"][0] # Extract the business phone number ID from the webhook payload business_phone_number_id = value["metadata"]["phone_number_id"] diff --git a/src/ansari/util/fastapi_helpers.py b/src/ansari/util/fastapi_helpers.py new file mode 100644 index 0000000..390b4cc --- /dev/null +++ b/src/ansari/util/fastapi_helpers.py @@ -0,0 +1,26 @@ +from fastapi import Depends, HTTPException, Request +from jwt import PyJWTError + +from ansari.ansari_logger import get_logger +from ansari.config import Settings, get_settings + +logger = get_logger() + + +# Defined in a separate file to avoid circular imports between main_*.py files +def validate_cors(request: Request, settings: Settings = Depends(get_settings)) -> bool: + try: + logger.debug(f"Headers of raw request are: {request.headers}") + origins = get_settings().ORIGINS + incoming_origin = [ + request.headers.get("origin", ""), # If coming from ansari's frontend website + request.headers.get("host", ""), # If coming from Meta's WhatsApp API + ] + + mobile = request.headers.get("x-mobile-ansari", "") + if any(i_o in origins for i_o in incoming_origin) or mobile == "ANSARI": + logger.debug("CORS OK") + return True + raise HTTPException(status_code=502, detail="Not Allowed Origin") + except PyJWTError: + raise HTTPException(status_code=403, detail="Could not validate credentials") diff --git a/tests/test_answer_quality.py b/tests/test_answer_quality.py index ebf6150..794e90e 100644 --- a/tests/test_answer_quality.py +++ b/tests/test_answer_quality.py @@ -8,7 +8,7 @@ from ansari.ansari_logger import get_logger from ansari.config import get_settings -logger = get_logger(__name__) +logger = get_logger() @pytest.fixture(scope="module") diff --git a/tests/test_main_api.py b/tests/test_main_api.py index 4991f39..729444a 100644 --- a/tests/test_main_api.py +++ b/tests/test_main_api.py @@ -10,7 +10,7 @@ from ansari.app.main_api import app from ansari.config import get_settings -logger = get_logger(__name__) +logger = get_logger() client = TestClient(app)