Skip to content

Commit

Permalink
style(src): run ruff formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
abdullah-alnahas committed Nov 15, 2024
1 parent 42e142e commit db9de8e
Show file tree
Hide file tree
Showing 23 changed files with 245 additions and 265 deletions.
7 changes: 4 additions & 3 deletions data/mawsuah/strip_tashkeel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pathlib import Path, PurePath

import pyarabic.araby as araby
import textract
from pyarabic import araby
from tqdm.auto import tqdm

from ansari.ansari_logger import get_logger
Expand All @@ -22,14 +22,15 @@ def strip_tashkeel_from_doc(input_file, output_file):
path_components = list(input_dir.parts)
path_components[-1] = "txt"
output_dir = PurePath(
*path_components
*path_components,
) # --> "/path/to/The Kuwaiti Encyclopaedia of Islamic Jurisprudence/txt"

# iterate over all files in the directory
for input_file in tqdm(input_dir.glob("*.doc")):
if input_file.is_file() and input_file.suffix == ".doc":
logger.info(f"Processing {input_file.name}...")
strip_tashkeel_from_doc(
input_file, output_dir.joinpath(input_file.with_suffix(".txt").name)
input_file,
output_dir.joinpath(input_file.with_suffix(".txt").name),
)
logger.info(f"Done processing {input_file.name}")
2 changes: 1 addition & 1 deletion setup_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def import_sql_files(directory, db_url):
logger.info(f"Importing: {file_path}")

# Read the SQL file
with open(file_path, "r") as f:
with open(file_path) as f:
sql_query = f.read()
try:
# Execute the SQL query
Expand Down
3 changes: 2 additions & 1 deletion src/ansari/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# This file marks the directory as a Python package.
from .config import Settings, get_settings
import ansari_logger

__all__ = ["Settings", "get_settings"]
__all__ = ["Settings", "get_settings", "ansari_logger"]
54 changes: 19 additions & 35 deletions src/ansari/agents/ansari.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import time
import traceback
from datetime import date, datetime
from typing import Union

import litellm
from langfuse.decorators import langfuse_context, observe
Expand Down Expand Up @@ -41,9 +40,7 @@ def __init__(self, settings, message_logger=None, json_format=False):
self.model = settings.MODEL
self.pm = PromptMgr(src_dir=settings.PROMPT_PATH)
self.sys_msg = self.pm.bind(settings.SYSTEM_PROMPT_FILE_NAME).render()
self.tools = [
x.get_tool_description() for x in self.tool_name_to_instance.values()
]
self.tools = [x.get_tool_description() for x in self.tool_name_to_instance.values()]
self.message_history = [{"role": "system", "content": self.sys_msg}]

def set_message_logger(self, message_logger):
Expand Down Expand Up @@ -73,7 +70,7 @@ def log(self):
@observe()
def replace_message_history(self, message_history, use_tool=True, stream=True):
self.message_history = [
{"role": "system", "content": self.sys_msg}
{"role": "system", "content": self.sys_msg},
] + message_history
logger.info(f"Original trace is {self.message_logger.trace_id}")
logger.info(f"Id 1 is {langfuse_context.get_current_trace_id()}")
Expand Down Expand Up @@ -101,25 +98,19 @@ def process_message_history(self, use_tool=True, stream=True):
self.start_time = datetime.now()
count = 0
failures = 0
while (
self.message_history[-1]["role"] != "assistant"
or "tool_call_id" in self.message_history[-1]
):
while self.message_history[-1]["role"] != "assistant" or "tool_call_id" in self.message_history[-1]:
try:
logger.info(
f"Process attempt #{count+failures+1} of this message history:\n"
+ "-" * 60
+ f"\n{self.message_history}\n"
+ "-" * 60
+ "-" * 60,
)
# This is pretty complicated so leaving a comment.
# We want to yield from so that we can send the sequence through the input
# Also use tools only if we haven't tried too many times (failure) and if the last message was not from the tool (success!)
use_tool = (
use_tool
and (count < self.settings.MAX_TOOL_TRIES)
and self.message_history[-1]["role"] != "tool"
)
# Also use tools only if we haven't tried too many times (failure)
# and if the last message was not from the tool (success!)
use_tool = use_tool and (count < self.settings.MAX_TOOL_TRIES) and self.message_history[-1]["role"] != "tool"
if not use_tool:
status_msg = (
"Not using tools -- tries exceeded"
Expand All @@ -132,7 +123,7 @@ def process_message_history(self, use_tool=True, stream=True):
except Exception as e:
failures += 1
logger.warning(
f"Exception occurred in process_message_history: \n{e}\n"
f"Exception occurred in process_message_history: \n{e}\n",
)
logger.warning(traceback.format_exc())
logger.warning("Retrying in 5 seconds...")
Expand Down Expand Up @@ -165,21 +156,15 @@ def process_one_round(self, use_tool=True, stream=True):
try:
params = {
**common_params,
**(
{"tools": self.tools, "tool_choice": "auto"} if use_tool else {}
),
**(
{"response_format": {"type": "json_object"}}
if self.json_format
else {}
),
**({"tools": self.tools, "tool_choice": "auto"} if use_tool else {}),
**({"response_format": {"type": "json_object"}} if self.json_format else {}),
}
response = self.get_completion(**params)

except Exception as e:
failures += 1
logger.warning(
f"Exception occurred in process_one_round function: \n{e}\n"
f"Exception occurred in process_one_round function: \n{e}\n",
)
logger.warning(traceback.format_exc())
logger.warning("Retrying in 5 seconds...")
Expand Down Expand Up @@ -211,7 +196,7 @@ def process_one_round(self, use_tool=True, stream=True):
"id": "",
"type": "function",
"function": {"name": "", "arguments": ""},
}
},
)
tc = tool_calls[tcchunk.index]

