Feat/zhipuai function calling (#2199)

Co-authored-by: Joel <iamjoel007@gmail.com>
This commit is contained in:
Yeuoly 2024-01-25 16:29:35 +08:00 committed by GitHub
parent bdc5e9ceb0
commit b921c55677
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
46 changed files with 2115 additions and 138 deletions

View File

@ -1,61 +0,0 @@
"""Wrapper around ZhipuAI APIs."""
from __future__ import annotations
import logging
import posixpath
from pydantic import BaseModel, Extra
from zhipuai.model_api.api import InvokeType
from zhipuai.utils import jwt_token
from zhipuai.utils.http_client import post, stream
from zhipuai.utils.sse_client import SSEClient
logger = logging.getLogger(__name__)
class ZhipuModelAPI(BaseModel):
base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api"
api_key: str
api_timeout_seconds = 60
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def invoke(self, **kwargs):
url = self._build_api_url(kwargs, InvokeType.SYNC)
response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds)
if not response['success']:
raise ValueError(
f"Error Code: {response['code']}, Message: {response['msg']} "
)
return response
def sse_invoke(self, **kwargs):
url = self._build_api_url(kwargs, InvokeType.SSE)
data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds)
return SSEClient(data)
def _build_api_url(self, kwargs, *path):
if kwargs:
if "model" not in kwargs:
raise Exception("model param missed")
model = kwargs.pop("model")
else:
model = "-"
return posixpath.join(self.base_url, model, *path)
def _generate_token(self):
if not self.api_key:
raise Exception(
"api_key not provided, you could provide it."
)
try:
return jwt_token.generate_token(self.api_key)
except Exception:
raise ValueError(
f"Your api_key is invalid, please check it."
)

View File

