Skip to content

Commit

Permalink
fix qc
Browse files Browse the repository at this point in the history
  • Loading branch information
takatost committed Mar 21, 2024
1 parent a4a616b commit 700ac35
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 223 deletions.
36 changes: 18 additions & 18 deletions api/core/workflow/nodes/llm/llm_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.llm.entities import LLMNodeData, ModelConfig
from extensions.ext_database import db
from models.model import Conversation
from models.provider import Provider, ProviderType
Expand Down Expand Up @@ -64,10 +65,10 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_inputs['#context#'] = context

# fetch model config
model_instance, model_config = self._fetch_model_config(node_data)
model_instance, model_config = self._fetch_model_config(node_data.model)

# fetch memory
memory = self._fetch_memory(node_data, variable_pool, model_instance)
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)

# fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages(
Expand All @@ -89,7 +90,7 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult:

# handle invoke result
result_text, usage = self._invoke_llm(
node_data=node_data,
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop
Expand Down Expand Up @@ -119,13 +120,13 @@ def _run(self, variable_pool: VariablePool) -> NodeRunResult:
}
)

def _invoke_llm(self, node_data: LLMNodeData,
def _invoke_llm(self, node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: list[str]) -> tuple[str, LLMUsage]:
"""
Invoke large language model
:param node_data: node data
:param node_data_model: node data model
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
Expand All @@ -135,7 +136,7 @@ def _invoke_llm(self, node_data: LLMNodeData,

invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data.model.completion_params,
model_parameters=node_data_model.completion_params,
stop=stop,
stream=True,
user=self.user_id,
Expand Down Expand Up @@ -286,14 +287,14 @@ def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optiona

return None

def _fetch_model_config(self, node_data: LLMNodeData) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config
:param node_data: node data
:param node_data_model: node data model
:return:
"""
model_name = node_data.model.name
provider_name = node_data.model.provider
model_name = node_data_model.name
provider_name = node_data_model.provider

model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
Expand Down Expand Up @@ -326,14 +327,14 @@ def _fetch_model_config(self, node_data: LLMNodeData) -> tuple[ModelInstance, Mo
raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.")

# model config
completion_params = node_data.model.completion_params
completion_params = node_data_model.completion_params
stop = []
if 'stop' in completion_params:
stop = completion_params['stop']
del completion_params['stop']

# get model mode
model_mode = node_data.model.mode
model_mode = node_data_model.mode
if not model_mode:
raise ValueError("LLM mode is required.")

Expand All @@ -356,26 +357,25 @@ def _fetch_model_config(self, node_data: LLMNodeData) -> tuple[ModelInstance, Mo
stop=stop,
)

def _fetch_memory(self, node_data: LLMNodeData,
def _fetch_memory(self, node_data_memory: Optional[MemoryConfig],
variable_pool: VariablePool,
model_instance: ModelInstance) -> Optional[TokenBufferMemory]:
"""
Fetch memory
:param node_data: node data
:param node_data_memory: node data memory
:param variable_pool: variable pool
:return:
"""
if not node_data.memory:
if not node_data_memory:
return None

# get conversation id
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION])
conversation_id = variable_pool.get_variable_value(['sys', SystemVariable.CONVERSATION.value])
if conversation_id is None:
return None

# get conversation
conversation = db.session.query(Conversation).filter(
Conversation.tenant_id == self.tenant_id,
Conversation.app_id == self.app_id,
Conversation.id == conversation_id
).first()
Expand Down
16 changes: 1 addition & 15 deletions api/core/workflow/nodes/question_classifier/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pydantic import BaseModel

from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData


Expand All @@ -23,21 +24,6 @@ class ClassConfig(BaseModel):
name: str


class WindowConfig(BaseModel):
"""
Window Config.
"""
enabled: bool
size: int


class MemoryConfig(BaseModel):
"""
Memory Config.
"""
window: WindowConfig


class QuestionClassifierNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
Expand Down
Loading

0 comments on commit 700ac35

Please sign in to comment.