mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-02 11:18:19 +08:00
Feat/zhipuai function calling (#2199)
Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
parent
bdc5e9ceb0
commit
b921c55677
@ -1,61 +0,0 @@
|
||||
"""Wrapper around ZhipuAI APIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import posixpath
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
from zhipuai.model_api.api import InvokeType
|
||||
from zhipuai.utils import jwt_token
|
||||
from zhipuai.utils.http_client import post, stream
|
||||
from zhipuai.utils.sse_client import SSEClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ZhipuModelAPI(BaseModel):
|
||||
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
|
||||
api_key: str
|
||||
api_timeout_seconds = 60
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def invoke(self, **kwargs):
|
||||
url = self._build_api_url(kwargs, InvokeType.SYNC)
|
||||
response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds)
|
||||
if not response['success']:
|
||||
raise ValueError(
|
||||
f"Error Code: {response['code']}, Message: {response['msg']} "
|
||||
)
|
||||
return response
|
||||
|
||||
def sse_invoke(self, **kwargs):
|
||||
url = self._build_api_url(kwargs, InvokeType.SSE)
|
||||
data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds)
|
||||
return SSEClient(data)
|
||||
|
||||
def _build_api_url(self, kwargs, *path):
|
||||
if kwargs:
|
||||
if "model" not in kwargs:
|
||||
raise Exception("model param missed")
|
||||
model = kwargs.pop("model")
|
||||
else:
|
||||
model = "-"
|
||||
|
||||
return posixpath.join(self.base_url, model, *path)
|
||||
|
||||
def _generate_token(self):
|
||||
if not self.api_key:
|
||||
raise Exception(
|
||||
"api_key not provided, you could provide it."
|
||||
)
|
||||
|
||||
try:
|
||||
return jwt_token.generate_token(self.api_key)
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
f"Your api_key is invalid, please check it."
|
||||
)
|
@ -3,13 +3,15 @@ from typing import Any, Dict, Generator, List, Optional, Union
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole,
|
||||
PromptMessageTool, SystemPromptMessage, UserPromptMessage,
|
||||
PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage,
|
||||
TextPromptMessageContent, ImagePromptMessageContent, PromptMessageContentType)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.utils import helper
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI
|
||||
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
||||
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion
|
||||
|
||||
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
|
||||
@ -35,7 +37,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
# invoke model
|
||||
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, stop, stream, user)
|
||||
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
@ -48,7 +50,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages, tools)
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
@ -72,6 +74,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
model_parameters={
|
||||
"temperature": 0.5,
|
||||
},
|
||||
tools=[],
|
||||
stream=False
|
||||
)
|
||||
except Exception as ex:
|
||||
@ -79,6 +82,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
|
||||
def _generate(self, model: str, credentials_kwargs: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[List[str]] = None, stream: bool = True,
|
||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
@ -97,7 +101,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
if stop:
|
||||
extra_model_kwargs['stop_sequences'] = stop
|
||||
|
||||
client = ZhipuModelAPI(
|
||||
client = ZhipuAI(
|
||||
api_key=credentials_kwargs['api_key']
|
||||
)
|
||||
|
||||
@ -128,11 +132,17 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
# not support image message
|
||||
continue
|
||||
|
||||
if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER:
|
||||
if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \
|
||||
copy_prompt_message.role == PromptMessageRole.USER:
|
||||
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
|
||||
else:
|
||||
if copy_prompt_message.role == PromptMessageRole.USER:
|
||||
new_prompt_messages.append(copy_prompt_message)
|
||||
elif copy_prompt_message.role == PromptMessageRole.TOOL:
|
||||
new_prompt_messages.append(copy_prompt_message)
|
||||
elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
|
||||
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
|
||||
new_prompt_messages.append(new_prompt_message)
|
||||
else:
|
||||
new_prompt_message = UserPromptMessage(content=copy_prompt_message.content)
|
||||
new_prompt_messages.append(new_prompt_message)
|
||||
@ -145,7 +155,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
if model == 'glm-4v':
|
||||
params = {
|
||||
'model': model,
|
||||
'prompt': [{
|
||||
'messages': [{
|
||||
'role': prompt_message.role.value,
|
||||
'content':
|
||||
[
|
||||
@ -171,23 +181,63 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
else:
|
||||
params = {
|
||||
'model': model,
|
||||
'prompt': [{
|
||||
'role': prompt_message.role.value,
|
||||
'content': prompt_message.content,
|
||||
} for prompt_message in new_prompt_messages],
|
||||
'messages': [],
|
||||
**model_parameters
|
||||
}
|
||||
# glm model
|
||||
if not model.startswith('chatglm'):
|
||||
|
||||
for prompt_message in new_prompt_messages:
|
||||
if prompt_message.role == PromptMessageRole.TOOL:
|
||||
params['messages'].append({
|
||||
'role': 'tool',
|
||||
'content': prompt_message.content,
|
||||
'tool_call_id': prompt_message.tool_call_id
|
||||
})
|
||||
else:
|
||||
params['messages'].append({
|
||||
'role': prompt_message.role.value,
|
||||
'content': prompt_message.content
|
||||
})
|
||||
else:
|
||||
# chatglm model
|
||||
for prompt_message in new_prompt_messages:
|
||||
# merge system message to user message
|
||||
if prompt_message.role == PromptMessageRole.SYSTEM or \
|
||||
prompt_message.role == PromptMessageRole.TOOL or \
|
||||
prompt_message.role == PromptMessageRole.USER:
|
||||
if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user':
|
||||
params['messages'][-1]['content'] += "\n\n" + prompt_message.content
|
||||
else:
|
||||
params['messages'].append({
|
||||
'role': 'user',
|
||||
'content': prompt_message.content
|
||||
})
|
||||
else:
|
||||
params['messages'].append({
|
||||
'role': prompt_message.role.value,
|
||||
'content': prompt_message.content
|
||||
})
|
||||
|
||||
if tools and len(tools) > 0:
|
||||
params['tools'] = [
|
||||
{
|
||||
'type': 'function',
|
||||
'function': helper.dump_model(tool)
|
||||
} for tool in tools
|
||||
]
|
||||
|
||||
if stream:
|
||||
response = client.sse_invoke(incremental=True, **params).events()
|
||||
return self._handle_generate_stream_response(model, credentials_kwargs, response, prompt_messages)
|
||||
response = client.chat.completions.create(stream=stream, **params)
|
||||
return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages)
|
||||
|
||||
response = client.invoke(**params)
|
||||
return self._handle_generate_response(model, credentials_kwargs, response, prompt_messages)
|
||||
response = client.chat.completions.create(**params)
|
||||
return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str,
|
||||
credentials: dict,
|
||||
response: Dict[str, Any],
|
||||
tools: Optional[list[PromptMessageTool]],
|
||||
response: Completion,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
@ -197,26 +247,39 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response
|
||||
"""
|
||||
data = response["data"]
|
||||
text = ''
|
||||
for res in data["choices"]:
|
||||
text += res['content']
|
||||
assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = []
|
||||
for choice in response.choices:
|
||||
if choice.message.tool_calls:
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if tool_call.type == 'function':
|
||||
assistant_tool_calls.append(
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call.id,
|
||||
type=tool_call.type,
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
token_usage = data.get("usage")
|
||||
if token_usage is not None:
|
||||
if 'prompt_tokens' not in token_usage:
|
||||
token_usage['prompt_tokens'] = 0
|
||||
if 'completion_tokens' not in token_usage:
|
||||
token_usage['completion_tokens'] = token_usage['total_tokens']
|
||||
text += choice.message.content or ''
|
||||
|
||||
prompt_usage = response.usage.prompt_tokens
|
||||
completion_usage = response.usage.completion_tokens
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens'])
|
||||
usage = self._calc_response_usage(model, credentials, prompt_usage, completion_usage)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
message=AssistantPromptMessage(
|
||||
content=text,
|
||||
tool_calls=assistant_tool_calls
|
||||
),
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
@ -224,7 +287,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
|
||||
def _handle_generate_stream_response(self, model: str,
|
||||
credentials: dict,
|
||||
responses: list[Generator],
|
||||
tools: Optional[list[PromptMessageTool]],
|
||||
responses: Generator[ChatCompletionChunk, None, None],
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
@ -234,41 +298,66 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator result
|
||||
"""
|
||||
for index, event in enumerate(responses):
|
||||
if event.event == "add":
|
||||
yield LLMResultChunk(
|
||||
prompt_messages=prompt_messages,
|
||||
model=model,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=event.data)
|
||||
)
|
||||
)
|
||||
elif event.event == "error" or event.event == "interrupted":
|
||||
raise ValueError(
|
||||
f"{event.data}"
|
||||
)
|
||||
elif event.event == "finish":
|
||||
meta = json.loads(event.meta)
|
||||
token_usage = meta['usage']
|
||||
if token_usage is not None:
|
||||
if 'prompt_tokens' not in token_usage:
|
||||
token_usage['prompt_tokens'] = 0
|
||||
if 'completion_tokens' not in token_usage:
|
||||
token_usage['completion_tokens'] = token_usage['total_tokens']
|
||||
full_assistant_content = ''
|
||||
for chunk in responses:
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens'])
|
||||
delta = chunk.choices[0]
|
||||
|
||||
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
|
||||
continue
|
||||
|
||||
assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = []
|
||||
for tool_call in delta.delta.tool_calls or []:
|
||||
if tool_call.type == 'function':
|
||||
assistant_tool_calls.append(
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call.id,
|
||||
type=tool_call.type,
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.delta.content if delta.delta.content else '',
|
||||
tool_calls=assistant_tool_calls
|
||||
)
|
||||
|
||||
full_assistant_content += delta.delta.content if delta.delta.content else ''
|
||||
|
||||
if delta.finish_reason is not None and chunk.usage is not None:
|
||||
completion_tokens = chunk.usage.completion_tokens
|
||||
prompt_tokens = chunk.usage.prompt_tokens
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
model=chunk.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=event.data),
|
||||
finish_reason='finish',
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=delta.finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=chunk.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint='',
|
||||
delta=LLMResultChunkDelta(
|
||||
index=delta.index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||
"""
|
||||
@ -292,10 +381,9 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
|
||||
return message_text
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
|
||||
"""
|
||||
Format a list of messages into a full prompt for the Anthropic model
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str:
|
||||
"""
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
@ -306,5 +394,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
for message in messages
|
||||
)
|
||||
|
||||
if tools and len(tools) > 0:
|
||||
text += "\n\nTools:"
|
||||
for tool in tools:
|
||||
text += f"\n{tool.json()}"
|
||||
|
||||
# trim off the trailing ' ' that might come from the "Assistant: "
|
||||
return text.rstrip()
|
@ -5,7 +5,7 @@ from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI
|
||||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
|
||||
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
|
||||
from langchain.schema.language_model import _get_token_ids_default_method
|
||||
|
||||
@ -28,7 +28,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||
:return: embeddings result
|
||||
"""
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = ZhipuModelAPI(
|
||||
client = ZhipuAI(
|
||||
api_key=credentials_kwargs['api_key']
|
||||
)
|
||||
|
||||
@ -69,7 +69,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||
try:
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = ZhipuModelAPI(
|
||||
client = ZhipuAI(
|
||||
api_key=credentials_kwargs['api_key']
|
||||
)
|
||||
|
||||
@ -82,7 +82,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def embed_documents(self, model: str, client: ZhipuModelAPI, texts: List[str]) -> Tuple[List[List[float]], int]:
|
||||
def embed_documents(self, model: str, client: ZhipuAI, texts: List[str]) -> Tuple[List[List[float]], int]:
|
||||
"""Call out to ZhipuAI's embedding endpoint.
|
||||
|
||||
Args:
|
||||
@ -91,17 +91,16 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
|
||||
|
||||
embeddings = []
|
||||
embedding_used_tokens = 0
|
||||
|
||||
for text in texts:
|
||||
response = client.invoke(model=model, prompt=text)
|
||||
data = response["data"]
|
||||
embeddings.append(data.get('embedding'))
|
||||
response = client.embeddings.create(model=model, input=text)
|
||||
data = response.data[0]
|
||||
embeddings.append(data.embedding)
|
||||
embedding_used_tokens += response.usage.total_tokens
|
||||
|
||||
embedding_used_tokens = data.get('usage')
|
||||
|
||||
return [list(map(float, e)) for e in embeddings], embedding_used_tokens['total_tokens'] if embedding_used_tokens else 0
|
||||
return [list(map(float, e)) for e in embeddings], embedding_used_tokens
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to ZhipuAI's embedding endpoint.
|
||||
|
@ -0,0 +1,17 @@
|
||||
|
||||
from ._client import ZhipuAI
|
||||
|
||||
from .core._errors import (
|
||||
ZhipuAIError,
|
||||
APIStatusError,
|
||||
APIRequestFailedError,
|
||||
APIAuthenticationError,
|
||||
APIReachLimitError,
|
||||
APIInternalError,
|
||||
APIServerFlowExceedError,
|
||||
APIResponseError,
|
||||
APIResponseValidationError,
|
||||
APITimeoutError,
|
||||
)
|
||||
|
||||
from .__version__ import __version__
|
@ -0,0 +1,2 @@
|
||||
|
||||
__version__ = 'v2.0.1'
|
@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, Mapping
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .core import _jwt_token
|
||||
from .core._errors import ZhipuAIError
|
||||
from .core._http_client import HttpClient, ZHIPUAI_DEFAULT_MAX_RETRIES
|
||||
from .core._base_type import NotGiven, NOT_GIVEN
|
||||
from . import api_resource
|
||||
import os
|
||||
import httpx
|
||||
from httpx import Timeout
|
||||
|
||||
|
||||
class ZhipuAI(HttpClient):
|
||||
chat: api_resource.chat
|
||||
api_key: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | httpx.URL | None = None,
|
||||
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
|
||||
max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
|
||||
http_client: httpx.Client | None = None,
|
||||
custom_headers: Mapping[str, str] | None = None
|
||||
) -> None:
|
||||
# if api_key is None:
|
||||
# api_key = os.environ.get("ZHIPUAI_API_KEY")
|
||||
if api_key is None:
|
||||
raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供")
|
||||
self.api_key = api_key
|
||||
|
||||
if base_url is None:
|
||||
base_url = os.environ.get("ZHIPUAI_BASE_URL")
|
||||
if base_url is None:
|
||||
base_url = f"https://open.bigmodel.cn/api/paas/v4"
|
||||
from .__version__ import __version__
|
||||
super().__init__(
|
||||
version=__version__,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
custom_httpx_client=http_client,
|
||||
custom_headers=custom_headers,
|
||||
)
|
||||
self.chat = api_resource.chat.Chat(self)
|
||||
self.images = api_resource.images.Images(self)
|
||||
self.embeddings = api_resource.embeddings.Embeddings(self)
|
||||
self.files = api_resource.files.Files(self)
|
||||
self.fine_tuning = api_resource.fine_tuning.FineTuning(self)
|
||||
|
||||
@property
|
||||
@override
|
||||
def _auth_headers(self) -> dict[str, str]:
|
||||
api_key = self.api_key
|
||||
return {"Authorization": f"{_jwt_token.generate_token(api_key)}"}
|
||||
|
||||
def __del__(self) -> None:
|
||||
if (not hasattr(self, "_has_custom_http_client")
|
||||
or not hasattr(self, "close")
|
||||
or not hasattr(self, "_client")):
|
||||
# if the '__init__' method raised an error, self would not have client attr
|
||||
return
|
||||
|
||||
if self._has_custom_http_client:
|
||||
return
|
||||
|
||||
self.close()
|
@ -0,0 +1,5 @@
|
||||
from .chat import chat
|
||||
from .images import Images
|
||||
from .embeddings import Embeddings
|
||||
from .files import Files
|
||||
from .fine_tuning import fine_tuning
|
@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, List, Optional, TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ...core._base_api import BaseAPI
|
||||
from ...core._base_type import NotGiven, NOT_GIVEN, Headers
|
||||
from ...core._http_client import make_user_request_input
|
||||
from ...types.chat.async_chat_completion import AsyncTaskStatus, AsyncCompletion
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
|
||||
class AsyncCompletions(BaseAPI):
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||
temperature: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
top_p: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
max_tokens: int | NotGiven = NOT_GIVEN,
|
||||
seed: int | NotGiven = NOT_GIVEN,
|
||||
messages: Union[str, List[str], List[int], List[List[int]], None],
|
||||
stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN,
|
||||
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
tools: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
tool_choice: str | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> AsyncTaskStatus:
|
||||
_cast_type = AsyncTaskStatus
|
||||
|
||||
if disable_strict_validation:
|
||||
_cast_type = object
|
||||
return self._post(
|
||||
"/async/chat/completions",
|
||||
body={
|
||||
"model": model,
|
||||
"request_id": request_id,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"do_sample": do_sample,
|
||||
"max_tokens": max_tokens,
|
||||
"seed": seed,
|
||||
"messages": messages,
|
||||
"stop": stop,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
},
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers, timeout=timeout
|
||||
),
|
||||
cast_type=_cast_type,
|
||||
enable_stream=False,
|
||||
)
|
||||
|
||||
def retrieve_completion_result(
|
||||
self,
|
||||
id: str,
|
||||
extra_headers: Headers | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Union[AsyncCompletion, AsyncTaskStatus]:
|
||||
_cast_type = Union[AsyncCompletion,AsyncTaskStatus]
|
||||
if disable_strict_validation:
|
||||
_cast_type = object
|
||||
return self._get(
|
||||
path=f"/async-result/{id}",
|
||||
cast_type=_cast_type,
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -0,0 +1,16 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from .completions import Completions
|
||||
from .async_completions import AsyncCompletions
|
||||
from ...core._base_api import BaseAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
|
||||
class Chat(BaseAPI):
|
||||
completions: Completions
|
||||
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
self.completions = Completions(client)
|
||||
self.asyncCompletions = AsyncCompletions(client)
|
@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, List, Optional, TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ...core._base_api import BaseAPI
|
||||
from ...core._base_type import NotGiven, NOT_GIVEN, Headers
|
||||
from ...core._http_client import make_user_request_input
|
||||
from ...core._sse_client import StreamResponse
|
||||
from ...types.chat.chat_completion import Completion
|
||||
from ...types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
|
||||
class Completions(BaseAPI):
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||
temperature: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
top_p: Optional[float] | NotGiven = NOT_GIVEN,
|
||||
max_tokens: int | NotGiven = NOT_GIVEN,
|
||||
seed: int | NotGiven = NOT_GIVEN,
|
||||
messages: Union[str, List[str], List[int], object, None],
|
||||
stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN,
|
||||
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
tools: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
tool_choice: str | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Completion | StreamResponse[ChatCompletionChunk]:
|
||||
_cast_type = Completion
|
||||
_stream_cls = StreamResponse[ChatCompletionChunk]
|
||||
if disable_strict_validation:
|
||||
_cast_type = object
|
||||
_stream_cls = StreamResponse[object]
|
||||
return self._post(
|
||||
"/chat/completions",
|
||||
body={
|
||||
"model": model,
|
||||
"request_id": request_id,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"do_sample": do_sample,
|
||||
"max_tokens": max_tokens,
|
||||
"seed": seed,
|
||||
"messages": messages,
|
||||
"stop": stop,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
"stream": stream,
|
||||
"tools": tools,
|
||||
"tool_choice": tool_choice,
|
||||
},
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers,
|
||||
),
|
||||
cast_type=_cast_type,
|
||||
enable_stream=stream or False,
|
||||
stream_cls=_stream_cls,
|
||||
)
|
@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, List, Optional, TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core._base_api import BaseAPI
|
||||
from ..core._base_type import NotGiven, NOT_GIVEN, Headers
|
||||
from ..core._http_client import make_user_request_input
|
||||
from ..types.embeddings import EmbeddingsResponded
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
|
||||
class Embeddings(BaseAPI):
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
input: Union[str, List[str], List[int], List[List[int]]],
|
||||
model: Union[str],
|
||||
encoding_format: str | NotGiven = NOT_GIVEN,
|
||||
user: str | NotGiven = NOT_GIVEN,
|
||||
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> EmbeddingsResponded:
|
||||
_cast_type = EmbeddingsResponded
|
||||
if disable_strict_validation:
|
||||
_cast_type = object
|
||||
return self._post(
|
||||
"/embeddings",
|
||||
body={
|
||||
"input": input,
|
||||
"model": model,
|
||||
"encoding_format": encoding_format,
|
||||
"user": user,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
},
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers, timeout=timeout
|
||||
),
|
||||
cast_type=_cast_type,
|
||||
enable_stream=False,
|
||||
)
|
@ -0,0 +1,78 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core._base_api import BaseAPI
|
||||
from ..core._base_type import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
|
||||
from ..core._files import is_file_content
|
||||
from ..core._http_client import (
|
||||
make_user_request_input,
|
||||
)
|
||||
from ..types.file_object import FileObject, ListOfFileObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
__all__ = ["Files"]
|
||||
|
||||
|
||||
class Files(BaseAPI):
|
||||
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
file: FileTypes,
|
||||
purpose: str,
|
||||
extra_headers: Headers | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FileObject:
|
||||
if not is_file_content(file):
|
||||
prefix = f"Expected file input `{file!r}`"
|
||||
raise RuntimeError(
|
||||
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(file)} instead."
|
||||
) from None
|
||||
files = [("file", file)]
|
||||
|
||||
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
|
||||
|
||||
return self._post(
|
||||
"/files",
|
||||
body={
|
||||
"purpose": purpose,
|
||||
},
|
||||
files=files,
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers, timeout=timeout
|
||||
),
|
||||
cast_type=FileObject,
|
||||
)
|
||||
|
||||
def list(
|
||||
self,
|
||||
*,
|
||||
purpose: str | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
order: str | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ListOfFileObject:
|
||||
return self._get(
|
||||
"/files",
|
||||
cast_type=ListOfFileObject,
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
query={
|
||||
"purpose": purpose,
|
||||
"limit": limit,
|
||||
"after": after,
|
||||
"order": order,
|
||||
},
|
||||
),
|
||||
)
|
@ -0,0 +1,15 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from .jobs import Jobs
|
||||
from ...core._base_api import BaseAPI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
|
||||
class FineTuning(BaseAPI):
|
||||
jobs: Jobs
|
||||
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
self.jobs = Jobs(client)
|
||||
|
@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core._base_api import BaseAPI
|
||||
from ...core._base_type import NOT_GIVEN, Headers, NotGiven
|
||||
from ...core._http_client import (
|
||||
make_user_request_input,
|
||||
)
|
||||
from ...types.fine_tuning import (
|
||||
FineTuningJob,
|
||||
job_create_params,
|
||||
ListOfFineTuningJob,
|
||||
FineTuningJobEvent,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
__all__ = ["Jobs"]
|
||||
|
||||
|
||||
class Jobs(BaseAPI):
|
||||
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
training_file: str,
|
||||
hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN,
|
||||
suffix: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
validation_file: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTuningJob:
|
||||
return self._post(
|
||||
"/fine_tuning/jobs",
|
||||
body={
|
||||
"model": model,
|
||||
"training_file": training_file,
|
||||
"hyperparameters": hyperparameters,
|
||||
"suffix": suffix,
|
||||
"validation_file": validation_file,
|
||||
"request_id": request_id,
|
||||
},
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers, timeout=timeout
|
||||
),
|
||||
cast_type=FineTuningJob,
|
||||
)
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
fine_tuning_job_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTuningJob:
|
||||
return self._get(
|
||||
f"/fine_tuning/jobs/{fine_tuning_job_id}",
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers, timeout=timeout
|
||||
),
|
||||
cast_type=FineTuningJob,
|
||||
)
|
||||
|
||||
def list(
|
||||
self,
|
||||
*,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ListOfFineTuningJob:
|
||||
return self._get(
|
||||
"/fine_tuning/jobs",
|
||||
cast_type=ListOfFineTuningJob,
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
query={
|
||||
"after": after,
|
||||
"limit": limit,
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
def list_events(
|
||||
self,
|
||||
fine_tuning_job_id: str,
|
||||
*,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTuningJobEvent:
|
||||
|
||||
return self._get(
|
||||
f"/fine_tuning/jobs/{fine_tuning_job_id}/events",
|
||||
cast_type=FineTuningJobEvent,
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers,
|
||||
timeout=timeout,
|
||||
query={
|
||||
"after": after,
|
||||
"limit": limit,
|
||||
},
|
||||
),
|
||||
)
|
@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, List, Optional, TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core._base_api import BaseAPI
|
||||
from ..core._base_type import NotGiven, NOT_GIVEN, Headers
|
||||
from ..core._http_client import make_user_request_input
|
||||
from ..types.image import ImagesResponded
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
|
||||
class Images(BaseAPI):
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def generations(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
model: str | NotGiven = NOT_GIVEN,
|
||||
n: Optional[int] | NotGiven = NOT_GIVEN,
|
||||
quality: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
response_format: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
size: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
style: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
user: str | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
disable_strict_validation: Optional[bool] | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ImagesResponded:
|
||||
_cast_type = ImagesResponded
|
||||
if disable_strict_validation:
|
||||
_cast_type = object
|
||||
return self._post(
|
||||
"/images/generations",
|
||||
body={
|
||||
"prompt": prompt,
|
||||
"model": model,
|
||||
"n": n,
|
||||
"quality": quality,
|
||||
"response_format": response_format,
|
||||
"size": size,
|
||||
"style": style,
|
||||
"user": user,
|
||||
},
|
||||
options=make_user_request_input(
|
||||
extra_headers=extra_headers, timeout=timeout
|
||||
),
|
||||
cast_type=_cast_type,
|
||||
enable_stream=False,
|
||||
)
|
@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
|
||||
class BaseAPI:
|
||||
_client: ZhipuAI
|
||||
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
self._client = client
|
||||
self._delete = client.delete
|
||||
self._get = client.get
|
||||
self._post = client.post
|
||||
self._put = client.put
|
||||
self._patch = client.patch
|
@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from os import PathLike
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Type,
|
||||
Union,
|
||||
Mapping,
|
||||
TypeVar, IO, Tuple, Sequence, Any, List,
|
||||
)
|
||||
|
||||
import pydantic
|
||||
from typing_extensions import (
|
||||
Literal,
|
||||
override,
|
||||
)
|
||||
|
||||
|
||||
Query = Mapping[str, object]
|
||||
Body = object
|
||||
AnyMapping = Mapping[str, object]
|
||||
PrimitiveData = Union[str, int, float, bool, None]
|
||||
Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"]
|
||||
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
|
||||
_T = TypeVar("_T")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
NoneType: Type[None]
|
||||
else:
|
||||
NoneType = type(None)
|
||||
|
||||
|
||||
# Sentinel class used until PEP 0661 is accepted
|
||||
class NotGiven(pydantic.BaseModel):
|
||||
"""
|
||||
A sentinel singleton class used to distinguish omitted keyword arguments
|
||||
from those passed in with the value None (which may have different behavior).
|
||||
|
||||
For example:
|
||||
|
||||
```py
|
||||
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...
|
||||
|
||||
get(timeout=1) # 1s timeout
|
||||
get(timeout=None) # No timeout
|
||||
get() # Default timeout behavior, which may not be statically known at the method definition.
|
||||
```
|
||||
"""
|
||||
|
||||
def __bool__(self) -> Literal[False]:
|
||||
return False
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return "NOT_GIVEN"
|
||||
|
||||
|
||||
NotGivenOr = Union[_T, NotGiven]
|
||||
NOT_GIVEN = NotGiven()
|
||||
|
||||
|
||||
class Omit(pydantic.BaseModel):
|
||||
"""In certain situations you need to be able to represent a case where a default value has
|
||||
to be explicitly removed and `None` is not an appropriate substitute, for example:
|
||||
|
||||
```py
|
||||
# as the default `Content-Type` header is `application/json` that will be sent
|
||||
client.post('/upload/files', files={'file': b'my raw file content'})
|
||||
|
||||
# you can't explicitly override the header as it has to be dynamically generated
|
||||
# to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
|
||||
client.post(..., headers={'Content-Type': 'multipart/form-data'})
|
||||
|
||||
# instead you can remove the default `application/json` header by passing Omit
|
||||
client.post(..., headers={'Content-Type': Omit()})
|
||||
```
|
||||
"""
|
||||
|
||||
def __bool__(self) -> Literal[False]:
|
||||
return False
|
||||
|
||||
|
||||
Headers = Mapping[str, Union[str, Omit]]
|
||||
|
||||
ResponseT = TypeVar(
|
||||
"ResponseT",
|
||||
bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]",
|
||||
)
|
||||
|
||||
# for user input files
|
||||
if TYPE_CHECKING:
|
||||
FileContent = Union[IO[bytes], bytes, PathLike[str]]
|
||||
else:
|
||||
FileContent = Union[IO[bytes], bytes, PathLike]
|
||||
|
||||
FileTypes = Union[
|
||||
FileContent, # file content
|
||||
Tuple[str, FileContent], # (filename, file)
|
||||
Tuple[str, FileContent, str], # (filename, file , content_type)
|
||||
Tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers)
|
||||
]
|
||||
|
||||
RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
|
||||
|
||||
# for httpx client supported files
|
||||
|
||||
HttpxFileContent = Union[bytes, IO[bytes]]
|
||||
HttpxFileTypes = Union[
|
||||
FileContent, # file content
|
||||
Tuple[str, HttpxFileContent], # (filename, file)
|
||||
Tuple[str, HttpxFileContent, str], # (filename, file , content_type)
|
||||
Tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers)
|
||||
]
|
||||
|
||||
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]
|
@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
|
||||
__all__ = [
|
||||
"ZhipuAIError",
|
||||
"APIStatusError",
|
||||
"APIRequestFailedError",
|
||||
"APIAuthenticationError",
|
||||
"APIReachLimitError",
|
||||
"APIInternalError",
|
||||
"APIServerFlowExceedError",
|
||||
"APIResponseError",
|
||||
"APIResponseValidationError",
|
||||
"APITimeoutError",
|
||||
]
|
||||
|
||||
|
||||
class ZhipuAIError(Exception):
|
||||
def __init__(self, message: str, ) -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class APIStatusError(Exception):
|
||||
response: httpx.Response
|
||||
status_code: int
|
||||
|
||||
def __init__(self, message: str, *, response: httpx.Response) -> None:
|
||||
super().__init__(message)
|
||||
self.response = response
|
||||
self.status_code = response.status_code
|
||||
|
||||
|
||||
class APIRequestFailedError(APIStatusError):
|
||||
...
|
||||
|
||||
|
||||
class APIAuthenticationError(APIStatusError):
|
||||
...
|
||||
|
||||
|
||||
class APIReachLimitError(APIStatusError):
|
||||
...
|
||||
|
||||
|
||||
class APIInternalError(APIStatusError):
|
||||
...
|
||||
|
||||
|
||||
class APIServerFlowExceedError(APIStatusError):
|
||||
...
|
||||
|
||||
|
||||
class APIResponseError(Exception):
|
||||
message: str
|
||||
request: httpx.Request
|
||||
json_data: object
|
||||
|
||||
def __init__(self, message: str, request: httpx.Request, json_data: object):
|
||||
self.message = message
|
||||
self.request = request
|
||||
self.json_data = json_data
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class APIResponseValidationError(APIResponseError):
|
||||
status_code: int
|
||||
response: httpx.Response
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
json_data: object | None, *,
|
||||
message: str | None = None
|
||||
) -> None:
|
||||
super().__init__(
|
||||
message=message or "Data returned by API invalid for expected schema.",
|
||||
request=response.request,
|
||||
json_data=json_data
|
||||
)
|
||||
self.response = response
|
||||
self.status_code = response.status_code
|
||||
|
||||
|
||||
class APITimeoutError(Exception):
|
||||
request: httpx.Request
|
||||
|
||||
def __init__(self, request: httpx.Request):
|
||||
self.request = request
|
||||
super().__init__("Request Timeout")
|
@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Mapping, Sequence
|
||||
|
||||
from ._base_type import (
|
||||
FileTypes,
|
||||
HttpxFileTypes,
|
||||
HttpxRequestFiles,
|
||||
RequestFiles,
|
||||
)
|
||||
|
||||
|
||||
def is_file_content(obj: object) -> bool:
|
||||
return isinstance(obj, (bytes, tuple, io.IOBase, os.PathLike))
|
||||
|
||||
|
||||
def _transform_file(file: FileTypes) -> HttpxFileTypes:
|
||||
if is_file_content(file):
|
||||
if isinstance(file, os.PathLike):
|
||||
path = Path(file)
|
||||
return path.name, path.read_bytes()
|
||||
else:
|
||||
return file
|
||||
if isinstance(file, tuple):
|
||||
if isinstance(file[1], os.PathLike):
|
||||
return (file[0], Path(file[1]).read_bytes(), *file[2:])
|
||||
else:
|
||||
return (file[0], file[1], *file[2:])
|
||||
else:
|
||||
raise TypeError(f"Unexpected input file with type {type(file)},Expected FileContent type or tuple type")
|
||||
|
||||
|
||||
def make_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
|
||||
if files is None:
|
||||
return None
|
||||
|
||||
if isinstance(files, Mapping):
|
||||
files = {key: _transform_file(file) for key, file in files.items()}
|
||||
elif isinstance(files, Sequence):
|
||||
files = [(key, _transform_file(file)) for key, file in files]
|
||||
else:
|
||||
raise TypeError(f"Unexpected input file with type {type(files)}, excepted Mapping or Sequence")
|
||||
return files
|
@ -0,0 +1,377 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import (
|
||||
Any,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
Mapping,
|
||||
)
|
||||
|
||||
import httpx
|
||||
import pydantic
|
||||
from httpx import URL, Timeout
|
||||
|
||||
from . import _errors
|
||||
from ._base_type import NotGiven, ResponseT, Body, Headers, NOT_GIVEN, RequestFiles, Query, Data
|
||||
from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError
|
||||
from ._files import make_httpx_files
|
||||
from ._request_opt import ClientRequestParam, UserRequestInput
|
||||
from ._response import HttpResponse
|
||||
from ._sse_client import StreamResponse
|
||||
from ._utils import flatten
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json; charset=UTF-8",
|
||||
}
|
||||
|
||||
|
||||
def _merge_map(map1: Mapping, map2: Mapping) -> Mapping:
|
||||
merged = {**map1, **map2}
|
||||
return {key: val for key, val in merged.items() if val is not None}
|
||||
|
||||
|
||||
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
|
||||
|
||||
ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
|
||||
ZHIPUAI_DEFAULT_MAX_RETRIES = 3
|
||||
ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10)
|
||||
|
||||
|
||||
class HttpClient:
|
||||
_client: httpx.Client
|
||||
_version: str
|
||||
_base_url: URL
|
||||
|
||||
timeout: Union[float, Timeout, None]
|
||||
_limits: httpx.Limits
|
||||
_has_custom_http_client: bool
|
||||
_default_stream_cls: type[StreamResponse[Any]] | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
version: str,
|
||||
base_url: URL,
|
||||
timeout: Union[float, Timeout, None],
|
||||
custom_httpx_client: httpx.Client | None = None,
|
||||
custom_headers: Mapping[str, str] | None = None,
|
||||
) -> None:
|
||||
if timeout is None or isinstance(timeout, NotGiven):
|
||||
if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT:
|
||||
timeout = custom_httpx_client.timeout
|
||||
else:
|
||||
timeout = ZHIPUAI_DEFAULT_TIMEOUT
|
||||
self.timeout = cast(Timeout, timeout)
|
||||
self._has_custom_http_client = bool(custom_httpx_client)
|
||||
self._client = custom_httpx_client or httpx.Client(
|
||||
base_url=base_url,
|
||||
timeout=self.timeout,
|
||||
limits=ZHIPUAI_DEFAULT_LIMITS,
|
||||
)
|
||||
self._version = version
|
||||
url = URL(url=base_url)
|
||||
if not url.raw_path.endswith(b"/"):
|
||||
url = url.copy_with(raw_path=url.raw_path + b"/")
|
||||
self._base_url = url
|
||||
self._custom_headers = custom_headers or {}
|
||||
|
||||
def _prepare_url(self, url: str) -> URL:
|
||||
|
||||
sub_url = URL(url)
|
||||
if sub_url.is_relative_url:
|
||||
request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/")
|
||||
return self._base_url.copy_with(raw_path=request_raw_url)
|
||||
|
||||
return sub_url
|
||||
|
||||
@property
|
||||
def _default_headers(self):
|
||||
return \
|
||||
{
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json; charset=UTF-8",
|
||||
"ZhipuAI-SDK-Ver": self._version,
|
||||
"source_type": "zhipu-sdk-python",
|
||||
"x-request-sdk": "zhipu-sdk-python",
|
||||
**self._auth_headers,
|
||||
**self._custom_headers,
|
||||
}
|
||||
|
||||
@property
|
||||
def _auth_headers(self):
|
||||
return {}
|
||||
|
||||
def _prepare_headers(self, request_param: ClientRequestParam) -> httpx.Headers:
|
||||
custom_headers = request_param.headers or {}
|
||||
headers_dict = _merge_map(self._default_headers, custom_headers)
|
||||
|
||||
httpx_headers = httpx.Headers(headers_dict)
|
||||
|
||||
return httpx_headers
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
request_param: ClientRequestParam
|
||||
) -> httpx.Request:
|
||||
kwargs: dict[str, Any] = {}
|
||||
json_data = request_param.json_data
|
||||
headers = self._prepare_headers(request_param)
|
||||
url = self._prepare_url(request_param.url)
|
||||
json_data = request_param.json_data
|
||||
if headers.get("Content-Type") == "multipart/form-data":
|
||||
headers.pop("Content-Type")
|
||||
|
||||
if json_data:
|
||||
kwargs["data"] = self._make_multipartform(json_data)
|
||||
|
||||
return self._client.build_request(
|
||||
headers=headers,
|
||||
timeout=self.timeout if isinstance(request_param.timeout, NotGiven) else request_param.timeout,
|
||||
method=request_param.method,
|
||||
url=url,
|
||||
json=json_data,
|
||||
files=request_param.files,
|
||||
params=request_param.params,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _object_to_formfata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]:
|
||||
items = []
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
for k, v in value.items():
|
||||
items.extend(self._object_to_formfata(f"{key}[{k}]", v))
|
||||
return items
|
||||
if isinstance(value, (list, tuple)):
|
||||
for v in value:
|
||||
items.extend(self._object_to_formfata(key + "[]", v))
|
||||
return items
|
||||
|
||||
def _primitive_value_to_str(val) -> str:
|
||||
# copied from httpx
|
||||
if val is True:
|
||||
return "true"
|
||||
elif val is False:
|
||||
return "false"
|
||||
elif val is None:
|
||||
return ""
|
||||
return str(val)
|
||||
|
||||
str_data = _primitive_value_to_str(value)
|
||||
|
||||
if not str_data:
|
||||
return []
|
||||
return [(key, str_data)]
|
||||
|
||||
def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
|
||||
|
||||
items = flatten([self._object_to_formfata(k, v) for k, v in data.items()])
|
||||
|
||||
serialized: dict[str, object] = {}
|
||||
for key, value in items:
|
||||
if key in serialized:
|
||||
raise ValueError(f"存在重复的键: {key};")
|
||||
serialized[key] = value
|
||||
return serialized
|
||||
|
||||
def _parse_response(
|
||||
self,
|
||||
*,
|
||||
cast_type: Type[ResponseT],
|
||||
response: httpx.Response,
|
||||
enable_stream: bool,
|
||||
request_param: ClientRequestParam,
|
||||
stream_cls: type[StreamResponse[Any]] | None = None,
|
||||
) -> HttpResponse:
|
||||
|
||||
http_response = HttpResponse(
|
||||
raw_response=response,
|
||||
cast_type=cast_type,
|
||||
client=self,
|
||||
enable_stream=enable_stream,
|
||||
stream_cls=stream_cls
|
||||
)
|
||||
return http_response.parse()
|
||||
|
||||
def _process_response_data(
|
||||
self,
|
||||
*,
|
||||
data: object,
|
||||
cast_type: type[ResponseT],
|
||||
response: httpx.Response,
|
||||
) -> ResponseT:
|
||||
if data is None:
|
||||
return cast(ResponseT, None)
|
||||
|
||||
try:
|
||||
if inspect.isclass(cast_type) and issubclass(cast_type, pydantic.BaseModel):
|
||||
return cast(ResponseT, cast_type.validate(data))
|
||||
|
||||
return cast(ResponseT, pydantic.TypeAdapter(cast_type).validate_python(data))
|
||||
except pydantic.ValidationError as err:
|
||||
raise APIResponseValidationError(response=response, json_data=data) from err
|
||||
|
||||
def is_closed(self) -> bool:
|
||||
return self._client.is_closed
|
||||
|
||||
def close(self):
|
||||
self._client.close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def request(
|
||||
self,
|
||||
*,
|
||||
cast_type: Type[ResponseT],
|
||||
params: ClientRequestParam,
|
||||
enable_stream: bool = False,
|
||||
stream_cls: type[StreamResponse[Any]] | None = None,
|
||||
) -> ResponseT | StreamResponse:
|
||||
request = self._prepare_request(params)
|
||||
|
||||
try:
|
||||
response = self._client.send(
|
||||
request,
|
||||
stream=enable_stream,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.TimeoutException as err:
|
||||
raise APITimeoutError(request=request) from err
|
||||
except httpx.HTTPStatusError as err:
|
||||
err.response.read()
|
||||
# raise err
|
||||
raise self._make_status_error(err.response) from None
|
||||
|
||||
except Exception as err:
|
||||
raise err
|
||||
|
||||
return self._parse_response(
|
||||
cast_type=cast_type,
|
||||
request_param=params,
|
||||
response=response,
|
||||
enable_stream=enable_stream,
|
||||
stream_cls=stream_cls,
|
||||
)
|
||||
|
||||
def get(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
cast_type: Type[ResponseT],
|
||||
options: UserRequestInput = {},
|
||||
enable_stream: bool = False,
|
||||
) -> ResponseT | StreamResponse:
|
||||
opts = ClientRequestParam.construct(method="get", url=path, **options)
|
||||
return self.request(
|
||||
cast_type=cast_type, params=opts,
|
||||
enable_stream=enable_stream
|
||||
)
|
||||
|
||||
def post(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
body: Body | None = None,
|
||||
cast_type: Type[ResponseT],
|
||||
options: UserRequestInput = {},
|
||||
files: RequestFiles | None = None,
|
||||
enable_stream: bool = False,
|
||||
stream_cls: type[StreamResponse[Any]] | None = None,
|
||||
) -> ResponseT | StreamResponse:
|
||||
opts = ClientRequestParam.construct(method="post", json_data=body, files=make_httpx_files(files), url=path,
|
||||
**options)
|
||||
|
||||
return self.request(
|
||||
cast_type=cast_type, params=opts,
|
||||
enable_stream=enable_stream,
|
||||
stream_cls=stream_cls
|
||||
)
|
||||
|
||||
def patch(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
body: Body | None = None,
|
||||
cast_type: Type[ResponseT],
|
||||
options: UserRequestInput = {},
|
||||
) -> ResponseT:
|
||||
opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options)
|
||||
|
||||
return self.request(
|
||||
cast_type=cast_type, params=opts,
|
||||
)
|
||||
|
||||
def put(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
body: Body | None = None,
|
||||
cast_type: Type[ResponseT],
|
||||
options: UserRequestInput = {},
|
||||
files: RequestFiles | None = None,
|
||||
) -> ResponseT | StreamResponse:
|
||||
opts = ClientRequestParam.construct(method="put", url=path, json_data=body, files=make_httpx_files(files),
|
||||
**options)
|
||||
|
||||
return self.request(
|
||||
cast_type=cast_type, params=opts,
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
path: str,
|
||||
*,
|
||||
body: Body | None = None,
|
||||
cast_type: Type[ResponseT],
|
||||
options: UserRequestInput = {},
|
||||
) -> ResponseT | StreamResponse:
|
||||
opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options)
|
||||
|
||||
return self.request(
|
||||
cast_type=cast_type, params=opts,
|
||||
)
|
||||
|
||||
def _make_status_error(self, response) -> APIStatusError:
|
||||
response_text = response.text.strip()
|
||||
status_code = response.status_code
|
||||
error_msg = f"Error code: {status_code}, with error text {response_text}"
|
||||
|
||||
if status_code == 400:
|
||||
return _errors.APIRequestFailedError(message=error_msg, response=response)
|
||||
elif status_code == 401:
|
||||
return _errors.APIAuthenticationError(message=error_msg, response=response)
|
||||
elif status_code == 429:
|
||||
return _errors.APIReachLimitError(message=error_msg, response=response)
|
||||
elif status_code == 500:
|
||||
return _errors.APIInternalError(message=error_msg, response=response)
|
||||
elif status_code == 503:
|
||||
return _errors.APIServerFlowExceedError(message=error_msg, response=response)
|
||||
return APIStatusError(message=error_msg, response=response)
|
||||
|
||||
|
||||
def make_user_request_input(
|
||||
max_retries: int | None = None,
|
||||
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers = None,
|
||||
query: Query | None = None,
|
||||
) -> UserRequestInput:
|
||||
options: UserRequestInput = {}
|
||||
|
||||
if extra_headers is not None:
|
||||
options["headers"] = extra_headers
|
||||
if max_retries is not None:
|
||||
options["max_retries"] = max_retries
|
||||
if not isinstance(timeout, NotGiven):
|
||||
options['timeout'] = timeout
|
||||
if query is not None:
|
||||
options["params"] = query
|
||||
|
||||
return options
|
@ -0,0 +1,30 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import time
|
||||
|
||||
import cachetools.func
|
||||
import jwt
|
||||
|
||||
API_TOKEN_TTL_SECONDS = 3 * 60
|
||||
|
||||
CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30
|
||||
|
||||
|
||||
@cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS)
|
||||
def generate_token(apikey: str):
|
||||
try:
|
||||
api_key, secret = apikey.split(".")
|
||||
except Exception as e:
|
||||
raise Exception("invalid api_key", e)
|
||||
|
||||
payload = {
|
||||
"api_key": api_key,
|
||||
"exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000,
|
||||
"timestamp": int(round(time.time() * 1000)),
|
||||
}
|
||||
ret = jwt.encode(
|
||||
payload,
|
||||
secret,
|
||||
algorithm="HS256",
|
||||
headers={"alg": "HS256", "sign_type": "SIGN"},
|
||||
)
|
||||
return ret
|
@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union, Any, cast
|
||||
|
||||
import pydantic.generics
|
||||
from httpx import Timeout
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import (
|
||||
Unpack, ClassVar, TypedDict
|
||||
)
|
||||
|
||||
from ._base_type import Body, NotGiven, Headers, HttpxRequestFiles, Query
|
||||
from ._utils import remove_notgiven_indict
|
||||
|
||||
|
||||
class UserRequestInput(TypedDict, total=False):
|
||||
max_retries: int
|
||||
timeout: float | Timeout | None
|
||||
headers: Headers
|
||||
params: Query | None
|
||||
|
||||
|
||||
class ClientRequestParam():
|
||||
method: str
|
||||
url: str
|
||||
max_retries: Union[int, NotGiven] = NotGiven()
|
||||
timeout: Union[float, NotGiven] = NotGiven()
|
||||
headers: Union[Headers, NotGiven] = NotGiven()
|
||||
json_data: Union[Body, None] = None
|
||||
files: Union[HttpxRequestFiles, None] = None
|
||||
params: Query = {}
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def get_max_retries(self, max_retries) -> int:
|
||||
if isinstance(self.max_retries, NotGiven):
|
||||
return max_retries
|
||||
return self.max_retries
|
||||
|
||||
@classmethod
|
||||
def construct( # type: ignore
|
||||
cls,
|
||||
_fields_set: set[str] | None = None,
|
||||
**values: Unpack[UserRequestInput],
|
||||
) -> ClientRequestParam :
|
||||
kwargs: dict[str, Any] = {
|
||||
key: remove_notgiven_indict(value) for key, value in values.items()
|
||||
}
|
||||
client = cls()
|
||||
client.__dict__.update(kwargs)
|
||||
|
||||
return client
|
||||
|
||||
model_construct = construct
|
||||
|
@ -0,0 +1,121 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from typing import TypeVar, Generic, cast, Any, TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
import pydantic
|
||||
from typing_extensions import ParamSpec, get_origin, get_args
|
||||
|
||||
from ._base_type import NoneType
|
||||
from ._sse_client import StreamResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._http_client import HttpClient
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class HttpResponse(Generic[R]):
|
||||
_cast_type: type[R]
|
||||
_client: "HttpClient"
|
||||
_parsed: R | None
|
||||
_enable_stream: bool
|
||||
_stream_cls: type[StreamResponse[Any]]
|
||||
http_response: httpx.Response
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
raw_response: httpx.Response,
|
||||
cast_type: type[R],
|
||||
client: "HttpClient",
|
||||
enable_stream: bool = False,
|
||||
stream_cls: type[StreamResponse[Any]] | None = None,
|
||||
) -> None:
|
||||
self._cast_type = cast_type
|
||||
self._client = client
|
||||
self._parsed = None
|
||||
self._stream_cls = stream_cls
|
||||
self._enable_stream = enable_stream
|
||||
self.http_response = raw_response
|
||||
|
||||
def parse(self) -> R:
|
||||
self._parsed = self._parse()
|
||||
return self._parsed
|
||||
|
||||
def _parse(self) -> R:
|
||||
if self._enable_stream:
|
||||
self._parsed = cast(
|
||||
R,
|
||||
self._stream_cls(
|
||||
cast_type=cast(type, get_args(self._stream_cls)[0]),
|
||||
response=self.http_response,
|
||||
client=self._client
|
||||
)
|
||||
)
|
||||
return self._parsed
|
||||
cast_type = self._cast_type
|
||||
if cast_type is NoneType:
|
||||
return cast(R, None)
|
||||
http_response = self.http_response
|
||||
if cast_type == str:
|
||||
return cast(R, http_response.text)
|
||||
|
||||
content_type, *_ = http_response.headers.get("content-type", "application/json").split(";")
|
||||
origin = get_origin(cast_type) or cast_type
|
||||
if content_type != "application/json":
|
||||
if issubclass(origin, pydantic.BaseModel):
|
||||
data = http_response.json()
|
||||
return self._client._process_response_data(
|
||||
data=data,
|
||||
cast_type=cast_type, # type: ignore
|
||||
response=http_response,
|
||||
)
|
||||
|
||||
return http_response.text
|
||||
|
||||
data = http_response.json()
|
||||
|
||||
return self._client._process_response_data(
|
||||
data=data,
|
||||
cast_type=cast_type, # type: ignore
|
||||
response=http_response,
|
||||
)
|
||||
|
||||
@property
|
||||
def headers(self) -> httpx.Headers:
|
||||
return self.http_response.headers
|
||||
|
||||
@property
|
||||
def http_request(self) -> httpx.Request:
|
||||
return self.http_response.request
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
return self.http_response.status_code
|
||||
|
||||
@property
|
||||
def url(self) -> httpx.URL:
|
||||
return self.http_response.url
|
||||
|
||||
@property
|
||||
def method(self) -> str:
|
||||
return self.http_request.method
|
||||
|
||||
@property
|
||||
def content(self) -> bytes:
|
||||
return self.http_response.content
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return self.http_response.text
|
||||
|
||||
@property
|
||||
def http_version(self) -> str:
|
||||
return self.http_response.http_version
|
||||
|
||||
@property
|
||||
def elapsed(self) -> datetime.timedelta:
|
||||
return self.http_response.elapsed
|
@ -0,0 +1,149 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Generic, Iterator, TYPE_CHECKING, Mapping
|
||||
|
||||
import httpx
|
||||
|
||||
from ._base_type import ResponseT
|
||||
from ._errors import APIResponseError
|
||||
|
||||
_FIELD_SEPARATOR = ":"
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._http_client import HttpClient
|
||||
|
||||
|
||||
class StreamResponse(Generic[ResponseT]):
|
||||
|
||||
response: httpx.Response
|
||||
_cast_type: type[ResponseT]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
cast_type: type[ResponseT],
|
||||
response: httpx.Response,
|
||||
client: HttpClient,
|
||||
) -> None:
|
||||
self.response = response
|
||||
self._cast_type = cast_type
|
||||
self._data_process_func = client._process_response_data
|
||||
self._stream_chunks = self.__stream__()
|
||||
|
||||
def __next__(self) -> ResponseT:
|
||||
return self._stream_chunks.__next__()
|
||||
|
||||
def __iter__(self) -> Iterator[ResponseT]:
|
||||
for item in self._stream_chunks:
|
||||
yield item
|
||||
|
||||
def __stream__(self) -> Iterator[ResponseT]:
|
||||
|
||||
sse_line_parser = SSELineParser()
|
||||
iterator = sse_line_parser.iter_lines(self.response.iter_lines())
|
||||
|
||||
for sse in iterator:
|
||||
if sse.data.startswith("[DONE]"):
|
||||
break
|
||||
|
||||
if sse.event is None:
|
||||
data = sse.json_data()
|
||||
if isinstance(data, Mapping) and data.get("error"):
|
||||
raise APIResponseError(
|
||||
message="An error occurred during streaming",
|
||||
request=self.response.request,
|
||||
json_data=data["error"],
|
||||
)
|
||||
|
||||
yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response)
|
||||
for sse in iterator:
|
||||
pass
|
||||
|
||||
|
||||
class Event(object):
|
||||
def __init__(
|
||||
self,
|
||||
event: str | None = None,
|
||||
data: str | None = None,
|
||||
id: str | None = None,
|
||||
retry: int | None = None
|
||||
):
|
||||
self._event = event
|
||||
self._data = data
|
||||
self._id = id
|
||||
self._retry = retry
|
||||
|
||||
def __repr__(self):
|
||||
data_len = len(self._data) if self._data else 0
|
||||
return f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}"
|
||||
|
||||
@property
|
||||
def event(self): return self._event
|
||||
|
||||
@property
|
||||
def data(self): return self._data
|
||||
|
||||
def json_data(self): return json.loads(self._data)
|
||||
|
||||
@property
|
||||
def id(self): return self._id
|
||||
|
||||
@property
|
||||
def retry(self): return self._retry
|
||||
|
||||
|
||||
class SSELineParser:
|
||||
_data: list[str]
|
||||
_event: str | None
|
||||
_retry: int | None
|
||||
_id: str | None
|
||||
|
||||
def __init__(self):
|
||||
self._event = None
|
||||
self._data = []
|
||||
self._id = None
|
||||
self._retry = None
|
||||
|
||||
def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]:
|
||||
for line in lines:
|
||||
line = line.rstrip('\n')
|
||||
if not line:
|
||||
if self._event is None and \
|
||||
not self._data and \
|
||||
self._id is None and \
|
||||
self._retry is None:
|
||||
continue
|
||||
sse_event = Event(
|
||||
event=self._event,
|
||||
data='\n'.join(self._data),
|
||||
id=self._id,
|
||||
retry=self._retry
|
||||
)
|
||||
self._event = None
|
||||
self._data = []
|
||||
self._id = None
|
||||
self._retry = None
|
||||
|
||||
yield sse_event
|
||||
self.decode_line(line)
|
||||
|
||||
def decode_line(self, line: str):
|
||||
if line.startswith(":") or not line:
|
||||
return
|
||||
|
||||
field, _p, value = line.partition(":")
|
||||
|
||||
if value.startswith(' '):
|
||||
value = value[1:]
|
||||
if field == "data":
|
||||
self._data.append(value)
|
||||
elif field == "event":
|
||||
self._event = value
|
||||
elif field == "retry":
|
||||
try:
|
||||
self._retry = int(value)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return
|
@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Mapping, Iterable, TypeVar
|
||||
|
||||
from ._base_type import NotGiven
|
||||
|
||||
|
||||
def remove_notgiven_indict(obj):
|
||||
if obj is None or (not isinstance(obj, Mapping)):
|
||||
return obj
|
||||
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
|
||||
return [item for sublist in t for item in sublist]
|
@ -0,0 +1,23 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .chat_completion import CompletionChoice, CompletionUsage
|
||||
|
||||
__all__ = ["AsyncTaskStatus"]
|
||||
|
||||
|
||||
class AsyncTaskStatus(BaseModel):
|
||||
id: Optional[str] = None
|
||||
request_id: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
task_status: Optional[str] = None
|
||||
|
||||
|
||||
class AsyncCompletion(BaseModel):
|
||||
id: Optional[str] = None
|
||||
request_id: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
task_status: str
|
||||
choices: List[CompletionChoice]
|
||||
usage: CompletionUsage
|
@ -0,0 +1,45 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = ["Completion", "CompletionUsage"]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
arguments: str
|
||||
name: str
|
||||
|
||||
|
||||
class CompletionMessageToolCall(BaseModel):
|
||||
id: str
|
||||
function: Function
|
||||
type: str
|
||||
|
||||
|
||||
class CompletionMessage(BaseModel):
|
||||
content: Optional[str] = None
|
||||
role: str
|
||||
tool_calls: Optional[List[CompletionMessageToolCall]] = None
|
||||
|
||||
|
||||
class CompletionUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class CompletionChoice(BaseModel):
|
||||
index: int
|
||||
finish_reason: str
|
||||
message: CompletionMessage
|
||||
|
||||
|
||||
class Completion(BaseModel):
|
||||
model: Optional[str] = None
|
||||
created: Optional[int] = None
|
||||
choices: List[CompletionChoice]
|
||||
request_id: Optional[str] = None
|
||||
id: Optional[str] = None
|
||||
usage: CompletionUsage
|
||||
|
||||
|
@ -0,0 +1,55 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = [
|
||||
"ChatCompletionChunk",
|
||||
"Choice",
|
||||
"ChoiceDelta",
|
||||
"ChoiceDeltaFunctionCall",
|
||||
"ChoiceDeltaToolCall",
|
||||
"ChoiceDeltaToolCallFunction",
|
||||
]
|
||||
|
||||
|
||||
class ChoiceDeltaFunctionCall(BaseModel):
|
||||
arguments: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ChoiceDeltaToolCallFunction(BaseModel):
|
||||
arguments: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ChoiceDeltaToolCall(BaseModel):
|
||||
index: int
|
||||
id: Optional[str] = None
|
||||
function: Optional[ChoiceDeltaToolCallFunction] = None
|
||||
type: Optional[str] = None
|
||||
|
||||
|
||||
class ChoiceDelta(BaseModel):
|
||||
content: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
delta: ChoiceDelta
|
||||
finish_reason: Optional[str] = None
|
||||
index: int
|
||||
|
||||
|
||||
class CompletionUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ChatCompletionChunk(BaseModel):
|
||||
id: Optional[str] = None
|
||||
choices: List[Choice]
|
||||
created: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
usage: Optional[CompletionUsage] = None
|
@ -0,0 +1,8 @@
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class Reference(TypedDict, total=False):
|
||||
enable: Optional[bool]
|
||||
search_query: Optional[str]
|
@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from .chat.chat_completion import CompletionUsage
|
||||
__all__ = ["Embedding", "EmbeddingsResponded"]
|
||||
|
||||
|
||||
class Embedding(BaseModel):
|
||||
object: str
|
||||
index: Optional[int] = None
|
||||
embedding: List[float]
|
||||
|
||||
|
||||
class EmbeddingsResponded(BaseModel):
|
||||
object: str
|
||||
data: List[Embedding]
|
||||
model: str
|
||||
usage: CompletionUsage
|
@ -0,0 +1,24 @@
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = ["FileObject"]
|
||||
|
||||
|
||||
class FileObject(BaseModel):
|
||||
|
||||
id: Optional[str] = None
|
||||
bytes: Optional[int] = None
|
||||
created_at: Optional[int] = None
|
||||
filename: Optional[str] = None
|
||||
object: Optional[str] = None
|
||||
purpose: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
status_details: Optional[str] = None
|
||||
|
||||
|
||||
class ListOfFileObject(BaseModel):
|
||||
|
||||
object: Optional[str] = None
|
||||
data: List[FileObject]
|
||||
has_more: Optional[bool] = None
|
@ -0,0 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .fine_tuning_job import FineTuningJob as FineTuningJob
|
||||
from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob
|
||||
from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent
|
@ -0,0 +1,52 @@
|
||||
from typing import List, Union, Optional
|
||||
from typing_extensions import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ]
|
||||
|
||||
|
||||
class Error(BaseModel):
|
||||
code: str
|
||||
message: str
|
||||
param: Optional[str] = None
|
||||
|
||||
|
||||
class Hyperparameters(BaseModel):
|
||||
n_epochs: Union[str, int, None] = None
|
||||
|
||||
|
||||
class FineTuningJob(BaseModel):
|
||||
id: Optional[str] = None
|
||||
|
||||
request_id: Optional[str] = None
|
||||
|
||||
created_at: Optional[int] = None
|
||||
|
||||
error: Optional[Error] = None
|
||||
|
||||
fine_tuned_model: Optional[str] = None
|
||||
|
||||
finished_at: Optional[int] = None
|
||||
|
||||
hyperparameters: Optional[Hyperparameters] = None
|
||||
|
||||
model: Optional[str] = None
|
||||
|
||||
object: Optional[str] = None
|
||||
|
||||
result_files: List[str]
|
||||
|
||||
status: str
|
||||
|
||||
trained_tokens: Optional[int] = None
|
||||
|
||||
training_file: str
|
||||
|
||||
validation_file: Optional[str] = None
|
||||
|
||||
|
||||
class ListOfFineTuningJob(BaseModel):
|
||||
object: Optional[str] = None
|
||||
data: List[FineTuningJob]
|
||||
has_more: Optional[bool] = None
|
@ -0,0 +1,36 @@
|
||||
from typing import List, Union, Optional
|
||||
from typing_extensions import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = ["FineTuningJobEvent", "Metric", "JobEvent"]
|
||||
|
||||
|
||||
class Metric(BaseModel):
|
||||
epoch: Optional[Union[str, int, float]] = None
|
||||
current_steps: Optional[int] = None
|
||||
total_steps: Optional[int] = None
|
||||
elapsed_time: Optional[str] = None
|
||||
remaining_time: Optional[str] = None
|
||||
trained_tokens: Optional[int] = None
|
||||
loss: Optional[Union[str, int, float]] = None
|
||||
eval_loss: Optional[Union[str, int, float]] = None
|
||||
acc: Optional[Union[str, int, float]] = None
|
||||
eval_acc: Optional[Union[str, int, float]] = None
|
||||
learning_rate: Optional[Union[str, int, float]] = None
|
||||
|
||||
|
||||
class JobEvent(BaseModel):
|
||||
object: Optional[str] = None
|
||||
id: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
created_at: Optional[int] = None
|
||||
level: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
data: Optional[Metric] = None
|
||||
|
||||
|
||||
class FineTuningJobEvent(BaseModel):
|
||||
object: Optional[str] = None
|
||||
data: List[JobEvent]
|
||||
has_more: Optional[bool] = None
|
@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union
|
||||
|
||||
from typing_extensions import Literal, TypedDict
|
||||
|
||||
__all__ = ["Hyperparameters"]
|
||||
|
||||
|
||||
class Hyperparameters(TypedDict, total=False):
|
||||
batch_size: Union[Literal["auto"], int]
|
||||
|
||||
learning_rate_multiplier: Union[Literal["auto"], float]
|
||||
|
||||
n_epochs: Union[Literal["auto"], int]
|
@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
__all__ = ["GeneratedImage", "ImagesResponded"]
|
||||
|
||||
|
||||
class GeneratedImage(BaseModel):
|
||||
b64_json: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
revised_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class ImagesResponded(BaseModel):
|
||||
created: int
|
||||
data: List[GeneratedImage]
|
@ -3,7 +3,8 @@ from typing import Generator
|
||||
|
||||
import pytest
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, SystemPromptMessage,
|
||||
UserPromptMessage, PromptMessageTool)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.zhipuai.llm.llm import ZhipuAILargeLanguageModel
|
||||
|
||||
@ -102,3 +103,48 @@ def test_get_num_tokens():
|
||||
)
|
||||
|
||||
assert num_tokens == 14
|
||||
|
||||
def test_get_tools_num_tokens():
|
||||
model = ZhipuAILargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='tools',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
},
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_current_weather',
|
||||
description='Get the current weather in a given location',
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
)
|
||||
],
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 108
|
@ -42,7 +42,7 @@ def test_invoke_model():
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 2
|
||||
assert result.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
|
@ -229,7 +229,7 @@ const Answer: FC<IAnswerProps> = ({
|
||||
<Thought
|
||||
thought={item}
|
||||
allToolIcons={allToolIcons || {}}
|
||||
isFinished={!!item.observation}
|
||||
isFinished={!!item.observation || !isResponsing}
|
||||
/>
|
||||
)}
|
||||
|
||||
|
@ -43,7 +43,7 @@ import { fetchDatasets } from '@/service/datasets'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { AgentStrategy, AppType, ModelModeType, RETRIEVE_TYPE, Resolution, TransferMethod } from '@/types/app'
|
||||
import { PromptMode } from '@/models/debug'
|
||||
import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config'
|
||||
import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG, supportFunctionCallModels } from '@/config'
|
||||
import SelectDataSet from '@/app/components/app/configuration/dataset-config/select-dataset'
|
||||
import I18n from '@/context/i18n'
|
||||
import { useModalContext } from '@/context/modal-context'
|
||||
@ -163,8 +163,7 @@ const Configuration: FC = () => {
|
||||
doSetModelConfig(newModelConfig)
|
||||
}
|
||||
const isOpenAI = modelConfig.provider === 'openai'
|
||||
const isFunctionCall = isOpenAI && modelConfig.mode === ModelModeType.chat
|
||||
|
||||
const isFunctionCall = (isOpenAI && modelConfig.mode === ModelModeType.chat) || supportFunctionCallModels.includes(modelConfig.model_id)
|
||||
const [collectionList, setCollectionList] = useState<Collection[]>([])
|
||||
useEffect(() => {
|
||||
|
||||
|
@ -160,6 +160,8 @@ export const DEFAULT_AGENT_SETTING = {
|
||||
tools: [],
|
||||
}
|
||||
|
||||
export const supportFunctionCallModels = ['glm-3-turbo', 'glm-4']
|
||||
|
||||
export const DEFAULT_AGENT_PROMPT = {
|
||||
chat: `Respond to the human as helpfully and accurately as possible.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user