Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add events util #1480

Merged
merged 2 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support for `BranchTask` in `StructureVisualizer`.
- `EvalEngine` for evaluating the performance of an LLM's output against a given input.
- `BaseFileLoader.save()` method for saving an Artifact to a destination.
- `Structure.run_stream()` for streaming Events from a Structure as an iterator.

### Changed

Expand Down
11 changes: 11 additions & 0 deletions docs/griptape-framework/misc/events.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ Handler 1 <class 'griptape.events.finish_structure_run_event.FinishStructureRunE
Handler 2 <class 'griptape.events.finish_structure_run_event.FinishStructureRunEvent'>
```

## Stream Iterator

You can use `Structure.run_stream()` for streaming Events from the `Structure` in the form of an iterator.
Comment on lines +77 to +79
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would prefer a EventStream util like Stream that currently exists and punt more core API changes to the Structure api to a later time, but not going to block the PR based on that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The downside of this approach is that EventStream would need to support anything that might need this functionality (Structures, Tasks, Drivers, Tools, etc).

I'm just as nervous about modifying Structure's API, but this implementation does at least seem consistent with other similar frameworks.


!!! tip
Set `stream=True` on your [Prompt Driver](../drivers/prompt-drivers.md) in order to receive completion chunk events.

```python
--8<-- "docs/griptape-framework/misc/src/events_streaming.py"
```

## Context Managers

You can also use [EventListener](../../reference/griptape/events/event_listener.md)s as a Python Context Manager.
Expand Down
7 changes: 7 additions & 0 deletions docs/griptape-framework/misc/src/events_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from griptape.events import BaseEvent
from griptape.structures import Agent

agent = Agent()

for event in agent.run_stream("Hi!", event_types=[BaseEvent]): # All Events
print(type(event))
28 changes: 28 additions & 0 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,27 @@

import uuid
from abc import ABC, abstractmethod
from queue import Queue
from threading import Thread
from typing import TYPE_CHECKING, Any, Literal, Optional, Union

from attrs import Factory, define, field

from griptape.common import observable
from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent
from griptape.events.base_event import BaseEvent
from griptape.events.event_listener import EventListener
from griptape.memory import TaskMemory
from griptape.memory.meta import MetaMemory
from griptape.memory.structure import ConversationMemory, Run
from griptape.mixins.rule_mixin import RuleMixin
from griptape.mixins.runnable_mixin import RunnableMixin
from griptape.mixins.serializable_mixin import SerializableMixin
from griptape.utils.contextvars_utils import with_contextvars

if TYPE_CHECKING:
from collections.abc import Iterator

from griptape.artifacts import BaseArtifact
from griptape.memory.structure import BaseConversationMemory
from griptape.tasks import BaseTask
Expand All @@ -42,6 +49,7 @@ class Structure(RuleMixin, SerializableMixin, RunnableMixin["Structure"], ABC):
meta_memory: MetaMemory = field(default=Factory(lambda: MetaMemory()), kw_only=True)
fail_fast: bool = field(default=True, kw_only=True, metadata={"serializable": True})
_execution_args: tuple = ()
_event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue()), init=False)

def __attrs_post_init__(self) -> None:
tasks = self._tasks.copy()
Expand Down Expand Up @@ -198,5 +206,25 @@ def run(self, *args) -> Structure:

return result

@observable
def run_stream(self, *args, event_types: Optional[list[type[BaseEvent]]] = None) -> Iterator[BaseEvent]:
if event_types is None:
event_types = [BaseEvent]
else:
if FinishStructureRunEvent not in event_types:
event_types = [*event_types, FinishStructureRunEvent]

with EventListener(self._event_queue.put, event_types=event_types):
t = Thread(target=with_contextvars(self.run), args=args)
t.start()

while True:
event = self._event_queue.get()
if isinstance(event, FinishStructureRunEvent):
break
else:
yield event
t.join()

@abstractmethod
def try_run(self, *args) -> Structure: ...
43 changes: 6 additions & 37 deletions griptape/utils/stream.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,40 @@
from __future__ import annotations

import json
from queue import Queue
from threading import Thread
from typing import TYPE_CHECKING

from attrs import Factory, define, field
from attrs import define, field

from griptape.artifacts.text_artifact import TextArtifact
from griptape.events import (
ActionChunkEvent,
BaseChunkEvent,
EventBus,
EventListener,
FinishPromptEvent,
FinishStructureRunEvent,
TextChunkEvent,
)
from griptape.utils.contextvars_utils import with_contextvars

if TYPE_CHECKING:
from collections.abc import Iterator

from griptape.events.base_event import BaseEvent
from griptape.structures import Structure


@define
class Stream:
"""A wrapper for Structures that converts `BaseChunkEvent`s into an iterator of TextArtifacts.
It achieves this by running the Structure in a separate thread, listening for events from the Structure,
and yielding those events.
See relevant Stack Overflow post: https://stackoverflow.com/questions/9968592/turn-functions-with-a-callback-into-python-generators
"""A wrapper for Structures filters Events relevant to text output and converts them to TextArtifacts.
Attributes:
structure: The Structure to wrap.
_event_queue: A queue to hold events from the Structure.
"""

structure: Structure = field()

_event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue()))

def run(self, *args) -> Iterator[TextArtifact]:
t = Thread(target=with_contextvars(self._run_structure), args=args)
t.start()

action_str = ""
while True:
event = self._event_queue.get()

for event in self.structure.run_stream(
*args, event_types=[TextChunkEvent, ActionChunkEvent, FinishPromptEvent, FinishStructureRunEvent]
):
if isinstance(event, FinishStructureRunEvent):
break
elif isinstance(event, FinishPromptEvent):
Expand All @@ -67,18 +51,3 @@ def run(self, *args) -> Iterator[TextArtifact]:
action_str = ""
except Exception:
pass
t.join()

def _run_structure(self, *args) -> None:
def event_handler(event: BaseEvent) -> None:
self._event_queue.put(event)

stream_event_listener = EventListener(
on_event=event_handler,
event_types=[BaseChunkEvent, FinishPromptEvent, FinishStructureRunEvent],
)
EventBus.add_event_listener(stream_event_listener)

self.structure.run(*args)

EventBus.remove_event_listener(stream_event_listener)
37 changes: 37 additions & 0 deletions tests/unit/structures/test_structure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from griptape.events import FinishStructureRunEvent, FinishTaskEvent, StartTaskEvent
from griptape.structures import Agent, Pipeline
from griptape.tasks import PromptTask
from tests.mocks.mock_prompt_driver import MockPromptDriver
Expand Down Expand Up @@ -130,3 +131,39 @@ def test_from_dict(self):

assert len(deserialized_agent.task_outputs) == 1
assert deserialized_agent.task_outputs[task.id].value == "mock output"

def test_run_stream(self):
from griptape.events import (
EventBus,
FinishPromptEvent,
FinishStructureRunEvent,
StartPromptEvent,
StartStructureRunEvent,
)

agent = Agent()
event_types = [
StartStructureRunEvent,
StartTaskEvent,
StartPromptEvent,
FinishPromptEvent,
FinishTaskEvent,
FinishStructureRunEvent,
]
events = agent.run_stream()

for idx, event in enumerate(events):
assert isinstance(event, event_types[idx])
assert len(EventBus.event_listeners) == 0

def test_run_stream_custom_event_types(self):
from griptape.events import EventBus, FinishPromptEvent, StartPromptEvent, StartStructureRunEvent

agent = Agent()
event_types = [StartStructureRunEvent, StartPromptEvent, FinishPromptEvent]
expected_event_types = [StartStructureRunEvent, StartPromptEvent, FinishPromptEvent, FinishStructureRunEvent]
events = agent.run_stream(event_types=event_types)

for idx, event in enumerate(events):
assert isinstance(event, expected_event_types[idx])
assert len(EventBus.event_listeners) == 0
2 changes: 1 addition & 1 deletion tests/unit/utils/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ def test_init(self, agent):
with pytest.raises(StopIteration):
next(chat_stream_run)
else:
next(chat_stream.run())
assert next(chat_stream.run()).value == "\n"
with pytest.raises(StopIteration):
next(chat_stream.run())
Loading