mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-04 04:07:47 +08:00
fix: minimax streaming function_call message (#4271)
This commit is contained in:
parent
a80fe20456
commit
8cc492721b
@ -20,16 +20,16 @@ class MinimaxChatCompletionPro:
|
|||||||
Minimax Chat Completion Pro API, supports function calling
|
Minimax Chat Completion Pro API, supports function calling
|
||||||
however, we do not have enough time and energy to implement it, but the parameters are reserved
|
however, we do not have enough time and energy to implement it, but the parameters are reserved
|
||||||
"""
|
"""
|
||||||
def generate(self, model: str, api_key: str, group_id: str,
|
def generate(self, model: str, api_key: str, group_id: str,
|
||||||
prompt_messages: list[MinimaxMessage], model_parameters: dict,
|
prompt_messages: list[MinimaxMessage], model_parameters: dict,
|
||||||
tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \
|
tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \
|
||||||
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
|
||||||
"""
|
"""
|
||||||
generate chat completion
|
generate chat completion
|
||||||
"""
|
"""
|
||||||
if not api_key or not group_id:
|
if not api_key or not group_id:
|
||||||
raise InvalidAPIKeyError('Invalid API key or group ID')
|
raise InvalidAPIKeyError('Invalid API key or group ID')
|
||||||
|
|
||||||
url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}'
|
url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}'
|
||||||
|
|
||||||
extra_kwargs = {}
|
extra_kwargs = {}
|
||||||
@ -42,7 +42,7 @@ class MinimaxChatCompletionPro:
|
|||||||
|
|
||||||
if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
|
if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
|
||||||
extra_kwargs['top_p'] = model_parameters['top_p']
|
extra_kwargs['top_p'] = model_parameters['top_p']
|
||||||
|
|
||||||
if 'plugin_web_search' in model_parameters and model_parameters['plugin_web_search']:
|
if 'plugin_web_search' in model_parameters and model_parameters['plugin_web_search']:
|
||||||
extra_kwargs['plugins'] = [
|
extra_kwargs['plugins'] = [
|
||||||
'plugin_web_search'
|
'plugin_web_search'
|
||||||
@ -61,7 +61,7 @@ class MinimaxChatCompletionPro:
|
|||||||
# check if there is a system message
|
# check if there is a system message
|
||||||
if len(prompt_messages) == 0:
|
if len(prompt_messages) == 0:
|
||||||
raise BadRequestError('At least one message is required')
|
raise BadRequestError('At least one message is required')
|
||||||
|
|
||||||
if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value:
|
if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value:
|
||||||
if prompt_messages[0].content:
|
if prompt_messages[0].content:
|
||||||
bot_setting['content'] = prompt_messages[0].content
|
bot_setting['content'] = prompt_messages[0].content
|
||||||
@ -70,7 +70,7 @@ class MinimaxChatCompletionPro:
|
|||||||
# check if there is a user message
|
# check if there is a user message
|
||||||
if len(prompt_messages) == 0:
|
if len(prompt_messages) == 0:
|
||||||
raise BadRequestError('At least one user message is required')
|
raise BadRequestError('At least one user message is required')
|
||||||
|
|
||||||
messages = [message.to_dict() for message in prompt_messages]
|
messages = [message.to_dict() for message in prompt_messages]
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
@ -89,21 +89,21 @@ class MinimaxChatCompletionPro:
|
|||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
body['functions'] = tools
|
body['functions'] = tools
|
||||||
body['function_call'] = { 'type': 'auto' }
|
body['function_call'] = {'type': 'auto'}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = post(
|
response = post(
|
||||||
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
|
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise InternalServerError(e)
|
raise InternalServerError(e)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise InternalServerError(response.text)
|
raise InternalServerError(response.text)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._handle_stream_chat_generate_response(response)
|
return self._handle_stream_chat_generate_response(response)
|
||||||
return self._handle_chat_generate_response(response)
|
return self._handle_chat_generate_response(response)
|
||||||
|
|
||||||
def _handle_error(self, code: int, msg: str):
|
def _handle_error(self, code: int, msg: str):
|
||||||
if code == 1000 or code == 1001 or code == 1013 or code == 1027:
|
if code == 1000 or code == 1001 or code == 1013 or code == 1027:
|
||||||
raise InternalServerError(msg)
|
raise InternalServerError(msg)
|
||||||
@ -127,7 +127,7 @@ class MinimaxChatCompletionPro:
|
|||||||
code = response['base_resp']['status_code']
|
code = response['base_resp']['status_code']
|
||||||
msg = response['base_resp']['status_msg']
|
msg = response['base_resp']['status_msg']
|
||||||
self._handle_error(code, msg)
|
self._handle_error(code, msg)
|
||||||
|
|
||||||
message = MinimaxMessage(
|
message = MinimaxMessage(
|
||||||
content=response['reply'],
|
content=response['reply'],
|
||||||
role=MinimaxMessage.Role.ASSISTANT.value
|
role=MinimaxMessage.Role.ASSISTANT.value
|
||||||
@ -144,7 +144,6 @@ class MinimaxChatCompletionPro:
|
|||||||
"""
|
"""
|
||||||
handle stream chat generate response
|
handle stream chat generate response
|
||||||
"""
|
"""
|
||||||
function_call_storage = None
|
|
||||||
for line in response.iter_lines():
|
for line in response.iter_lines():
|
||||||
if not line:
|
if not line:
|
||||||
continue
|
continue
|
||||||
@ -158,54 +157,41 @@ class MinimaxChatCompletionPro:
|
|||||||
msg = data['base_resp']['status_msg']
|
msg = data['base_resp']['status_msg']
|
||||||
self._handle_error(code, msg)
|
self._handle_error(code, msg)
|
||||||
|
|
||||||
|
# final chunk
|
||||||
if data['reply'] or 'usage' in data and data['usage']:
|
if data['reply'] or 'usage' in data and data['usage']:
|
||||||
total_tokens = data['usage']['total_tokens']
|
total_tokens = data['usage']['total_tokens']
|
||||||
message = MinimaxMessage(
|
minimax_message = MinimaxMessage(
|
||||||
role=MinimaxMessage.Role.ASSISTANT.value,
|
role=MinimaxMessage.Role.ASSISTANT.value,
|
||||||
content=''
|
content=''
|
||||||
)
|
)
|
||||||
message.usage = {
|
minimax_message.usage = {
|
||||||
'prompt_tokens': 0,
|
'prompt_tokens': 0,
|
||||||
'completion_tokens': total_tokens,
|
'completion_tokens': total_tokens,
|
||||||
'total_tokens': total_tokens
|
'total_tokens': total_tokens
|
||||||
}
|
}
|
||||||
message.stop_reason = data['choices'][0]['finish_reason']
|
minimax_message.stop_reason = data['choices'][0]['finish_reason']
|
||||||
|
|
||||||
if function_call_storage:
|
choices = data.get('choices', [])
|
||||||
function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
|
if len(choices) > 0:
|
||||||
function_call_message.function_call = function_call_storage
|
for choice in choices:
|
||||||
yield function_call_message
|
message = choice['messages'][0]
|
||||||
|
# append function_call message
|
||||||
|
if 'function_call' in message:
|
||||||
|
function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
|
||||||
|
function_call_message.function_call = message['function_call']
|
||||||
|
yield function_call_message
|
||||||
|
|
||||||
yield message
|
yield minimax_message
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# partial chunk
|
||||||
choices = data.get('choices', [])
|
choices = data.get('choices', [])
|
||||||
if len(choices) == 0:
|
if len(choices) == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
message = choice['messages'][0]
|
message = choice['messages'][0]
|
||||||
|
# append text message
|
||||||
if 'function_call' in message:
|
|
||||||
if not function_call_storage:
|
|
||||||
function_call_storage = message['function_call']
|
|
||||||
if 'arguments' not in function_call_storage or not function_call_storage['arguments']:
|
|
||||||
function_call_storage['arguments'] = ''
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
function_call_storage['arguments'] += message['function_call']['arguments']
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
if function_call_storage:
|
|
||||||
message['function_call'] = function_call_storage
|
|
||||||
function_call_storage = None
|
|
||||||
|
|
||||||
minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
|
|
||||||
|
|
||||||
if 'function_call' in message:
|
|
||||||
minimax_message.function_call = message['function_call']
|
|
||||||
|
|
||||||
if 'text' in message:
|
if 'text' in message:
|
||||||
minimax_message.content = message['text']
|
minimax_message = MinimaxMessage(content=message['text'], role=MinimaxMessage.Role.ASSISTANT.value)
|
||||||
|
yield minimax_message
|
||||||
yield minimax_message
|
|
||||||
|
Loading…
Reference in New Issue
Block a user