@ -3,13 +3,15 @@ from typing import Any, Dict, Generator, List, Optional, Union
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole,
PromptMessageTool, SystemPromptMessage, UserPromptMessage,
PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage,
TextPromptMessageContent, ImagePromptMessageContent, PromptMessageContentType)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.utils import helper
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
@ -35,7 +37,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
credentials_kwargs = self._to_credential_kwargs(credentials)
# invoke model
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, stop, stream, user)
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user)
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
@ -48,7 +50,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
:param tools: tools for tool calling
:return:
"""
prompt = self._convert_messages_to_prompt(prompt_messages)
prompt = self._convert_messages_to_prompt(prompt_messages, tools)
return self._get_num_tokens_by_gpt2(prompt)
@ -72,6 +74,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
model_parameters={
"temperature": 0.5,
},
tools=[],
stream=False
)
except Exception as ex:
@ -79,6 +82,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
def _generate(self, model: str, credentials_kwargs: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[List[str]] = None, stream: bool = True,
user: Optional[str] = None) -> Union[LLMResult, Generator]:
"""
@ -97,7 +101,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
if stop:
extra_model_kwargs['stop_sequences'] = stop
client = ZhipuModelAPI(
client = ZhipuAI(
api_key=credentials_kwargs['api_key']
)
@ -128,11 +132,17 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
# not support image message
continue
if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER:
if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \
copy_prompt_message.role == PromptMessageRole.USER:
new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content
else:
if copy_prompt_message.role == PromptMessageRole.USER:
new_prompt_messages.append(copy_prompt_message)
elif copy_prompt_message.role == PromptMessageRole.TOOL:
new_prompt_messages.append(copy_prompt_message)
elif copy_prompt_message.role == PromptMessageRole.SYSTEM:
new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content)
new_prompt_messages.append(new_prompt_message)
else:
new_prompt_message = UserPromptMessage(content=copy_prompt_message.content)
new_prompt_messages.append(new_prompt_message)
@ -145,7 +155,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
if model == 'glm-4v':
params = {
'model': model,
'prompt': [{
'messages': [{
'role': prompt_message.role.value,
'content':
[
@ -171,23 +181,63 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
else:
params = {
'model': model,
'prompt': [{
'role': prompt_message.role.value,
'content': prompt_message.content,
} for prompt_message in new_prompt_messages],
'messages': [],
**model_parameters
}
# glm model
if not model.startswith('chatglm'):
for prompt_message in new_prompt_messages:
if prompt_message.role == PromptMessageRole.TOOL:
params['messages'].append({
'role': 'tool',
'content': prompt_message.content,
'tool_call_id': prompt_message.tool_call_id
})
else:
params['messages'].append({
'role': prompt_message.role.value,
'content': prompt_message.content
})
else:
# chatglm model
for prompt_message in new_prompt_messages:
# merge system message to user message
if prompt_message.role == PromptMessageRole.SYSTEM or \
prompt_message.role == PromptMessageRole.TOOL or \
prompt_message.role == PromptMessageRole.USER:
if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user':
params['messages'][-1]['content'] += "\n\n" + prompt_message.content
else:
params['messages'].append({
'role': 'user',
'content': prompt_message.content
})
else:
params['messages'].append({
'role': prompt_message.role.value,
'content': prompt_message.content
})
if tools and len(tools) > 0:
params['tools'] = [
{
'type': 'function',
'function': helper.dump_model(tool)
} for tool in tools
]
if stream:
response = client.sse_invoke(incremental=True, **params).events()
return self._handle_generate_stream_response(model, credentials_kwargs, response, prompt_messages)
response = client.chat.completions.create(stream=stream, **params)
return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages)
response = client.invoke(**params)
return self._handle_generate_response(model, credentials_kwargs, response, prompt_messages)
response = client.chat.completions.create(**params)
return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
def _handle_generate_response(self, model: str,
credentials: dict,
response: Dict[str, Any],
tools: Optional[list[PromptMessageTool]],
response: Completion,
prompt_messages: list[PromptMessage]) -> LLMResult:
"""
Handle llm response
@ -197,26 +247,39 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
:param prompt_messages: prompt messages
:return: llm response
"""
data = response["data"]
text = ''
for res in data["choices"]:
text += res['content']
assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = []
for choice in response.choices:
if choice.message.tool_calls:
for tool_call in choice.message.tool_calls:
if tool_call.type == 'function':
assistant_tool_calls.append(
AssistantPromptMessage.ToolCall(
id=tool_call.id,
type=tool_call.type,
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
)
)
text += choice.message.content or ''
token_usage = data.get("usage")
if token_usage is not None:
if 'prompt_tokens' not in token_usage:
token_usage['prompt_tokens'] = 0
if 'completion_tokens' not in token_usage:
token_usage['completion_tokens'] = token_usage['total_tokens']
prompt_usage = response.usage.prompt_tokens
completion_usage = response.usage.completion_tokens
# transform usage
usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens'])
usage = self._calc_response_usage(model, credentials, prompt_usage, completion_usage)
# transform response
result = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text),
message=AssistantPromptMessage(
content=text,
tool_calls=assistant_tool_calls
),
usage=usage,
)
@ -224,7 +287,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
def _handle_generate_stream_response(self, model: str,
credentials: dict,
responses: list[Generator],
tools: Optional[list[PromptMessageTool]],
responses: Generator[ChatCompletionChunk, None, None],
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
@ -234,39 +298,64 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
:param prompt_messages: prompt messages
:return: llm response chunk generator result
"""
for index, event in enumerate(responses):
if event.event == "add":
full_assistant_content = ''
for chunk in responses:
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0]
if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''):
continue
assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = []
for tool_call in delta.delta.tool_calls or []:
if tool_call.type == 'function':
assistant_tool_calls.append(
AssistantPromptMessage.ToolCall(
id=tool_call.id,
type=tool_call.type,
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
)
)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.delta.content if delta.delta.content else '',
tool_calls=assistant_tool_calls
)
full_assistant_content += delta.delta.content if delta.delta.content else ''
if delta.finish_reason is not None and chunk.usage is not None:
completion_tokens = chunk.usage.completion_tokens
prompt_tokens = chunk.usage.prompt_tokens
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=chunk.model,
prompt_messages=prompt_messages,
model=model,
system_fingerprint='',
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=event.data)
index=delta.index,
message=assistant_prompt_message,
finish_reason=delta.finish_reason,
usage=usage
)
)
elif event.event == "error" or event.event == "interrupted":
raise ValueError(
f"{event.data}"
)
elif event.event == "finish":
meta = json.loads(event.meta)
token_usage = meta['usage']
if token_usage is not None:
if 'prompt_tokens' not in token_usage:
token_usage['prompt_tokens'] = 0
if 'completion_tokens' not in token_usage:
token_usage['completion_tokens'] = token_usage['total_tokens']
usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens'])
else:
yield LLMResultChunk(
model=model,
model=chunk.model,
prompt_messages=prompt_messages,
system_fingerprint='',
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=event.data),
finish_reason='finish',
usage=usage
index=delta.index,
message=assistant_prompt_message,
)
)
@ -291,11 +380,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
raise ValueError(f"Got unknown type {message}")
return message_text
def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str:
"""
Format a list of messages into a full prompt for the Anthropic model
def _convert_messages_to_prompt(self, messages: List[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str:
"""
:param messages: List of PromptMessage to combine.
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
@ -306,5 +394,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
for message in messages
)
if tools and len(tools) > 0:
text += "\n\nTools:"
for tool in tools:
text += f"\n{tool.json()}"
# trim off the trailing ' ' that might come from the "Assistant: "
return text.rstrip()
return text.rstrip()

View File

@ -5,7 +5,7 @@ from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI
from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI
from langchain.schema.language_model import _get_token_ids_default_method
@ -28,7 +28,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
:return: embeddings result
"""
credentials_kwargs = self._to_credential_kwargs(credentials)
client = ZhipuModelAPI(
client = ZhipuAI(
api_key=credentials_kwargs['api_key']
)
@ -69,7 +69,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
try:
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
client = ZhipuModelAPI(
client = ZhipuAI(
api_key=credentials_kwargs['api_key']
)
@ -82,7 +82,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def embed_documents(self, model: str, client: ZhipuModelAPI, texts: List[str]) -> Tuple[List[List[float]], int]:
def embed_documents(self, model: str, client: ZhipuAI, texts: List[str]) -> Tuple[List[List[float]], int]:
"""Call out to ZhipuAI's embedding endpoint.
Args:
@ -91,17 +91,16 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel):
Returns:
List of embeddings, one for each text.
"""
embeddings = []
embedding_used_tokens = 0
for text in texts:
response = client.invoke(model=model, prompt=text)
data = response["data"]
embeddings.append(data.get('embedding'))
response = client.embeddings.create(model=model, input=text)
data = response.data[0]
embeddings.append(data.embedding)
embedding_used_tokens += response.usage.total_tokens
embedding_used_tokens = data.get('usage')
return [list(map(float, e)) for e in embeddings], embedding_used_tokens['total_tokens'] if embedding_used_tokens else 0
return [list(map(float, e)) for e in embeddings], embedding_used_tokens
def embed_query(self, text: str) -> List[float]:
"""Call out to ZhipuAI's embedding endpoint.

View File

@ -0,0 +1,17 @@
from ._client import ZhipuAI
from .core._errors import (
ZhipuAIError,
APIStatusError,
APIRequestFailedError,
APIAuthenticationError,
APIReachLimitError,
APIInternalError,
APIServerFlowExceedError,
APIResponseError,
APIResponseValidationError,
APITimeoutError,
)
from .__version__ import __version__

View File

@ -0,0 +1,2 @@
__version__ = 'v2.0.1'

View File

@ -0,0 +1,71 @@
from __future__ import annotations
from typing import Union, Mapping
from typing_extensions import override
from .core import _jwt_token
from .core._errors import ZhipuAIError
from .core._http_client import HttpClient, ZHIPUAI_DEFAULT_MAX_RETRIES
from .core._base_type import NotGiven, NOT_GIVEN
from . import api_resource
import os
import httpx
from httpx import Timeout
class ZhipuAI(HttpClient):
chat: api_resource.chat
api_key: str
def __init__(
self,
*,
api_key: str | None = None,
base_url: str | httpx.URL | None = None,
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES,
http_client: httpx.Client | None = None,
custom_headers: Mapping[str, str] | None = None
) -> None:
# if api_key is None:
# api_key = os.environ.get("ZHIPUAI_API_KEY")
if api_key is None:
raise ZhipuAIError("未提供api_key请通过参数或环境变量提供")
self.api_key = api_key
if base_url is None:
base_url = os.environ.get("ZHIPUAI_BASE_URL")
if base_url is None:
base_url = f"https://open.bigmodel.cn/api/paas/v4"
from .__version__ import __version__
super().__init__(
version=__version__,
base_url=base_url,
timeout=timeout,
custom_httpx_client=http_client,
custom_headers=custom_headers,
)
self.chat = api_resource.chat.Chat(self)
self.images = api_resource.images.Images(self)
self.embeddings = api_resource.embeddings.Embeddings(self)
self.files = api_resource.files.Files(self)
self.fine_tuning = api_resource.fine_tuning.FineTuning(self)
@property
@override
def _auth_headers(self) -> dict[str, str]:
api_key = self.api_key
return {"Authorization": f"{_jwt_token.generate_token(api_key)}"}
def __del__(self) -> None:
if (not hasattr(self, "_has_custom_http_client")
or not hasattr(self, "close")
or not hasattr(self, "_client")):
# if the '__init__' method raised an error, self would not have client attr
return
if self._has_custom_http_client:
return
self.close()

View File

@ -0,0 +1,5 @@
from .chat import chat
from .images import Images
from .embeddings import Embeddings
from .files import Files
from .fine_tuning import fine_tuning

View File

@ -0,0 +1,87 @@
from __future__ import annotations
from typing import Union, List, Optional, TYPE_CHECKING
import httpx
from typing_extensions import Literal
from ...core._base_api import BaseAPI
from ...core._base_type import NotGiven, NOT_GIVEN, Headers
from ...core._http_client import make_user_request_input
from ...types.chat.async_chat_completion import AsyncTaskStatus, AsyncCompletion
if TYPE_CHECKING:
from ..._client import ZhipuAI
class AsyncCompletions(BaseAPI):
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
def create(
self,
*,
model: str,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
messages: Union[str, List[str], List[int], List[List[int]], None],
stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
tools: Optional[object] | NotGiven = NOT_GIVEN,
tool_choice: str | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> AsyncTaskStatus:
_cast_type = AsyncTaskStatus
if disable_strict_validation:
_cast_type = object
return self._post(
"/async/chat/completions",
body={
"model": model,
"request_id": request_id,
"temperature": temperature,
"top_p": top_p,
"do_sample": do_sample,
"max_tokens": max_tokens,
"seed": seed,
"messages": messages,
"stop": stop,
"sensitive_word_check": sensitive_word_check,
"tools": tools,
"tool_choice": tool_choice,
},
options=make_user_request_input(
extra_headers=extra_headers, timeout=timeout
),
cast_type=_cast_type,
enable_stream=False,
)
def retrieve_completion_result(
self,
id: str,
extra_headers: Headers | None = None,
disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Union[AsyncCompletion, AsyncTaskStatus]:
_cast_type = Union[AsyncCompletion,AsyncTaskStatus]
if disable_strict_validation:
_cast_type = object
return self._get(
path=f"/async-result/{id}",
cast_type=_cast_type,
options=make_user_request_input(
extra_headers=extra_headers,
timeout=timeout
)
)

View File

@ -0,0 +1,16 @@
from typing import TYPE_CHECKING
from .completions import Completions
from .async_completions import AsyncCompletions
from ...core._base_api import BaseAPI
if TYPE_CHECKING:
from ..._client import ZhipuAI
class Chat(BaseAPI):
completions: Completions
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
self.completions = Completions(client)
self.asyncCompletions = AsyncCompletions(client)

View File

@ -0,0 +1,71 @@
from __future__ import annotations
from typing import Union, List, Optional, TYPE_CHECKING
import httpx
from typing_extensions import Literal
from ...core._base_api import BaseAPI
from ...core._base_type import NotGiven, NOT_GIVEN, Headers
from ...core._http_client import make_user_request_input
from ...core._sse_client import StreamResponse
from ...types.chat.chat_completion import Completion
from ...types.chat.chat_completion_chunk import ChatCompletionChunk
if TYPE_CHECKING:
from ..._client import ZhipuAI
class Completions(BaseAPI):
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
def create(
self,
*,
model: str,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
max_tokens: int | NotGiven = NOT_GIVEN,
seed: int | NotGiven = NOT_GIVEN,
messages: Union[str, List[str], List[int], object, None],
stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
tools: Optional[object] | NotGiven = NOT_GIVEN,
tool_choice: str | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Completion | StreamResponse[ChatCompletionChunk]:
_cast_type = Completion
_stream_cls = StreamResponse[ChatCompletionChunk]
if disable_strict_validation:
_cast_type = object
_stream_cls = StreamResponse[object]
return self._post(
"/chat/completions",
body={
"model": model,
"request_id": request_id,
"temperature": temperature,
"top_p": top_p,
"do_sample": do_sample,
"max_tokens": max_tokens,
"seed": seed,
"messages": messages,
"stop": stop,
"sensitive_word_check": sensitive_word_check,
"stream": stream,
"tools": tools,
"tool_choice": tool_choice,
},
options=make_user_request_input(
extra_headers=extra_headers,
),
cast_type=_cast_type,
enable_stream=stream or False,
stream_cls=_stream_cls,
)

View File

@ -0,0 +1,49 @@
from __future__ import annotations
from typing import Union, List, Optional, TYPE_CHECKING
import httpx
from ..core._base_api import BaseAPI
from ..core._base_type import NotGiven, NOT_GIVEN, Headers
from ..core._http_client import make_user_request_input
from ..types.embeddings import EmbeddingsResponded
if TYPE_CHECKING:
from .._client import ZhipuAI
class Embeddings(BaseAPI):
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
def create(
self,
*,
input: Union[str, List[str], List[int], List[List[int]]],
model: Union[str],
encoding_format: str | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> EmbeddingsResponded:
_cast_type = EmbeddingsResponded
if disable_strict_validation:
_cast_type = object
return self._post(
"/embeddings",
body={
"input": input,
"model": model,
"encoding_format": encoding_format,
"user": user,
"sensitive_word_check": sensitive_word_check,
},
options=make_user_request_input(
extra_headers=extra_headers, timeout=timeout
),
cast_type=_cast_type,
enable_stream=False,
)

View File

@ -0,0 +1,78 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import httpx
from ..core._base_api import BaseAPI
from ..core._base_type import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes
from ..core._files import is_file_content
from ..core._http_client import (
make_user_request_input,
)
from ..types.file_object import FileObject, ListOfFileObject
if TYPE_CHECKING:
from .._client import ZhipuAI
__all__ = ["Files"]
class Files(BaseAPI):
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
def create(
self,
*,
file: FileTypes,
purpose: str,
extra_headers: Headers | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FileObject:
if not is_file_content(file):
prefix = f"Expected file input `{file!r}`"
raise RuntimeError(
f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(file)} instead."
) from None
files = [("file", file)]
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return self._post(
"/files",
body={
"purpose": purpose,
},
files=files,
options=make_user_request_input(
extra_headers=extra_headers, timeout=timeout
),
cast_type=FileObject,
)
def list(
self,
*,
purpose: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
after: str | NotGiven = NOT_GIVEN,
order: str | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ListOfFileObject:
return self._get(
"/files",
cast_type=ListOfFileObject,
options=make_user_request_input(
extra_headers=extra_headers,
timeout=timeout,
query={
"purpose": purpose,
"limit": limit,
"after": after,
"order": order,
},
),
)

View File

@ -0,0 +1,15 @@
from typing import TYPE_CHECKING
from .jobs import Jobs
from ...core._base_api import BaseAPI
if TYPE_CHECKING:
from ..._client import ZhipuAI
class FineTuning(BaseAPI):
jobs: Jobs
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
self.jobs = Jobs(client)

View File

@ -0,0 +1,115 @@
from __future__ import annotations
from typing import Optional, TYPE_CHECKING
import httpx
from ...core._base_api import BaseAPI
from ...core._base_type import NOT_GIVEN, Headers, NotGiven
from ...core._http_client import (
make_user_request_input,
)
from ...types.fine_tuning import (
FineTuningJob,
job_create_params,
ListOfFineTuningJob,
FineTuningJobEvent,
)
if TYPE_CHECKING:
from ..._client import ZhipuAI
__all__ = ["Jobs"]
class Jobs(BaseAPI):
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
def create(
self,
*,
model: str,
training_file: str,
hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN,
suffix: Optional[str] | NotGiven = NOT_GIVEN,
request_id: Optional[str] | NotGiven = NOT_GIVEN,
validation_file: Optional[str] | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJob:
return self._post(
"/fine_tuning/jobs",
body={
"model": model,
"training_file": training_file,
"hyperparameters": hyperparameters,
"suffix": suffix,
"validation_file": validation_file,
"request_id": request_id,
},
options=make_user_request_input(
extra_headers=extra_headers, timeout=timeout
),
cast_type=FineTuningJob,
)
def retrieve(
self,
fine_tuning_job_id: str,
*,
extra_headers: Headers | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJob:
return self._get(
f"/fine_tuning/jobs/{fine_tuning_job_id}",
options=make_user_request_input(
extra_headers=extra_headers, timeout=timeout
),
cast_type=FineTuningJob,
)
def list(
self,
*,
after: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ListOfFineTuningJob:
return self._get(
"/fine_tuning/jobs",
cast_type=ListOfFineTuningJob,
options=make_user_request_input(
extra_headers=extra_headers,
timeout=timeout,
query={
"after": after,
"limit": limit,
},
),
)
def list_events(
self,
fine_tuning_job_id: str,
*,
after: str | NotGiven = NOT_GIVEN,
limit: int | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> FineTuningJobEvent:
return self._get(
f"/fine_tuning/jobs/{fine_tuning_job_id}/events",
cast_type=FineTuningJobEvent,
options=make_user_request_input(
extra_headers=extra_headers,
timeout=timeout,
query={
"after": after,
"limit": limit,
},
),
)

View File

@ -0,0 +1,55 @@
from __future__ import annotations
from typing import Union, List, Optional, TYPE_CHECKING
import httpx
from ..core._base_api import BaseAPI
from ..core._base_type import NotGiven, NOT_GIVEN, Headers
from ..core._http_client import make_user_request_input
from ..types.image import ImagesResponded
if TYPE_CHECKING:
from .._client import ZhipuAI
class Images(BaseAPI):
def __init__(self, client: "ZhipuAI") -> None:
super().__init__(client)
def generations(
self,
*,
prompt: str,
model: str | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
quality: Optional[str] | NotGiven = NOT_GIVEN,
response_format: Optional[str] | NotGiven = NOT_GIVEN,
size: Optional[str] | NotGiven = NOT_GIVEN,
style: Optional[str] | NotGiven = NOT_GIVEN,
user: str | NotGiven = NOT_GIVEN,
extra_headers: Headers | None = None,
disable_strict_validation: Optional[bool] | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> ImagesResponded:
_cast_type = ImagesResponded
if disable_strict_validation:
_cast_type = object
return self._post(
"/images/generations",
body={
"prompt": prompt,
"model": model,
"n": n,
"quality": quality,
"response_format": response_format,
"size": size,
"style": style,
"user": user,
},
options=make_user_request_input(
extra_headers=extra_headers, timeout=timeout
),
cast_type=_cast_type,
enable_stream=False,
)

View File

@ -0,0 +1,17 @@
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .._client import ZhipuAI
class BaseAPI:
_client: ZhipuAI
def __init__(self, client: ZhipuAI) -> None:
self._client = client
self._delete = client.delete
self._get = client.get
self._post = client.post
self._put = client.put
self._patch = client.patch

View File

@ -0,0 +1,115 @@
from __future__ import annotations
from os import PathLike
from typing import (
TYPE_CHECKING,
Type,
Union,
Mapping,
TypeVar, IO, Tuple, Sequence, Any, List,
)
import pydantic
from typing_extensions import (
Literal,
override,
)
Query = Mapping[str, object]
Body = object
AnyMapping = Mapping[str, object]
PrimitiveData = Union[str, int, float, bool, None]
Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"]
ModelT = TypeVar("ModelT", bound=pydantic.BaseModel)
_T = TypeVar("_T")
if TYPE_CHECKING:
NoneType: Type[None]
else:
NoneType = type(None)
# Sentinel class used until PEP 0661 is accepted
class NotGiven(pydantic.BaseModel):
"""
A sentinel singleton class used to distinguish omitted keyword arguments
from those passed in with the value None (which may have different behavior).
For example:
```py
def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ...
get(timeout=1) # 1s timeout
get(timeout=None) # No timeout
get() # Default timeout behavior, which may not be statically known at the method definition.
```
"""
def __bool__(self) -> Literal[False]:
return False
@override
def __repr__(self) -> str:
return "NOT_GIVEN"
NotGivenOr = Union[_T, NotGiven]
NOT_GIVEN = NotGiven()
class Omit(pydantic.BaseModel):
"""In certain situations you need to be able to represent a case where a default value has
to be explicitly removed and `None` is not an appropriate substitute, for example:
```py
# as the default `Content-Type` header is `application/json` that will be sent
client.post('/upload/files', files={'file': b'my raw file content'})
# you can't explicitly override the header as it has to be dynamically generated
# to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983'
client.post(..., headers={'Content-Type': 'multipart/form-data'})
# instead you can remove the default `application/json` header by passing Omit
client.post(..., headers={'Content-Type': Omit()})
```
"""
def __bool__(self) -> Literal[False]:
return False
Headers = Mapping[str, Union[str, Omit]]
ResponseT = TypeVar(
"ResponseT",
bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]",
)
# for user input files
if TYPE_CHECKING:
FileContent = Union[IO[bytes], bytes, PathLike[str]]
else:
FileContent = Union[IO[bytes], bytes, PathLike]
FileTypes = Union[
FileContent, # file content
Tuple[str, FileContent], # (filename, file)
Tuple[str, FileContent, str], # (filename, file , content_type)
Tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers)
]
RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]]
# for httpx client supported files
HttpxFileContent = Union[bytes, IO[bytes]]
HttpxFileTypes = Union[
FileContent, # file content
Tuple[str, HttpxFileContent], # (filename, file)
Tuple[str, HttpxFileContent, str], # (filename, file , content_type)
Tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers)
]
HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]]

View File

@ -0,0 +1,90 @@
from __future__ import annotations
import httpx
__all__ = [
"ZhipuAIError",
"APIStatusError",
"APIRequestFailedError",
"APIAuthenticationError",
"APIReachLimitError",
"APIInternalError",
"APIServerFlowExceedError",
"APIResponseError",
"APIResponseValidationError",
"APITimeoutError",
]
class ZhipuAIError(Exception):
def __init__(self, message: str, ) -> None:
super().__init__(message)
class APIStatusError(Exception):
response: httpx.Response
status_code: int
def __init__(self, message: str, *, response: httpx.Response) -> None:
super().__init__(message)
self.response = response
self.status_code = response.status_code
class APIRequestFailedError(APIStatusError):
...
class APIAuthenticationError(APIStatusError):
...
class APIReachLimitError(APIStatusError):
...
class APIInternalError(APIStatusError):
...
class APIServerFlowExceedError(APIStatusError):
...
class APIResponseError(Exception):
message: str
request: httpx.Request
json_data: object
def __init__(self, message: str, request: httpx.Request, json_data: object):
self.message = message
self.request = request
self.json_data = json_data
super().__init__(message)
class APIResponseValidationError(APIResponseError):
status_code: int
response: httpx.Response
def __init__(
self,
response: httpx.Response,
json_data: object | None, *,
message: str | None = None
) -> None:
super().__init__(
message=message or "Data returned by API invalid for expected schema.",
request=response.request,
json_data=json_data
)
self.response = response
self.status_code = response.status_code
class APITimeoutError(Exception):
request: httpx.Request
def __init__(self, request: httpx.Request):
self.request = request
super().__init__("Request Timeout")

View File

@ -0,0 +1,46 @@
from __future__ import annotations
import io
import os
from pathlib import Path
from typing import Mapping, Sequence
from ._base_type import (
FileTypes,
HttpxFileTypes,
HttpxRequestFiles,
RequestFiles,
)
def is_file_content(obj: object) -> bool:
return isinstance(obj, (bytes, tuple, io.IOBase, os.PathLike))
def _transform_file(file: FileTypes) -> HttpxFileTypes:
if is_file_content(file):
if isinstance(file, os.PathLike):
path = Path(file)
return path.name, path.read_bytes()
else:
return file
if isinstance(file, tuple):
if isinstance(file[1], os.PathLike):
return (file[0], Path(file[1]).read_bytes(), *file[2:])
else:
return (file[0], file[1], *file[2:])
else:
raise TypeError(f"Unexpected input file with type {type(file)},Expected FileContent type or tuple type")
def make_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None:
if files is None:
return None
if isinstance(files, Mapping):
files = {key: _transform_file(file) for key, file in files.items()}
elif isinstance(files, Sequence):
files = [(key, _transform_file(file)) for key, file in files]
else:
raise TypeError(f"Unexpected input file with type {type(files)}, excepted Mapping or Sequence")
return files

View File

@ -0,0 +1,377 @@
# -*- coding:utf-8 -*-
from __future__ import annotations
import inspect
from typing import (
Any,
Type,
Union,
cast,
Mapping,
)
import httpx
import pydantic
from httpx import URL, Timeout
from . import _errors
from ._base_type import NotGiven, ResponseT, Body, Headers, NOT_GIVEN, RequestFiles, Query, Data
from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError
from ._files import make_httpx_files
from ._request_opt import ClientRequestParam, UserRequestInput
from ._response import HttpResponse
from ._sse_client import StreamResponse
from ._utils import flatten
headers = {
"Accept": "application/json",
"Content-Type": "application/json; charset=UTF-8",
}
def _merge_map(map1: Mapping, map2: Mapping) -> Mapping:
merged = {**map1, **map2}
return {key: val for key, val in merged.items() if val is not None}
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0)
ZHIPUAI_DEFAULT_MAX_RETRIES = 3
ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10)
class HttpClient:
_client: httpx.Client
_version: str
_base_url: URL
timeout: Union[float, Timeout, None]
_limits: httpx.Limits
_has_custom_http_client: bool
_default_stream_cls: type[StreamResponse[Any]] | None = None
def __init__(
self,
*,
version: str,
base_url: URL,
timeout: Union[float, Timeout, None],
custom_httpx_client: httpx.Client | None = None,
custom_headers: Mapping[str, str] | None = None,
) -> None:
if timeout is None or isinstance(timeout, NotGiven):
if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT:
timeout = custom_httpx_client.timeout
else:
timeout = ZHIPUAI_DEFAULT_TIMEOUT
self.timeout = cast(Timeout, timeout)
self._has_custom_http_client = bool(custom_httpx_client)
self._client = custom_httpx_client or httpx.Client(
base_url=base_url,
timeout=self.timeout,
limits=ZHIPUAI_DEFAULT_LIMITS,
)
self._version = version
url = URL(url=base_url)
if not url.raw_path.endswith(b"/"):
url = url.copy_with(raw_path=url.raw_path + b"/")
self._base_url = url
self._custom_headers = custom_headers or {}
def _prepare_url(self, url: str) -> URL:
sub_url = URL(url)
if sub_url.is_relative_url:
request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/")
return self._base_url.copy_with(raw_path=request_raw_url)
return sub_url
@property
def _default_headers(self):
return \
{
"Accept": "application/json",
"Content-Type": "application/json; charset=UTF-8",
"ZhipuAI-SDK-Ver": self._version,
"source_type": "zhipu-sdk-python",
"x-request-sdk": "zhipu-sdk-python",
**self._auth_headers,
**self._custom_headers,
}
@property
def _auth_headers(self):
return {}
def _prepare_headers(self, request_param: ClientRequestParam) -> httpx.Headers:
custom_headers = request_param.headers or {}
headers_dict = _merge_map(self._default_headers, custom_headers)
httpx_headers = httpx.Headers(headers_dict)
return httpx_headers
def _prepare_request(
self,
request_param: ClientRequestParam
) -> httpx.Request:
kwargs: dict[str, Any] = {}
json_data = request_param.json_data
headers = self._prepare_headers(request_param)
url = self._prepare_url(request_param.url)
json_data = request_param.json_data
if headers.get("Content-Type") == "multipart/form-data":
headers.pop("Content-Type")
if json_data:
kwargs["data"] = self._make_multipartform(json_data)
return self._client.build_request(
headers=headers,
timeout=self.timeout if isinstance(request_param.timeout, NotGiven) else request_param.timeout,
method=request_param.method,
url=url,
json=json_data,
files=request_param.files,
params=request_param.params,
**kwargs,
)
def _object_to_formfata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]:
items = []
if isinstance(value, Mapping):
for k, v in value.items():
items.extend(self._object_to_formfata(f"{key}[{k}]", v))
return items
if isinstance(value, (list, tuple)):
for v in value:
items.extend(self._object_to_formfata(key + "[]", v))
return items
def _primitive_value_to_str(val) -> str:
# copied from httpx
if val is True:
return "true"
elif val is False:
return "false"
elif val is None:
return ""
return str(val)
str_data = _primitive_value_to_str(value)
if not str_data:
return []
return [(key, str_data)]
def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
items = flatten([self._object_to_formfata(k, v) for k, v in data.items()])
serialized: dict[str, object] = {}
for key, value in items:
if key in serialized:
raise ValueError(f"存在重复的键: {key};")
serialized[key] = value
return serialized
def _parse_response(
self,
*,
cast_type: Type[ResponseT],
response: httpx.Response,
enable_stream: bool,
request_param: ClientRequestParam,
stream_cls: type[StreamResponse[Any]] | None = None,
) -> HttpResponse:
http_response = HttpResponse(
raw_response=response,
cast_type=cast_type,
client=self,
enable_stream=enable_stream,
stream_cls=stream_cls
)
return http_response.parse()
def _process_response_data(
self,
*,
data: object,
cast_type: type[ResponseT],
response: httpx.Response,
) -> ResponseT:
if data is None:
return cast(ResponseT, None)
try:
if inspect.isclass(cast_type) and issubclass(cast_type, pydantic.BaseModel):
return cast(ResponseT, cast_type.validate(data))
return cast(ResponseT, pydantic.TypeAdapter(cast_type).validate_python(data))
except pydantic.ValidationError as err:
raise APIResponseValidationError(response=response, json_data=data) from err
def is_closed(self) -> bool:
return self._client.is_closed
def close(self):
self._client.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def request(
self,
*,
cast_type: Type[ResponseT],
params: ClientRequestParam,
enable_stream: bool = False,
stream_cls: type[StreamResponse[Any]] | None = None,
) -> ResponseT | StreamResponse:
request = self._prepare_request(params)
try:
response = self._client.send(
request,
stream=enable_stream,
)
response.raise_for_status()
except httpx.TimeoutException as err:
raise APITimeoutError(request=request) from err
except httpx.HTTPStatusError as err:
err.response.read()
# raise err
raise self._make_status_error(err.response) from None
except Exception as err:
raise err
return self._parse_response(
cast_type=cast_type,
request_param=params,
response=response,
enable_stream=enable_stream,
stream_cls=stream_cls,
)
def get(
self,
path: str,
*,
cast_type: Type[ResponseT],
options: UserRequestInput = {},
enable_stream: bool = False,
) -> ResponseT | StreamResponse:
opts = ClientRequestParam.construct(method="get", url=path, **options)
return self.request(
cast_type=cast_type, params=opts,
enable_stream=enable_stream
)
def post(
self,
path: str,
*,
body: Body | None = None,
cast_type: Type[ResponseT],
options: UserRequestInput = {},
files: RequestFiles | None = None,
enable_stream: bool = False,
stream_cls: type[StreamResponse[Any]] | None = None,
) -> ResponseT | StreamResponse:
opts = ClientRequestParam.construct(method="post", json_data=body, files=make_httpx_files(files), url=path,
**options)
return self.request(
cast_type=cast_type, params=opts,
enable_stream=enable_stream,
stream_cls=stream_cls
)
def patch(
self,
path: str,
*,
body: Body | None = None,
cast_type: Type[ResponseT],
options: UserRequestInput = {},
) -> ResponseT:
opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options)
return self.request(
cast_type=cast_type, params=opts,
)
def put(
self,
path: str,
*,
body: Body | None = None,
cast_type: Type[ResponseT],
options: UserRequestInput = {},
files: RequestFiles | None = None,
) -> ResponseT | StreamResponse:
opts = ClientRequestParam.construct(method="put", url=path, json_data=body, files=make_httpx_files(files),
**options)
return self.request(
cast_type=cast_type, params=opts,
)
def delete(
self,
path: str,
*,
body: Body | None = None,
cast_type: Type[ResponseT],
options: UserRequestInput = {},
) -> ResponseT | StreamResponse:
opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options)
return self.request(
cast_type=cast_type, params=opts,
)
def _make_status_error(self, response) -> APIStatusError:
response_text = response.text.strip()
status_code = response.status_code
error_msg = f"Error code: {status_code}, with error text {response_text}"
if status_code == 400:
return _errors.APIRequestFailedError(message=error_msg, response=response)
elif status_code == 401:
return _errors.APIAuthenticationError(message=error_msg, response=response)
elif status_code == 429:
return _errors.APIReachLimitError(message=error_msg, response=response)
elif status_code == 500:
return _errors.APIInternalError(message=error_msg, response=response)
elif status_code == 503:
return _errors.APIServerFlowExceedError(message=error_msg, response=response)
return APIStatusError(message=error_msg, response=response)
def make_user_request_input(
max_retries: int | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
extra_headers: Headers = None,
query: Query | None = None,
) -> UserRequestInput:
options: UserRequestInput = {}
if extra_headers is not None:
options["headers"] = extra_headers
if max_retries is not None:
options["max_retries"] = max_retries
if not isinstance(timeout, NotGiven):
options['timeout'] = timeout
if query is not None:
options["params"] = query
return options

View File

@ -0,0 +1,30 @@
# -*- coding:utf-8 -*-
import time
import cachetools.func
import jwt
API_TOKEN_TTL_SECONDS = 3 * 60
CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30
@cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS)
def generate_token(apikey: str):
try:
api_key, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid api_key", e)
payload = {
"api_key": api_key,
"exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000,
"timestamp": int(round(time.time() * 1000)),
}
ret = jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)
return ret

View File

@ -0,0 +1,54 @@
from __future__ import annotations
from typing import Union, Any, cast
import pydantic.generics
from httpx import Timeout
from pydantic import ConfigDict
from typing_extensions import (
Unpack, ClassVar, TypedDict
)
from ._base_type import Body, NotGiven, Headers, HttpxRequestFiles, Query
from ._utils import remove_notgiven_indict
class UserRequestInput(TypedDict, total=False):
max_retries: int
timeout: float | Timeout | None
headers: Headers
params: Query | None
class ClientRequestParam():
method: str
url: str
max_retries: Union[int, NotGiven] = NotGiven()
timeout: Union[float, NotGiven] = NotGiven()
headers: Union[Headers, NotGiven] = NotGiven()
json_data: Union[Body, None] = None
files: Union[HttpxRequestFiles, None] = None
params: Query = {}
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
def get_max_retries(self, max_retries) -> int:
if isinstance(self.max_retries, NotGiven):
return max_retries
return self.max_retries
@classmethod
def construct( # type: ignore
cls,
_fields_set: set[str] | None = None,
**values: Unpack[UserRequestInput],
) -> ClientRequestParam :
kwargs: dict[str, Any] = {
key: remove_notgiven_indict(value) for key, value in values.items()
}
client = cls()
client.__dict__.update(kwargs)
return client
model_construct = construct

View File

@ -0,0 +1,121 @@
from __future__ import annotations
import datetime
from typing import TypeVar, Generic, cast, Any, TYPE_CHECKING
import httpx
import pydantic
from typing_extensions import ParamSpec, get_origin, get_args
from ._base_type import NoneType
from ._sse_client import StreamResponse
if TYPE_CHECKING:
from ._http_client import HttpClient
P = ParamSpec("P")
R = TypeVar("R")
class HttpResponse(Generic[R]):
_cast_type: type[R]
_client: "HttpClient"
_parsed: R | None
_enable_stream: bool
_stream_cls: type[StreamResponse[Any]]
http_response: httpx.Response
def __init__(
self,
*,
raw_response: httpx.Response,
cast_type: type[R],
client: "HttpClient",
enable_stream: bool = False,
stream_cls: type[StreamResponse[Any]] | None = None,
) -> None:
self._cast_type = cast_type
self._client = client
self._parsed = None
self._stream_cls = stream_cls
self._enable_stream = enable_stream
self.http_response = raw_response
def parse(self) -> R:
self._parsed = self._parse()
return self._parsed
def _parse(self) -> R:
if self._enable_stream:
self._parsed = cast(
R,
self._stream_cls(
cast_type=cast(type, get_args(self._stream_cls)[0]),
response=self.http_response,
client=self._client
)
)
return self._parsed
cast_type = self._cast_type
if cast_type is NoneType:
return cast(R, None)
http_response = self.http_response
if cast_type == str:
return cast(R, http_response.text)
content_type, *_ = http_response.headers.get("content-type", "application/json").split(";")
origin = get_origin(cast_type) or cast_type
if content_type != "application/json":
if issubclass(origin, pydantic.BaseModel):
data = http_response.json()
return self._client._process_response_data(
data=data,
cast_type=cast_type, # type: ignore
response=http_response,
)
return http_response.text
data = http_response.json()
return self._client._process_response_data(
data=data,
cast_type=cast_type, # type: ignore
response=http_response,
)
@property
def headers(self) -> httpx.Headers:
return self.http_response.headers
@property
def http_request(self) -> httpx.Request:
return self.http_response.request
@property
def status_code(self) -> int:
return self.http_response.status_code
@property
def url(self) -> httpx.URL:
return self.http_response.url
@property
def method(self) -> str:
return self.http_request.method
@property
def content(self) -> bytes:
return self.http_response.content
@property
def text(self) -> str:
return self.http_response.text
@property
def http_version(self) -> str:
return self.http_response.http_version
@property
def elapsed(self) -> datetime.timedelta:
return self.http_response.elapsed

View File

@ -0,0 +1,149 @@
# -*- coding:utf-8 -*-
from __future__ import annotations
import json
from typing import Generic, Iterator, TYPE_CHECKING, Mapping
import httpx
from ._base_type import ResponseT
from ._errors import APIResponseError
_FIELD_SEPARATOR = ":"
if TYPE_CHECKING:
from ._http_client import HttpClient
class StreamResponse(Generic[ResponseT]):
response: httpx.Response
_cast_type: type[ResponseT]
def __init__(
self,
*,
cast_type: type[ResponseT],
response: httpx.Response,
client: HttpClient,
) -> None:
self.response = response
self._cast_type = cast_type
self._data_process_func = client._process_response_data
self._stream_chunks = self.__stream__()
def __next__(self) -> ResponseT:
return self._stream_chunks.__next__()
def __iter__(self) -> Iterator[ResponseT]:
for item in self._stream_chunks:
yield item
def __stream__(self) -> Iterator[ResponseT]:
sse_line_parser = SSELineParser()
iterator = sse_line_parser.iter_lines(self.response.iter_lines())
for sse in iterator:
if sse.data.startswith("[DONE]"):
break
if sse.event is None:
data = sse.json_data()
if isinstance(data, Mapping) and data.get("error"):
raise APIResponseError(
message="An error occurred during streaming",
request=self.response.request,
json_data=data["error"],
)
yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response)
for sse in iterator:
pass
class Event(object):
def __init__(
self,
event: str | None = None,
data: str | None = None,
id: str | None = None,
retry: int | None = None
):
self._event = event
self._data = data
self._id = id
self._retry = retry
def __repr__(self):
data_len = len(self._data) if self._data else 0
return f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}"
@property
def event(self): return self._event
@property
def data(self): return self._data
def json_data(self): return json.loads(self._data)
@property
def id(self): return self._id
@property
def retry(self): return self._retry
class SSELineParser:
_data: list[str]
_event: str | None
_retry: int | None
_id: str | None
def __init__(self):
self._event = None
self._data = []
self._id = None
self._retry = None
def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]:
for line in lines:
line = line.rstrip('\n')
if not line:
if self._event is None and \
not self._data and \
self._id is None and \
self._retry is None:
continue
sse_event = Event(
event=self._event,
data='\n'.join(self._data),
id=self._id,
retry=self._retry
)
self._event = None
self._data = []
self._id = None
self._retry = None
yield sse_event
self.decode_line(line)
def decode_line(self, line: str):
if line.startswith(":") or not line:
return
field, _p, value = line.partition(":")
if value.startswith(' '):
value = value[1:]
if field == "data":
self._data.append(value)
elif field == "event":
self._event = value
elif field == "retry":
try:
self._retry = int(value)
except (TypeError, ValueError):
pass
return

View File

@ -0,0 +1,18 @@
from __future__ import annotations
from typing import Mapping, Iterable, TypeVar
from ._base_type import NotGiven
def remove_notgiven_indict(obj):
if obj is None or (not isinstance(obj, Mapping)):
return obj
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
_T = TypeVar("_T")
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
return [item for sublist in t for item in sublist]

View File

@ -0,0 +1,23 @@
from typing import List, Optional
from pydantic import BaseModel
from .chat_completion import CompletionChoice, CompletionUsage
__all__ = ["AsyncTaskStatus"]
class AsyncTaskStatus(BaseModel):
id: Optional[str] = None
request_id: Optional[str] = None
model: Optional[str] = None
task_status: Optional[str] = None
class AsyncCompletion(BaseModel):
id: Optional[str] = None
request_id: Optional[str] = None
model: Optional[str] = None
task_status: str
choices: List[CompletionChoice]
usage: CompletionUsage

View File

@ -0,0 +1,45 @@
from typing import List, Optional
from pydantic import BaseModel
__all__ = ["Completion", "CompletionUsage"]
class Function(BaseModel):
arguments: str
name: str
class CompletionMessageToolCall(BaseModel):
id: str
function: Function
type: str
class CompletionMessage(BaseModel):
content: Optional[str] = None
role: str
tool_calls: Optional[List[CompletionMessageToolCall]] = None
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class CompletionChoice(BaseModel):
index: int
finish_reason: str
message: CompletionMessage
class Completion(BaseModel):
model: Optional[str] = None
created: Optional[int] = None
choices: List[CompletionChoice]
request_id: Optional[str] = None
id: Optional[str] = None
usage: CompletionUsage

View File

@ -0,0 +1,55 @@
from typing import List, Optional
from pydantic import BaseModel
__all__ = [
"ChatCompletionChunk",
"Choice",
"ChoiceDelta",
"ChoiceDeltaFunctionCall",
"ChoiceDeltaToolCall",
"ChoiceDeltaToolCallFunction",
]
class ChoiceDeltaFunctionCall(BaseModel):
arguments: Optional[str] = None
name: Optional[str] = None
class ChoiceDeltaToolCallFunction(BaseModel):
arguments: Optional[str] = None
name: Optional[str] = None
class ChoiceDeltaToolCall(BaseModel):
index: int
id: Optional[str] = None
function: Optional[ChoiceDeltaToolCallFunction] = None
type: Optional[str] = None
class ChoiceDelta(BaseModel):
content: Optional[str] = None
role: Optional[str] = None
tool_calls: Optional[List[ChoiceDeltaToolCall]] = None
class Choice(BaseModel):
delta: ChoiceDelta
finish_reason: Optional[str] = None
index: int
class CompletionUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionChunk(BaseModel):
id: Optional[str] = None
choices: List[Choice]
created: Optional[int] = None
model: Optional[str] = None
usage: Optional[CompletionUsage] = None

View File

@ -0,0 +1,8 @@
from typing import Optional
from typing_extensions import TypedDict
class Reference(TypedDict, total=False):
enable: Optional[bool]
search_query: Optional[str]

View File

@ -0,0 +1,20 @@
from __future__ import annotations
from typing import Optional, List
from pydantic import BaseModel
from .chat.chat_completion import CompletionUsage
__all__ = ["Embedding", "EmbeddingsResponded"]
class Embedding(BaseModel):
object: str
index: Optional[int] = None
embedding: List[float]
class EmbeddingsResponded(BaseModel):
object: str
data: List[Embedding]
model: str
usage: CompletionUsage

View File

@ -0,0 +1,24 @@
from typing import Optional, List
from pydantic import BaseModel
__all__ = ["FileObject"]
class FileObject(BaseModel):
id: Optional[str] = None
bytes: Optional[int] = None
created_at: Optional[int] = None
filename: Optional[str] = None
object: Optional[str] = None
purpose: Optional[str] = None
status: Optional[str] = None
status_details: Optional[str] = None
class ListOfFileObject(BaseModel):
object: Optional[str] = None
data: List[FileObject]
has_more: Optional[bool] = None

View File

@ -0,0 +1,5 @@
from __future__ import annotations
from .fine_tuning_job import FineTuningJob as FineTuningJob
from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob
from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent

View File

@ -0,0 +1,52 @@
from typing import List, Union, Optional
from typing_extensions import Literal
from pydantic import BaseModel
__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ]
class Error(BaseModel):
code: str
message: str
param: Optional[str] = None
class Hyperparameters(BaseModel):
n_epochs: Union[str, int, None] = None
class FineTuningJob(BaseModel):
id: Optional[str] = None
request_id: Optional[str] = None
created_at: Optional[int] = None
error: Optional[Error] = None
fine_tuned_model: Optional[str] = None
finished_at: Optional[int] = None
hyperparameters: Optional[Hyperparameters] = None
model: Optional[str] = None
object: Optional[str] = None
result_files: List[str]
status: str
trained_tokens: Optional[int] = None
training_file: str
validation_file: Optional[str] = None
class ListOfFineTuningJob(BaseModel):
object: Optional[str] = None
data: List[FineTuningJob]
has_more: Optional[bool] = None

View File

@ -0,0 +1,36 @@
from typing import List, Union, Optional
from typing_extensions import Literal
from pydantic import BaseModel
__all__ = ["FineTuningJobEvent", "Metric", "JobEvent"]
class Metric(BaseModel):
epoch: Optional[Union[str, int, float]] = None
current_steps: Optional[int] = None
total_steps: Optional[int] = None
elapsed_time: Optional[str] = None
remaining_time: Optional[str] = None
trained_tokens: Optional[int] = None
loss: Optional[Union[str, int, float]] = None
eval_loss: Optional[Union[str, int, float]] = None
acc: Optional[Union[str, int, float]] = None
eval_acc: Optional[Union[str, int, float]] = None
learning_rate: Optional[Union[str, int, float]] = None
class JobEvent(BaseModel):
object: Optional[str] = None
id: Optional[str] = None
type: Optional[str] = None
created_at: Optional[int] = None
level: Optional[str] = None
message: Optional[str] = None
data: Optional[Metric] = None
class FineTuningJobEvent(BaseModel):
object: Optional[str] = None
data: List[JobEvent]
has_more: Optional[bool] = None

View File

@ -0,0 +1,15 @@
from __future__ import annotations
from typing import Union
from typing_extensions import Literal, TypedDict
__all__ = ["Hyperparameters"]
class Hyperparameters(TypedDict, total=False):
batch_size: Union[Literal["auto"], int]
learning_rate_multiplier: Union[Literal["auto"], float]
n_epochs: Union[Literal["auto"], int]

View File

@ -0,0 +1,18 @@
from __future__ import annotations
from typing import Optional, List
from pydantic import BaseModel
__all__ = ["GeneratedImage", "ImagesResponded"]
class GeneratedImage(BaseModel):
b64_json: Optional[str] = None
url: Optional[str] = None
revised_prompt: Optional[str] = None
class ImagesResponded(BaseModel):
created: int
data: List[GeneratedImage]

View File

@ -3,7 +3,8 @@ from typing import Generator
import pytest
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, SystemPromptMessage,
UserPromptMessage, PromptMessageTool)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.zhipuai.llm.llm import ZhipuAILargeLanguageModel
@ -102,3 +103,48 @@ def test_get_num_tokens():
)
assert num_tokens == 14
def test_get_tools_num_tokens():
model = ZhipuAILargeLanguageModel()
num_tokens = model.get_num_tokens(
model='tools',
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
},
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"c",
"f"
]
}
},
"required": [
"location"
]
}
)
],
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
),
UserPromptMessage(
content='Hello World!'
)
]
)
assert num_tokens == 108

View File

@ -42,7 +42,7 @@ def test_invoke_model():
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
assert result.usage.total_tokens > 0
def test_get_num_tokens():

View File

@ -229,7 +229,7 @@ const Answer: FC<IAnswerProps> = ({
<Thought
thought={item}
allToolIcons={allToolIcons || {}}
isFinished={!!item.observation}
isFinished={!!item.observation || !isResponsing}
/>
)}

View File

@ -43,7 +43,7 @@ import { fetchDatasets } from '@/service/datasets'
import { useProviderContext } from '@/context/provider-context'
import { AgentStrategy, AppType, ModelModeType, RETRIEVE_TYPE, Resolution, TransferMethod } from '@/types/app'
import { PromptMode } from '@/models/debug'
import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config'
import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG, supportFunctionCallModels } from '@/config'
import SelectDataSet from '@/app/components/app/configuration/dataset-config/select-dataset'
import I18n from '@/context/i18n'
import { useModalContext } from '@/context/modal-context'
@ -163,8 +163,7 @@ const Configuration: FC = () => {
doSetModelConfig(newModelConfig)
}
const isOpenAI = modelConfig.provider === 'openai'
const isFunctionCall = isOpenAI && modelConfig.mode === ModelModeType.chat
const isFunctionCall = (isOpenAI && modelConfig.mode === ModelModeType.chat) || supportFunctionCallModels.includes(modelConfig.model_id)
const [collectionList, setCollectionList] = useState<Collection[]>([])
useEffect(() => {

View File

@ -160,6 +160,8 @@ export const DEFAULT_AGENT_SETTING = {
tools: [],
}
export const supportFunctionCallModels = ['glm-3-turbo', 'glm-4']
export const DEFAULT_AGENT_PROMPT = {
chat: `Respond to the human as helpfully and accurately as possible.