Skip to content

Commit

Permalink
Add Structure.run_stream() method
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Dec 20, 2024
1 parent 2cc47a7 commit 7301a5a
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 107 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support for `BranchTask` in `StructureVisualizer`.
- `EvalEngine` for evaluating the performance of an LLM's output against a given input.
- `BaseFileLoader.save()` method for saving an Artifact to a destination.
- `Structure.run_stream()` for streaming Events from a Structure as an iterator.

### Changed

Expand Down
8 changes: 8 additions & 0 deletions docs/griptape-framework/misc/events.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ Handler 1 <class 'griptape.events.finish_structure_run_event.FinishStructureRunE
Handler 2 <class 'griptape.events.finish_structure_run_event.FinishStructureRunEvent'>
```

## Stream Iterator

You can use `Structure.run_stream()` for streaming Events from the `Structure` in the form of an iterator.

```python
--8<-- "docs/griptape-framework/misc/src/events_streaming.py"
```

## Context Managers

You can also use [EventListener](../../reference/griptape/events/event_listener.md)s as a Python Context Manager.
Expand Down
7 changes: 7 additions & 0 deletions docs/griptape-framework/misc/src/events_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from griptape.events import BaseEvent
from griptape.structures import Agent

agent = Agent(stream=True)

for event in agent.run_stream("Hi!", event_types=[BaseEvent]): # All Events
print(type(event))
27 changes: 27 additions & 0 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,27 @@

import uuid
from abc import ABC, abstractmethod
from queue import Queue
from threading import Thread
from typing import TYPE_CHECKING, Any, Literal, Optional, Union

from attrs import Factory, define, field

from griptape.common import observable
from griptape.events import EventBus, FinishStructureRunEvent, StartStructureRunEvent
from griptape.events.base_event import BaseEvent
from griptape.events.event_listener import EventListener
from griptape.memory import TaskMemory
from griptape.memory.meta import MetaMemory
from griptape.memory.structure import ConversationMemory, Run
from griptape.mixins.rule_mixin import RuleMixin
from griptape.mixins.runnable_mixin import RunnableMixin
from griptape.mixins.serializable_mixin import SerializableMixin
from griptape.utils.contextvars_utils import with_contextvars

if TYPE_CHECKING:
from collections.abc import Iterator

from griptape.artifacts import BaseArtifact
from griptape.memory.structure import BaseConversationMemory
from griptape.tasks import BaseTask
Expand All @@ -42,6 +49,7 @@ class Structure(RuleMixin, SerializableMixin, RunnableMixin["Structure"], ABC):
meta_memory: MetaMemory = field(default=Factory(lambda: MetaMemory()), kw_only=True)
fail_fast: bool = field(default=True, kw_only=True, metadata={"serializable": True})
_execution_args: tuple = ()
_event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue()), init=False)

def __attrs_post_init__(self) -> None:
tasks = self._tasks.copy()
Expand Down Expand Up @@ -198,5 +206,24 @@ def run(self, *args) -> Structure:

return result

@observable
def run_stream(self, *args, event_types: Optional[list[type[BaseEvent]]] = None) -> Iterator[BaseEvent]:
if event_types is None:
event_types = [BaseEvent]
else:
if FinishStructureRunEvent not in event_types:
event_types = [*event_types, FinishStructureRunEvent]

with EventListener(self._event_queue.put, event_types=event_types):
t = Thread(target=with_contextvars(self.run), args=args)
t.start()

while True:
event = self._event_queue.get()
yield event
if isinstance(event, FinishStructureRunEvent):
break
t.join()

@abstractmethod
def try_run(self, *args) -> Structure: ...
2 changes: 0 additions & 2 deletions griptape/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from .reference_utils import references_from_artifacts
from .file_utils import get_mime_type
from .contextvars_utils import with_contextvars
from .events import Events


def minify_json(value: str) -> str:
Expand Down Expand Up @@ -50,5 +49,4 @@ def minify_json(value: str) -> str:
"references_from_artifacts",
"get_mime_type",
"with_contextvars",
"Events",
]
68 changes: 0 additions & 68 deletions griptape/utils/events.py

This file was deleted.

43 changes: 6 additions & 37 deletions griptape/utils/stream.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,40 @@
from __future__ import annotations

import json
from queue import Queue
from threading import Thread
from typing import TYPE_CHECKING

from attrs import Factory, define, field
from attrs import define, field

from griptape.artifacts.text_artifact import TextArtifact
from griptape.events import (
ActionChunkEvent,
BaseChunkEvent,
EventBus,
EventListener,
FinishPromptEvent,
FinishStructureRunEvent,
TextChunkEvent,
)
from griptape.utils.contextvars_utils import with_contextvars

if TYPE_CHECKING:
from collections.abc import Iterator

from griptape.events.base_event import BaseEvent
from griptape.structures import Structure


@define
class Stream:
"""A wrapper for Structures that converts `BaseChunkEvent`s into an iterator of TextArtifacts.
It achieves this by running the Structure in a separate thread, listening for events from the Structure,
and yielding those events.
See relevant Stack Overflow post: https://stackoverflow.com/questions/9968592/turn-functions-with-a-callback-into-python-generators
"""A wrapper for Structures filters Events relevant to text output and converts them to TextArtifacts.
Attributes:
structure: The Structure to wrap.
_event_queue: A queue to hold events from the Structure.
"""

structure: Structure = field()

_event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue()))

def run(self, *args) -> Iterator[TextArtifact]:
t = Thread(target=with_contextvars(self._run_structure), args=args)
t.start()

action_str = ""
while True:
event = self._event_queue.get()

for event in self.structure.run_stream(
*args, event_types=[TextChunkEvent, ActionChunkEvent, FinishPromptEvent, FinishStructureRunEvent]
):
if isinstance(event, FinishStructureRunEvent):
break
elif isinstance(event, FinishPromptEvent):
Expand All @@ -67,18 +51,3 @@ def run(self, *args) -> Iterator[TextArtifact]:
action_str = ""
except Exception:
pass
t.join()

def _run_structure(self, *args) -> None:
def event_handler(event: BaseEvent) -> None:
self._event_queue.put(event)

stream_event_listener = EventListener(
on_event=event_handler,
event_types=[BaseChunkEvent, FinishPromptEvent, FinishStructureRunEvent],
)
EventBus.add_event_listener(stream_event_listener)

self.structure.run(*args)

EventBus.remove_event_listener(stream_event_listener)
37 changes: 37 additions & 0 deletions tests/unit/structures/test_structure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from griptape.events import FinishStructureRunEvent, FinishTaskEvent, StartTaskEvent
from griptape.structures import Agent, Pipeline
from griptape.tasks import PromptTask
from tests.mocks.mock_prompt_driver import MockPromptDriver
Expand Down Expand Up @@ -130,3 +131,39 @@ def test_from_dict(self):

assert len(deserialized_agent.task_outputs) == 1
assert deserialized_agent.task_outputs[task.id].value == "mock output"

def test_run_stream(self):
from griptape.events import (
EventBus,
FinishPromptEvent,
FinishStructureRunEvent,
StartPromptEvent,
StartStructureRunEvent,
)

agent = Agent()
event_types = [
StartStructureRunEvent,
StartTaskEvent,
StartPromptEvent,
FinishPromptEvent,
FinishTaskEvent,
FinishStructureRunEvent,
]
events = agent.run_stream()

for idx, event in enumerate(events):
assert isinstance(event, event_types[idx])
assert len(EventBus.event_listeners) == 0

def test_run_stream_custom_event_types(self):
from griptape.events import EventBus, FinishPromptEvent, StartPromptEvent, StartStructureRunEvent

agent = Agent()
event_types = [StartStructureRunEvent, StartPromptEvent, FinishPromptEvent]
expected_event_types = [StartStructureRunEvent, StartPromptEvent, FinishPromptEvent, FinishStructureRunEvent]
events = agent.run_stream(event_types=event_types)

for idx, event in enumerate(events):
assert isinstance(event, expected_event_types[idx])
assert len(EventBus.event_listeners) == 0

0 comments on commit 7301a5a

Please sign in to comment.