fix: minimax streaming function_call message (#4271)

This commit is contained in:
Weaxs 2024-05-11 21:07:22 +08:00 committed by GitHub
parent a80fe20456
commit 8cc492721b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -20,16 +20,16 @@ class MinimaxChatCompletionPro:
Minimax Chat Completion Pro API, supports function calling Minimax Chat Completion Pro API, supports function calling
however, we do not have enough time and energy to implement it, but the parameters are reserved however, we do not have enough time and energy to implement it, but the parameters are reserved
""" """
def generate(self, model: str, api_key: str, group_id: str, def generate(self, model: str, api_key: str, group_id: str,
prompt_messages: list[MinimaxMessage], model_parameters: dict, prompt_messages: list[MinimaxMessage], model_parameters: dict,
tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \
-> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]:
""" """
generate chat completion generate chat completion
""" """
if not api_key or not group_id: if not api_key or not group_id:
raise InvalidAPIKeyError('Invalid API key or group ID') raise InvalidAPIKeyError('Invalid API key or group ID')
url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}' url = f'https://api.minimax.chat/v1/text/chatcompletion_pro?GroupId={group_id}'
extra_kwargs = {} extra_kwargs = {}
@ -42,7 +42,7 @@ class MinimaxChatCompletionPro:
if 'top_p' in model_parameters and type(model_parameters['top_p']) == float: if 'top_p' in model_parameters and type(model_parameters['top_p']) == float:
extra_kwargs['top_p'] = model_parameters['top_p'] extra_kwargs['top_p'] = model_parameters['top_p']
if 'plugin_web_search' in model_parameters and model_parameters['plugin_web_search']: if 'plugin_web_search' in model_parameters and model_parameters['plugin_web_search']:
extra_kwargs['plugins'] = [ extra_kwargs['plugins'] = [
'plugin_web_search' 'plugin_web_search'
@ -61,7 +61,7 @@ class MinimaxChatCompletionPro:
# check if there is a system message # check if there is a system message
if len(prompt_messages) == 0: if len(prompt_messages) == 0:
raise BadRequestError('At least one message is required') raise BadRequestError('At least one message is required')
if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value: if prompt_messages[0].role == MinimaxMessage.Role.SYSTEM.value:
if prompt_messages[0].content: if prompt_messages[0].content:
bot_setting['content'] = prompt_messages[0].content bot_setting['content'] = prompt_messages[0].content
@ -70,7 +70,7 @@ class MinimaxChatCompletionPro:
# check if there is a user message # check if there is a user message
if len(prompt_messages) == 0: if len(prompt_messages) == 0:
raise BadRequestError('At least one user message is required') raise BadRequestError('At least one user message is required')
messages = [message.to_dict() for message in prompt_messages] messages = [message.to_dict() for message in prompt_messages]
headers = { headers = {
@ -89,21 +89,21 @@ class MinimaxChatCompletionPro:
if tools: if tools:
body['functions'] = tools body['functions'] = tools
body['function_call'] = { 'type': 'auto' } body['function_call'] = {'type': 'auto'}
try: try:
response = post( response = post(
url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300))
except Exception as e: except Exception as e:
raise InternalServerError(e) raise InternalServerError(e)
if response.status_code != 200: if response.status_code != 200:
raise InternalServerError(response.text) raise InternalServerError(response.text)
if stream: if stream:
return self._handle_stream_chat_generate_response(response) return self._handle_stream_chat_generate_response(response)
return self._handle_chat_generate_response(response) return self._handle_chat_generate_response(response)
def _handle_error(self, code: int, msg: str): def _handle_error(self, code: int, msg: str):
if code == 1000 or code == 1001 or code == 1013 or code == 1027: if code == 1000 or code == 1001 or code == 1013 or code == 1027:
raise InternalServerError(msg) raise InternalServerError(msg)
@ -127,7 +127,7 @@ class MinimaxChatCompletionPro:
code = response['base_resp']['status_code'] code = response['base_resp']['status_code']
msg = response['base_resp']['status_msg'] msg = response['base_resp']['status_msg']
self._handle_error(code, msg) self._handle_error(code, msg)
message = MinimaxMessage( message = MinimaxMessage(
content=response['reply'], content=response['reply'],
role=MinimaxMessage.Role.ASSISTANT.value role=MinimaxMessage.Role.ASSISTANT.value
@ -144,7 +144,6 @@ class MinimaxChatCompletionPro:
""" """
handle stream chat generate response handle stream chat generate response
""" """
function_call_storage = None
for line in response.iter_lines(): for line in response.iter_lines():
if not line: if not line:
continue continue
@ -158,54 +157,41 @@ class MinimaxChatCompletionPro:
msg = data['base_resp']['status_msg'] msg = data['base_resp']['status_msg']
self._handle_error(code, msg) self._handle_error(code, msg)
# final chunk
if data['reply'] or 'usage' in data and data['usage']: if data['reply'] or 'usage' in data and data['usage']:
total_tokens = data['usage']['total_tokens'] total_tokens = data['usage']['total_tokens']
message = MinimaxMessage( minimax_message = MinimaxMessage(
role=MinimaxMessage.Role.ASSISTANT.value, role=MinimaxMessage.Role.ASSISTANT.value,
content='' content=''
) )
message.usage = { minimax_message.usage = {
'prompt_tokens': 0, 'prompt_tokens': 0,
'completion_tokens': total_tokens, 'completion_tokens': total_tokens,
'total_tokens': total_tokens 'total_tokens': total_tokens
} }
message.stop_reason = data['choices'][0]['finish_reason'] minimax_message.stop_reason = data['choices'][0]['finish_reason']
if function_call_storage: choices = data.get('choices', [])
function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value) if len(choices) > 0:
function_call_message.function_call = function_call_storage for choice in choices:
yield function_call_message message = choice['messages'][0]
# append function_call message
if 'function_call' in message:
function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
function_call_message.function_call = message['function_call']
yield function_call_message
yield message yield minimax_message
return return
# partial chunk
choices = data.get('choices', []) choices = data.get('choices', [])
if len(choices) == 0: if len(choices) == 0:
continue continue
for choice in choices: for choice in choices:
message = choice['messages'][0] message = choice['messages'][0]
# append text message
if 'function_call' in message:
if not function_call_storage:
function_call_storage = message['function_call']
if 'arguments' not in function_call_storage or not function_call_storage['arguments']:
function_call_storage['arguments'] = ''
continue
else:
function_call_storage['arguments'] += message['function_call']['arguments']
continue
else:
if function_call_storage:
message['function_call'] = function_call_storage
function_call_storage = None
minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value)
if 'function_call' in message:
minimax_message.function_call = message['function_call']
if 'text' in message: if 'text' in message:
minimax_message.content = message['text'] minimax_message = MinimaxMessage(content=message['text'], role=MinimaxMessage.Role.ASSISTANT.value)
yield minimax_message
yield minimax_message