From 55510c00d83541ea37af0c9353c614cfbd5a661e Mon Sep 17 00:00:00 2001 From: Carson Date: Thu, 19 Dec 2024 16:35:36 -0600 Subject: [PATCH] Add format argument to .tokens(); make default behavior same as before: --- chatlas/_chat.py | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index fdb16bf..0eb2360 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -16,6 +16,7 @@ Optional, Sequence, TypeVar, + overload, ) from pydantic import BaseModel @@ -176,14 +177,41 @@ def system_prompt(self, value: str | None): if value is not None: self._turns.insert(0, Turn("system", value)) - def tokens(self) -> list[int]: + @overload + def tokens( + self, + format: Literal["overall"], + ) -> list[tuple[int, int] | None]: + pass + + @overload + def tokens( + self, + format: Literal["input"] + ) -> list[int]: + pass + + def tokens( + self, + format: Literal["overall", "input"] = "overall", + ) -> list[int] | list[tuple[int, int] | None]: """ Get the tokens for each turn in the chat. + Parameters + ---------- + format + If "overall" (the default), the result can be summed to get the + chat's overall token usage (helpful for computing overall cost of + the chat). If "input", the result can be summed to get the number of + tokens the turns will cost to generate the next response (helpful + for estimating cost of the next response, or for determining if you + are about to exceed the token limit). + Returns ------- list[int] - A list of token counts for each (non-system )turn in the chat. The + A list of token counts for each (non-system) turn in the chat. The 1st turn includes the tokens count for the system prompt (if any). Raises @@ -198,6 +226,9 @@ def tokens(self) -> list[int]: turns = self.get_turns(include_system_prompt=False) + if format == "overall": + return [turn.tokens for turn in turns] + if len(turns) == 0: return []