Expand All @@ -225,7 +210,8 @@ def process_one_round(self, use_tool=True, stream=True):
if response_mode == "words":
self.message_history.append({"role": "assistant", "content": words})
langfuse_context.update_current_observation(
output=words, metadata={"delta": delta}
output=words,
metadata={"delta": delta},
)
if self.message_logger:
self.message_logger.log("assistant", words)
Expand All @@ -245,7 +231,7 @@ def process_one_round(self, use_tool=True, stream=True):
)
except json.JSONDecodeError:
logger.warning(
f"Failed to parse tool arguments: {tc['function']['arguments']}"
f"Failed to parse tool arguments: {tc['function']['arguments']}",
)

else:
Expand All @@ -264,9 +250,7 @@ def process_tool_call(
return

query: str = tool_arguments["query"]
tool_instance: Union[SearchQuran, SearchHadith] = self.tool_name_to_instance[
tool_name
]
tool_instance: SearchQuran | SearchHadith = self.tool_name_to_instance[tool_name]
results = tool_instance.run_as_list(query)

# we have to first add this message before any tool response, as mentioned in this source:
Expand All @@ -276,9 +260,9 @@ def process_tool_call(
"role": "assistant",
"content": "",
"tool_calls": [
{"type": "function", "id": tool_id, "function": tool_definition}
{"type": "function", "id": tool_id, "function": tool_definition},
],
}
},
)

if len(results) == 0:
Expand All @@ -295,5 +279,5 @@ def process_tool_call(
# Now we have to pass the results back in
results_str = msg_prefix + "\nAnother relevant ayah:\n".join(results)
self.message_history.append(
{"role": "tool", "content": results_str, "tool_call_id": tool_id}
{"role": "tool", "content": results_str, "tool_call_id": tool_id},
)
19 changes: 10 additions & 9 deletions src/ansari/agents/ansari_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@


class AnsariWorkflow:
"""
AnsariWorkflow manages the execution of modular workflow steps for processing user queries.
"""AnsariWorkflow manages the execution of modular workflow steps for processing user queries.
This class provides a flexible framework for generating queries, performing searches,
and generating answers based on the results. It supports customizable workflow steps
Expand Down Expand Up @@ -50,6 +49,7 @@ class AnsariWorkflow:
("gen_answer", {"input": "How is zakat calculated?", "search_results_indices": [0]})
]
results = ansari.execute_workflow(workflow_steps)
"""

def __init__(self, settings, message_logger=None, json_format=False):
Expand Down Expand Up @@ -81,9 +81,7 @@ def __init__(self, settings, message_logger=None, json_format=False):
self.model = settings.MODEL
self.pm = PromptMgr()
self.sys_msg = self.pm.bind(settings.SYSTEM_PROMPT_FILE_NAME).render()
self.tools = [
x.get_tool_description() for x in self.tool_name_to_instance.values()
]
self.tools = [x.get_tool_description() for x in self.tool_name_to_instance.values()]
self.json_format = json_format
self.message_logger = message_logger

Expand All @@ -107,7 +105,8 @@ def _execute_search_step(self, step_params, prev_outputs):
tool = self.tool_name_to_instance[step_params["tool_name"]]
if "query" in step_params:
results = tool.run_as_string(
step_params["query"], metadata_filter=step_params.get("metadata_filter")
step_params["query"],
metadata_filter=step_params.get("metadata_filter"),
)
elif "query_from_prev_output_index" in step_params:
results = tool.run_as_string(
Expand All @@ -116,12 +115,14 @@ def _execute_search_step(self, step_params, prev_outputs):
)
else:
raise ValueError(
"search step must have either query or query_from_prev_output_index"
"search step must have either query or query_from_prev_output_index",
)
return results

def _execute_gen_query_step(self, step_params, prev_outputs):
prompt = f"""Generate 3-5 key terms or phrases for searching the input: '{step_params["input"]}' in the '{step_params["target_corpus"]}' corpus. These search terms should be:
prompt = f"""Generate 3-5 key terms or phrases for searching the input:
'{step_params["input"]}' in the '{step_params["target_corpus"]}' corpus.
These search terms should be:
- Relevant words/phrases that appear in or closely match content in '{step_params["target_corpus"]}'
- Usable for both keyword and semantic search
Expand All @@ -145,7 +146,7 @@ def _execute_gen_query_step(self, step_params, prev_outputs):
def _execute_gen_answer_step(self, step_params, prev_outputs):
if step_params.get("search_results_indices"):
search_results = "\n---\n".join(
[prev_outputs[i] for i in step_params["search_results_indices"]]
[prev_outputs[i] for i in step_params["search_results_indices"]],
)
prompt = f"""Using {search_results}, compose a response that:
1. Directly answers the query of the user
Expand Down
Loading

0 comments on commit db9de8e

Please sign in to comment.