Skip to content

Commit

Permalink
Add format argument to .tokens(); make default behavior same as before:
Browse files Browse the repository at this point in the history
  • Loading branch information
cpsievert committed Dec 19, 2024
1 parent 41d4065 commit 55510c0
Showing 1 changed file with 33 additions and 2 deletions.
35 changes: 33 additions & 2 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Optional,
Sequence,
TypeVar,
overload,
)

from pydantic import BaseModel
Expand Down Expand Up @@ -176,14 +177,41 @@ def system_prompt(self, value: str | None):
if value is not None:
self._turns.insert(0, Turn("system", value))

def tokens(self) -> list[int]:
@overload
def tokens(
self,
format: Literal["overall"],
) -> list[tuple[int, int] | None]:
pass

@overload
def tokens(
self,
format: Literal["input"]
) -> list[int]:
pass

def tokens(
self,
format: Literal["overall", "input"] = "overall",
) -> list[int] | list[tuple[int, int] | None]:
"""
Get the tokens for each turn in the chat.
Parameters
----------
format
If "overall" (the default), the result can be summed to get the
chat's overall token usage (helpful for computing overall cost of
the chat). If "input", the result can be summed to get the number of
tokens the turns will cost to generate the next response (helpful
for estimating cost of the next response, or for determining if you
are about to exceed the token limit).
Returns
-------
list[int]
A list of token counts for each (non-system )turn in the chat. The
A list of token counts for each (non-system) turn in the chat. The
1st turn includes the tokens count for the system prompt (if any).
Raises
Expand All @@ -198,6 +226,9 @@ def tokens(self) -> list[int]:

turns = self.get_turns(include_system_prompt=False)

if format == "overall":
return [turn.tokens for turn in turns]

if len(turns) == 0:
return []

Expand Down

0 comments on commit 55510c0

Please sign in to comment.