From 46d8ab895e44c1ad5681a85b0d20f8497cd0a66c Mon Sep 17 00:00:00 2001 From: Carson Sievert Date: Tue, 26 Nov 2024 13:50:46 -0600 Subject: [PATCH] `ui.Chat()` now correctly handles new `ollama.chat()` return value introduced in ollama 0.4 (#1787) --- CHANGELOG.md | 6 ++++++ .../chat/hello-providers/ollama/app.py | 2 +- shiny/ui/_chat.py | 9 -------- shiny/ui/_chat_normalize.py | 16 ++++++++++---- shiny/ui/_chat_tokenizer.py | 21 ++++++++++++++----- tests/pytest/test_chat.py | 18 +++++++++++++--- 6 files changed, 50 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d0689bafe..af2367286 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to Shiny for Python will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [UNRELEASED] + +### Bug fixes + +* `ui.Chat()` now correctly handles new `ollama.chat()` return value introduced in `ollama` v0.4. (#1787) + ## [1.2.1] - 2024-11-14 ### Bug fixes diff --git a/shiny/templates/chat/hello-providers/ollama/app.py b/shiny/templates/chat/hello-providers/ollama/app.py index a60250bcf..25d1e37ff 100644 --- a/shiny/templates/chat/hello-providers/ollama/app.py +++ b/shiny/templates/chat/hello-providers/ollama/app.py @@ -29,7 +29,7 @@ async def _(): # Create a response message stream # Assumes you've run `ollama run llama3` to start the server response = ollama.chat( - model="llama3", + model="llama3.2", messages=messages, stream=True, ) diff --git a/shiny/ui/_chat.py b/shiny/ui/_chat.py index 0a1a9b187..b72995d94 100644 --- a/shiny/ui/_chat.py +++ b/shiny/ui/_chat.py @@ -914,15 +914,6 @@ def _get_token_count( if self._tokenizer is None: self._tokenizer = get_default_tokenizer() - if self._tokenizer is None: - raise ValueError( - "A tokenizer is required to impose `token_limits` on messages. " - "To get a generic default tokenizer, install the `tokenizers` " - "package (`pip install tokenizers`). " - "To get a more precise token count, provide a specific tokenizer " - "to the `Chat` constructor." - ) - encoded = self._tokenizer.encode(content) if isinstance(encoded, TokenizersEncoding): return len(encoded.ids) diff --git a/shiny/ui/_chat_normalize.py b/shiny/ui/_chat_normalize.py index 7bec8102c..5a2edefc3 100644 --- a/shiny/ui/_chat_normalize.py +++ b/shiny/ui/_chat_normalize.py @@ -231,11 +231,19 @@ def normalize_chunk(self, chunk: "dict[str, Any]") -> ChatMessage: return super().normalize_chunk(msg) def can_normalize(self, message: Any) -> bool: - if not isinstance(message, dict): - return False - if "message" not in message: + try: + from ollama import ChatResponse + + # Ollama<0.4 used TypedDict (now it uses pydantic) + # https://github.com/ollama/ollama-python/pull/276 + if isinstance(ChatResponse, dict): + return "message" in message and super().can_normalize( + message["message"] + ) + else: + return isinstance(message, ChatResponse) + except Exception: return False - return super().can_normalize(message["message"]) def can_normalize_chunk(self, chunk: Any) -> bool: return self.can_normalize(chunk) diff --git a/shiny/ui/_chat_tokenizer.py b/shiny/ui/_chat_tokenizer.py index eabf83179..3e0fc6fb7 100644 --- a/shiny/ui/_chat_tokenizer.py +++ b/shiny/ui/_chat_tokenizer.py @@ -45,12 +45,23 @@ def encode( TokenEncoding = Union[TiktokenEncoding, TokenizersTokenizer] -def get_default_tokenizer() -> TokenizersTokenizer | None: +def get_default_tokenizer() -> TokenizersTokenizer: try: from tokenizers import Tokenizer return Tokenizer.from_pretrained("bert-base-cased") # type: ignore - except Exception: - pass - - return None + except ImportError: + raise ImportError( + "Failed to download a default tokenizer. " + "A tokenizer is required to impose `token_limits` on `chat.messages()`. " + "To get a generic default tokenizer, install the `tokenizers` " + "package (`pip install tokenizers`). " + ) + except Exception as e: + raise RuntimeError( + "Failed to download a default tokenizer. " + "A tokenizer is required to impose `token_limits` on `chat.messages()`. " + "Try manually downloading a tokenizer using " + "`tokenizers.Tokenizer.from_pretrained()` and passing it to `ui.Chat()`." + f"Error: {e}" + ) from e diff --git a/tests/pytest/test_chat.py b/tests/pytest/test_chat.py index c4d9d3e14..540b1cda7 100644 --- a/tests/pytest/test_chat.py +++ b/tests/pytest/test_chat.py @@ -333,6 +333,20 @@ def test_openai_normalization(): assert msg == {"content": "Hello ", "role": "assistant"} +def test_ollama_normalization(): + from ollama import ChatResponse + from ollama import Message as OllamaMessage + + # Mock return object from ollama.chat() + msg = ChatResponse( + message=OllamaMessage(content="Hello world!", role="assistant"), + ) + + msg_dict = {"content": "Hello world!", "role": "assistant"} + assert normalize_message(msg) == msg_dict + assert normalize_message_chunk(msg) == msg_dict + + # ------------------------------------------------------------------------------------ # Unit tests for as_provider_message() # @@ -462,9 +476,7 @@ def test_as_ollama_message(): import ollama from ollama import Message as OllamaMessage - assert "typing.Sequence[ollama._types.Message]" in str( - ollama.chat.__annotations__["messages"] - ) + assert "ollama._types.Message" in str(ollama.chat.__annotations__["messages"]) from shiny.ui._chat_provider_types import as_ollama_message