Skip to content

Commit

Permalink
Improvements and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cpsievert committed Dec 19, 2024
1 parent 55510c0 commit 7e7c42c
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 11 deletions.
23 changes: 12 additions & 11 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,33 +177,34 @@ def system_prompt(self, value: str | None):
if value is not None:
self._turns.insert(0, Turn("system", value))

@overload
def tokens(self) -> list[tuple[int, int] | None]: ...

@overload
def tokens(
self,
format: Literal["overall"],
) -> list[tuple[int, int] | None]:
pass
values: Literal["cumulative"],
) -> list[tuple[int, int] | None]: ...

@overload
def tokens(
self,
format: Literal["input"]
) -> list[int]:
pass
values: Literal["discrete"],
) -> list[int]: ...

def tokens(
self,
format: Literal["overall", "input"] = "overall",
values: Literal["cumulative", "discrete"] = "discrete",
) -> list[int] | list[tuple[int, int] | None]:
"""
Get the tokens for each turn in the chat.
Parameters
----------
format
If "overall" (the default), the result can be summed to get the
values
If "cumulative" (the default), the result can be summed to get the
chat's overall token usage (helpful for computing overall cost of
the chat). If "input", the result can be summed to get the number of
the chat). If "discrete", the result can be summed to get the number of
tokens the turns will cost to generate the next response (helpful
for estimating cost of the next response, or for determining if you
are about to exceed the token limit).
Expand All @@ -226,7 +227,7 @@ def tokens(

turns = self.get_turns(include_system_prompt=False)

if format == "overall":
if values == "cumulative":
return [turn.tokens for turn in turns]

if len(turns) == 0:
Expand Down
28 changes: 28 additions & 0 deletions tests/test_tokens.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,35 @@
from chatlas import ChatOpenAI, Turn
from chatlas._openai import OpenAIAzureProvider, OpenAIProvider
from chatlas._tokens import token_usage, tokens_log, tokens_reset


def test_tokens_method():
chat = ChatOpenAI()
assert chat.tokens(values="discrete") == []

chat = ChatOpenAI(
turns=[
Turn(role="user", contents="Hi"),
Turn(role="assistant", contents="Hello", tokens=(2, 10)),
]
)

assert chat.tokens(values="discrete") == [2, 10]

chat = ChatOpenAI(
turns=[
Turn(role="user", contents="Hi"),
Turn(role="assistant", contents="Hello", tokens=(2, 10)),
Turn(role="user", contents="Hi"),
Turn(role="assistant", contents="Hello", tokens=(14, 10)),
]
)

assert chat.tokens(values="discrete") == [2, 10, 2, 10]

assert chat.tokens(values="cumulative") == [None, (2, 10), None, (14, 10)]


def test_usage_is_none():
tokens_reset()
assert token_usage() is None
Expand Down

0 comments on commit 7e7c42c

Please sign in to comment.