From f53454f81d03e19387c7b230d664db13e8d09ff4 Mon Sep 17 00:00:00 2001 From: orangeclk Date: Wed, 21 Aug 2024 17:29:30 +0800 Subject: [PATCH] add finish_reason to the LLM node output (#7498) --- .../openai_api_compatible/llm/llm.py | 4 +++- api/core/workflow/nodes/llm/llm_node.py | 15 ++++++++++----- .../question_classifier_node.py | 3 ++- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index e5cc884b6..753dc6cb2 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -428,7 +428,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): if new_tool_call.function.arguments: tool_call.function.arguments += new_tool_call.function.arguments - finish_reason = 'Unknown' + finish_reason = None # The default value of finish_reason is None for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): chunk = chunk.strip() @@ -437,6 +437,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): if chunk.startswith(':'): continue decoded_chunk = chunk.strip().lstrip('data: ').lstrip() + if decoded_chunk == '[DONE]': # Some provider returns "data: [DONE]" + continue try: chunk_json = json.loads(decoded_chunk) diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index c3e494942..eb8921b52 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -113,7 +113,7 @@ class LLMNode(BaseNode): } # handle invoke result - result_text, usage = self._invoke_llm( + result_text, usage, finish_reason = self._invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, @@ -129,7 +129,8 @@ class LLMNode(BaseNode): outputs = { 'text': result_text, - 'usage': jsonable_encoder(usage) + 'usage': jsonable_encoder(usage), + 'finish_reason': finish_reason } return NodeRunResult( @@ -167,14 +168,14 @@ class LLMNode(BaseNode): ) # handle invoke result - text, usage = self._handle_invoke_result( + text, usage, finish_reason = self._handle_invoke_result( invoke_result=invoke_result ) # deduct quota self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) - return text, usage + return text, usage, finish_reason def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: """ @@ -186,6 +187,7 @@ class LLMNode(BaseNode): prompt_messages = [] full_text = '' usage = None + finish_reason = None for result in invoke_result: text = result.delta.message.content full_text += text @@ -201,10 +203,13 @@ class LLMNode(BaseNode): if not usage and result.delta.usage: usage = result.delta.usage + if not finish_reason and result.delta.finish_reason: + finish_reason = result.delta.finish_reason + if not usage: usage = LLMUsage.empty_usage() - return full_text, usage + return full_text, usage, finish_reason def _transform_chat_messages(self, messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 2e1464efc..f4057d50f 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -63,7 +63,7 @@ class QuestionClassifierNode(LLMNode): ) # handle invoke result - result_text, usage = self._invoke_llm( + result_text, usage, finish_reason = self._invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, @@ -93,6 +93,7 @@ class QuestionClassifierNode(LLMNode): prompt_messages=prompt_messages ), 'usage': jsonable_encoder(usage), + 'finish_reason': finish_reason } outputs = { 'class_name': category_name