mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-03 03:38:08 +08:00
add finish_reason to the LLM node output (#7498)
This commit is contained in:
parent
784b11ce19
commit
f53454f81d
@ -428,7 +428,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
if new_tool_call.function.arguments:
|
if new_tool_call.function.arguments:
|
||||||
tool_call.function.arguments += new_tool_call.function.arguments
|
tool_call.function.arguments += new_tool_call.function.arguments
|
||||||
|
|
||||||
finish_reason = 'Unknown'
|
finish_reason = None # The default value of finish_reason is None
|
||||||
|
|
||||||
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
|
||||||
chunk = chunk.strip()
|
chunk = chunk.strip()
|
||||||
@ -437,6 +437,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
|||||||
if chunk.startswith(':'):
|
if chunk.startswith(':'):
|
||||||
continue
|
continue
|
||||||
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
|
||||||
|
if decoded_chunk == '[DONE]': # Some provider returns "data: [DONE]"
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
chunk_json = json.loads(decoded_chunk)
|
chunk_json = json.loads(decoded_chunk)
|
||||||
|
@ -113,7 +113,7 @@ class LLMNode(BaseNode):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# handle invoke result
|
# handle invoke result
|
||||||
result_text, usage = self._invoke_llm(
|
result_text, usage, finish_reason = self._invoke_llm(
|
||||||
node_data_model=node_data.model,
|
node_data_model=node_data.model,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
@ -129,7 +129,8 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
outputs = {
|
outputs = {
|
||||||
'text': result_text,
|
'text': result_text,
|
||||||
'usage': jsonable_encoder(usage)
|
'usage': jsonable_encoder(usage),
|
||||||
|
'finish_reason': finish_reason
|
||||||
}
|
}
|
||||||
|
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
@ -167,14 +168,14 @@ class LLMNode(BaseNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# handle invoke result
|
# handle invoke result
|
||||||
text, usage = self._handle_invoke_result(
|
text, usage, finish_reason = self._handle_invoke_result(
|
||||||
invoke_result=invoke_result
|
invoke_result=invoke_result
|
||||||
)
|
)
|
||||||
|
|
||||||
# deduct quota
|
# deduct quota
|
||||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||||
|
|
||||||
return text, usage
|
return text, usage, finish_reason
|
||||||
|
|
||||||
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
|
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
|
||||||
"""
|
"""
|
||||||
@ -186,6 +187,7 @@ class LLMNode(BaseNode):
|
|||||||
prompt_messages = []
|
prompt_messages = []
|
||||||
full_text = ''
|
full_text = ''
|
||||||
usage = None
|
usage = None
|
||||||
|
finish_reason = None
|
||||||
for result in invoke_result:
|
for result in invoke_result:
|
||||||
text = result.delta.message.content
|
text = result.delta.message.content
|
||||||
full_text += text
|
full_text += text
|
||||||
@ -201,10 +203,13 @@ class LLMNode(BaseNode):
|
|||||||
if not usage and result.delta.usage:
|
if not usage and result.delta.usage:
|
||||||
usage = result.delta.usage
|
usage = result.delta.usage
|
||||||
|
|
||||||
|
if not finish_reason and result.delta.finish_reason:
|
||||||
|
finish_reason = result.delta.finish_reason
|
||||||
|
|
||||||
if not usage:
|
if not usage:
|
||||||
usage = LLMUsage.empty_usage()
|
usage = LLMUsage.empty_usage()
|
||||||
|
|
||||||
return full_text, usage
|
return full_text, usage, finish_reason
|
||||||
|
|
||||||
def _transform_chat_messages(self,
|
def _transform_chat_messages(self,
|
||||||
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||||
|
@ -63,7 +63,7 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# handle invoke result
|
# handle invoke result
|
||||||
result_text, usage = self._invoke_llm(
|
result_text, usage, finish_reason = self._invoke_llm(
|
||||||
node_data_model=node_data.model,
|
node_data_model=node_data.model,
|
||||||
model_instance=model_instance,
|
model_instance=model_instance,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
@ -93,6 +93,7 @@ class QuestionClassifierNode(LLMNode):
|
|||||||
prompt_messages=prompt_messages
|
prompt_messages=prompt_messages
|
||||||
),
|
),
|
||||||
'usage': jsonable_encoder(usage),
|
'usage': jsonable_encoder(usage),
|
||||||
|
'finish_reason': finish_reason
|
||||||
}
|
}
|
||||||
outputs = {
|
outputs = {
|
||||||
'class_name': category_name
|
'class_name': category_name
|
||||||
|
Loading…
Reference in New Issue
Block a user