add finish_reason to the LLM node output (#7498)

This commit is contained in:
orangeclk 2024-08-21 17:29:30 +08:00 committed by GitHub
parent 784b11ce19
commit f53454f81d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 7 deletions

View File

@ -428,7 +428,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if new_tool_call.function.arguments:
tool_call.function.arguments += new_tool_call.function.arguments
finish_reason = 'Unknown'
finish_reason = None # The default value of finish_reason is None
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
chunk = chunk.strip()
@ -437,6 +437,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if chunk.startswith(':'):
continue
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
if decoded_chunk == '[DONE]': # Some provider returns "data: [DONE]"
continue
try:
chunk_json = json.loads(decoded_chunk)

View File

@ -113,7 +113,7 @@ class LLMNode(BaseNode):
}
# handle invoke result
result_text, usage = self._invoke_llm(
result_text, usage, finish_reason = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
@ -129,7 +129,8 @@ class LLMNode(BaseNode):
outputs = {
'text': result_text,
'usage': jsonable_encoder(usage)
'usage': jsonable_encoder(usage),
'finish_reason': finish_reason
}
return NodeRunResult(
@ -167,14 +168,14 @@ class LLMNode(BaseNode):
)
# handle invoke result
text, usage = self._handle_invoke_result(
text, usage, finish_reason = self._handle_invoke_result(
invoke_result=invoke_result
)
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage
return text, usage, finish_reason
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
"""
@ -186,6 +187,7 @@ class LLMNode(BaseNode):
prompt_messages = []
full_text = ''
usage = None
finish_reason = None
for result in invoke_result:
text = result.delta.message.content
full_text += text
@ -201,10 +203,13 @@ class LLMNode(BaseNode):
if not usage and result.delta.usage:
usage = result.delta.usage
if not finish_reason and result.delta.finish_reason:
finish_reason = result.delta.finish_reason
if not usage:
usage = LLMUsage.empty_usage()
return full_text, usage
return full_text, usage, finish_reason
def _transform_chat_messages(self,
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate

View File

@ -63,7 +63,7 @@ class QuestionClassifierNode(LLMNode):
)
# handle invoke result
result_text, usage = self._invoke_llm(
result_text, usage, finish_reason = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
@ -93,6 +93,7 @@ class QuestionClassifierNode(LLMNode):
prompt_messages=prompt_messages
),
'usage': jsonable_encoder(usage),
'finish_reason': finish_reason
}
outputs = {
'class_name': category_name