fix: resolve issue with cot_agent_runner not analyzing user-uploaded images correctly (#5360)

Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
Xiao Ley 2024-06-18 18:15:41 +08:00 committed by GitHub
parent 4e3d76a1d1
commit 369a395ee9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 9 deletions

View File

@ -61,8 +61,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
# convert tools into ModelRuntime Tool format
tool_instances, self._prompt_messages_tools = self._init_prompt_tools()
prompt_messages = self._organize_prompt_messages()
function_call_state = True
llm_usage = {
'usage': None

View File

@ -5,6 +5,7 @@ from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.utils.encoders import jsonable_encoder
@ -25,6 +26,21 @@ class CotChatAgentRunner(CotAgentRunner):
return SystemPromptMessage(content=system_prompt)
def _organize_user_query(self, query, prompt_messages: list[PromptMessage] = None) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
prompt_message_contents = [TextPromptMessageContent(data=query)]
for file_obj in self.files:
prompt_message_contents.append(file_obj.prompt_message_content)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize
@ -51,27 +67,27 @@ class CotChatAgentRunner(CotAgentRunner):
assistant_messages = [assistant_message]
# query messages
query_messages = UserPromptMessage(content=self._query)
query_messages = self._organize_user_query(self._query, [])
if assistant_messages:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([
system_message,
query_messages,
*query_messages,
*assistant_messages,
UserPromptMessage(content='continue')
])
])
messages = [
system_message,
*historic_messages,
query_messages,
*query_messages,
*assistant_messages,
UserPromptMessage(content='continue')
]
else:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([system_message, query_messages])
messages = [system_message, *historic_messages, query_messages]
historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
messages = [system_message, *historic_messages, *query_messages]
# join all messages
return messages
return messages