Add support of tool-call for model provider "hunyuan" (#6656)

Co-authored-by: sun <sun@centen.cn>
This commit is contained in:
Giga Group 2024-07-25 11:27:58 +08:00 committed by GitHub
parent 585444c50c
commit ca696fe94c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,6 +14,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.errors.invoke import InvokeError
@ -44,6 +45,17 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
"Stream": stream,
**custom_parameters,
}
# add Tools and ToolChoice
if (tools and len(tools) > 0):
params['ToolChoice'] = "auto"
params['Tools'] = [{
"Type": "function",
"Function": {
"Name": tool.name,
"Description": tool.description,
"Parameters": json.dumps(tool.parameters)
}
} for tool in tools]
request.from_json_string(json.dumps(params))
response = client.ChatCompletions(request)
@ -89,9 +101,43 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
def _convert_prompt_messages_to_dicts(self, prompt_messages: list[PromptMessage]) -> list[dict]:
"""Convert a list of PromptMessage objects to a list of dictionaries with 'Role' and 'Content' keys."""
return [{"Role": message.role.value, "Content": message.content} for message in prompt_messages]
dict_list = []
for message in prompt_messages:
if isinstance(message, AssistantPromptMessage):
tool_calls = message.tool_calls
if (tool_calls and len(tool_calls) > 0):
dict_tool_calls = [
{
"Id": tool_call.id,
"Type": tool_call.type,
"Function": {
"Name": tool_call.function.name,
"Arguments": tool_call.function.arguments if (tool_call.function.arguments == "") else "{}"
}
} for tool_call in tool_calls]
dict_list.append({
"Role": message.role.value,
# fix set content = "" while tool_call request
# fix [hunyuan] None, [TencentCloudSDKException] code:InvalidParameter message:Messages Content and Contents not allowed empty at the same time.
"Content": " ", # message.content if (message.content is not None) else "",
"ToolCalls": dict_tool_calls
})
else:
dict_list.append({ "Role": message.role.value, "Content": message.content })
elif isinstance(message, ToolPromptMessage):
tool_execute_result = { "result": message.content }
content =json.dumps(tool_execute_result, ensure_ascii=False)
dict_list.append({ "Role": message.role.value, "Content": content, "ToolCallId": message.tool_call_id })
else:
dict_list.append({ "Role": message.role.value, "Content": message.content })
return dict_list
def _handle_stream_chat_response(self, model, credentials, prompt_messages, resp):
tool_call = None
tool_calls = []
for index, event in enumerate(resp):
logging.debug("_handle_stream_chat_response, event: %s", event)
@ -109,20 +155,54 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
usage = data.get('Usage', {})
prompt_tokens = usage.get('PromptTokens', 0)
completion_tokens = usage.get('CompletionTokens', 0)
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
response_tool_calls = delta.get('ToolCalls')
if (response_tool_calls is not None):
new_tool_calls = self._extract_response_tool_calls(response_tool_calls)
if (len(new_tool_calls) > 0):
new_tool_call = new_tool_calls[0]
if (tool_call is None): tool_call = new_tool_call
elif (tool_call.id != new_tool_call.id):
tool_calls.append(tool_call)
tool_call = new_tool_call
else:
tool_call.function.name += new_tool_call.function.name
tool_call.function.arguments += new_tool_call.function.arguments
if (tool_call is not None and len(tool_call.function.name) > 0 and len(tool_call.function.arguments) > 0):
tool_calls.append(tool_call)
tool_call = None
assistant_prompt_message = AssistantPromptMessage(
content=message_content,
tool_calls=[]
)
# rewrite content = "" while tool_call to avoid show content on web page
if (len(tool_calls) > 0): assistant_prompt_message.content = ""
# add tool_calls to assistant_prompt_message
if (finish_reason == 'tool_calls'):
assistant_prompt_message.tool_calls = tool_calls
tool_call = None
tool_calls = []
delta_chunk = LLMResultChunkDelta(
index=index,
role=delta.get('Role', 'assistant'),
message=assistant_prompt_message,
usage=usage,
finish_reason=finish_reason,
)
if (len(finish_reason) > 0):
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
delta_chunk = LLMResultChunkDelta(
index=index,
role=delta.get('Role', 'assistant'),
message=assistant_prompt_message,
usage=usage,
finish_reason=finish_reason,
)
tool_call = None
tool_calls = []
else:
delta_chunk = LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
)
yield LLMResultChunk(
model=model,
@ -177,12 +257,15 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
"""
human_prompt = "\n\nHuman:"
ai_prompt = "\n\nAssistant:"
tool_prompt = "\n\nTool:"
content = message.content
if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
elif isinstance(message, AssistantPromptMessage):
message_text = f"{ai_prompt} {content}"
elif isinstance(message, ToolPromptMessage):
message_text = f"{tool_prompt} {content}"
elif isinstance(message, SystemPromptMessage):
message_text = content
else:
@ -203,3 +286,30 @@ class HunyuanLargeLanguageModel(LargeLanguageModel):
return {
InvokeError: [TencentCloudSDKException],
}
def _extract_response_tool_calls(self,
response_tool_calls: list[dict]) \
-> list[AssistantPromptMessage.ToolCall]:
"""
Extract tool calls from response
:param response_tool_calls: response tool calls
:return: list of tool calls
"""
tool_calls = []
if response_tool_calls:
for response_tool_call in response_tool_calls:
response_function = response_tool_call.get('Function', {})
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_function.get('Name', ''),
arguments=response_function.get('Arguments', '')
)
tool_call = AssistantPromptMessage.ToolCall(
id=response_tool_call.get('Id', 0),
type='function',
function=function
)
tool_calls.append(tool_call)
return tool_calls