add finish_reason to the LLM node output (#7498)

This commit is contained in:
orangeclk 2024-08-21 17:29:30 +08:00 committed by GitHub
parent 784b11ce19
commit f53454f81d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 7 deletions

View File

@ -428,7 +428,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if new_tool_call.function.arguments: if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments tool_call.function.arguments += new_tool_call.function.arguments
finish_reason = 'Unknown' finish_reason = None # The default value of finish_reason is None
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
chunk = chunk.strip() chunk = chunk.strip()
@ -437,6 +437,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if chunk.startswith(':'): if chunk.startswith(':'):
continue continue
decoded_chunk = chunk.strip().lstrip('data: ').lstrip() decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
if decoded_chunk == '[DONE]': # Some provider returns "data: [DONE]"
continue
try: try:
chunk_json = json.loads(decoded_chunk) chunk_json = json.loads(decoded_chunk)

View File

@ -113,7 +113,7 @@ class LLMNode(BaseNode):
} }
# handle invoke result # handle invoke result
result_text, usage = self._invoke_llm( result_text, usage, finish_reason = self._invoke_llm(
node_data_model=node_data.model, node_data_model=node_data.model,
model_instance=model_instance, model_instance=model_instance,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
@ -129,7 +129,8 @@ class LLMNode(BaseNode):
outputs = { outputs = {
'text': result_text, 'text': result_text,
'usage': jsonable_encoder(usage) 'usage': jsonable_encoder(usage),
'finish_reason': finish_reason
} }
return NodeRunResult( return NodeRunResult(
@ -167,14 +168,14 @@ class LLMNode(BaseNode):
) )
# handle invoke result # handle invoke result
text, usage = self._handle_invoke_result( text, usage, finish_reason = self._handle_invoke_result(
invoke_result=invoke_result invoke_result=invoke_result
) )
# deduct quota # deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage return text, usage, finish_reason
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
""" """
@ -186,6 +187,7 @@ class LLMNode(BaseNode):
prompt_messages = [] prompt_messages = []
full_text = '' full_text = ''
usage = None usage = None
finish_reason = None
for result in invoke_result: for result in invoke_result:
text = result.delta.message.content text = result.delta.message.content
full_text += text full_text += text
@ -201,10 +203,13 @@ class LLMNode(BaseNode):
if not usage and result.delta.usage: if not usage and result.delta.usage:
usage = result.delta.usage usage = result.delta.usage
if not finish_reason and result.delta.finish_reason:
finish_reason = result.delta.finish_reason
if not usage: if not usage:
usage = LLMUsage.empty_usage() usage = LLMUsage.empty_usage()
return full_text, usage return full_text, usage, finish_reason
def _transform_chat_messages(self, def _transform_chat_messages(self,
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate

View File

@ -63,7 +63,7 @@ class QuestionClassifierNode(LLMNode):
) )
# handle invoke result # handle invoke result
result_text, usage = self._invoke_llm( result_text, usage, finish_reason = self._invoke_llm(
node_data_model=node_data.model, node_data_model=node_data.model,
model_instance=model_instance, model_instance=model_instance,
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
@ -93,6 +93,7 @@ class QuestionClassifierNode(LLMNode):
prompt_messages=prompt_messages prompt_messages=prompt_messages
), ),
'usage': jsonable_encoder(usage), 'usage': jsonable_encoder(usage),
'finish_reason': finish_reason
} }
outputs = { outputs = {
'class_name': category_name 'class_name': category_name