mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-02 11:18:19 +08:00
Feat/zhipuai function calling (#2199)
Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
parent
bdc5e9ceb0
commit
b921c55677
@ -1,61 +0,0 @@
|
|||||||
"""Wrapper around ZhipuAI APIs."""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import posixpath
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra
|
|
||||||
from zhipuai.model_api.api import InvokeType
|
|
||||||
from zhipuai.utils import jwt_token
|
|
||||||
from zhipuai.utils.http_client import post, stream
|
|
||||||
from zhipuai.utils.sse_client import SSEClient
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ZhipuModelAPI(BaseModel):
|
|
||||||
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
|
|
||||||
api_key: str
|
|
||||||
api_timeout_seconds = 60
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.forbid
|
|
||||||
|
|
||||||
def invoke(self, **kwargs):
|
|
||||||
url = self._build_api_url(kwargs, InvokeType.SYNC)
|
|
||||||
response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds)
|
|
||||||
if not response['success']:
|
|
||||||
raise ValueError(
|
|
||||||
f"Error Code: {response['code']}, Message: {response['msg']} "
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
|
|
||||||
def sse_invoke(self, **kwargs):
|
|
||||||
url = self._build_api_url(kwargs, InvokeType.SSE)
|
|
||||||
data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds)
|
|
||||||
return SSEClient(data)
|
|
||||||
|
|
||||||
def _build_api_url(self, kwargs, *path):
|
|
||||||
if kwargs:
|
|
||||||
if "model" not in kwargs:
|
|
||||||
raise Exception("model param missed")
|
|
||||||
model = kwargs.pop("model")
|
|
||||||
else:
|
|
||||||
model = "-"
|
|
||||||
|
|
||||||
return posixpath.join(self.base_url, model, *path)
|
|
||||||
|
|
||||||
def _generate_token(self):
|
|
||||||
if not self.api_key:
|
|
||||||
raise Exception(
|
|
||||||
"api_key not provided, you could provide it."
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return jwt_token.generate_token(self.api_key)
|
|
||||||
except Exception:
|
|
||||||
raise ValueError(
|
|
||||||
f"Your api_key is invalid, please check it."
|
|
||||||
)
|
|
@ -3,13 +3,15 @@ from typing import Any, Dict, Generator, List, Optional, Union
|
|||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole,
|
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole,
|
||||||
PromptMessageTool, SystemPromptMessage, UserPromptMessage,
|
PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage,
|
||||||
TextPromptMessageContent, ImagePromptMessageContent, PromptMessageContentType)
|
TextPromptMessageContent, ImagePromptMessageContent, PromptMessageContentType)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.utils import helper
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI
|
|
||||||
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
||||||
|
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||||
|
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
|
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion
|
||||||
|
|
||||||
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||||
|
|
||||||
@ -35,7 +37,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
|
|
||||||
# invoke model
|
# invoke model
|
||||||
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, stop, stream, user)
|
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||||
@ -48,7 +50,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
:param tools: tools for tool calling
|
:param tools: tools for tool calling
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
prompt = self._convert_messages_to_prompt(prompt_messages, tools)
|
||||||
|
|
||||||
return self._get_num_tokens_by_gpt2(prompt)
|
return self._get_num_tokens_by_gpt2(prompt)
|
||||||
|
|
||||||
@ -72,6 +74,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
model_parameters={
|
model_parameters={
|
||||||
"temperature": 0.5,
|
"temperature": 0.5,
|
||||||
},
|
},
|
||||||
|
tools=[],
|
||||||
stream=False
|
stream=False
|
||||||
)
|
)
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
@ -79,6 +82,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
|
|
||||||
def _generate(self, model: str, credentials_kwargs: dict,
|
def _generate(self, model: str, credentials_kwargs: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,
|
||||||
stop: Optional[List[str]] = None, stream: bool = True,
|
stop: Optional[List[str]] = None, stream: bool = True,
|
||||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
@ -97,7 +101,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
if stop:
|
if stop:
|
||||||
extra_model_kwargs['stop_sequences'] = stop
|
extra_model_kwargs['stop_sequences'] = stop
|
||||||
|
|
||||||
client = ZhipuModelAPI(
|
client = ZhipuAI(
|
||||||
api_key=credentials_kwargs['api_key']
|
api_key=credentials_kwargs['api_key']
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -128,11 +132,17 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
# not support image message
|
# not support image message
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER:
|
if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \
|
||||||
|
copy_prompt_message.role == PromptMessageRole.USER:
|
||||||
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
||||||
else:
|
else:
|
||||||
if copy_prompt_message.role == PromptMessageRole.USER:
|
if copy_prompt_message.role == PromptMessageRole.USER:
|
||||||
new_prompt_messages.append(copy_prompt_message)
|
new_prompt_messages.append(copy_prompt_message)
|
||||||
|
elif copy_prompt_message.role == PromptMessageRole.TOOL:
|
||||||
|
new_prompt_messages.append(copy_prompt_message)
|
||||||
|
elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
|
||||||
|
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
|
||||||
|
new_prompt_messages.append(new_prompt_message)
|
||||||
else:
|
else:
|
||||||
new_prompt_message = UserPromptMessage(content=copy_prompt_message.content)
|
new_prompt_message = UserPromptMessage(content=copy_prompt_message.content)
|
||||||
new_prompt_messages.append(new_prompt_message)
|
new_prompt_messages.append(new_prompt_message)
|
||||||
@ -145,7 +155,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
if model == 'glm-4v':
|
if model == 'glm-4v':
|
||||||
params = {
|
params = {
|
||||||
'model': model,
|
'model': model,
|
||||||
'prompt': [{
|
'messages': [{
|
||||||
'role': prompt_message.role.value,
|
'role': prompt_message.role.value,
|
||||||
'content':
|
'content':
|
||||||
[
|
[
|
||||||
@ -171,23 +181,63 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
else:
|
else:
|
||||||
params = {
|
params = {
|
||||||
'model': model,
|
'model': model,
|
||||||
'prompt': [{
|
'messages': [],
|
||||||
'role': prompt_message.role.value,
|
|
||||||
'content': prompt_message.content,
|
|
||||||
} for prompt_message in new_prompt_messages],
|
|
||||||
**model_parameters
|
**model_parameters
|
||||||
}
|
}
|
||||||
|
# glm model
|
||||||
|
if not model.startswith('chatglm'):
|
||||||
|
|
||||||
|
for prompt_message in new_prompt_messages:
|
||||||
|
if prompt_message.role == PromptMessageRole.TOOL:
|
||||||
|
params['messages'].append({
|
||||||
|
'role': 'tool',
|
||||||
|
'content': prompt_message.content,
|
||||||
|
'tool_call_id': prompt_message.tool_call_id
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
params['messages'].append({
|
||||||
|
'role': prompt_message.role.value,
|
||||||
|
'content': prompt_message.content
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# chatglm model
|
||||||
|
for prompt_message in new_prompt_messages:
|
||||||
|
# merge system message to user message
|
||||||
|
if prompt_message.role == PromptMessageRole.SYSTEM or \
|
||||||
|
prompt_message.role == PromptMessageRole.TOOL or \
|
||||||
|
prompt_message.role == PromptMessageRole.USER:
|
||||||
|
if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user':
|
||||||
|
params['messages'][-1]['content'] += "\n\n" + prompt_message.content
|
||||||
|
else:
|
||||||
|
params['messages'].append({
|
||||||
|
'role': 'user',
|
||||||
|
'content': prompt_message.content
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
params['messages'].append({
|
||||||
|
'role': prompt_message.role.value,
|
||||||
|
'content': prompt_message.content
|
||||||
|
})
|
||||||
|
|
||||||
|
if tools and len(tools) > 0:
|
||||||
|
params['tools'] = [
|
||||||
|
{
|
||||||
|
'type': 'function',
|
||||||
|
'function': helper.dump_model(tool)
|
||||||
|
} for tool in tools
|
||||||
|
]
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
response = client.sse_invoke(incremental=True, **params).events()
|
response = client.chat.completions.create(stream=stream, **params)
|
||||||
return self._handle_generate_stream_response(model, credentials_kwargs, response, prompt_messages)
|
return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages)
|
||||||
|
|
||||||
response = client.invoke(**params)
|
response = client.chat.completions.create(**params)
|
||||||
return self._handle_generate_response(model, credentials_kwargs, response, prompt_messages)
|
return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
|
||||||
|
|
||||||
def _handle_generate_response(self, model: str,
|
def _handle_generate_response(self, model: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
response: Dict[str, Any],
|
tools: Optional[list[PromptMessageTool]],
|
||||||
|
response: Completion,
|
||||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||||
"""
|
"""
|
||||||
Handle llm response
|
Handle llm response
|
||||||
@ -197,26 +247,39 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
:param prompt_messages: prompt messages
|
:param prompt_messages: prompt messages
|
||||||
:return: llm response
|
:return: llm response
|
||||||
"""
|
"""
|
||||||
data = response["data"]
|
|
||||||
text = ''
|
text = ''
|
||||||
for res in data["choices"]:
|
assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = []
|
||||||
text += res['content']
|
for choice in response.choices:
|
||||||
|
if choice.message.tool_calls:
|
||||||
|
for tool_call in choice.message.tool_calls:
|
||||||
|
if tool_call.type == 'function':
|
||||||
|
assistant_tool_calls.append(
|
||||||
|
AssistantPromptMessage.ToolCall(
|
||||||
|
id=tool_call.id,
|
||||||
|
type=tool_call.type,
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=tool_call.function.name,
|
||||||
|
arguments=tool_call.function.arguments,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
text += choice.message.content or ''
|
||||||
|
|
||||||
token_usage = data.get("usage")
|
prompt_usage = response.usage.prompt_tokens
|
||||||
if token_usage is not None:
|
completion_usage = response.usage.completion_tokens
|
||||||
if 'prompt_tokens' not in token_usage:
|
|
||||||
token_usage['prompt_tokens'] = 0
|
|
||||||
if 'completion_tokens' not in token_usage:
|
|
||||||
token_usage['completion_tokens'] = token_usage['total_tokens']
|
|
||||||
|
|
||||||
# transform usage
|
# transform usage
|
||||||
usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens'])
|
usage = self._calc_response_usage(model, credentials, prompt_usage, completion_usage)
|
||||||
|
|
||||||
# transform response
|
# transform response
|
||||||
result = LLMResult(
|
result = LLMResult(
|
||||||
model=model,
|
model=model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=AssistantPromptMessage(content=text),
|
message=AssistantPromptMessage(
|
||||||
|
content=text,
|
||||||
|
tool_calls=assistant_tool_calls
|
||||||
|
),
|
||||||
usage=usage,
|
usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -224,7 +287,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
|
|
||||||
def _handle_generate_stream_response(self, model: str,
|
def _handle_generate_stream_response(self, model: str,
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
responses: list[Generator],
|
tools: Optional[list[PromptMessageTool]],
|
||||||
|
responses: Generator[ChatCompletionChunk, None, None],
|
||||||
prompt_messages: list[PromptMessage]) -> Generator:
|
prompt_messages: list[PromptMessage]) -> Generator:
|
||||||
"""
|
"""
|
||||||
Handle llm stream response
|
Handle llm stream response
|
||||||
@ -234,39 +298,64 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
:param prompt_messages: prompt messages
|
:param prompt_messages: prompt messages
|
||||||
:return: llm response chunk generator result
|
:return: llm response chunk generator result
|
||||||
"""
|
"""
|
||||||
for index, event in enumerate(responses):
|
full_assistant_content = ''
|
||||||
if event.event == "add":
|
for chunk in responses:
|
||||||
|
if len(chunk.choices) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
delta = chunk.choices[0]
|
||||||
|
|
||||||
|
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
|
||||||
|
continue
|
||||||
|
|
||||||
|
assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = []
|
||||||
|
for tool_call in delta.delta.tool_calls or []:
|
||||||
|
if tool_call.type == 'function':
|
||||||
|
assistant_tool_calls.append(
|
||||||
|
AssistantPromptMessage.ToolCall(
|
||||||
|
id=tool_call.id,
|
||||||
|
type=tool_call.type,
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=tool_call.function.name,
|
||||||
|
arguments=tool_call.function.arguments,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# transform assistant message to prompt message
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=delta.delta.content if delta.delta.content else '',
|
||||||
|
tool_calls=assistant_tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
|
full_assistant_content += delta.delta.content if delta.delta.content else ''
|
||||||
|
|
||||||
|
if delta.finish_reason is not None and chunk.usage is not None:
|
||||||
|
completion_tokens = chunk.usage.completion_tokens
|
||||||
|
prompt_tokens = chunk.usage.prompt_tokens
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
|
model=chunk.model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
model=model,
|
system_fingerprint='',
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=index,
|
index=delta.index,
|
||||||
message=AssistantPromptMessage(content=event.data)
|
message=assistant_prompt_message,
|
||||||
|
finish_reason=delta.finish_reason,
|
||||||
|
usage=usage
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif event.event == "error" or event.event == "interrupted":
|
else:
|
||||||
raise ValueError(
|
|
||||||
f"{event.data}"
|
|
||||||
)
|
|
||||||
elif event.event == "finish":
|
|
||||||
meta = json.loads(event.meta)
|
|
||||||
token_usage = meta['usage']
|
|
||||||
if token_usage is not None:
|
|
||||||
if 'prompt_tokens' not in token_usage:
|
|
||||||
token_usage['prompt_tokens'] = 0
|
|
||||||
if 'completion_tokens' not in token_usage:
|
|
||||||
token_usage['completion_tokens'] = token_usage['total_tokens']
|
|
||||||
|
|
||||||
usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens'])
|
|
||||||
|
|
||||||
yield LLMResultChunk(
|
yield LLMResultChunk(
|
||||||
model=model,
|
model=chunk.model,
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
|
system_fingerprint='',
|
||||||
delta=LLMResultChunkDelta(
|
delta=LLMResultChunkDelta(
|
||||||
index=index,
|
index=delta.index,
|
||||||
message=AssistantPromptMessage(content=event.data),
|
message=assistant_prompt_message,
|
||||||
finish_reason='finish',
|
|
||||||
usage=usage
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -291,11 +380,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
return message_text
|
return message_text
|
||||||
|
|
||||||
def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
|
|
||||||
"""
|
|
||||||
Format a list of messages into a full prompt for the Anthropic model
|
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_messages_to_prompt(self, messages: List[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str:
|
||||||
|
"""
|
||||||
:param messages: List of PromptMessage to combine.
|
:param messages: List of PromptMessage to combine.
|
||||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||||
"""
|
"""
|
||||||
@ -306,5 +394,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||||||
for message in messages
|
for message in messages
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if tools and len(tools) > 0:
|
||||||
|
text += "\n\nTools:"
|
||||||
|
for tool in tools:
|
||||||
|
text += f"\n{tool.json()}"
|
||||||
|
|
||||||
# trim off the trailing ' ' that might come from the "Assistant: "
|
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||||
return text.rstrip()
|
return text.rstrip()
|
@ -5,7 +5,7 @@ from core.model_runtime.entities.model_entities import PriceType
|
|||||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI
|
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||||
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
||||||
from langchain.schema.language_model import _get_token_ids_default_method
|
from langchain.schema.language_model import _get_token_ids_default_method
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
|||||||
:return: embeddings result
|
:return: embeddings result
|
||||||
"""
|
"""
|
||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
client = ZhipuModelAPI(
|
client = ZhipuAI(
|
||||||
api_key=credentials_kwargs['api_key']
|
api_key=credentials_kwargs['api_key']
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
|||||||
try:
|
try:
|
||||||
# transform credentials to kwargs for model instance
|
# transform credentials to kwargs for model instance
|
||||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||||
client = ZhipuModelAPI(
|
client = ZhipuAI(
|
||||||
api_key=credentials_kwargs['api_key']
|
api_key=credentials_kwargs['api_key']
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -82,7 +82,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
|||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
def embed_documents(self, model: str, client: ZhipuModelAPI, texts: List[str]) -> Tuple[List[List[float]], int]:
|
def embed_documents(self, model: str, client: ZhipuAI, texts: List[str]) -> Tuple[List[List[float]], int]:
|
||||||
"""Call out to ZhipuAI's embedding endpoint.
|
"""Call out to ZhipuAI's embedding endpoint.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -91,17 +91,16 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
|||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
embedding_used_tokens = 0
|
||||||
|
|
||||||
for text in texts:
|
for text in texts:
|
||||||
response = client.invoke(model=model, prompt=text)
|
response = client.embeddings.create(model=model, input=text)
|
||||||
data = response["data"]
|
data = response.data[0]
|
||||||
embeddings.append(data.get('embedding'))
|
embeddings.append(data.embedding)
|
||||||
|
embedding_used_tokens += response.usage.total_tokens
|
||||||
|
|
||||||
embedding_used_tokens = data.get('usage')
|
return [list(map(float, e)) for e in embeddings], embedding_used_tokens
|
||||||
|
|
||||||
return [list(map(float, e)) for e in embeddings], embedding_used_tokens['total_tokens'] if embedding_used_tokens else 0
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Call out to ZhipuAI's embedding endpoint.
|
"""Call out to ZhipuAI's embedding endpoint.
|
||||||
|
@ -0,0 +1,17 @@
|
|||||||
|
|
||||||
|
from ._client import ZhipuAI
|
||||||
|
|
||||||
|
from .core._errors import (
|
||||||
|
ZhipuAIError,
|
||||||
|
APIStatusError,
|
||||||
|
APIRequestFailedError,
|
||||||
|
APIAuthenticationError,
|
||||||
|
APIReachLimitError,
|
||||||
|
APIInternalError,
|
||||||
|
APIServerFlowExceedError,
|
||||||
|
APIResponseError,
|
||||||
|
APIResponseValidationError,
|
||||||
|
APITimeoutError,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .__version__ import __version__
|
@ -0,0 +1,2 @@
|
|||||||
|
|
||||||
|
__version__ = 'v2.0.1'
|
@ -0,0 +1,71 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Union, Mapping
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from .core import _jwt_token
|
||||||
|
from .core._errors import ZhipuAIError
|
||||||
|
from .core._http_client import HttpClient, ZHIPUAI_DEFAULT_MAX_RETRIES
|
||||||
|
from .core._base_type import NotGiven, NOT_GIVEN
|
||||||
|
from . import api_resource
|
||||||
|
import os
|
||||||
|
import httpx
|
||||||
|
from httpx import Timeout
|
||||||
|
|
||||||
|
|
||||||
|
class ZhipuAI(HttpClient):
|
||||||
|
chat: api_resource.chat
|
||||||
|
api_key: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | httpx.URL | None = None,
|
||||||
|
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
|
||||||
|
max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
|
||||||
|
http_client: httpx.Client | None = None,
|
||||||
|
custom_headers: Mapping[str, str] | None = None
|
||||||
|
) -> None:
|
||||||
|
# if api_key is None:
|
||||||
|
# api_key = os.environ.get("ZHIPUAI_API_KEY")
|
||||||
|
if api_key is None:
|
||||||
|
raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供")
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
if base_url is None:
|
||||||
|
base_url = os.environ.get("ZHIPUAI_BASE_URL")
|
||||||
|
if base_url is None:
|
||||||
|
base_url = f"https://open.bigmodel.cn/api/paas/v4"
|
||||||
|
from .__version__ import __version__
|
||||||
|
super().__init__(
|
||||||
|
version=__version__,
|
||||||
|
base_url=base_url,
|
||||||
|
timeout=timeout,
|
||||||
|
custom_httpx_client=http_client,
|
||||||
|
custom_headers=custom_headers,
|
||||||
|
)
|
||||||
|
self.chat = api_resource.chat.Chat(self)
|
||||||
|
self.images = api_resource.images.Images(self)
|
||||||
|
self.embeddings = api_resource.embeddings.Embeddings(self)
|
||||||
|
self.files = api_resource.files.Files(self)
|
||||||
|
self.fine_tuning = api_resource.fine_tuning.FineTuning(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@override
|
||||||
|
def _auth_headers(self) -> dict[str, str]:
|
||||||
|
api_key = self.api_key
|
||||||
|
return {"Authorization": f"{_jwt_token.generate_token(api_key)}"}
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
if (not hasattr(self, "_has_custom_http_client")
|
||||||
|
or not hasattr(self, "close")
|
||||||
|
or not hasattr(self, "_client")):
|
||||||
|
# if the '__init__' method raised an error, self would not have client attr
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._has_custom_http_client:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.close()
|
@ -0,0 +1,5 @@
|
|||||||
|
from .chat import chat
|
||||||
|
from .images import Images
|
||||||
|
from .embeddings import Embeddings
|
||||||
|
from .files import Files
|
||||||
|
from .fine_tuning import fine_tuning
|
@ -0,0 +1,87 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Union, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
from ...core._base_api import BaseAPI
|
||||||
|
from ...core._base_type import NotGiven, NOT_GIVEN, Headers
|
||||||
|
from ...core._http_client import make_user_request_input
|
||||||
|
from ...types.chat.async_chat_completion import AsyncTaskStatus, AsyncCompletion
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..._client import ZhipuAI
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCompletions(BaseAPI):
|
||||||
|
def __init__(self, client: "ZhipuAI") -> None:
|
||||||
|
super().__init__(client)
|
||||||
|
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||||
|
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||||
|
temperature: Optional[float] | NotGiven = NOT_GIVEN,
|
||||||
|
top_p: Optional[float] | NotGiven = NOT_GIVEN,
|
||||||
|
max_tokens: int | NotGiven = NOT_GIVEN,
|
||||||
|
seed: int | NotGiven = NOT_GIVEN,
|
||||||
|
messages: Union[str, List[str], List[int], List[List[int]], None],
|
||||||
|
stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN,
|
||||||
|
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
|
||||||
|
tools: Optional[object] | NotGiven = NOT_GIVEN,
|
||||||
|
tool_choice: str | NotGiven = NOT_GIVEN,
|
||||||
|
extra_headers: Headers | None = None,
|
||||||
|
disable_strict_validation: Optional[bool] | None = None,
|
||||||
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
) -> AsyncTaskStatus:
|
||||||
|
_cast_type = AsyncTaskStatus
|
||||||
|
|
||||||
|
if disable_strict_validation:
|
||||||
|
_cast_type = object
|
||||||
|
return self._post(
|
||||||
|
"/async/chat/completions",
|
||||||
|
body={
|
||||||
|
"model": model,
|
||||||
|
"request_id": request_id,
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_p": top_p,
|
||||||
|
"do_sample": do_sample,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"seed": seed,
|
||||||
|
"messages": messages,
|
||||||
|
"stop": stop,
|
||||||
|
"sensitive_word_check": sensitive_word_check,
|
||||||
|
"tools": tools,
|
||||||
|
"tool_choice": tool_choice,
|
||||||
|
},
|
||||||
|
options=make_user_request_input(
|
||||||
|
extra_headers=extra_headers, timeout=timeout
|
||||||
|
),
|
||||||
|
cast_type=_cast_type,
|
||||||
|
enable_stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def retrieve_completion_result(
|
||||||
|
self,
|
||||||
|
id: str,
|
||||||
|
extra_headers: Headers | None = None,
|
||||||
|
disable_strict_validation: Optional[bool] | None = None,
|
||||||
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
) -> Union[AsyncCompletion, AsyncTaskStatus]:
|
||||||
|
_cast_type = Union[AsyncCompletion,AsyncTaskStatus]
|
||||||
|
if disable_strict_validation:
|
||||||
|
_cast_type = object
|
||||||
|
return self._get(
|
||||||
|
path=f"/async-result/{id}",
|
||||||
|
cast_type=_cast_type,
|
||||||
|
options=make_user_request_input(
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,16 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from .completions import Completions
|
||||||
|
from .async_completions import AsyncCompletions
|
||||||
|
from ...core._base_api import BaseAPI
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..._client import ZhipuAI
|
||||||
|
|
||||||
|
|
||||||
|
class Chat(BaseAPI):
|
||||||
|
completions: Completions
|
||||||
|
|
||||||
|
def __init__(self, client: "ZhipuAI") -> None:
|
||||||
|
super().__init__(client)
|
||||||
|
self.completions = Completions(client)
|
||||||
|
self.asyncCompletions = AsyncCompletions(client)
|
@ -0,0 +1,71 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Union, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
from ...core._base_api import BaseAPI
|
||||||
|
from ...core._base_type import NotGiven, NOT_GIVEN, Headers
|
||||||
|
from ...core._http_client import make_user_request_input
|
||||||
|
from ...core._sse_client import StreamResponse
|
||||||
|
from ...types.chat.chat_completion import Completion
|
||||||
|
from ...types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..._client import ZhipuAI
|
||||||
|
|
||||||
|
|
||||||
|
class Completions(BaseAPI):
|
||||||
|
def __init__(self, client: "ZhipuAI") -> None:
|
||||||
|
super().__init__(client)
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||||
|
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||||
|
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||||
|
temperature: Optional[float] | NotGiven = NOT_GIVEN,
|
||||||
|
top_p: Optional[float] | NotGiven = NOT_GIVEN,
|
||||||
|
max_tokens: int | NotGiven = NOT_GIVEN,
|
||||||
|
seed: int | NotGiven = NOT_GIVEN,
|
||||||
|
messages: Union[str, List[str], List[int], object, None],
|
||||||
|
stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN,
|
||||||
|
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
|
||||||
|
tools: Optional[object] | NotGiven = NOT_GIVEN,
|
||||||
|
tool_choice: str | NotGiven = NOT_GIVEN,
|
||||||
|
extra_headers: Headers | None = None,
|
||||||
|
disable_strict_validation: Optional[bool] | None = None,
|
||||||
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
) -> Completion | StreamResponse[ChatCompletionChunk]:
|
||||||
|
_cast_type = Completion
|
||||||
|
_stream_cls = StreamResponse[ChatCompletionChunk]
|
||||||
|
if disable_strict_validation:
|
||||||
|
_cast_type = object
|
||||||
|
_stream_cls = StreamResponse[object]
|
||||||
|
return self._post(
|
||||||
|
"/chat/completions",
|
||||||
|
body={
|
||||||
|
"model": model,
|
||||||
|
"request_id": request_id,
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_p": top_p,
|
||||||
|
"do_sample": do_sample,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"seed": seed,
|
||||||
|
"messages": messages,
|
||||||
|
"stop": stop,
|
||||||
|
"sensitive_word_check": sensitive_word_check,
|
||||||
|
"stream": stream,
|
||||||
|
"tools": tools,
|
||||||
|
"tool_choice": tool_choice,
|
||||||
|
},
|
||||||
|
options=make_user_request_input(
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
),
|
||||||
|
cast_type=_cast_type,
|
||||||
|
enable_stream=stream or False,
|
||||||
|
stream_cls=_stream_cls,
|
||||||
|
)
|
@ -0,0 +1,49 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Union, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from ..core._base_api import BaseAPI
|
||||||
|
from ..core._base_type import NotGiven, NOT_GIVEN, Headers
|
||||||
|
from ..core._http_client import make_user_request_input
|
||||||
|
from ..types.embeddings import EmbeddingsResponded
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .._client import ZhipuAI
|
||||||
|
|
||||||
|
|
||||||
|
class Embeddings(BaseAPI):
|
||||||
|
def __init__(self, client: "ZhipuAI") -> None:
|
||||||
|
super().__init__(client)
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
input: Union[str, List[str], List[int], List[List[int]]],
|
||||||
|
model: Union[str],
|
||||||
|
encoding_format: str | NotGiven = NOT_GIVEN,
|
||||||
|
user: str | NotGiven = NOT_GIVEN,
|
||||||
|
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
|
||||||
|
extra_headers: Headers | None = None,
|
||||||
|
disable_strict_validation: Optional[bool] | None = None,
|
||||||
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
) -> EmbeddingsResponded:
|
||||||
|
_cast_type = EmbeddingsResponded
|
||||||
|
if disable_strict_validation:
|
||||||
|
_cast_type = object
|
||||||
|
return self._post(
|
||||||
|
"/embeddings",
|
||||||
|
body={
|
||||||
|
"input": input,
|
||||||
|
"model": model,
|
||||||
|
"encoding_format": encoding_format,
|
||||||
|
"user": user,
|
||||||
|
"sensitive_word_check": sensitive_word_check,
|
||||||
|
},
|
||||||
|
options=make_user_request_input(
|
||||||
|
extra_headers=extra_headers, timeout=timeout
|
||||||
|
),
|
||||||
|
cast_type=_cast_type,
|
||||||
|
enable_stream=False,
|
||||||
|
)
|
@ -0,0 +1,78 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from ..core._base_api import BaseAPI
|
||||||
|
from ..core._base_type import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
|
||||||
|
from ..core._files import is_file_content
|
||||||
|
from ..core._http_client import (
|
||||||
|
make_user_request_input,
|
||||||
|
)
|
||||||
|
from ..types.file_object import FileObject, ListOfFileObject
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .._client import ZhipuAI
|
||||||
|
|
||||||
|
__all__ = ["Files"]
|
||||||
|
|
||||||
|
|
||||||
|
class Files(BaseAPI):
|
||||||
|
|
||||||
|
def __init__(self, client: "ZhipuAI") -> None:
|
||||||
|
super().__init__(client)
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
file: FileTypes,
|
||||||
|
purpose: str,
|
||||||
|
extra_headers: Headers | None = None,
|
||||||
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
) -> FileObject:
|
||||||
|
if not is_file_content(file):
|
||||||
|
prefix = f"Expected file input `{file!r}`"
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(file)} instead."
|
||||||
|
) from None
|
||||||
|
files = [("file", file)]
|
||||||
|
|
||||||
|
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
|
||||||
|
|
||||||
|
return self._post(
|
||||||
|
"/files",
|
||||||
|
body={
|
||||||
|
"purpose": purpose,
|
||||||
|
},
|
||||||
|
files=files,
|
||||||
|
options=make_user_request_input(
|
||||||
|
extra_headers=extra_headers, timeout=timeout
|
||||||
|
),
|
||||||
|
cast_type=FileObject,
|
||||||
|
)
|
||||||
|
|
||||||
|
def list(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
purpose: str | NotGiven = NOT_GIVEN,
|
||||||
|
limit: int | NotGiven = NOT_GIVEN,
|
||||||
|
after: str | NotGiven = NOT_GIVEN,
|
||||||
|
order: str | NotGiven = NOT_GIVEN,
|
||||||
|
extra_headers: Headers | None = None,
|
||||||
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
) -> ListOfFileObject:
|
||||||
|
return self._get(
|
||||||
|
"/files",
|
||||||
|
cast_type=ListOfFileObject,
|
||||||
|
options=make_user_request_input(
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
timeout=timeout,
|
||||||
|
query={
|
||||||
|
"purpose": purpose,
|
||||||
|
"limit": limit,
|
||||||
|
"after": after,
|
||||||
|
"order": order,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
@ -0,0 +1,15 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from .jobs import Jobs
|
||||||
|
from ...core._base_api import BaseAPI
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..._client import ZhipuAI
|
||||||
|
|
||||||
|
|
||||||
|
class FineTuning(BaseAPI):
|
||||||
|
jobs: Jobs
|
||||||
|
|
||||||
|
def __init__(self, client: "ZhipuAI") -> None:
|
||||||
|
super().__init__(client)
|
||||||
|
self.jobs = Jobs(client)
|
||||||
|
|
@ -0,0 +1,115 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from ...core._base_api import BaseAPI
|
||||||
|
from ...core._base_type import NOT_GIVEN, Headers, NotGiven
|
||||||
|
from ...core._http_client import (
|
||||||
|
make_user_request_input,
|
||||||
|
)
|
||||||
|
from ...types.fine_tuning import (
|
||||||
|
FineTuningJob,
|
||||||
|
job_create_params,
|
||||||
|
ListOfFineTuningJob,
|
||||||
|
FineTuningJobEvent,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..._client import ZhipuAI
|
||||||
|
|
||||||
|
__all__ = ["Jobs"]
|
||||||
|
|
||||||
|
|
||||||
|
class Jobs(BaseAPI):
|
||||||
|
|
||||||
|
def __init__(self, client: "ZhipuAI") -> None:
|
||||||
|
super().__init__(client)
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model: str,
|
||||||
|
training_file: str,
|
||||||
|
hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN,
|
||||||
|
suffix: Optional[str] | NotGiven = NOT_GIVEN,
|
||||||
|
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||||
|
validation_file: Optional[str] | NotGiven = NOT_GIVEN,
|
||||||
|
extra_headers: Headers | None = None,
|
||||||
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
) -> FineTuningJob:
|
||||||
|
return self._post(
|
||||||
|
"/fine_tuning/jobs",
|
||||||
|
body={
|
||||||
|
"model": model,
|
||||||
|
"training_file": training_file,
|
||||||
|
"hyperparameters": hyperparameters,
|
||||||
|
"suffix": suffix,
|
||||||
|
"validation_file": validation_file,
|
||||||
|
"request_id": request_id,
|
||||||
|
},
|
||||||
|
options=make_user_request_input(
|
||||||
|
extra_headers=extra_headers, timeout=timeout
|
||||||
|
),
|
||||||
|
cast_type=FineTuningJob,
|
||||||
|
)
|
||||||
|
|
||||||
|
def retrieve(
|
||||||
|
self,
|
||||||
|
fine_tuning_job_id: str,
|
||||||
|
*,
|
||||||
|
extra_headers: Headers | None = None,
|
||||||
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
) -> FineTuningJob:
|
||||||
|
return self._get(
|
||||||
|
f"/fine_tuning/jobs/{fine_tuning_job_id}",
|
||||||
|
options=make_user_request_input(
|
||||||
|
extra_headers=extra_headers, timeout=timeout
|
||||||
|
),
|
||||||
|
cast_type=FineTuningJob,
|
||||||
|
)
|
||||||
|
|
||||||
|
def list(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
after: str | NotGiven = NOT_GIVEN,
|
||||||
|
limit: int | NotGiven = NOT_GIVEN,
|
||||||
|
extra_headers: Headers | None = None,
|
||||||
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
) -> ListOfFineTuningJob:
|
||||||
|
return self._get(
|
||||||
|
"/fine_tuning/jobs",
|
||||||
|
cast_type=ListOfFineTuningJob,
|
||||||
|
options=make_user_request_input(
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
timeout=timeout,
|
||||||
|
query={
|
||||||
|
"after": after,
|
||||||
|
"limit": limit,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_events(
|
||||||
|
self,
|
||||||
|
fine_tuning_job_id: str,
|
||||||
|
*,
|
||||||
|
after: str | NotGiven = NOT_GIVEN,
|
||||||
|
limit: int | NotGiven = NOT_GIVEN,
|
||||||
|
extra_headers: Headers | None = None,
|
||||||
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
) -> FineTuningJobEvent:
|
||||||
|
|
||||||
|
return self._get(
|
||||||
|
f"/fine_tuning/jobs/{fine_tuning_job_id}/events",
|
||||||
|
cast_type=FineTuningJobEvent,
|
||||||
|
options=make_user_request_input(
|
||||||
|
extra_headers=extra_headers,
|
||||||
|
timeout=timeout,
|
||||||
|
query={
|
||||||
|
"after": after,
|
||||||
|
"limit": limit,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
@ -0,0 +1,55 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Union, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from ..core._base_api import BaseAPI
|
||||||
|
from ..core._base_type import NotGiven, NOT_GIVEN, Headers
|
||||||
|
from ..core._http_client import make_user_request_input
|
||||||
|
from ..types.image import ImagesResponded
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .._client import ZhipuAI
|
||||||
|
|
||||||
|
|
||||||
|
class Images(BaseAPI):
|
||||||
|
def __init__(self, client: "ZhipuAI") -> None:
|
||||||
|
super().__init__(client)
|
||||||
|
|
||||||
|
def generations(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
prompt: str,
|
||||||
|
model: str | NotGiven = NOT_GIVEN,
|
||||||
|
n: Optional[int] | NotGiven = NOT_GIVEN,
|
||||||
|
quality: Optional[str] | NotGiven = NOT_GIVEN,
|
||||||
|
response_format: Optional[str] | NotGiven = NOT_GIVEN,
|
||||||
|
size: Optional[str] | NotGiven = NOT_GIVEN,
|
||||||
|
style: Optional[str] | NotGiven = NOT_GIVEN,
|
||||||
|
user: str | NotGiven = NOT_GIVEN,
|
||||||
|
extra_headers: Headers | None = None,
|
||||||
|
disable_strict_validation: Optional[bool] | None = None,
|
||||||
|
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
) -> ImagesResponded:
|
||||||
|
_cast_type = ImagesResponded
|
||||||
|
if disable_strict_validation:
|
||||||
|
_cast_type = object
|
||||||
|
return self._post(
|
||||||
|
"/images/generations",
|
||||||
|
body={
|
||||||
|
"prompt": prompt,
|
||||||
|
"model": model,
|
||||||
|
"n": n,
|
||||||
|
"quality": quality,
|
||||||
|
"response_format": response_format,
|
||||||
|
"size": size,
|
||||||
|
"style": style,
|
||||||
|
"user": user,
|
||||||
|
},
|
||||||
|
options=make_user_request_input(
|
||||||
|
extra_headers=extra_headers, timeout=timeout
|
||||||
|
),
|
||||||
|
cast_type=_cast_type,
|
||||||
|
enable_stream=False,
|
||||||
|
)
|
@ -0,0 +1,17 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .._client import ZhipuAI
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAPI:
|
||||||
|
_client: ZhipuAI
|
||||||
|
|
||||||
|
def __init__(self, client: ZhipuAI) -> None:
|
||||||
|
self._client = client
|
||||||
|
self._delete = client.delete
|
||||||
|
self._get = client.get
|
||||||
|
self._post = client.post
|
||||||
|
self._put = client.put
|
||||||
|
self._patch = client.patch
|
@ -0,0 +1,115 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from os import PathLike
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
Mapping,
|
||||||
|
TypeVar, IO, Tuple, Sequence, Any, List,
|
||||||
|
)
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
from typing_extensions import (
|
||||||
|
Literal,
|
||||||
|
override,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Query = Mapping[str, object]
|
||||||
|
Body = object
|
||||||
|
AnyMapping = Mapping[str, object]
|
||||||
|
PrimitiveData = Union[str, int, float, bool, None]
|
||||||
|
Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"]
|
||||||
|
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
NoneType: Type[None]
|
||||||
|
else:
|
||||||
|
NoneType = type(None)
|
||||||
|
|
||||||
|
|
||||||
|
# Sentinel class used until PEP 0661 is accepted
|
||||||
|
class NotGiven(pydantic.BaseModel):
|
||||||
|
"""
|
||||||
|
A sentinel singleton class used to distinguish omitted keyword arguments
|
||||||
|
from those passed in with the value None (which may have different behavior).
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...
|
||||||
|
|
||||||
|
get(timeout=1) # 1s timeout
|
||||||
|
get(timeout=None) # No timeout
|
||||||
|
get() # Default timeout behavior, which may not be statically known at the method definition.
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __bool__(self) -> Literal[False]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@override
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "NOT_GIVEN"
|
||||||
|
|
||||||
|
|
||||||
|
NotGivenOr = Union[_T, NotGiven]
|
||||||
|
NOT_GIVEN = NotGiven()
|
||||||
|
|
||||||
|
|
||||||
|
class Omit(pydantic.BaseModel):
|
||||||
|
"""In certain situations you need to be able to represent a case where a default value has
|
||||||
|
to be explicitly removed and `None` is not an appropriate substitute, for example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
# as the default `Content-Type` header is `application/json` that will be sent
|
||||||
|
client.post('/upload/files', files={'file': b'my raw file content'})
|
||||||
|
|
||||||
|
# you can't explicitly override the header as it has to be dynamically generated
|
||||||
|
# to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
|
||||||
|
client.post(..., headers={'Content-Type': 'multipart/form-data'})
|
||||||
|
|
||||||
|
# instead you can remove the default `application/json` header by passing Omit
|
||||||
|
client.post(..., headers={'Content-Type': Omit()})
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __bool__(self) -> Literal[False]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
Headers = Mapping[str, Union[str, Omit]]
|
||||||
|
|
||||||
|
ResponseT = TypeVar(
|
||||||
|
"ResponseT",
|
||||||
|
bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]",
|
||||||
|
)
|
||||||
|
|
||||||
|
# for user input files
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
FileContent = Union[IO[bytes], bytes, PathLike[str]]
|
||||||
|
else:
|
||||||
|
FileContent = Union[IO[bytes], bytes, PathLike]
|
||||||
|
|
||||||
|
FileTypes = Union[
|
||||||
|
FileContent, # file content
|
||||||
|
Tuple[str, FileContent], # (filename, file)
|
||||||
|
Tuple[str, FileContent, str], # (filename, file , content_type)
|
||||||
|
Tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers)
|
||||||
|
]
|
||||||
|
|
||||||
|
RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
|
||||||
|
|
||||||
|
# for httpx client supported files
|
||||||
|
|
||||||
|
HttpxFileContent = Union[bytes, IO[bytes]]
|
||||||
|
HttpxFileTypes = Union[
|
||||||
|
FileContent, # file content
|
||||||
|
Tuple[str, HttpxFileContent], # (filename, file)
|
||||||
|
Tuple[str, HttpxFileContent, str], # (filename, file , content_type)
|
||||||
|
Tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers)
|
||||||
|
]
|
||||||
|
|
||||||
|
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]
|
@ -0,0 +1,90 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ZhipuAIError",
|
||||||
|
"APIStatusError",
|
||||||
|
"APIRequestFailedError",
|
||||||
|
"APIAuthenticationError",
|
||||||
|
"APIReachLimitError",
|
||||||
|
"APIInternalError",
|
||||||
|
"APIServerFlowExceedError",
|
||||||
|
"APIResponseError",
|
||||||
|
"APIResponseValidationError",
|
||||||
|
"APITimeoutError",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ZhipuAIError(Exception):
|
||||||
|
def __init__(self, message: str, ) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class APIStatusError(Exception):
|
||||||
|
response: httpx.Response
|
||||||
|
status_code: int
|
||||||
|
|
||||||
|
def __init__(self, message: str, *, response: httpx.Response) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
self.response = response
|
||||||
|
self.status_code = response.status_code
|
||||||
|
|
||||||
|
|
||||||
|
class APIRequestFailedError(APIStatusError):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class APIAuthenticationError(APIStatusError):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class APIReachLimitError(APIStatusError):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class APIInternalError(APIStatusError):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class APIServerFlowExceedError(APIStatusError):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class APIResponseError(Exception):
|
||||||
|
message: str
|
||||||
|
request: httpx.Request
|
||||||
|
json_data: object
|
||||||
|
|
||||||
|
def __init__(self, message: str, request: httpx.Request, json_data: object):
|
||||||
|
self.message = message
|
||||||
|
self.request = request
|
||||||
|
self.json_data = json_data
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
|
||||||
|
class APIResponseValidationError(APIResponseError):
|
||||||
|
status_code: int
|
||||||
|
response: httpx.Response
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
response: httpx.Response,
|
||||||
|
json_data: object | None, *,
|
||||||
|
message: str | None = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
message=message or "Data returned by API invalid for expected schema.",
|
||||||
|
request=response.request,
|
||||||
|
json_data=json_data
|
||||||
|
)
|
||||||
|
self.response = response
|
||||||
|
self.status_code = response.status_code
|
||||||
|
|
||||||
|
|
||||||
|
class APITimeoutError(Exception):
|
||||||
|
request: httpx.Request
|
||||||
|
|
||||||
|
def __init__(self, request: httpx.Request):
|
||||||
|
self.request = request
|
||||||
|
super().__init__("Request Timeout")
|
@ -0,0 +1,46 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Mapping, Sequence
|
||||||
|
|
||||||
|
from ._base_type import (
|
||||||
|
FileTypes,
|
||||||
|
HttpxFileTypes,
|
||||||
|
HttpxRequestFiles,
|
||||||
|
RequestFiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_file_content(obj: object) -> bool:
|
||||||
|
return isinstance(obj, (bytes, tuple, io.IOBase, os.PathLike))
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_file(file: FileTypes) -> HttpxFileTypes:
|
||||||
|
if is_file_content(file):
|
||||||
|
if isinstance(file, os.PathLike):
|
||||||
|
path = Path(file)
|
||||||
|
return path.name, path.read_bytes()
|
||||||
|
else:
|
||||||
|
return file
|
||||||
|
if isinstance(file, tuple):
|
||||||
|
if isinstance(file[1], os.PathLike):
|
||||||
|
return (file[0], Path(file[1]).read_bytes(), *file[2:])
|
||||||
|
else:
|
||||||
|
return (file[0], file[1], *file[2:])
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unexpected input file with type {type(file)},Expected FileContent type or tuple type")
|
||||||
|
|
||||||
|
|
||||||
|
def make_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
|
||||||
|
if files is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if isinstance(files, Mapping):
|
||||||
|
files = {key: _transform_file(file) for key, file in files.items()}
|
||||||
|
elif isinstance(files, Sequence):
|
||||||
|
files = [(key, _transform_file(file)) for key, file in files]
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unexpected input file with type {type(files)}, excepted Mapping or Sequence")
|
||||||
|
return files
|
@ -0,0 +1,377 @@
|
|||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
Mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pydantic
|
||||||
|
from httpx import URL, Timeout
|
||||||
|
|
||||||
|
from . import _errors
|
||||||
|
from ._base_type import NotGiven, ResponseT, Body, Headers, NOT_GIVEN, RequestFiles, Query, Data
|
||||||
|
from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError
|
||||||
|
from ._files import make_httpx_files
|
||||||
|
from ._request_opt import ClientRequestParam, UserRequestInput
|
||||||
|
from ._response import HttpResponse
|
||||||
|
from ._sse_client import StreamResponse
|
||||||
|
from ._utils import flatten
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Content-Type": "application/json; charset=UTF-8",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_map(map1: Mapping, map2: Mapping) -> Mapping:
|
||||||
|
merged = {**map1, **map2}
|
||||||
|
return {key: val for key, val in merged.items() if val is not None}
|
||||||
|
|
||||||
|
|
||||||
|
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
|
||||||
|
|
||||||
|
ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
|
||||||
|
ZHIPUAI_DEFAULT_MAX_RETRIES = 3
|
||||||
|
ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10)
|
||||||
|
|
||||||
|
|
||||||
|
class HttpClient:
|
||||||
|
_client: httpx.Client
|
||||||
|
_version: str
|
||||||
|
_base_url: URL
|
||||||
|
|
||||||
|
timeout: Union[float, Timeout, None]
|
||||||
|
_limits: httpx.Limits
|
||||||
|
_has_custom_http_client: bool
|
||||||
|
_default_stream_cls: type[StreamResponse[Any]] | None = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
version: str,
|
||||||
|
base_url: URL,
|
||||||
|
timeout: Union[float, Timeout, None],
|
||||||
|
custom_httpx_client: httpx.Client | None = None,
|
||||||
|
custom_headers: Mapping[str, str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
if timeout is None or isinstance(timeout, NotGiven):
|
||||||
|
if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT:
|
||||||
|
timeout = custom_httpx_client.timeout
|
||||||
|
else:
|
||||||
|
timeout = ZHIPUAI_DEFAULT_TIMEOUT
|
||||||
|
self.timeout = cast(Timeout, timeout)
|
||||||
|
self._has_custom_http_client = bool(custom_httpx_client)
|
||||||
|
self._client = custom_httpx_client or httpx.Client(
|
||||||
|
base_url=base_url,
|
||||||
|
timeout=self.timeout,
|
||||||
|
limits=ZHIPUAI_DEFAULT_LIMITS,
|
||||||
|
)
|
||||||
|
self._version = version
|
||||||
|
url = URL(url=base_url)
|
||||||
|
if not url.raw_path.endswith(b"/"):
|
||||||
|
url = url.copy_with(raw_path=url.raw_path + b"/")
|
||||||
|
self._base_url = url
|
||||||
|
self._custom_headers = custom_headers or {}
|
||||||
|
|
||||||
|
def _prepare_url(self, url: str) -> URL:
|
||||||
|
|
||||||
|
sub_url = URL(url)
|
||||||
|
if sub_url.is_relative_url:
|
||||||
|
request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/")
|
||||||
|
return self._base_url.copy_with(raw_path=request_raw_url)
|
||||||
|
|
||||||
|
return sub_url
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_headers(self):
|
||||||
|
return \
|
||||||
|
{
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Content-Type": "application/json; charset=UTF-8",
|
||||||
|
"ZhipuAI-SDK-Ver": self._version,
|
||||||
|
"source_type": "zhipu-sdk-python",
|
||||||
|
"x-request-sdk": "zhipu-sdk-python",
|
||||||
|
**self._auth_headers,
|
||||||
|
**self._custom_headers,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _auth_headers(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _prepare_headers(self, request_param: ClientRequestParam) -> httpx.Headers:
|
||||||
|
custom_headers = request_param.headers or {}
|
||||||
|
headers_dict = _merge_map(self._default_headers, custom_headers)
|
||||||
|
|
||||||
|
httpx_headers = httpx.Headers(headers_dict)
|
||||||
|
|
||||||
|
return httpx_headers
|
||||||
|
|
||||||
|
def _prepare_request(
|
||||||
|
self,
|
||||||
|
request_param: ClientRequestParam
|
||||||
|
) -> httpx.Request:
|
||||||
|
kwargs: dict[str, Any] = {}
|
||||||
|
json_data = request_param.json_data
|
||||||
|
headers = self._prepare_headers(request_param)
|
||||||
|
url = self._prepare_url(request_param.url)
|
||||||
|
json_data = request_param.json_data
|
||||||
|
if headers.get("Content-Type") == "multipart/form-data":
|
||||||
|
headers.pop("Content-Type")
|
||||||
|
|
||||||
|
if json_data:
|
||||||
|
kwargs["data"] = self._make_multipartform(json_data)
|
||||||
|
|
||||||
|
return self._client.build_request(
|
||||||
|
headers=headers,
|
||||||
|
timeout=self.timeout if isinstance(request_param.timeout, NotGiven) else request_param.timeout,
|
||||||
|
method=request_param.method,
|
||||||
|
url=url,
|
||||||
|
json=json_data,
|
||||||
|
files=request_param.files,
|
||||||
|
params=request_param.params,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _object_to_formfata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]:
|
||||||
|
items = []
|
||||||
|
|
||||||
|
if isinstance(value, Mapping):
|
||||||
|
for k, v in value.items():
|
||||||
|
items.extend(self._object_to_formfata(f"{key}[{k}]", v))
|
||||||
|
return items
|
||||||
|
if isinstance(value, (list, tuple)):
|
||||||
|
for v in value:
|
||||||
|
items.extend(self._object_to_formfata(key + "[]", v))
|
||||||
|
return items
|
||||||
|
|
||||||
|
def _primitive_value_to_str(val) -> str:
|
||||||
|
# copied from httpx
|
||||||
|
if val is True:
|
||||||
|
return "true"
|
||||||
|
elif val is False:
|
||||||
|
return "false"
|
||||||
|
elif val is None:
|
||||||
|
return ""
|
||||||
|
return str(val)
|
||||||
|
|
||||||
|
str_data = _primitive_value_to_str(value)
|
||||||
|
|
||||||
|
if not str_data:
|
||||||
|
return []
|
||||||
|
return [(key, str_data)]
|
||||||
|
|
||||||
|
def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
|
||||||
|
|
||||||
|
items = flatten([self._object_to_formfata(k, v) for k, v in data.items()])
|
||||||
|
|
||||||
|
serialized: dict[str, object] = {}
|
||||||
|
for key, value in items:
|
||||||
|
if key in serialized:
|
||||||
|
raise ValueError(f"存在重复的键: {key};")
|
||||||
|
serialized[key] = value
|
||||||
|
return serialized
|
||||||
|
|
||||||
|
def _parse_response(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
cast_type: Type[ResponseT],
|
||||||
|
response: httpx.Response,
|
||||||
|
enable_stream: bool,
|
||||||
|
request_param: ClientRequestParam,
|
||||||
|
stream_cls: type[StreamResponse[Any]] | None = None,
|
||||||
|
) -> HttpResponse:
|
||||||
|
|
||||||
|
http_response = HttpResponse(
|
||||||
|
raw_response=response,
|
||||||
|
cast_type=cast_type,
|
||||||
|
client=self,
|
||||||
|
enable_stream=enable_stream,
|
||||||
|
stream_cls=stream_cls
|
||||||
|
)
|
||||||
|
return http_response.parse()
|
||||||
|
|
||||||
|
def _process_response_data(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
data: object,
|
||||||
|
cast_type: type[ResponseT],
|
||||||
|
response: httpx.Response,
|
||||||
|
) -> ResponseT:
|
||||||
|
if data is None:
|
||||||
|
return cast(ResponseT, None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if inspect.isclass(cast_type) and issubclass(cast_type, pydantic.BaseModel):
|
||||||
|
return cast(ResponseT, cast_type.validate(data))
|
||||||
|
|
||||||
|
return cast(ResponseT, pydantic.TypeAdapter(cast_type).validate_python(data))
|
||||||
|
except pydantic.ValidationError as err:
|
||||||
|
raise APIResponseValidationError(response=response, json_data=data) from err
|
||||||
|
|
||||||
|
def is_closed(self) -> bool:
|
||||||
|
return self._client.is_closed
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._client.close()
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def request(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
cast_type: Type[ResponseT],
|
||||||
|
params: ClientRequestParam,
|
||||||
|
enable_stream: bool = False,
|
||||||
|
stream_cls: type[StreamResponse[Any]] | None = None,
|
||||||
|
) -> ResponseT | StreamResponse:
|
||||||
|
request = self._prepare_request(params)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self._client.send(
|
||||||
|
request,
|
||||||
|
stream=enable_stream,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.TimeoutException as err:
|
||||||
|
raise APITimeoutError(request=request) from err
|
||||||
|
except httpx.HTTPStatusError as err:
|
||||||
|
err.response.read()
|
||||||
|
# raise err
|
||||||
|
raise self._make_status_error(err.response) from None
|
||||||
|
|
||||||
|
except Exception as err:
|
||||||
|
raise err
|
||||||
|
|
||||||
|
return self._parse_response(
|
||||||
|
cast_type=cast_type,
|
||||||
|
request_param=params,
|
||||||
|
response=response,
|
||||||
|
enable_stream=enable_stream,
|
||||||
|
stream_cls=stream_cls,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
cast_type: Type[ResponseT],
|
||||||
|
options: UserRequestInput = {},
|
||||||
|
enable_stream: bool = False,
|
||||||
|
) -> ResponseT | StreamResponse:
|
||||||
|
opts = ClientRequestParam.construct(method="get", url=path, **options)
|
||||||
|
return self.request(
|
||||||
|
cast_type=cast_type, params=opts,
|
||||||
|
enable_stream=enable_stream
|
||||||
|
)
|
||||||
|
|
||||||
|
def post(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
body: Body | None = None,
|
||||||
|
cast_type: Type[ResponseT],
|
||||||
|
options: UserRequestInput = {},
|
||||||
|
files: RequestFiles | None = None,
|
||||||
|
enable_stream: bool = False,
|
||||||
|
stream_cls: type[StreamResponse[Any]] | None = None,
|
||||||
|
) -> ResponseT | StreamResponse:
|
||||||
|
opts = ClientRequestParam.construct(method="post", json_data=body, files=make_httpx_files(files), url=path,
|
||||||
|
**options)
|
||||||
|
|
||||||
|
return self.request(
|
||||||
|
cast_type=cast_type, params=opts,
|
||||||
|
enable_stream=enable_stream,
|
||||||
|
stream_cls=stream_cls
|
||||||
|
)
|
||||||
|
|
||||||
|
def patch(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
body: Body | None = None,
|
||||||
|
cast_type: Type[ResponseT],
|
||||||
|
options: UserRequestInput = {},
|
||||||
|
) -> ResponseT:
|
||||||
|
opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options)
|
||||||
|
|
||||||
|
return self.request(
|
||||||
|
cast_type=cast_type, params=opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
def put(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
body: Body | None = None,
|
||||||
|
cast_type: Type[ResponseT],
|
||||||
|
options: UserRequestInput = {},
|
||||||
|
files: RequestFiles | None = None,
|
||||||
|
) -> ResponseT | StreamResponse:
|
||||||
|
opts = ClientRequestParam.construct(method="put", url=path, json_data=body, files=make_httpx_files(files),
|
||||||
|
**options)
|
||||||
|
|
||||||
|
return self.request(
|
||||||
|
cast_type=cast_type, params=opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
*,
|
||||||
|
body: Body | None = None,
|
||||||
|
cast_type: Type[ResponseT],
|
||||||
|
options: UserRequestInput = {},
|
||||||
|
) -> ResponseT | StreamResponse:
|
||||||
|
opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options)
|
||||||
|
|
||||||
|
return self.request(
|
||||||
|
cast_type=cast_type, params=opts,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_status_error(self, response) -> APIStatusError:
|
||||||
|
response_text = response.text.strip()
|
||||||
|
status_code = response.status_code
|
||||||
|
error_msg = f"Error code: {status_code}, with error text {response_text}"
|
||||||
|
|
||||||
|
if status_code == 400:
|
||||||
|
return _errors.APIRequestFailedError(message=error_msg, response=response)
|
||||||
|
elif status_code == 401:
|
||||||
|
return _errors.APIAuthenticationError(message=error_msg, response=response)
|
||||||
|
elif status_code == 429:
|
||||||
|
return _errors.APIReachLimitError(message=error_msg, response=response)
|
||||||
|
elif status_code == 500:
|
||||||
|
return _errors.APIInternalError(message=error_msg, response=response)
|
||||||
|
elif status_code == 503:
|
||||||
|
return _errors.APIServerFlowExceedError(message=error_msg, response=response)
|
||||||
|
return APIStatusError(message=error_msg, response=response)
|
||||||
|
|
||||||
|
|
||||||
|
def make_user_request_input(
|
||||||
|
max_retries: int | None = None,
|
||||||
|
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||||
|
extra_headers: Headers = None,
|
||||||
|
query: Query | None = None,
|
||||||
|
) -> UserRequestInput:
|
||||||
|
options: UserRequestInput = {}
|
||||||
|
|
||||||
|
if extra_headers is not None:
|
||||||
|
options["headers"] = extra_headers
|
||||||
|
if max_retries is not None:
|
||||||
|
options["max_retries"] = max_retries
|
||||||
|
if not isinstance(timeout, NotGiven):
|
||||||
|
options['timeout'] = timeout
|
||||||
|
if query is not None:
|
||||||
|
options["params"] = query
|
||||||
|
|
||||||
|
return options
|
@ -0,0 +1,30 @@
|
|||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
import time
|
||||||
|
|
||||||
|
import cachetools.func
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
API_TOKEN_TTL_SECONDS = 3 * 60
|
||||||
|
|
||||||
|
CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30
|
||||||
|
|
||||||
|
|
||||||
|
@cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS)
|
||||||
|
def generate_token(apikey: str):
|
||||||
|
try:
|
||||||
|
api_key, secret = apikey.split(".")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception("invalid api_key", e)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"api_key": api_key,
|
||||||
|
"exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000,
|
||||||
|
"timestamp": int(round(time.time() * 1000)),
|
||||||
|
}
|
||||||
|
ret = jwt.encode(
|
||||||
|
payload,
|
||||||
|
secret,
|
||||||
|
algorithm="HS256",
|
||||||
|
headers={"alg": "HS256", "sign_type": "SIGN"},
|
||||||
|
)
|
||||||
|
return ret
|
@ -0,0 +1,54 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Union, Any, cast
|
||||||
|
|
||||||
|
import pydantic.generics
|
||||||
|
from httpx import Timeout
|
||||||
|
from pydantic import ConfigDict
|
||||||
|
from typing_extensions import (
|
||||||
|
Unpack, ClassVar, TypedDict
|
||||||
|
)
|
||||||
|
|
||||||
|
from ._base_type import Body, NotGiven, Headers, HttpxRequestFiles, Query
|
||||||
|
from ._utils import remove_notgiven_indict
|
||||||
|
|
||||||
|
|
||||||
|
class UserRequestInput(TypedDict, total=False):
|
||||||
|
max_retries: int
|
||||||
|
timeout: float | Timeout | None
|
||||||
|
headers: Headers
|
||||||
|
params: Query | None
|
||||||
|
|
||||||
|
|
||||||
|
class ClientRequestParam():
|
||||||
|
method: str
|
||||||
|
url: str
|
||||||
|
max_retries: Union[int, NotGiven] = NotGiven()
|
||||||
|
timeout: Union[float, NotGiven] = NotGiven()
|
||||||
|
headers: Union[Headers, NotGiven] = NotGiven()
|
||||||
|
json_data: Union[Body, None] = None
|
||||||
|
files: Union[HttpxRequestFiles, None] = None
|
||||||
|
params: Query = {}
|
||||||
|
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
def get_max_retries(self, max_retries) -> int:
|
||||||
|
if isinstance(self.max_retries, NotGiven):
|
||||||
|
return max_retries
|
||||||
|
return self.max_retries
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def construct( # type: ignore
|
||||||
|
cls,
|
||||||
|
_fields_set: set[str] | None = None,
|
||||||
|
**values: Unpack[UserRequestInput],
|
||||||
|
) -> ClientRequestParam :
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
key: remove_notgiven_indict(value) for key, value in values.items()
|
||||||
|
}
|
||||||
|
client = cls()
|
||||||
|
client.__dict__.update(kwargs)
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
model_construct = construct
|
||||||
|
|
@ -0,0 +1,121 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
from typing import TypeVar, Generic, cast, Any, TYPE_CHECKING
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pydantic
|
||||||
|
from typing_extensions import ParamSpec, get_origin, get_args
|
||||||
|
|
||||||
|
from ._base_type import NoneType
|
||||||
|
from ._sse_client import StreamResponse
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ._http_client import HttpClient
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
|
class HttpResponse(Generic[R]):
|
||||||
|
_cast_type: type[R]
|
||||||
|
_client: "HttpClient"
|
||||||
|
_parsed: R | None
|
||||||
|
_enable_stream: bool
|
||||||
|
_stream_cls: type[StreamResponse[Any]]
|
||||||
|
http_response: httpx.Response
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
raw_response: httpx.Response,
|
||||||
|
cast_type: type[R],
|
||||||
|
client: "HttpClient",
|
||||||
|
enable_stream: bool = False,
|
||||||
|
stream_cls: type[StreamResponse[Any]] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._cast_type = cast_type
|
||||||
|
self._client = client
|
||||||
|
self._parsed = None
|
||||||
|
self._stream_cls = stream_cls
|
||||||
|
self._enable_stream = enable_stream
|
||||||
|
self.http_response = raw_response
|
||||||
|
|
||||||
|
def parse(self) -> R:
|
||||||
|
self._parsed = self._parse()
|
||||||
|
return self._parsed
|
||||||
|
|
||||||
|
def _parse(self) -> R:
|
||||||
|
if self._enable_stream:
|
||||||
|
self._parsed = cast(
|
||||||
|
R,
|
||||||
|
self._stream_cls(
|
||||||
|
cast_type=cast(type, get_args(self._stream_cls)[0]),
|
||||||
|
response=self.http_response,
|
||||||
|
client=self._client
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return self._parsed
|
||||||
|
cast_type = self._cast_type
|
||||||
|
if cast_type is NoneType:
|
||||||
|
return cast(R, None)
|
||||||
|
http_response = self.http_response
|
||||||
|
if cast_type == str:
|
||||||
|
return cast(R, http_response.text)
|
||||||
|
|
||||||
|
content_type, *_ = http_response.headers.get("content-type", "application/json").split(";")
|
||||||
|
origin = get_origin(cast_type) or cast_type
|
||||||
|
if content_type != "application/json":
|
||||||
|
if issubclass(origin, pydantic.BaseModel):
|
||||||
|
data = http_response.json()
|
||||||
|
return self._client._process_response_data(
|
||||||
|
data=data,
|
||||||
|
cast_type=cast_type, # type: ignore
|
||||||
|
response=http_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
return http_response.text
|
||||||
|
|
||||||
|
data = http_response.json()
|
||||||
|
|
||||||
|
return self._client._process_response_data(
|
||||||
|
data=data,
|
||||||
|
cast_type=cast_type, # type: ignore
|
||||||
|
response=http_response,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def headers(self) -> httpx.Headers:
|
||||||
|
return self.http_response.headers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def http_request(self) -> httpx.Request:
|
||||||
|
return self.http_response.request
|
||||||
|
|
||||||
|
@property
|
||||||
|
def status_code(self) -> int:
|
||||||
|
return self.http_response.status_code
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url(self) -> httpx.URL:
|
||||||
|
return self.http_response.url
|
||||||
|
|
||||||
|
@property
|
||||||
|
def method(self) -> str:
|
||||||
|
return self.http_request.method
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content(self) -> bytes:
|
||||||
|
return self.http_response.content
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str:
|
||||||
|
return self.http_response.text
|
||||||
|
|
||||||
|
@property
|
||||||
|
def http_version(self) -> str:
|
||||||
|
return self.http_response.http_version
|
||||||
|
|
||||||
|
@property
|
||||||
|
def elapsed(self) -> datetime.timedelta:
|
||||||
|
return self.http_response.elapsed
|
@ -0,0 +1,149 @@
|
|||||||
|
# -*- coding:utf-8 -*-
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Generic, Iterator, TYPE_CHECKING, Mapping
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from ._base_type import ResponseT
|
||||||
|
from ._errors import APIResponseError
|
||||||
|
|
||||||
|
_FIELD_SEPARATOR = ":"
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ._http_client import HttpClient
|
||||||
|
|
||||||
|
|
||||||
|
class StreamResponse(Generic[ResponseT]):
|
||||||
|
|
||||||
|
response: httpx.Response
|
||||||
|
_cast_type: type[ResponseT]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
cast_type: type[ResponseT],
|
||||||
|
response: httpx.Response,
|
||||||
|
client: HttpClient,
|
||||||
|
) -> None:
|
||||||
|
self.response = response
|
||||||
|
self._cast_type = cast_type
|
||||||
|
self._data_process_func = client._process_response_data
|
||||||
|
self._stream_chunks = self.__stream__()
|
||||||
|
|
||||||
|
def __next__(self) -> ResponseT:
|
||||||
|
return self._stream_chunks.__next__()
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[ResponseT]:
|
||||||
|
for item in self._stream_chunks:
|
||||||
|
yield item
|
||||||
|
|
||||||
|
def __stream__(self) -> Iterator[ResponseT]:
|
||||||
|
|
||||||
|
sse_line_parser = SSELineParser()
|
||||||
|
iterator = sse_line_parser.iter_lines(self.response.iter_lines())
|
||||||
|
|
||||||
|
for sse in iterator:
|
||||||
|
if sse.data.startswith("[DONE]"):
|
||||||
|
break
|
||||||
|
|
||||||
|
if sse.event is None:
|
||||||
|
data = sse.json_data()
|
||||||
|
if isinstance(data, Mapping) and data.get("error"):
|
||||||
|
raise APIResponseError(
|
||||||
|
message="An error occurred during streaming",
|
||||||
|
request=self.response.request,
|
||||||
|
json_data=data["error"],
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response)
|
||||||
|
for sse in iterator:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Event(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
event: str | None = None,
|
||||||
|
data: str | None = None,
|
||||||
|
id: str | None = None,
|
||||||
|
retry: int | None = None
|
||||||
|
):
|
||||||
|
self._event = event
|
||||||
|
self._data = data
|
||||||
|
self._id = id
|
||||||
|
self._retry = retry
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
data_len = len(self._data) if self._data else 0
|
||||||
|
return f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def event(self): return self._event
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self): return self._data
|
||||||
|
|
||||||
|
def json_data(self): return json.loads(self._data)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def id(self): return self._id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def retry(self): return self._retry
|
||||||
|
|
||||||
|
|
||||||
|
class SSELineParser:
|
||||||
|
_data: list[str]
|
||||||
|
_event: str | None
|
||||||
|
_retry: int | None
|
||||||
|
_id: str | None
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._event = None
|
||||||
|
self._data = []
|
||||||
|
self._id = None
|
||||||
|
self._retry = None
|
||||||
|
|
||||||
|
def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]:
|
||||||
|
for line in lines:
|
||||||
|
line = line.rstrip('\n')
|
||||||
|
if not line:
|
||||||
|
if self._event is None and \
|
||||||
|
not self._data and \
|
||||||
|
self._id is None and \
|
||||||
|
self._retry is None:
|
||||||
|
continue
|
||||||
|
sse_event = Event(
|
||||||
|
event=self._event,
|
||||||
|
data='\n'.join(self._data),
|
||||||
|
id=self._id,
|
||||||
|
retry=self._retry
|
||||||
|
)
|
||||||
|
self._event = None
|
||||||
|
self._data = []
|
||||||
|
self._id = None
|
||||||
|
self._retry = None
|
||||||
|
|
||||||
|
yield sse_event
|
||||||
|
self.decode_line(line)
|
||||||
|
|
||||||
|
def decode_line(self, line: str):
|
||||||
|
if line.startswith(":") or not line:
|
||||||
|
return
|
||||||
|
|
||||||
|
field, _p, value = line.partition(":")
|
||||||
|
|
||||||
|
if value.startswith(' '):
|
||||||
|
value = value[1:]
|
||||||
|
if field == "data":
|
||||||
|
self._data.append(value)
|
||||||
|
elif field == "event":
|
||||||
|
self._event = value
|
||||||
|
elif field == "retry":
|
||||||
|
try:
|
||||||
|
self._retry = int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
return
|
@ -0,0 +1,18 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Mapping, Iterable, TypeVar
|
||||||
|
|
||||||
|
from ._base_type import NotGiven
|
||||||
|
|
||||||
|
|
||||||
|
def remove_notgiven_indict(obj):
|
||||||
|
if obj is None or (not isinstance(obj, Mapping)):
|
||||||
|
return obj
|
||||||
|
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
|
||||||
|
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
|
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
|
||||||
|
return [item for sublist in t for item in sublist]
|
@ -0,0 +1,23 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .chat_completion import CompletionChoice, CompletionUsage
|
||||||
|
|
||||||
|
__all__ = ["AsyncTaskStatus"]
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncTaskStatus(BaseModel):
|
||||||
|
id: Optional[str] = None
|
||||||
|
request_id: Optional[str] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
task_status: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncCompletion(BaseModel):
|
||||||
|
id: Optional[str] = None
|
||||||
|
request_id: Optional[str] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
task_status: str
|
||||||
|
choices: List[CompletionChoice]
|
||||||
|
usage: CompletionUsage
|
@ -0,0 +1,45 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
__all__ = ["Completion", "CompletionUsage"]
|
||||||
|
|
||||||
|
|
||||||
|
class Function(BaseModel):
|
||||||
|
arguments: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionMessageToolCall(BaseModel):
|
||||||
|
id: str
|
||||||
|
function: Function
|
||||||
|
type: str
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionMessage(BaseModel):
|
||||||
|
content: Optional[str] = None
|
||||||
|
role: str
|
||||||
|
tool_calls: Optional[List[CompletionMessageToolCall]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionUsage(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionChoice(BaseModel):
|
||||||
|
index: int
|
||||||
|
finish_reason: str
|
||||||
|
message: CompletionMessage
|
||||||
|
|
||||||
|
|
||||||
|
class Completion(BaseModel):
|
||||||
|
model: Optional[str] = None
|
||||||
|
created: Optional[int] = None
|
||||||
|
choices: List[CompletionChoice]
|
||||||
|
request_id: Optional[str] = None
|
||||||
|
id: Optional[str] = None
|
||||||
|
usage: CompletionUsage
|
||||||
|
|
||||||
|
|
@ -0,0 +1,55 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ChatCompletionChunk",
|
||||||
|
"Choice",
|
||||||
|
"ChoiceDelta",
|
||||||
|
"ChoiceDeltaFunctionCall",
|
||||||
|
"ChoiceDeltaToolCall",
|
||||||
|
"ChoiceDeltaToolCallFunction",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ChoiceDeltaFunctionCall(BaseModel):
|
||||||
|
arguments: Optional[str] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChoiceDeltaToolCallFunction(BaseModel):
|
||||||
|
arguments: Optional[str] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChoiceDeltaToolCall(BaseModel):
|
||||||
|
index: int
|
||||||
|
id: Optional[str] = None
|
||||||
|
function: Optional[ChoiceDeltaToolCallFunction] = None
|
||||||
|
type: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChoiceDelta(BaseModel):
|
||||||
|
content: Optional[str] = None
|
||||||
|
role: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Choice(BaseModel):
|
||||||
|
delta: ChoiceDelta
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
index: int
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionUsage(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionChunk(BaseModel):
|
||||||
|
id: Optional[str] = None
|
||||||
|
choices: List[Choice]
|
||||||
|
created: Optional[int] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
usage: Optional[CompletionUsage] = None
|
@ -0,0 +1,8 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
|
|
||||||
|
class Reference(TypedDict, total=False):
|
||||||
|
enable: Optional[bool]
|
||||||
|
search_query: Optional[str]
|
@ -0,0 +1,20 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from .chat.chat_completion import CompletionUsage
|
||||||
|
__all__ = ["Embedding", "EmbeddingsResponded"]
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding(BaseModel):
|
||||||
|
object: str
|
||||||
|
index: Optional[int] = None
|
||||||
|
embedding: List[float]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsResponded(BaseModel):
|
||||||
|
object: str
|
||||||
|
data: List[Embedding]
|
||||||
|
model: str
|
||||||
|
usage: CompletionUsage
|
@ -0,0 +1,24 @@
|
|||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
__all__ = ["FileObject"]
|
||||||
|
|
||||||
|
|
||||||
|
class FileObject(BaseModel):
|
||||||
|
|
||||||
|
id: Optional[str] = None
|
||||||
|
bytes: Optional[int] = None
|
||||||
|
created_at: Optional[int] = None
|
||||||
|
filename: Optional[str] = None
|
||||||
|
object: Optional[str] = None
|
||||||
|
purpose: Optional[str] = None
|
||||||
|
status: Optional[str] = None
|
||||||
|
status_details: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ListOfFileObject(BaseModel):
|
||||||
|
|
||||||
|
object: Optional[str] = None
|
||||||
|
data: List[FileObject]
|
||||||
|
has_more: Optional[bool] = None
|
@ -0,0 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .fine_tuning_job import FineTuningJob as FineTuningJob
|
||||||
|
from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob
|
||||||
|
from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent
|
@ -0,0 +1,52 @@
|
|||||||
|
from typing import List, Union, Optional
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ]
|
||||||
|
|
||||||
|
|
||||||
|
class Error(BaseModel):
|
||||||
|
code: str
|
||||||
|
message: str
|
||||||
|
param: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Hyperparameters(BaseModel):
|
||||||
|
n_epochs: Union[str, int, None] = None
|
||||||
|
|
||||||
|
|
||||||
|
class FineTuningJob(BaseModel):
|
||||||
|
id: Optional[str] = None
|
||||||
|
|
||||||
|
request_id: Optional[str] = None
|
||||||
|
|
||||||
|
created_at: Optional[int] = None
|
||||||
|
|
||||||
|
error: Optional[Error] = None
|
||||||
|
|
||||||
|
fine_tuned_model: Optional[str] = None
|
||||||
|
|
||||||
|
finished_at: Optional[int] = None
|
||||||
|
|
||||||
|
hyperparameters: Optional[Hyperparameters] = None
|
||||||
|
|
||||||
|
model: Optional[str] = None
|
||||||
|
|
||||||
|
object: Optional[str] = None
|
||||||
|
|
||||||
|
result_files: List[str]
|
||||||
|
|
||||||
|
status: str
|
||||||
|
|
||||||
|
trained_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
training_file: str
|
||||||
|
|
||||||
|
validation_file: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ListOfFineTuningJob(BaseModel):
|
||||||
|
object: Optional[str] = None
|
||||||
|
data: List[FineTuningJob]
|
||||||
|
has_more: Optional[bool] = None
|
@ -0,0 +1,36 @@
|
|||||||
|
from typing import List, Union, Optional
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
__all__ = ["FineTuningJobEvent", "Metric", "JobEvent"]
|
||||||
|
|
||||||
|
|
||||||
|
class Metric(BaseModel):
|
||||||
|
epoch: Optional[Union[str, int, float]] = None
|
||||||
|
current_steps: Optional[int] = None
|
||||||
|
total_steps: Optional[int] = None
|
||||||
|
elapsed_time: Optional[str] = None
|
||||||
|
remaining_time: Optional[str] = None
|
||||||
|
trained_tokens: Optional[int] = None
|
||||||
|
loss: Optional[Union[str, int, float]] = None
|
||||||
|
eval_loss: Optional[Union[str, int, float]] = None
|
||||||
|
acc: Optional[Union[str, int, float]] = None
|
||||||
|
eval_acc: Optional[Union[str, int, float]] = None
|
||||||
|
learning_rate: Optional[Union[str, int, float]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class JobEvent(BaseModel):
|
||||||
|
object: Optional[str] = None
|
||||||
|
id: Optional[str] = None
|
||||||
|
type: Optional[str] = None
|
||||||
|
created_at: Optional[int] = None
|
||||||
|
level: Optional[str] = None
|
||||||
|
message: Optional[str] = None
|
||||||
|
data: Optional[Metric] = None
|
||||||
|
|
||||||
|
|
||||||
|
class FineTuningJobEvent(BaseModel):
|
||||||
|
object: Optional[str] = None
|
||||||
|
data: List[JobEvent]
|
||||||
|
has_more: Optional[bool] = None
|
@ -0,0 +1,15 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from typing_extensions import Literal, TypedDict
|
||||||
|
|
||||||
|
__all__ = ["Hyperparameters"]
|
||||||
|
|
||||||
|
|
||||||
|
class Hyperparameters(TypedDict, total=False):
|
||||||
|
batch_size: Union[Literal["auto"], int]
|
||||||
|
|
||||||
|
learning_rate_multiplier: Union[Literal["auto"], float]
|
||||||
|
|
||||||
|
n_epochs: Union[Literal["auto"], int]
|
@ -0,0 +1,18 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
__all__ = ["GeneratedImage", "ImagesResponded"]
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratedImage(BaseModel):
|
||||||
|
b64_json: Optional[str] = None
|
||||||
|
url: Optional[str] = None
|
||||||
|
revised_prompt: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ImagesResponded(BaseModel):
|
||||||
|
created: int
|
||||||
|
data: List[GeneratedImage]
|
@ -3,7 +3,8 @@ from typing import Generator
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
|
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, SystemPromptMessage,
|
||||||
|
UserPromptMessage, PromptMessageTool)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.zhipuai.llm.llm import ZhipuAILargeLanguageModel
|
from core.model_runtime.model_providers.zhipuai.llm.llm import ZhipuAILargeLanguageModel
|
||||||
|
|
||||||
@ -102,3 +103,48 @@ def test_get_num_tokens():
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert num_tokens == 14
|
assert num_tokens == 14
|
||||||
|
|
||||||
|
def test_get_tools_num_tokens():
|
||||||
|
model = ZhipuAILargeLanguageModel()
|
||||||
|
|
||||||
|
num_tokens = model.get_num_tokens(
|
||||||
|
model='tools',
|
||||||
|
credentials={
|
||||||
|
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||||
|
},
|
||||||
|
tools=[
|
||||||
|
PromptMessageTool(
|
||||||
|
name='get_current_weather',
|
||||||
|
description='Get the current weather in a given location',
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state e.g. San Francisco, CA"
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": [
|
||||||
|
"c",
|
||||||
|
"f"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"location"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
],
|
||||||
|
prompt_messages=[
|
||||||
|
SystemPromptMessage(
|
||||||
|
content='You are a helpful AI assistant.',
|
||||||
|
),
|
||||||
|
UserPromptMessage(
|
||||||
|
content='Hello World!'
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_tokens == 108
|
@ -42,7 +42,7 @@ def test_invoke_model():
|
|||||||
|
|
||||||
assert isinstance(result, TextEmbeddingResult)
|
assert isinstance(result, TextEmbeddingResult)
|
||||||
assert len(result.embeddings) == 2
|
assert len(result.embeddings) == 2
|
||||||
assert result.usage.total_tokens == 2
|
assert result.usage.total_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
def test_get_num_tokens():
|
def test_get_num_tokens():
|
||||||
|
@ -229,7 +229,7 @@ const Answer: FC<IAnswerProps> = ({
|
|||||||
<Thought
|
<Thought
|
||||||
thought={item}
|
thought={item}
|
||||||
allToolIcons={allToolIcons || {}}
|
allToolIcons={allToolIcons || {}}
|
||||||
isFinished={!!item.observation}
|
isFinished={!!item.observation || !isResponsing}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ import { fetchDatasets } from '@/service/datasets'
|
|||||||
import { useProviderContext } from '@/context/provider-context'
|
import { useProviderContext } from '@/context/provider-context'
|
||||||
import { AgentStrategy, AppType, ModelModeType, RETRIEVE_TYPE, Resolution, TransferMethod } from '@/types/app'
|
import { AgentStrategy, AppType, ModelModeType, RETRIEVE_TYPE, Resolution, TransferMethod } from '@/types/app'
|
||||||
import { PromptMode } from '@/models/debug'
|
import { PromptMode } from '@/models/debug'
|
||||||
import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config'
|
import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG, supportFunctionCallModels } from '@/config'
|
||||||
import SelectDataSet from '@/app/components/app/configuration/dataset-config/select-dataset'
|
import SelectDataSet from '@/app/components/app/configuration/dataset-config/select-dataset'
|
||||||
import I18n from '@/context/i18n'
|
import I18n from '@/context/i18n'
|
||||||
import { useModalContext } from '@/context/modal-context'
|
import { useModalContext } from '@/context/modal-context'
|
||||||
@ -163,8 +163,7 @@ const Configuration: FC = () => {
|
|||||||
doSetModelConfig(newModelConfig)
|
doSetModelConfig(newModelConfig)
|
||||||
}
|
}
|
||||||
const isOpenAI = modelConfig.provider === 'openai'
|
const isOpenAI = modelConfig.provider === 'openai'
|
||||||
const isFunctionCall = isOpenAI && modelConfig.mode === ModelModeType.chat
|
const isFunctionCall = (isOpenAI && modelConfig.mode === ModelModeType.chat) || supportFunctionCallModels.includes(modelConfig.model_id)
|
||||||
|
|
||||||
const [collectionList, setCollectionList] = useState<Collection[]>([])
|
const [collectionList, setCollectionList] = useState<Collection[]>([])
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
|
||||||
|
@ -160,6 +160,8 @@ export const DEFAULT_AGENT_SETTING = {
|
|||||||
tools: [],
|
tools: [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const supportFunctionCallModels = ['glm-3-turbo', 'glm-4']
|
||||||
|
|
||||||
export const DEFAULT_AGENT_PROMPT = {
|
export const DEFAULT_AGENT_PROMPT = {
|
||||||
chat: `Respond to the human as helpfully and accurately as possible.
|
chat: `Respond to the human as helpfully and accurately as possible.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user