Skip to content

Commit

Permalink
_execute_query & Loguru Updates based on PR
Browse files Browse the repository at this point in the history
Based on this PR:
#99
  • Loading branch information
OdyAsh committed Dec 19, 2024
1 parent 2f9eeb2 commit 268618d
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 43 deletions.
7 changes: 6 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,9 @@ export ZROK_SHARE_TOKEN="<<THE-ZROK-SHARE-TOKEN-CURRENTLY-USED-IN-META'S-CALLBAC
# Leave the values below when locally debugging the application
# In production, don't add them to environment variables, or add them as "INFO"/"False" respectively
export LOGGING_LEVEL="DEBUG"
export DEBUG_MODE="True"
export DEBUG_MODE="True"

# To get rid of .py[cod] files (This should key should NOT be set in production!)
# This is only to de-clutter your local development environment
# Details: https://docs.python-guide.org/writing/gotchas/#disabling-bytecode-pyc-files
PYTHONDONTWRITEBYTECODE=1
53 changes: 24 additions & 29 deletions src/ansari/ansari_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _execute_query(
params: Union[tuple, list[tuple]],
which_fetch: Union[Literal["one", "all"], list[Literal["one", "all"]]] = "",
commit_after: Literal["each", "all"] = "each",
) -> Union[tuple, Optional[list[tuple]], list[Optional[list[tuple]]]]:
) -> list[Optional[any]]:
"""
Executes one or more SQL queries with the provided parameters and fetch types.
Expand All @@ -134,32 +134,27 @@ def _execute_query(
or only after all of them are executed.
Returns:
Union[Optional[List], List[Optional[List]]]:
- If a single query is executed:
- Returns a single result if which_fetch is "one"
- Returns a single list of rows if which_fetch is "all".
- Else, returns None.
- If multiple queries are executed:
- Returns a list of results, where each result is:
List[Optional[Any]]:
- When single or multiple queries are executed:
- Returns a list of results, where each "result" is:
- A single result if which_fetch is "one".
- A list of rows if which_fetch is "all".
- A list of results if which_fetch is "all".
- Else, returns None.
Note: the "single result" here could be a tuple if more than 1 column is selected in the query.
Note: the word "result" means a row in the DB,
which could be a tuple if more than 1 column is selected in the query.
Raises:
ValueError: If an invalid fetch type is provided.
"""
# If query is a single string, we assume that params and which_fetch are also non-list values
if isinstance(query, str):
requested_single_query = True
query = [query]
params = [params]
which_fetch = [which_fetch]
# else, we assume that params and which_fetch are lists of the same length
# and do a list-conversion just in case they are strings
else:
requested_single_query = False
if isinstance(params, str):
params = [params] * len(query)
if isinstance(which_fetch, str):
Expand All @@ -180,7 +175,7 @@ def _execute_query(
result = cur.fetchall()

# Remove possible SQL comments at the start of the q variable
q = re.sub(r"^\s*--.*\n", "", q, flags=re.MULTILINE)
q = re.sub(r"^\s*--.*\n|^\s*---.*\n", "", q, flags=re.MULTILINE)

if not q.strip().lower().startswith("select") and commit_after.lower() == "each":
conn.commit()
Expand All @@ -190,16 +185,15 @@ def _execute_query(
if commit_after.lower() == "all":
conn.commit()

# If only 1 query was to be executed, return it (or None if it was a non-fetch query)
if requested_single_query:
return results[0]
# Else, multiple queries were executed, so return all results
# Return a list when 1 or more queries are executed \
# (or a list of a single None if it was a non-fetch query)
return results

def _validate_token_in_db(self, user_id: str, token: str, table: str) -> bool:
try:
select_cmd = f"SELECT user_id FROM {table} WHERE user_id = %s AND token = %s;"
result = self._execute_query(select_cmd, (user_id, token), "one")
# Note: the "[0]" is added here because `select_cmd` is not a list
result = self._execute_query(select_cmd, (user_id, token), "one")[0]
return result is not None
except Exception:
logger.exception("Database error during token validation")
Expand Down Expand Up @@ -252,10 +246,10 @@ def register(self, email, first_name, last_name, password_hash):
logger.warning(f"Error is {e}")
return {"status": "failure", "error": str(e)}

def account_exists(self, email: str, phone_num: str = ""):
def account_exists(self, email):
try:
select_cmd = """SELECT id FROM users WHERE email = %s;"""
result = self._execute_query(select_cmd, (email,), "one")
result = self._execute_query(select_cmd, (email,), "one")[0]
return result is not None
except Exception as e:
logger.warning(f"Error is {e}")
Expand All @@ -264,7 +258,7 @@ def account_exists(self, email: str, phone_num: str = ""):
def save_access_token(self, user_id, token):
try:
insert_cmd = "INSERT INTO access_tokens (user_id, token) VALUES (%s, %s) RETURNING id;"
result = self._execute_query(insert_cmd, (user_id, token), "one")
result = self._execute_query(insert_cmd, (user_id, token), "one")[0]
inserted_id = result[0] if result else None
return {
"status": "success",
Expand Down Expand Up @@ -299,7 +293,7 @@ def save_reset_token(self, user_id, token):
def retrieve_user_info(self, email):
try:
select_cmd = "SELECT id, password_hash, first_name, last_name FROM users WHERE email = %s;"
result = self._execute_query(select_cmd, (email,), "one")
result = self._execute_query(select_cmd, (email,), "one")[0]
if result:
user_id, existing_hash, first_name, last_name = result
return user_id, existing_hash, first_name, last_name
Expand All @@ -322,7 +316,7 @@ def add_feedback(self, user_id, thread_id, message_id, feedback_class, comment):
def create_thread(self, user_id):
try:
insert_cmd = """INSERT INTO threads (user_id) values (%s) RETURNING id;"""
result = self._execute_query(insert_cmd, (user_id,), "one")
result = self._execute_query(insert_cmd, (user_id,), "one")[0]
inserted_id = result[0] if result else None
return {"status": "success", "thread_id": inserted_id}
except Exception as e:
Expand All @@ -332,7 +326,7 @@ def create_thread(self, user_id):
def get_all_threads(self, user_id):
try:
select_cmd = """SELECT id, name, updated_at FROM threads WHERE user_id = %s;"""
result = self._execute_query(select_cmd, (user_id,), "all")
result = self._execute_query(select_cmd, (user_id,), "all")[0]
return [{"thread_id": x[0], "thread_name": x[1], "updated_at": x[2]} for x in result] if result else []
except Exception as e:
logger.warning(f"Error is {e}")
Expand Down Expand Up @@ -389,6 +383,7 @@ def get_thread(self, thread_id, user_id):
select_cmd_2 = "SELECT name FROM threads WHERE id = %s AND user_id = %s;"
params = (thread_id, user_id)

# Note: we don't add "[0]" here since the first arg. below is a list
result, thread_name_result = self._execute_query([select_cmd_1, select_cmd_2], [params, params], ["all", "one"])

if not thread_name_result:
Expand Down Expand Up @@ -451,7 +446,7 @@ def snapshot_thread(self, thread_id, user_id):
# Now we create a new thread
insert_cmd = """INSERT INTO share (content) values (%s) RETURNING id;"""
thread_as_json = json.dumps(thread)
result = self._execute_query(insert_cmd, (thread_as_json,), "one")
result = self._execute_query(insert_cmd, (thread_as_json,), "one")[0]
logger.info(f"Result is {result}")
return result[0] if result else None
except Exception as e:
Expand All @@ -462,7 +457,7 @@ def get_snapshot(self, share_uuid):
"""Retrieve a snapshot of a thread."""
try:
select_cmd = """SELECT content FROM share WHERE id = %s;"""
result = self._execute_query(select_cmd, (share_uuid,), "one")
result = self._execute_query(select_cmd, (share_uuid,), "one")[0]
if result:
# Deserialize json string
return json.loads(result[0])
Expand Down Expand Up @@ -499,7 +494,7 @@ def delete_access_refresh_tokens_pair(self, refresh_token):
try:
# Retrieve the associated access_token_id
select_cmd = """SELECT access_token_id FROM refresh_tokens WHERE token = %s;"""
result = self._execute_query(select_cmd, (refresh_token,), "one")
result = self._execute_query(select_cmd, (refresh_token,), "one")[0]
if result is None:
raise HTTPException(
status_code=401,
Expand Down Expand Up @@ -544,7 +539,7 @@ def set_pref(self, user_id, key, value):

def get_prefs(self, user_id):
select_cmd = """SELECT pref_key, pref_value FROM preferences WHERE user_id = %s;"""
result = self._execute_query(select_cmd, (user_id,), "all")
result = self._execute_query(select_cmd, (user_id,), "all")[0]
retval = {}
for x in result:
retval[x[0]] = x[1]
Expand Down Expand Up @@ -605,7 +600,7 @@ def get_quran_answer(
ORDER BY created_at DESC, id DESC
LIMIT 1;
"""
result = self._execute_query(select_cmd, (surah, ayah, question), "one")
result = self._execute_query(select_cmd, (surah, ayah, question), "one")[0]
if result:
return result[0]
return None
Expand Down
21 changes: 11 additions & 10 deletions src/ansari/ansari_logger.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import copy
import os
import sys

from loguru import logger
from loguru._logger import Logger

from ansari.config import get_settings


# Using loguru for logging, check below resources for reasons/details:
# https://nikhilakki.in/loguru-logging-in-python-made-fun-and-easy#heading-why-use-loguru-over-the-std-logging-module
# https://loguru.readthedocs.io/en/stable/resources/migration.html
# https://loguru.readthedocs.io/en/stable/resources/recipes.html#creating-independent-loggers-with-separate-set-of-handlers


def get_logger(
logging_level: str = None,
debug_mode: bool = None,
) -> Logger:
"""Creates and returns a logger instance for the specified caller file.
Args:
caller_file_name (str): The name of the file requesting the logger.
logging_level (Optional[str]): The logging level to be set for the logger.
If None, it defaults to the LOGGING_LEVEL from settings.
debug_mode (Optional[bool]): If True, adds a console handler to the logger.
If None, it defaults to the DEBUG_MODE from settings.
Returns:
logger: Configured logger instance.
Expand All @@ -37,11 +38,11 @@ def get_logger(
)

logger.remove()
logger.add(
sys.stdout,
level=logging_level,
format=log_format,
enqueue=True,
cur_logger = copy.deepcopy(logger)

# In colorize, If None, the choice is automatically made based on the sink being a tty or not.
cur_logger.add(
sys.stdout, level=logging_level, format=log_format, enqueue=True, colorize=os.getenv("GITHUB_ACTIONS", None)
)

return logger
return cur_logger
2 changes: 1 addition & 1 deletion src/ansari/app/main_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ async def complete(request: Request, cors_ok: bool = Depends(validate_cors)):
if not cors_ok:
raise HTTPException(status_code=403, detail="CORS not permitted")

logger.info(f"Raw request is {request.headers}")
logger.debug(f"Raw request is {request.headers}")
body = await request.json()
logger.info(f"Request received > {body}.")
return presenter.complete(body)
Expand Down
2 changes: 1 addition & 1 deletion src/ansari/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import DirectoryPath, Field, PostgresDsn, SecretStr, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

# # Can't use get_logger() here due to circular import
# Can't use get_logger() here due to circular import
# logger = get_logger()


Expand Down
2 changes: 1 addition & 1 deletion src/ansari/util/fastapi_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# Defined in a separate file to avoid circular imports between main_*.py files
def validate_cors(request: Request, settings: Settings = Depends(get_settings)) -> bool:
try:
logger.info(f"Headers of raw request are: {request.headers}")
logger.debug(f"Headers of raw request are: {request.headers}")
origins = get_settings().ORIGINS
incoming_origin = [
request.headers.get("origin", ""), # If coming from ansari's frontend website
Expand Down

0 comments on commit 268618d

Please sign in to comment.