-
Notifications
You must be signed in to change notification settings - Fork 5
/
task_fetching_unit.py
158 lines (136 loc) · 5.32 KB
/
task_fetching_unit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from imports import *
import globals_
def _get_observations(messages: List[BaseMessage]) -> Dict[int, Any]:
# Get all previous tool responses
results = {}
for message in messages[::-1]:
if isinstance(message, FunctionMessage):
results[int(message.additional_kwargs["idx"])] = message.content
return results
class SchedulerInput(TypedDict):
messages: List[BaseMessage]
tasks: Iterable[Task]
def _execute_task(task, observations, config):
tool_to_use = task["tool"]
if isinstance(tool_to_use, str):
return tool_to_use
args = task["args"]
try:
if isinstance(args, str):
resolved_args = _resolve_arg(args, observations)
elif isinstance(args, dict):
resolved_args = {
key: _resolve_arg(val, observations) for key, val in args.items()
}
else:
# This will likely fail
resolved_args = args
except Exception as e:
return (
f"ERROR(Failed to call {tool_to_use.name} with args {args}.)"
f" Args could not be resolved. Error: {repr(e)}"
)
try:
return tool_to_use.invoke(resolved_args, config)
except Exception as e:
return (
f"ERROR(Failed to call {tool_to_use.name} with args {args}."
+ f" Args resolved to {resolved_args}. Error: {repr(e)})"
)
def _resolve_arg(arg: Union[str, Any], observations: Dict[int, Any]):
# $1 or ${1} -> 1
ID_PATTERN = r"\$\{?(\d+)\}?"
def replace_match(match):
# If the string is ${123}, match.group(0) is ${123}, and match.group(1) is 123.
# Return the match group, in this case the index, from the string. This is the index
# number we get back.
idx = int(match.group(1))
return str(observations.get(idx, match.group(0)))
# For dependencies on other tasks
if isinstance(arg, str):
return re.sub(ID_PATTERN, replace_match, arg)
elif isinstance(arg, list):
return [_resolve_arg(a, observations) for a in arg]
else:
return str(arg)
@as_runnable
def schedule_task(task_inputs, config):
task: Task = task_inputs["task"]
observations: Dict[int, Any] = task_inputs["observations"]
try:
observation = _execute_task(task, observations, config)
except Exception:
import traceback
observation = traceback.format_exception() # repr(e) +
observations[task["idx"]] = observation
def schedule_pending_task(
task: Task, observations: Dict[int, Any], retry_after: float = 0.2
):
while True:
deps = task["dependencies"]
if deps and (any([dep not in observations for dep in deps])):
# Dependencies not yet satisfied
time.sleep(retry_after)
continue
schedule_task.invoke({"task": task, "observations": observations})
break
@as_runnable
def schedule_tasks(scheduler_input: SchedulerInput) -> List[FunctionMessage]:
"""Group the tasks into a DAG schedule."""
# For streaming, we are making a few simplifying assumption:
# 1. The LLM does not create cyclic dependencies
# 2. That the LLM will not generate tasks with future deps
# If this ceases to be a good assumption, you can either
# adjust to do a proper topological sort (not-stream)
# or use a more complicated data structure
tasks = scheduler_input["tasks"]
args_for_tasks = {}
messages = scheduler_input["messages"]
# If we are re-planning, we may have calls that depend on previous
# plans. Start with those.
observations = _get_observations(messages)
task_names = {}
originals = set(observations)
# ^^ We assume each task inserts a different key above to
# avoid race conditions...
futures = []
retry_after = 0.25 # Retry every quarter second
with ThreadPoolExecutor() as executor:
for task in tasks:
deps = task["dependencies"]
task_names[task["idx"]] = (
task["tool"] if isinstance(task["tool"], str) else task["tool"].name
)
args_for_tasks[task["idx"]] = task["args"]
if (
# Depends on other tasks
deps and (any([dep not in observations for dep in deps]))
):
futures.append(
executor.submit(
schedule_pending_task, task, observations, retry_after
)
)
else:
# No deps or all deps satisfied
# can schedule now
schedule_task.invoke(dict(task=task, observations=observations))
# futures.append(executor.submit(schedule_task.invoke dict(task=task, observations=observations)))
# All tasks have been submitted or enqueued
# Wait for them to complete
wait(futures)
# Convert observations to new tool messages to add to the state
new_observations = {
k: (task_names[k], args_for_tasks[k], observations[k])
for k in sorted(observations.keys() - originals)
}
tool_messages = [
FunctionMessage(
name=name,
content=str(obs),
additional_kwargs={"idx": k, "args": task_args},
tool_call_id=k,
)
for k, (name, task_args, obs) in new_observations.items()
]
return tool_messages