Skip to content

Commit

Permalink
Fixed issue with return in async generators
Browse files Browse the repository at this point in the history
  • Loading branch information
AryazE committed Aug 21, 2024
1 parent d5c821f commit c3cb741
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 2 deletions.
17 changes: 15 additions & 2 deletions src/dynapyt/instrument/CodeInstrumenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,15 @@ def visit_FunctionDef(self, node: cst.FunctionDef):
params = node.params.params[0].name
elif len(node.params.posonly_params) > 0:
params = node.params.posonly_params[0].name
self.current_function.append({"params": params, "name": node.name, "iid": iid})
self.current_function.append(
{
"params": params,
"name": node.name,
"iid": iid,
"is_async": (node.asynchronous is not None),
"is_generator": (len(m.findall(node, m.Yield())) > 0),
}
)

def leave_FunctionDef(
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
Expand Down Expand Up @@ -1311,6 +1319,11 @@ def leave_Lambda(self, original_node, updated_node):
return updated_node.with_changes(body=new_stmt)

def leave_Return(self, original_node, updated_node):
if (
self.current_function[-1]["is_generator"]
and self.current_function[-1]["is_async"]
):
return original_node
if "_return" not in self.selected_hooks:
return updated_node
callee_name = cst.Attribute(
Expand Down Expand Up @@ -1342,7 +1355,7 @@ def leave_Return(self, original_node, updated_node):

def leave_Yield(self, original_node, updated_node):
function_metadata = self.current_function[-1]
if "yield" not in self.selected_hooks:
if "_yield" not in self.selected_hooks:
return updated_node
callee_name = cst.Attribute(
value=cst.Name(value="_rt"), attr=cst.Name(value="_yield_")
Expand Down
12 changes: 12 additions & 0 deletions tests/regression/async_return/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from dynapyt.analyses.BaseAnalysis import BaseAnalysis


class TestAnalysis(BaseAnalysis):
def begin_execution(self):
print("begin execution")

def function_exit(self, dyn_ast, iid, name, result):
print(f"function {name} exited with result {result}")

def end_execution(self):
print("end execution")
9 changes: 9 additions & 0 deletions tests/regression/async_return/expected.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
begin execution
function foo exited with result 0
0
function foo exited with result 1
1
function foo exited with result 2
2
function main exited with result None
end execution
14 changes: 14 additions & 0 deletions tests/regression/async_return/program.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
async def foo():
for i in range(3):
yield i
return


async def main():
async for i in foo():
print(i)


import asyncio

asyncio.run(main())

0 comments on commit c3cb741

Please sign in to comment.