feat: add o1-series models support in Agent App (ReACT only) (#8350)

This commit is contained in:
takatost 2024-09-13 13:08:27 +08:00 committed by GitHub
parent 8d2269f762
commit 4637ddaa7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -620,6 +620,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
if "stream_options" in extra_model_kwargs:
del extra_model_kwargs["stream_options"]
if "stop" in extra_model_kwargs:
del extra_model_kwargs["stop"]
# chat model
response = client.chat.completions.create(
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
@ -635,7 +638,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools)
if block_as_stream:
return self._handle_chat_block_as_stream_response(block_result, prompt_messages)
return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop)
return block_result
@ -643,6 +646,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
self,
block_result: LLMResult,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None,
) -> Generator[LLMResultChunk, None, None]:
"""
Handle llm chat response
@ -652,15 +656,22 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
:param response: response
:param prompt_messages: prompt messages
:param tools: tools for tool calling
:param stop: stop words
:return: llm response chunk generator
"""
text = block_result.message.content
text = cast(str, text)
if stop:
text = self.enforce_stop_tokens(text, stop)
yield LLMResultChunk(
model=block_result.model,
prompt_messages=prompt_messages,
system_fingerprint=block_result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=block_result.message,
message=AssistantPromptMessage(content=text),
finish_reason="stop",
usage=block_result.usage,
),
@ -912,6 +923,20 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
]
)
if model.startswith("o1"):
system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)])
if system_message_count > 0:
new_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message, SystemPromptMessage):
prompt_message = UserPromptMessage(
content=prompt_message.content,
name=prompt_message.name,
)
new_prompt_messages.append(prompt_message)
prompt_messages = new_prompt_messages
return prompt_messages
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: