Skip to content

Commit

Permalink
updated with return type
Browse files Browse the repository at this point in the history
  • Loading branch information
eavanvalkenburg committed Dec 5, 2024
1 parent 2ccf793 commit 41911cd
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import httpx
from ollama import AsyncClient
from ollama._types import Message
from ollama._types import ChatResponse, Message
from pydantic import ValidationError

from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
Expand Down Expand Up @@ -243,10 +243,30 @@ def _create_chat_message_content(self, response: Mapping[str, Any], metadata: di
)

def _create_streaming_chat_message_content(
self, part: Mapping[str, Any], metadata: dict[str, Any]
self, part: Mapping[str, Any] | ChatResponse, metadata: dict[str, Any]
) -> StreamingChatMessageContent:
"""Create a streaming chat message content from the response part."""
items: list[STREAMING_ITEM_TYPES] = []
if isinstance(part, ChatResponse):
if part.message is None:
raise ServiceInvalidResponseError("No message content found in response part.")
if part.message.content:
items.append(
StreamingTextContent(
choice_index=0,
text=part.message.content,
inner_content=part.message,
)
)
return StreamingChatMessageContent(
role=AuthorRole.ASSISTANT,
choice_index=0,
items=items,
inner_content=part,
ai_model_id=self.ai_model_id,
metadata=metadata,
)

if not (message := part.get("message", None)):
raise ServiceInvalidResponseError("No message content found in response part.")

Expand All @@ -268,9 +288,19 @@ def _create_streaming_chat_message_content(
metadata=metadata,
)

def _get_metadata_from_response(self, response: Mapping[str, Any]) -> dict[str, Any]:
def _get_metadata_from_response(self, response: Mapping[str, Any] | ChatResponse) -> dict[str, Any]:
"""Get metadata from the response."""
metadata = {
if isinstance(response, ChatResponse):
metadata: dict[str, Any] = {
"model": response.model,
}
if response.prompt_eval_count and response.eval_count:
metadata["usage"] = CompletionUsage(
prompt_tokens=response.prompt_eval_count,
completion_tokens=response.eval_count,
)
return metadata
metadata: dict[str, Any] = {
"model": response.get("model"),
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,6 @@ def update_settings_from_function_choice_configuration(
for f in function_choice_configuration.available_functions
]
try:
settings.tools = tools
settings.tools = tools # type: ignore
except Exception:
settings.extension_data["tools"] = tools

0 comments on commit 41911cd

Please sign in to comment.