From b921c55677d7209f3cbf9cd69abf807122e7e591 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Thu, 25 Jan 2024 16:29:35 +0800 Subject: [PATCH] Feat/zhipuai function calling (#2199) Co-authored-by: Joel --- .../model_providers/zhipuai/_client.py | 61 --- .../model_providers/zhipuai/llm/llm.py | 211 +++++++--- .../zhipuai/text_embedding/text_embedding.py | 23 +- .../zhipuai/zhipuai_sdk/__init__.py | 17 + .../zhipuai/zhipuai_sdk/__version__.py | 2 + .../zhipuai/zhipuai_sdk/_client.py | 71 ++++ .../zhipuai_sdk/api_resource/__init__.py | 5 + .../zhipuai_sdk/api_resource/chat/__init__.py | 0 .../api_resource/chat/async_completions.py | 87 ++++ .../zhipuai_sdk/api_resource/chat/chat.py | 16 + .../api_resource/chat/completions.py | 71 ++++ .../zhipuai_sdk/api_resource/embeddings.py | 49 +++ .../zhipuai/zhipuai_sdk/api_resource/files.py | 78 ++++ .../api_resource/fine_tuning/__init__.py | 0 .../api_resource/fine_tuning/fine_tuning.py | 15 + .../api_resource/fine_tuning/jobs.py | 115 ++++++ .../zhipuai_sdk/api_resource/images.py | 55 +++ .../zhipuai/zhipuai_sdk/core/__init__.py | 0 .../zhipuai/zhipuai_sdk/core/_base_api.py | 17 + .../zhipuai/zhipuai_sdk/core/_base_type.py | 115 ++++++ .../zhipuai/zhipuai_sdk/core/_errors.py | 90 +++++ .../zhipuai/zhipuai_sdk/core/_files.py | 46 +++ .../zhipuai/zhipuai_sdk/core/_http_client.py | 377 ++++++++++++++++++ .../zhipuai/zhipuai_sdk/core/_jwt_token.py | 30 ++ .../zhipuai/zhipuai_sdk/core/_request_opt.py | 54 +++ .../zhipuai/zhipuai_sdk/core/_response.py | 121 ++++++ .../zhipuai/zhipuai_sdk/core/_sse_client.py | 149 +++++++ .../zhipuai/zhipuai_sdk/core/_utils.py | 18 + .../zhipuai/zhipuai_sdk/types/__init__.py | 0 .../zhipuai_sdk/types/chat/__init__.py | 0 .../types/chat/async_chat_completion.py | 23 ++ .../zhipuai_sdk/types/chat/chat_completion.py | 45 +++ .../types/chat/chat_completion_chunk.py | 55 +++ .../chat/chat_completions_create_param.py | 8 + .../zhipuai/zhipuai_sdk/types/embeddings.py | 20 + .../zhipuai/zhipuai_sdk/types/file_object.py | 24 ++ .../zhipuai_sdk/types/fine_tuning/__init__.py | 5 + .../types/fine_tuning/fine_tuning_job.py | 52 +++ .../fine_tuning/fine_tuning_job_event.py | 36 ++ .../types/fine_tuning/job_create_params.py | 15 + .../zhipuai/zhipuai_sdk/types/image.py | 18 + .../model_runtime/zhipuai/test_llm.py | 48 ++- .../zhipuai/test_text_embedding.py | 2 +- web/app/components/app/chat/answer/index.tsx | 2 +- .../components/app/configuration/index.tsx | 5 +- web/config/index.ts | 2 + 46 files changed, 2115 insertions(+), 138 deletions(-) delete mode 100644 api/core/model_runtime/model_providers/zhipuai/_client.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/__init__.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/__init__.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completions_create_param.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py create mode 100644 api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py diff --git a/api/core/model_runtime/model_providers/zhipuai/_client.py b/api/core/model_runtime/model_providers/zhipuai/_client.py deleted file mode 100644 index 31042d318..000000000 --- a/api/core/model_runtime/model_providers/zhipuai/_client.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Wrapper around ZhipuAI APIs.""" -from __future__ import annotations - -import logging -import posixpath - -from pydantic import BaseModel, Extra -from zhipuai.model_api.api import InvokeType -from zhipuai.utils import jwt_token -from zhipuai.utils.http_client import post, stream -from zhipuai.utils.sse_client import SSEClient - -logger = logging.getLogger(__name__) - - -class ZhipuModelAPI(BaseModel): - base_url: str = "https://open.bigmodel.cn/api/paas/v3/model-api" - api_key: str - api_timeout_seconds = 60 - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - - def invoke(self, **kwargs): - url = self._build_api_url(kwargs, InvokeType.SYNC) - response = post(url, self._generate_token(), kwargs, self.api_timeout_seconds) - if not response['success']: - raise ValueError( - f"Error Code: {response['code']}, Message: {response['msg']} " - ) - return response - - def sse_invoke(self, **kwargs): - url = self._build_api_url(kwargs, InvokeType.SSE) - data = stream(url, self._generate_token(), kwargs, self.api_timeout_seconds) - return SSEClient(data) - - def _build_api_url(self, kwargs, *path): - if kwargs: - if "model" not in kwargs: - raise Exception("model param missed") - model = kwargs.pop("model") - else: - model = "-" - - return posixpath.join(self.base_url, model, *path) - - def _generate_token(self): - if not self.api_key: - raise Exception( - "api_key not provided, you could provide it." - ) - - try: - return jwt_token.generate_token(self.api_key) - except Exception: - raise ValueError( - f"Your api_key is invalid, please check it." - ) diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index 2b7d9a1b1..c4c1dfb85 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -3,13 +3,15 @@ from typing import Any, Dict, Generator, List, Optional, Union from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole, - PromptMessageTool, SystemPromptMessage, UserPromptMessage, + PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage, TextPromptMessageContent, ImagePromptMessageContent, PromptMessageContentType) from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.utils import helper from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI - +from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI +from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk +from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion import Completion class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): @@ -35,7 +37,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): credentials_kwargs = self._to_credential_kwargs(credentials) # invoke model - return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, stop, stream, user) + return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user) def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: @@ -48,7 +50,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): :param tools: tools for tool calling :return: """ - prompt = self._convert_messages_to_prompt(prompt_messages) + prompt = self._convert_messages_to_prompt(prompt_messages, tools) return self._get_num_tokens_by_gpt2(prompt) @@ -72,6 +74,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): model_parameters={ "temperature": 0.5, }, + tools=[], stream=False ) except Exception as ex: @@ -79,6 +82,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): def _generate(self, model: str, credentials_kwargs: dict, prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ @@ -97,7 +101,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): if stop: extra_model_kwargs['stop_sequences'] = stop - client = ZhipuModelAPI( + client = ZhipuAI( api_key=credentials_kwargs['api_key'] ) @@ -128,11 +132,17 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): # not support image message continue - if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER: + if new_prompt_messages and new_prompt_messages[-1].role == PromptMessageRole.USER and \ + copy_prompt_message.role == PromptMessageRole.USER: new_prompt_messages[-1].content += "\n\n" + copy_prompt_message.content else: if copy_prompt_message.role == PromptMessageRole.USER: new_prompt_messages.append(copy_prompt_message) + elif copy_prompt_message.role == PromptMessageRole.TOOL: + new_prompt_messages.append(copy_prompt_message) + elif copy_prompt_message.role == PromptMessageRole.SYSTEM: + new_prompt_message = SystemPromptMessage(content=copy_prompt_message.content) + new_prompt_messages.append(new_prompt_message) else: new_prompt_message = UserPromptMessage(content=copy_prompt_message.content) new_prompt_messages.append(new_prompt_message) @@ -145,7 +155,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): if model == 'glm-4v': params = { 'model': model, - 'prompt': [{ + 'messages': [{ 'role': prompt_message.role.value, 'content': [ @@ -171,23 +181,63 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): else: params = { 'model': model, - 'prompt': [{ - 'role': prompt_message.role.value, - 'content': prompt_message.content, - } for prompt_message in new_prompt_messages], + 'messages': [], **model_parameters } + # glm model + if not model.startswith('chatglm'): + + for prompt_message in new_prompt_messages: + if prompt_message.role == PromptMessageRole.TOOL: + params['messages'].append({ + 'role': 'tool', + 'content': prompt_message.content, + 'tool_call_id': prompt_message.tool_call_id + }) + else: + params['messages'].append({ + 'role': prompt_message.role.value, + 'content': prompt_message.content + }) + else: + # chatglm model + for prompt_message in new_prompt_messages: + # merge system message to user message + if prompt_message.role == PromptMessageRole.SYSTEM or \ + prompt_message.role == PromptMessageRole.TOOL or \ + prompt_message.role == PromptMessageRole.USER: + if len(params['messages']) > 0 and params['messages'][-1]['role'] == 'user': + params['messages'][-1]['content'] += "\n\n" + prompt_message.content + else: + params['messages'].append({ + 'role': 'user', + 'content': prompt_message.content + }) + else: + params['messages'].append({ + 'role': prompt_message.role.value, + 'content': prompt_message.content + }) + + if tools and len(tools) > 0: + params['tools'] = [ + { + 'type': 'function', + 'function': helper.dump_model(tool) + } for tool in tools + ] if stream: - response = client.sse_invoke(incremental=True, **params).events() - return self._handle_generate_stream_response(model, credentials_kwargs, response, prompt_messages) + response = client.chat.completions.create(stream=stream, **params) + return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages) - response = client.invoke(**params) - return self._handle_generate_response(model, credentials_kwargs, response, prompt_messages) + response = client.chat.completions.create(**params) + return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages) def _handle_generate_response(self, model: str, credentials: dict, - response: Dict[str, Any], + tools: Optional[list[PromptMessageTool]], + response: Completion, prompt_messages: list[PromptMessage]) -> LLMResult: """ Handle llm response @@ -197,26 +247,39 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response """ - data = response["data"] text = '' - for res in data["choices"]: - text += res['content'] + assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = [] + for choice in response.choices: + if choice.message.tool_calls: + for tool_call in choice.message.tool_calls: + if tool_call.type == 'function': + assistant_tool_calls.append( + AssistantPromptMessage.ToolCall( + id=tool_call.id, + type=tool_call.type, + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + ) + ) + + text += choice.message.content or '' - token_usage = data.get("usage") - if token_usage is not None: - if 'prompt_tokens' not in token_usage: - token_usage['prompt_tokens'] = 0 - if 'completion_tokens' not in token_usage: - token_usage['completion_tokens'] = token_usage['total_tokens'] + prompt_usage = response.usage.prompt_tokens + completion_usage = response.usage.completion_tokens # transform usage - usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens']) + usage = self._calc_response_usage(model, credentials, prompt_usage, completion_usage) # transform response result = LLMResult( model=model, prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=text), + message=AssistantPromptMessage( + content=text, + tool_calls=assistant_tool_calls + ), usage=usage, ) @@ -224,7 +287,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): def _handle_generate_stream_response(self, model: str, credentials: dict, - responses: list[Generator], + tools: Optional[list[PromptMessageTool]], + responses: Generator[ChatCompletionChunk, None, None], prompt_messages: list[PromptMessage]) -> Generator: """ Handle llm stream response @@ -234,39 +298,64 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): :param prompt_messages: prompt messages :return: llm response chunk generator result """ - for index, event in enumerate(responses): - if event.event == "add": + full_assistant_content = '' + for chunk in responses: + if len(chunk.choices) == 0: + continue + + delta = chunk.choices[0] + + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + continue + + assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = [] + for tool_call in delta.delta.tool_calls or []: + if tool_call.type == 'function': + assistant_tool_calls.append( + AssistantPromptMessage.ToolCall( + id=tool_call.id, + type=tool_call.type, + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + ) + ) + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=delta.delta.content if delta.delta.content else '', + tool_calls=assistant_tool_calls + ) + + full_assistant_content += delta.delta.content if delta.delta.content else '' + + if delta.finish_reason is not None and chunk.usage is not None: + completion_tokens = chunk.usage.completion_tokens + prompt_tokens = chunk.usage.prompt_tokens + + # transform usage + usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) + yield LLMResultChunk( + model=chunk.model, prompt_messages=prompt_messages, - model=model, + system_fingerprint='', delta=LLMResultChunkDelta( - index=index, - message=AssistantPromptMessage(content=event.data) + index=delta.index, + message=assistant_prompt_message, + finish_reason=delta.finish_reason, + usage=usage ) ) - elif event.event == "error" or event.event == "interrupted": - raise ValueError( - f"{event.data}" - ) - elif event.event == "finish": - meta = json.loads(event.meta) - token_usage = meta['usage'] - if token_usage is not None: - if 'prompt_tokens' not in token_usage: - token_usage['prompt_tokens'] = 0 - if 'completion_tokens' not in token_usage: - token_usage['completion_tokens'] = token_usage['total_tokens'] - - usage = self._calc_response_usage(model, credentials, token_usage['prompt_tokens'], token_usage['completion_tokens']) - + else: yield LLMResultChunk( - model=model, + model=chunk.model, prompt_messages=prompt_messages, + system_fingerprint='', delta=LLMResultChunkDelta( - index=index, - message=AssistantPromptMessage(content=event.data), - finish_reason='finish', - usage=usage + index=delta.index, + message=assistant_prompt_message, ) ) @@ -291,11 +380,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): raise ValueError(f"Got unknown type {message}") return message_text - - def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str: - """ - Format a list of messages into a full prompt for the Anthropic model + + def _convert_messages_to_prompt(self, messages: List[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str: + """ :param messages: List of PromptMessage to combine. :return: Combined string with necessary human_prompt and ai_prompt tags. """ @@ -306,5 +394,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): for message in messages ) + if tools and len(tools) > 0: + text += "\n\nTools:" + for tool in tools: + text += f"\n{tool.json()}" + # trim off the trailing ' ' that might come from the "Assistant: " - return text.rstrip() + return text.rstrip() \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 0fd04134b..e5ecc85c4 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.zhipuai._client import ZhipuModelAPI +from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI from langchain.schema.language_model import _get_token_ids_default_method @@ -28,7 +28,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): :return: embeddings result """ credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuModelAPI( + client = ZhipuAI( api_key=credentials_kwargs['api_key'] ) @@ -69,7 +69,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): try: # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) - client = ZhipuModelAPI( + client = ZhipuAI( api_key=credentials_kwargs['api_key'] ) @@ -82,7 +82,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def embed_documents(self, model: str, client: ZhipuModelAPI, texts: List[str]) -> Tuple[List[List[float]], int]: + def embed_documents(self, model: str, client: ZhipuAI, texts: List[str]) -> Tuple[List[List[float]], int]: """Call out to ZhipuAI's embedding endpoint. Args: @@ -91,17 +91,16 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): Returns: List of embeddings, one for each text. """ - - embeddings = [] + embedding_used_tokens = 0 + for text in texts: - response = client.invoke(model=model, prompt=text) - data = response["data"] - embeddings.append(data.get('embedding')) + response = client.embeddings.create(model=model, input=text) + data = response.data[0] + embeddings.append(data.embedding) + embedding_used_tokens += response.usage.total_tokens - embedding_used_tokens = data.get('usage') - - return [list(map(float, e)) for e in embeddings], embedding_used_tokens['total_tokens'] if embedding_used_tokens else 0 + return [list(map(float, e)) for e in embeddings], embedding_used_tokens def embed_query(self, text: str) -> List[float]: """Call out to ZhipuAI's embedding endpoint. diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py new file mode 100644 index 000000000..5527c4891 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py @@ -0,0 +1,17 @@ + +from ._client import ZhipuAI + +from .core._errors import ( + ZhipuAIError, + APIStatusError, + APIRequestFailedError, + APIAuthenticationError, + APIReachLimitError, + APIInternalError, + APIServerFlowExceedError, + APIResponseError, + APIResponseValidationError, + APITimeoutError, +) + +from .__version__ import __version__ diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py new file mode 100644 index 000000000..eb0ad332c --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py @@ -0,0 +1,2 @@ + +__version__ = 'v2.0.1' \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py new file mode 100644 index 000000000..e169c5485 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from typing import Union, Mapping + +from typing_extensions import override + +from .core import _jwt_token +from .core._errors import ZhipuAIError +from .core._http_client import HttpClient, ZHIPUAI_DEFAULT_MAX_RETRIES +from .core._base_type import NotGiven, NOT_GIVEN +from . import api_resource +import os +import httpx +from httpx import Timeout + + +class ZhipuAI(HttpClient): + chat: api_resource.chat + api_key: str + + def __init__( + self, + *, + api_key: str | None = None, + base_url: str | httpx.URL | None = None, + timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN, + max_retries: int = ZHIPUAI_DEFAULT_MAX_RETRIES, + http_client: httpx.Client | None = None, + custom_headers: Mapping[str, str] | None = None + ) -> None: + # if api_key is None: + # api_key = os.environ.get("ZHIPUAI_API_KEY") + if api_key is None: + raise ZhipuAIError("未提供api_key,请通过参数或环境变量提供") + self.api_key = api_key + + if base_url is None: + base_url = os.environ.get("ZHIPUAI_BASE_URL") + if base_url is None: + base_url = f"https://open.bigmodel.cn/api/paas/v4" + from .__version__ import __version__ + super().__init__( + version=__version__, + base_url=base_url, + timeout=timeout, + custom_httpx_client=http_client, + custom_headers=custom_headers, + ) + self.chat = api_resource.chat.Chat(self) + self.images = api_resource.images.Images(self) + self.embeddings = api_resource.embeddings.Embeddings(self) + self.files = api_resource.files.Files(self) + self.fine_tuning = api_resource.fine_tuning.FineTuning(self) + + @property + @override + def _auth_headers(self) -> dict[str, str]: + api_key = self.api_key + return {"Authorization": f"{_jwt_token.generate_token(api_key)}"} + + def __del__(self) -> None: + if (not hasattr(self, "_has_custom_http_client") + or not hasattr(self, "close") + or not hasattr(self, "_client")): + # if the '__init__' method raised an error, self would not have client attr + return + + if self._has_custom_http_client: + return + + self.close() diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py new file mode 100644 index 000000000..b596702bc --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/__init__.py @@ -0,0 +1,5 @@ +from .chat import chat +from .images import Images +from .embeddings import Embeddings +from .files import Files +from .fine_tuning import fine_tuning diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py new file mode 100644 index 000000000..b926bd013 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from typing import Union, List, Optional, TYPE_CHECKING + +import httpx +from typing_extensions import Literal + +from ...core._base_api import BaseAPI +from ...core._base_type import NotGiven, NOT_GIVEN, Headers +from ...core._http_client import make_user_request_input +from ...types.chat.async_chat_completion import AsyncTaskStatus, AsyncCompletion + +if TYPE_CHECKING: + from ..._client import ZhipuAI + + +class AsyncCompletions(BaseAPI): + def __init__(self, client: "ZhipuAI") -> None: + super().__init__(client) + + + def create( + self, + *, + model: str, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + messages: Union[str, List[str], List[int], List[List[int]], None], + stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + tools: Optional[object] | NotGiven = NOT_GIVEN, + tool_choice: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> AsyncTaskStatus: + _cast_type = AsyncTaskStatus + + if disable_strict_validation: + _cast_type = object + return self._post( + "/async/chat/completions", + body={ + "model": model, + "request_id": request_id, + "temperature": temperature, + "top_p": top_p, + "do_sample": do_sample, + "max_tokens": max_tokens, + "seed": seed, + "messages": messages, + "stop": stop, + "sensitive_word_check": sensitive_word_check, + "tools": tools, + "tool_choice": tool_choice, + }, + options=make_user_request_input( + extra_headers=extra_headers, timeout=timeout + ), + cast_type=_cast_type, + enable_stream=False, + ) + + def retrieve_completion_result( + self, + id: str, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Union[AsyncCompletion, AsyncTaskStatus]: + _cast_type = Union[AsyncCompletion,AsyncTaskStatus] + if disable_strict_validation: + _cast_type = object + return self._get( + path=f"/async-result/{id}", + cast_type=_cast_type, + options=make_user_request_input( + extra_headers=extra_headers, + timeout=timeout + ) + ) + + diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py new file mode 100644 index 000000000..6a9fef671 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/chat.py @@ -0,0 +1,16 @@ +from typing import TYPE_CHECKING +from .completions import Completions +from .async_completions import AsyncCompletions +from ...core._base_api import BaseAPI + +if TYPE_CHECKING: + from ..._client import ZhipuAI + + +class Chat(BaseAPI): + completions: Completions + + def __init__(self, client: "ZhipuAI") -> None: + super().__init__(client) + self.completions = Completions(client) + self.asyncCompletions = AsyncCompletions(client) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py new file mode 100644 index 000000000..bec2755f9 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from typing import Union, List, Optional, TYPE_CHECKING + +import httpx +from typing_extensions import Literal + +from ...core._base_api import BaseAPI +from ...core._base_type import NotGiven, NOT_GIVEN, Headers +from ...core._http_client import make_user_request_input +from ...core._sse_client import StreamResponse +from ...types.chat.chat_completion import Completion +from ...types.chat.chat_completion_chunk import ChatCompletionChunk + +if TYPE_CHECKING: + from ..._client import ZhipuAI + + +class Completions(BaseAPI): + def __init__(self, client: "ZhipuAI") -> None: + super().__init__(client) + + def create( + self, + *, + model: str, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + do_sample: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + max_tokens: int | NotGiven = NOT_GIVEN, + seed: int | NotGiven = NOT_GIVEN, + messages: Union[str, List[str], List[int], object, None], + stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + tools: Optional[object] | NotGiven = NOT_GIVEN, + tool_choice: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Completion | StreamResponse[ChatCompletionChunk]: + _cast_type = Completion + _stream_cls = StreamResponse[ChatCompletionChunk] + if disable_strict_validation: + _cast_type = object + _stream_cls = StreamResponse[object] + return self._post( + "/chat/completions", + body={ + "model": model, + "request_id": request_id, + "temperature": temperature, + "top_p": top_p, + "do_sample": do_sample, + "max_tokens": max_tokens, + "seed": seed, + "messages": messages, + "stop": stop, + "sensitive_word_check": sensitive_word_check, + "stream": stream, + "tools": tools, + "tool_choice": tool_choice, + }, + options=make_user_request_input( + extra_headers=extra_headers, + ), + cast_type=_cast_type, + enable_stream=stream or False, + stream_cls=_stream_cls, + ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py new file mode 100644 index 000000000..b12ce9564 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import Union, List, Optional, TYPE_CHECKING + +import httpx + +from ..core._base_api import BaseAPI +from ..core._base_type import NotGiven, NOT_GIVEN, Headers +from ..core._http_client import make_user_request_input +from ..types.embeddings import EmbeddingsResponded + +if TYPE_CHECKING: + from .._client import ZhipuAI + + +class Embeddings(BaseAPI): + def __init__(self, client: "ZhipuAI") -> None: + super().__init__(client) + + def create( + self, + *, + input: Union[str, List[str], List[int], List[List[int]]], + model: Union[str], + encoding_format: str | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> EmbeddingsResponded: + _cast_type = EmbeddingsResponded + if disable_strict_validation: + _cast_type = object + return self._post( + "/embeddings", + body={ + "input": input, + "model": model, + "encoding_format": encoding_format, + "user": user, + "sensitive_word_check": sensitive_word_check, + }, + options=make_user_request_input( + extra_headers=extra_headers, timeout=timeout + ), + cast_type=_cast_type, + enable_stream=False, + ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py new file mode 100644 index 000000000..5dde40dae --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import httpx + +from ..core._base_api import BaseAPI +from ..core._base_type import NOT_GIVEN, Body, Query, Headers, NotGiven, FileTypes +from ..core._files import is_file_content +from ..core._http_client import ( + make_user_request_input, +) +from ..types.file_object import FileObject, ListOfFileObject + +if TYPE_CHECKING: + from .._client import ZhipuAI + +__all__ = ["Files"] + + +class Files(BaseAPI): + + def __init__(self, client: "ZhipuAI") -> None: + super().__init__(client) + + def create( + self, + *, + file: FileTypes, + purpose: str, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> FileObject: + if not is_file_content(file): + prefix = f"Expected file input `{file!r}`" + raise RuntimeError( + f"{prefix} to be bytes, an io.IOBase instance, PathLike or a tuple but received {type(file)} instead." + ) from None + files = [("file", file)] + + extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})} + + return self._post( + "/files", + body={ + "purpose": purpose, + }, + files=files, + options=make_user_request_input( + extra_headers=extra_headers, timeout=timeout + ), + cast_type=FileObject, + ) + + def list( + self, + *, + purpose: str | NotGiven = NOT_GIVEN, + limit: int | NotGiven = NOT_GIVEN, + after: str | NotGiven = NOT_GIVEN, + order: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ListOfFileObject: + return self._get( + "/files", + cast_type=ListOfFileObject, + options=make_user_request_input( + extra_headers=extra_headers, + timeout=timeout, + query={ + "purpose": purpose, + "limit": limit, + "after": after, + "order": order, + }, + ), + ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py new file mode 100644 index 000000000..b06274c95 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py @@ -0,0 +1,15 @@ +from typing import TYPE_CHECKING +from .jobs import Jobs +from ...core._base_api import BaseAPI + +if TYPE_CHECKING: + from ..._client import ZhipuAI + + +class FineTuning(BaseAPI): + jobs: Jobs + + def __init__(self, client: "ZhipuAI") -> None: + super().__init__(client) + self.jobs = Jobs(client) + diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py new file mode 100644 index 000000000..03b597ddb --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from typing import Optional, TYPE_CHECKING + +import httpx + +from ...core._base_api import BaseAPI +from ...core._base_type import NOT_GIVEN, Headers, NotGiven +from ...core._http_client import ( + make_user_request_input, +) +from ...types.fine_tuning import ( + FineTuningJob, + job_create_params, + ListOfFineTuningJob, + FineTuningJobEvent, +) + +if TYPE_CHECKING: + from ..._client import ZhipuAI + +__all__ = ["Jobs"] + + +class Jobs(BaseAPI): + + def __init__(self, client: "ZhipuAI") -> None: + super().__init__(client) + + def create( + self, + *, + model: str, + training_file: str, + hyperparameters: job_create_params.Hyperparameters | NotGiven = NOT_GIVEN, + suffix: Optional[str] | NotGiven = NOT_GIVEN, + request_id: Optional[str] | NotGiven = NOT_GIVEN, + validation_file: Optional[str] | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> FineTuningJob: + return self._post( + "/fine_tuning/jobs", + body={ + "model": model, + "training_file": training_file, + "hyperparameters": hyperparameters, + "suffix": suffix, + "validation_file": validation_file, + "request_id": request_id, + }, + options=make_user_request_input( + extra_headers=extra_headers, timeout=timeout + ), + cast_type=FineTuningJob, + ) + + def retrieve( + self, + fine_tuning_job_id: str, + *, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> FineTuningJob: + return self._get( + f"/fine_tuning/jobs/{fine_tuning_job_id}", + options=make_user_request_input( + extra_headers=extra_headers, timeout=timeout + ), + cast_type=FineTuningJob, + ) + + def list( + self, + *, + after: str | NotGiven = NOT_GIVEN, + limit: int | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ListOfFineTuningJob: + return self._get( + "/fine_tuning/jobs", + cast_type=ListOfFineTuningJob, + options=make_user_request_input( + extra_headers=extra_headers, + timeout=timeout, + query={ + "after": after, + "limit": limit, + }, + ), + ) + + def list_events( + self, + fine_tuning_job_id: str, + *, + after: str | NotGiven = NOT_GIVEN, + limit: int | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> FineTuningJobEvent: + + return self._get( + f"/fine_tuning/jobs/{fine_tuning_job_id}/events", + cast_type=FineTuningJobEvent, + options=make_user_request_input( + extra_headers=extra_headers, + timeout=timeout, + query={ + "after": after, + "limit": limit, + }, + ), + ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py new file mode 100644 index 000000000..d245fb8ab --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +from typing import Union, List, Optional, TYPE_CHECKING + +import httpx + +from ..core._base_api import BaseAPI +from ..core._base_type import NotGiven, NOT_GIVEN, Headers +from ..core._http_client import make_user_request_input +from ..types.image import ImagesResponded + +if TYPE_CHECKING: + from .._client import ZhipuAI + + +class Images(BaseAPI): + def __init__(self, client: "ZhipuAI") -> None: + super().__init__(client) + + def generations( + self, + *, + prompt: str, + model: str | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + quality: Optional[str] | NotGiven = NOT_GIVEN, + response_format: Optional[str] | NotGiven = NOT_GIVEN, + size: Optional[str] | NotGiven = NOT_GIVEN, + style: Optional[str] | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + extra_headers: Headers | None = None, + disable_strict_validation: Optional[bool] | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ImagesResponded: + _cast_type = ImagesResponded + if disable_strict_validation: + _cast_type = object + return self._post( + "/images/generations", + body={ + "prompt": prompt, + "model": model, + "n": n, + "quality": quality, + "response_format": response_format, + "size": size, + "style": style, + "user": user, + }, + options=make_user_request_input( + extra_headers=extra_headers, timeout=timeout + ), + cast_type=_cast_type, + enable_stream=False, + ) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py new file mode 100644 index 000000000..21291b2a7 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_api.py @@ -0,0 +1,17 @@ +from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .._client import ZhipuAI + + +class BaseAPI: + _client: ZhipuAI + + def __init__(self, client: ZhipuAI) -> None: + self._client = client + self._delete = client.delete + self._get = client.get + self._post = client.post + self._put = client.put + self._patch = client.patch diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py new file mode 100644 index 000000000..cc613490d --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from os import PathLike +from typing import ( + TYPE_CHECKING, + Type, + Union, + Mapping, + TypeVar, IO, Tuple, Sequence, Any, List, +) + +import pydantic +from typing_extensions import ( + Literal, + override, +) + + +Query = Mapping[str, object] +Body = object +AnyMapping = Mapping[str, object] +PrimitiveData = Union[str, int, float, bool, None] +Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"] +ModelT = TypeVar("ModelT", bound=pydantic.BaseModel) +_T = TypeVar("_T") + +if TYPE_CHECKING: + NoneType: Type[None] +else: + NoneType = type(None) + + +# Sentinel class used until PEP 0661 is accepted +class NotGiven(pydantic.BaseModel): + """ + A sentinel singleton class used to distinguish omitted keyword arguments + from those passed in with the value None (which may have different behavior). + + For example: + + ```py + def get(timeout: Union[int, NotGiven, None] = NotGiven()) -> Response: ... + + get(timeout=1) # 1s timeout + get(timeout=None) # No timeout + get() # Default timeout behavior, which may not be statically known at the method definition. + ``` + """ + + def __bool__(self) -> Literal[False]: + return False + + @override + def __repr__(self) -> str: + return "NOT_GIVEN" + + +NotGivenOr = Union[_T, NotGiven] +NOT_GIVEN = NotGiven() + + +class Omit(pydantic.BaseModel): + """In certain situations you need to be able to represent a case where a default value has + to be explicitly removed and `None` is not an appropriate substitute, for example: + + ```py + # as the default `Content-Type` header is `application/json` that will be sent + client.post('/upload/files', files={'file': b'my raw file content'}) + + # you can't explicitly override the header as it has to be dynamically generated + # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983' + client.post(..., headers={'Content-Type': 'multipart/form-data'}) + + # instead you can remove the default `application/json` header by passing Omit + client.post(..., headers={'Content-Type': Omit()}) + ``` + """ + + def __bool__(self) -> Literal[False]: + return False + + +Headers = Mapping[str, Union[str, Omit]] + +ResponseT = TypeVar( + "ResponseT", + bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", +) + +# for user input files +if TYPE_CHECKING: + FileContent = Union[IO[bytes], bytes, PathLike[str]] +else: + FileContent = Union[IO[bytes], bytes, PathLike] + +FileTypes = Union[ + FileContent, # file content + Tuple[str, FileContent], # (filename, file) + Tuple[str, FileContent, str], # (filename, file , content_type) + Tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) +] + +RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]] + +# for httpx client supported files + +HttpxFileContent = Union[bytes, IO[bytes]] +HttpxFileTypes = Union[ + FileContent, # file content + Tuple[str, HttpxFileContent], # (filename, file) + Tuple[str, HttpxFileContent, str], # (filename, file , content_type) + Tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) +] + +HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py new file mode 100644 index 000000000..a2a438b8f --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_errors.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import httpx + +__all__ = [ + "ZhipuAIError", + "APIStatusError", + "APIRequestFailedError", + "APIAuthenticationError", + "APIReachLimitError", + "APIInternalError", + "APIServerFlowExceedError", + "APIResponseError", + "APIResponseValidationError", + "APITimeoutError", +] + + +class ZhipuAIError(Exception): + def __init__(self, message: str, ) -> None: + super().__init__(message) + + +class APIStatusError(Exception): + response: httpx.Response + status_code: int + + def __init__(self, message: str, *, response: httpx.Response) -> None: + super().__init__(message) + self.response = response + self.status_code = response.status_code + + +class APIRequestFailedError(APIStatusError): + ... + + +class APIAuthenticationError(APIStatusError): + ... + + +class APIReachLimitError(APIStatusError): + ... + + +class APIInternalError(APIStatusError): + ... + + +class APIServerFlowExceedError(APIStatusError): + ... + + +class APIResponseError(Exception): + message: str + request: httpx.Request + json_data: object + + def __init__(self, message: str, request: httpx.Request, json_data: object): + self.message = message + self.request = request + self.json_data = json_data + super().__init__(message) + + +class APIResponseValidationError(APIResponseError): + status_code: int + response: httpx.Response + + def __init__( + self, + response: httpx.Response, + json_data: object | None, *, + message: str | None = None + ) -> None: + super().__init__( + message=message or "Data returned by API invalid for expected schema.", + request=response.request, + json_data=json_data + ) + self.response = response + self.status_code = response.status_code + + +class APITimeoutError(Exception): + request: httpx.Request + + def __init__(self, request: httpx.Request): + self.request = request + super().__init__("Request Timeout") diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py new file mode 100644 index 000000000..9ae372adc --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import io +import os +from pathlib import Path +from typing import Mapping, Sequence + +from ._base_type import ( + FileTypes, + HttpxFileTypes, + HttpxRequestFiles, + RequestFiles, +) + + +def is_file_content(obj: object) -> bool: + return isinstance(obj, (bytes, tuple, io.IOBase, os.PathLike)) + + +def _transform_file(file: FileTypes) -> HttpxFileTypes: + if is_file_content(file): + if isinstance(file, os.PathLike): + path = Path(file) + return path.name, path.read_bytes() + else: + return file + if isinstance(file, tuple): + if isinstance(file[1], os.PathLike): + return (file[0], Path(file[1]).read_bytes(), *file[2:]) + else: + return (file[0], file[1], *file[2:]) + else: + raise TypeError(f"Unexpected input file with type {type(file)},Expected FileContent type or tuple type") + + +def make_httpx_files(files: RequestFiles | None) -> HttpxRequestFiles | None: + if files is None: + return None + + if isinstance(files, Mapping): + files = {key: _transform_file(file) for key, file in files.items()} + elif isinstance(files, Sequence): + files = [(key, _transform_file(file)) for key, file in files] + else: + raise TypeError(f"Unexpected input file with type {type(files)}, excepted Mapping or Sequence") + return files diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py new file mode 100644 index 000000000..09d8974c9 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py @@ -0,0 +1,377 @@ +# -*- coding:utf-8 -*- +from __future__ import annotations + +import inspect +from typing import ( + Any, + Type, + Union, + cast, + Mapping, +) + +import httpx +import pydantic +from httpx import URL, Timeout + +from . import _errors +from ._base_type import NotGiven, ResponseT, Body, Headers, NOT_GIVEN, RequestFiles, Query, Data +from ._errors import APIResponseValidationError, APIStatusError, APITimeoutError +from ._files import make_httpx_files +from ._request_opt import ClientRequestParam, UserRequestInput +from ._response import HttpResponse +from ._sse_client import StreamResponse +from ._utils import flatten + +headers = { + "Accept": "application/json", + "Content-Type": "application/json; charset=UTF-8", +} + + +def _merge_map(map1: Mapping, map2: Mapping) -> Mapping: + merged = {**map1, **map2} + return {key: val for key, val in merged.items() if val is not None} + + +from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT + +ZHIPUAI_DEFAULT_TIMEOUT = httpx.Timeout(timeout=300.0, connect=8.0) +ZHIPUAI_DEFAULT_MAX_RETRIES = 3 +ZHIPUAI_DEFAULT_LIMITS = httpx.Limits(max_connections=50, max_keepalive_connections=10) + + +class HttpClient: + _client: httpx.Client + _version: str + _base_url: URL + + timeout: Union[float, Timeout, None] + _limits: httpx.Limits + _has_custom_http_client: bool + _default_stream_cls: type[StreamResponse[Any]] | None = None + + def __init__( + self, + *, + version: str, + base_url: URL, + timeout: Union[float, Timeout, None], + custom_httpx_client: httpx.Client | None = None, + custom_headers: Mapping[str, str] | None = None, + ) -> None: + if timeout is None or isinstance(timeout, NotGiven): + if custom_httpx_client and custom_httpx_client.timeout != HTTPX_DEFAULT_TIMEOUT: + timeout = custom_httpx_client.timeout + else: + timeout = ZHIPUAI_DEFAULT_TIMEOUT + self.timeout = cast(Timeout, timeout) + self._has_custom_http_client = bool(custom_httpx_client) + self._client = custom_httpx_client or httpx.Client( + base_url=base_url, + timeout=self.timeout, + limits=ZHIPUAI_DEFAULT_LIMITS, + ) + self._version = version + url = URL(url=base_url) + if not url.raw_path.endswith(b"/"): + url = url.copy_with(raw_path=url.raw_path + b"/") + self._base_url = url + self._custom_headers = custom_headers or {} + + def _prepare_url(self, url: str) -> URL: + + sub_url = URL(url) + if sub_url.is_relative_url: + request_raw_url = self._base_url.raw_path + sub_url.raw_path.lstrip(b"/") + return self._base_url.copy_with(raw_path=request_raw_url) + + return sub_url + + @property + def _default_headers(self): + return \ + { + "Accept": "application/json", + "Content-Type": "application/json; charset=UTF-8", + "ZhipuAI-SDK-Ver": self._version, + "source_type": "zhipu-sdk-python", + "x-request-sdk": "zhipu-sdk-python", + **self._auth_headers, + **self._custom_headers, + } + + @property + def _auth_headers(self): + return {} + + def _prepare_headers(self, request_param: ClientRequestParam) -> httpx.Headers: + custom_headers = request_param.headers or {} + headers_dict = _merge_map(self._default_headers, custom_headers) + + httpx_headers = httpx.Headers(headers_dict) + + return httpx_headers + + def _prepare_request( + self, + request_param: ClientRequestParam + ) -> httpx.Request: + kwargs: dict[str, Any] = {} + json_data = request_param.json_data + headers = self._prepare_headers(request_param) + url = self._prepare_url(request_param.url) + json_data = request_param.json_data + if headers.get("Content-Type") == "multipart/form-data": + headers.pop("Content-Type") + + if json_data: + kwargs["data"] = self._make_multipartform(json_data) + + return self._client.build_request( + headers=headers, + timeout=self.timeout if isinstance(request_param.timeout, NotGiven) else request_param.timeout, + method=request_param.method, + url=url, + json=json_data, + files=request_param.files, + params=request_param.params, + **kwargs, + ) + + def _object_to_formfata(self, key: str, value: Data | Mapping[object, object]) -> list[tuple[str, str]]: + items = [] + + if isinstance(value, Mapping): + for k, v in value.items(): + items.extend(self._object_to_formfata(f"{key}[{k}]", v)) + return items + if isinstance(value, (list, tuple)): + for v in value: + items.extend(self._object_to_formfata(key + "[]", v)) + return items + + def _primitive_value_to_str(val) -> str: + # copied from httpx + if val is True: + return "true" + elif val is False: + return "false" + elif val is None: + return "" + return str(val) + + str_data = _primitive_value_to_str(value) + + if not str_data: + return [] + return [(key, str_data)] + + def _make_multipartform(self, data: Mapping[object, object]) -> dict[str, object]: + + items = flatten([self._object_to_formfata(k, v) for k, v in data.items()]) + + serialized: dict[str, object] = {} + for key, value in items: + if key in serialized: + raise ValueError(f"存在重复的键: {key};") + serialized[key] = value + return serialized + + def _parse_response( + self, + *, + cast_type: Type[ResponseT], + response: httpx.Response, + enable_stream: bool, + request_param: ClientRequestParam, + stream_cls: type[StreamResponse[Any]] | None = None, + ) -> HttpResponse: + + http_response = HttpResponse( + raw_response=response, + cast_type=cast_type, + client=self, + enable_stream=enable_stream, + stream_cls=stream_cls + ) + return http_response.parse() + + def _process_response_data( + self, + *, + data: object, + cast_type: type[ResponseT], + response: httpx.Response, + ) -> ResponseT: + if data is None: + return cast(ResponseT, None) + + try: + if inspect.isclass(cast_type) and issubclass(cast_type, pydantic.BaseModel): + return cast(ResponseT, cast_type.validate(data)) + + return cast(ResponseT, pydantic.TypeAdapter(cast_type).validate_python(data)) + except pydantic.ValidationError as err: + raise APIResponseValidationError(response=response, json_data=data) from err + + def is_closed(self) -> bool: + return self._client.is_closed + + def close(self): + self._client.close() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + def request( + self, + *, + cast_type: Type[ResponseT], + params: ClientRequestParam, + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, + ) -> ResponseT | StreamResponse: + request = self._prepare_request(params) + + try: + response = self._client.send( + request, + stream=enable_stream, + ) + response.raise_for_status() + except httpx.TimeoutException as err: + raise APITimeoutError(request=request) from err + except httpx.HTTPStatusError as err: + err.response.read() + # raise err + raise self._make_status_error(err.response) from None + + except Exception as err: + raise err + + return self._parse_response( + cast_type=cast_type, + request_param=params, + response=response, + enable_stream=enable_stream, + stream_cls=stream_cls, + ) + + def get( + self, + path: str, + *, + cast_type: Type[ResponseT], + options: UserRequestInput = {}, + enable_stream: bool = False, + ) -> ResponseT | StreamResponse: + opts = ClientRequestParam.construct(method="get", url=path, **options) + return self.request( + cast_type=cast_type, params=opts, + enable_stream=enable_stream + ) + + def post( + self, + path: str, + *, + body: Body | None = None, + cast_type: Type[ResponseT], + options: UserRequestInput = {}, + files: RequestFiles | None = None, + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, + ) -> ResponseT | StreamResponse: + opts = ClientRequestParam.construct(method="post", json_data=body, files=make_httpx_files(files), url=path, + **options) + + return self.request( + cast_type=cast_type, params=opts, + enable_stream=enable_stream, + stream_cls=stream_cls + ) + + def patch( + self, + path: str, + *, + body: Body | None = None, + cast_type: Type[ResponseT], + options: UserRequestInput = {}, + ) -> ResponseT: + opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options) + + return self.request( + cast_type=cast_type, params=opts, + ) + + def put( + self, + path: str, + *, + body: Body | None = None, + cast_type: Type[ResponseT], + options: UserRequestInput = {}, + files: RequestFiles | None = None, + ) -> ResponseT | StreamResponse: + opts = ClientRequestParam.construct(method="put", url=path, json_data=body, files=make_httpx_files(files), + **options) + + return self.request( + cast_type=cast_type, params=opts, + ) + + def delete( + self, + path: str, + *, + body: Body | None = None, + cast_type: Type[ResponseT], + options: UserRequestInput = {}, + ) -> ResponseT | StreamResponse: + opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options) + + return self.request( + cast_type=cast_type, params=opts, + ) + + def _make_status_error(self, response) -> APIStatusError: + response_text = response.text.strip() + status_code = response.status_code + error_msg = f"Error code: {status_code}, with error text {response_text}" + + if status_code == 400: + return _errors.APIRequestFailedError(message=error_msg, response=response) + elif status_code == 401: + return _errors.APIAuthenticationError(message=error_msg, response=response) + elif status_code == 429: + return _errors.APIReachLimitError(message=error_msg, response=response) + elif status_code == 500: + return _errors.APIInternalError(message=error_msg, response=response) + elif status_code == 503: + return _errors.APIServerFlowExceedError(message=error_msg, response=response) + return APIStatusError(message=error_msg, response=response) + + +def make_user_request_input( + max_retries: int | None = None, + timeout: float | Timeout | None | NotGiven = NOT_GIVEN, + extra_headers: Headers = None, + query: Query | None = None, +) -> UserRequestInput: + options: UserRequestInput = {} + + if extra_headers is not None: + options["headers"] = extra_headers + if max_retries is not None: + options["max_retries"] = max_retries + if not isinstance(timeout, NotGiven): + options['timeout'] = timeout + if query is not None: + options["params"] = query + + return options diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py new file mode 100644 index 000000000..bbf2e72e6 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py @@ -0,0 +1,30 @@ +# -*- coding:utf-8 -*- +import time + +import cachetools.func +import jwt + +API_TOKEN_TTL_SECONDS = 3 * 60 + +CACHE_TTL_SECONDS = API_TOKEN_TTL_SECONDS - 30 + + +@cachetools.func.ttl_cache(maxsize=10, ttl=CACHE_TTL_SECONDS) +def generate_token(apikey: str): + try: + api_key, secret = apikey.split(".") + except Exception as e: + raise Exception("invalid api_key", e) + + payload = { + "api_key": api_key, + "exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000, + "timestamp": int(round(time.time() * 1000)), + } + ret = jwt.encode( + payload, + secret, + algorithm="HS256", + headers={"alg": "HS256", "sign_type": "SIGN"}, + ) + return ret diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py new file mode 100644 index 000000000..f2281c528 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from typing import Union, Any, cast + +import pydantic.generics +from httpx import Timeout +from pydantic import ConfigDict +from typing_extensions import ( + Unpack, ClassVar, TypedDict +) + +from ._base_type import Body, NotGiven, Headers, HttpxRequestFiles, Query +from ._utils import remove_notgiven_indict + + +class UserRequestInput(TypedDict, total=False): + max_retries: int + timeout: float | Timeout | None + headers: Headers + params: Query | None + + +class ClientRequestParam(): + method: str + url: str + max_retries: Union[int, NotGiven] = NotGiven() + timeout: Union[float, NotGiven] = NotGiven() + headers: Union[Headers, NotGiven] = NotGiven() + json_data: Union[Body, None] = None + files: Union[HttpxRequestFiles, None] = None + params: Query = {} + model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) + + def get_max_retries(self, max_retries) -> int: + if isinstance(self.max_retries, NotGiven): + return max_retries + return self.max_retries + + @classmethod + def construct( # type: ignore + cls, + _fields_set: set[str] | None = None, + **values: Unpack[UserRequestInput], + ) -> ClientRequestParam : + kwargs: dict[str, Any] = { + key: remove_notgiven_indict(value) for key, value in values.items() + } + client = cls() + client.__dict__.update(kwargs) + + return client + + model_construct = construct + diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py new file mode 100644 index 000000000..86ce50d9f --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import datetime +from typing import TypeVar, Generic, cast, Any, TYPE_CHECKING + +import httpx +import pydantic +from typing_extensions import ParamSpec, get_origin, get_args + +from ._base_type import NoneType +from ._sse_client import StreamResponse + +if TYPE_CHECKING: + from ._http_client import HttpClient + +P = ParamSpec("P") +R = TypeVar("R") + + +class HttpResponse(Generic[R]): + _cast_type: type[R] + _client: "HttpClient" + _parsed: R | None + _enable_stream: bool + _stream_cls: type[StreamResponse[Any]] + http_response: httpx.Response + + def __init__( + self, + *, + raw_response: httpx.Response, + cast_type: type[R], + client: "HttpClient", + enable_stream: bool = False, + stream_cls: type[StreamResponse[Any]] | None = None, + ) -> None: + self._cast_type = cast_type + self._client = client + self._parsed = None + self._stream_cls = stream_cls + self._enable_stream = enable_stream + self.http_response = raw_response + + def parse(self) -> R: + self._parsed = self._parse() + return self._parsed + + def _parse(self) -> R: + if self._enable_stream: + self._parsed = cast( + R, + self._stream_cls( + cast_type=cast(type, get_args(self._stream_cls)[0]), + response=self.http_response, + client=self._client + ) + ) + return self._parsed + cast_type = self._cast_type + if cast_type is NoneType: + return cast(R, None) + http_response = self.http_response + if cast_type == str: + return cast(R, http_response.text) + + content_type, *_ = http_response.headers.get("content-type", "application/json").split(";") + origin = get_origin(cast_type) or cast_type + if content_type != "application/json": + if issubclass(origin, pydantic.BaseModel): + data = http_response.json() + return self._client._process_response_data( + data=data, + cast_type=cast_type, # type: ignore + response=http_response, + ) + + return http_response.text + + data = http_response.json() + + return self._client._process_response_data( + data=data, + cast_type=cast_type, # type: ignore + response=http_response, + ) + + @property + def headers(self) -> httpx.Headers: + return self.http_response.headers + + @property + def http_request(self) -> httpx.Request: + return self.http_response.request + + @property + def status_code(self) -> int: + return self.http_response.status_code + + @property + def url(self) -> httpx.URL: + return self.http_response.url + + @property + def method(self) -> str: + return self.http_request.method + + @property + def content(self) -> bytes: + return self.http_response.content + + @property + def text(self) -> str: + return self.http_response.text + + @property + def http_version(self) -> str: + return self.http_response.http_version + + @property + def elapsed(self) -> datetime.timedelta: + return self.http_response.elapsed diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py new file mode 100644 index 000000000..83b487e87 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py @@ -0,0 +1,149 @@ +# -*- coding:utf-8 -*- +from __future__ import annotations + +import json +from typing import Generic, Iterator, TYPE_CHECKING, Mapping + +import httpx + +from ._base_type import ResponseT +from ._errors import APIResponseError + +_FIELD_SEPARATOR = ":" + +if TYPE_CHECKING: + from ._http_client import HttpClient + + +class StreamResponse(Generic[ResponseT]): + + response: httpx.Response + _cast_type: type[ResponseT] + + def __init__( + self, + *, + cast_type: type[ResponseT], + response: httpx.Response, + client: HttpClient, + ) -> None: + self.response = response + self._cast_type = cast_type + self._data_process_func = client._process_response_data + self._stream_chunks = self.__stream__() + + def __next__(self) -> ResponseT: + return self._stream_chunks.__next__() + + def __iter__(self) -> Iterator[ResponseT]: + for item in self._stream_chunks: + yield item + + def __stream__(self) -> Iterator[ResponseT]: + + sse_line_parser = SSELineParser() + iterator = sse_line_parser.iter_lines(self.response.iter_lines()) + + for sse in iterator: + if sse.data.startswith("[DONE]"): + break + + if sse.event is None: + data = sse.json_data() + if isinstance(data, Mapping) and data.get("error"): + raise APIResponseError( + message="An error occurred during streaming", + request=self.response.request, + json_data=data["error"], + ) + + yield self._data_process_func(data=data, cast_type=self._cast_type, response=self.response) + for sse in iterator: + pass + + +class Event(object): + def __init__( + self, + event: str | None = None, + data: str | None = None, + id: str | None = None, + retry: int | None = None + ): + self._event = event + self._data = data + self._id = id + self._retry = retry + + def __repr__(self): + data_len = len(self._data) if self._data else 0 + return f"Event(event={self._event}, data={self._data} ,data_length={data_len}, id={self._id}, retry={self._retry}" + + @property + def event(self): return self._event + + @property + def data(self): return self._data + + def json_data(self): return json.loads(self._data) + + @property + def id(self): return self._id + + @property + def retry(self): return self._retry + + +class SSELineParser: + _data: list[str] + _event: str | None + _retry: int | None + _id: str | None + + def __init__(self): + self._event = None + self._data = [] + self._id = None + self._retry = None + + def iter_lines(self, lines: Iterator[str]) -> Iterator[Event]: + for line in lines: + line = line.rstrip('\n') + if not line: + if self._event is None and \ + not self._data and \ + self._id is None and \ + self._retry is None: + continue + sse_event = Event( + event=self._event, + data='\n'.join(self._data), + id=self._id, + retry=self._retry + ) + self._event = None + self._data = [] + self._id = None + self._retry = None + + yield sse_event + self.decode_line(line) + + def decode_line(self, line: str): + if line.startswith(":") or not line: + return + + field, _p, value = line.partition(":") + + if value.startswith(' '): + value = value[1:] + if field == "data": + self._data.append(value) + elif field == "event": + self._event = value + elif field == "retry": + try: + self._retry = int(value) + except (TypeError, ValueError): + pass + return diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py new file mode 100644 index 000000000..6193edcbe --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import Mapping, Iterable, TypeVar + +from ._base_type import NotGiven + + +def remove_notgiven_indict(obj): + if obj is None or (not isinstance(obj, Mapping)): + return obj + return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)} + + +_T = TypeVar("_T") + + +def flatten(t: Iterable[Iterable[_T]]) -> list[_T]: + return [item for sublist in t for item in sublist] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py new file mode 100644 index 000000000..bae4197c5 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py @@ -0,0 +1,23 @@ +from typing import List, Optional + +from pydantic import BaseModel + +from .chat_completion import CompletionChoice, CompletionUsage + +__all__ = ["AsyncTaskStatus"] + + +class AsyncTaskStatus(BaseModel): + id: Optional[str] = None + request_id: Optional[str] = None + model: Optional[str] = None + task_status: Optional[str] = None + + +class AsyncCompletion(BaseModel): + id: Optional[str] = None + request_id: Optional[str] = None + model: Optional[str] = None + task_status: str + choices: List[CompletionChoice] + usage: CompletionUsage \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py new file mode 100644 index 000000000..524e218d3 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py @@ -0,0 +1,45 @@ +from typing import List, Optional + +from pydantic import BaseModel + +__all__ = ["Completion", "CompletionUsage"] + + +class Function(BaseModel): + arguments: str + name: str + + +class CompletionMessageToolCall(BaseModel): + id: str + function: Function + type: str + + +class CompletionMessage(BaseModel): + content: Optional[str] = None + role: str + tool_calls: Optional[List[CompletionMessageToolCall]] = None + + +class CompletionUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class CompletionChoice(BaseModel): + index: int + finish_reason: str + message: CompletionMessage + + +class Completion(BaseModel): + model: Optional[str] = None + created: Optional[int] = None + choices: List[CompletionChoice] + request_id: Optional[str] = None + id: Optional[str] = None + usage: CompletionUsage + + diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py new file mode 100644 index 000000000..c2e0c5766 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py @@ -0,0 +1,55 @@ +from typing import List, Optional + +from pydantic import BaseModel + +__all__ = [ + "ChatCompletionChunk", + "Choice", + "ChoiceDelta", + "ChoiceDeltaFunctionCall", + "ChoiceDeltaToolCall", + "ChoiceDeltaToolCallFunction", +] + + +class ChoiceDeltaFunctionCall(BaseModel): + arguments: Optional[str] = None + name: Optional[str] = None + + +class ChoiceDeltaToolCallFunction(BaseModel): + arguments: Optional[str] = None + name: Optional[str] = None + + +class ChoiceDeltaToolCall(BaseModel): + index: int + id: Optional[str] = None + function: Optional[ChoiceDeltaToolCallFunction] = None + type: Optional[str] = None + + +class ChoiceDelta(BaseModel): + content: Optional[str] = None + role: Optional[str] = None + tool_calls: Optional[List[ChoiceDeltaToolCall]] = None + + +class Choice(BaseModel): + delta: ChoiceDelta + finish_reason: Optional[str] = None + index: int + + +class CompletionUsage(BaseModel): + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +class ChatCompletionChunk(BaseModel): + id: Optional[str] = None + choices: List[Choice] + created: Optional[int] = None + model: Optional[str] = None + usage: Optional[CompletionUsage] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completions_create_param.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completions_create_param.py new file mode 100644 index 000000000..6ee4dc479 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completions_create_param.py @@ -0,0 +1,8 @@ +from typing import Optional + +from typing_extensions import TypedDict + + +class Reference(TypedDict, total=False): + enable: Optional[bool] + search_query: Optional[str] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py new file mode 100644 index 000000000..9f52d296d --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from typing import Optional, List + +from pydantic import BaseModel +from .chat.chat_completion import CompletionUsage +__all__ = ["Embedding", "EmbeddingsResponded"] + + +class Embedding(BaseModel): + object: str + index: Optional[int] = None + embedding: List[float] + + +class EmbeddingsResponded(BaseModel): + object: str + data: List[Embedding] + model: str + usage: CompletionUsage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py new file mode 100644 index 000000000..39599786e --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py @@ -0,0 +1,24 @@ +from typing import Optional, List + +from pydantic import BaseModel + +__all__ = ["FileObject"] + + +class FileObject(BaseModel): + + id: Optional[str] = None + bytes: Optional[int] = None + created_at: Optional[int] = None + filename: Optional[str] = None + object: Optional[str] = None + purpose: Optional[str] = None + status: Optional[str] = None + status_details: Optional[str] = None + + +class ListOfFileObject(BaseModel): + + object: Optional[str] = None + data: List[FileObject] + has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py new file mode 100644 index 000000000..af0991892 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/__init__.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from .fine_tuning_job import FineTuningJob as FineTuningJob +from .fine_tuning_job import ListOfFineTuningJob as ListOfFineTuningJob +from .fine_tuning_job_event import FineTuningJobEvent as FineTuningJobEvent diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py new file mode 100644 index 000000000..c41fd5f24 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py @@ -0,0 +1,52 @@ +from typing import List, Union, Optional +from typing_extensions import Literal + +from pydantic import BaseModel + +__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ] + + +class Error(BaseModel): + code: str + message: str + param: Optional[str] = None + + +class Hyperparameters(BaseModel): + n_epochs: Union[str, int, None] = None + + +class FineTuningJob(BaseModel): + id: Optional[str] = None + + request_id: Optional[str] = None + + created_at: Optional[int] = None + + error: Optional[Error] = None + + fine_tuned_model: Optional[str] = None + + finished_at: Optional[int] = None + + hyperparameters: Optional[Hyperparameters] = None + + model: Optional[str] = None + + object: Optional[str] = None + + result_files: List[str] + + status: str + + trained_tokens: Optional[int] = None + + training_file: str + + validation_file: Optional[str] = None + + +class ListOfFineTuningJob(BaseModel): + object: Optional[str] = None + data: List[FineTuningJob] + has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py new file mode 100644 index 000000000..fd2a4138a --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py @@ -0,0 +1,36 @@ +from typing import List, Union, Optional +from typing_extensions import Literal + +from pydantic import BaseModel + +__all__ = ["FineTuningJobEvent", "Metric", "JobEvent"] + + +class Metric(BaseModel): + epoch: Optional[Union[str, int, float]] = None + current_steps: Optional[int] = None + total_steps: Optional[int] = None + elapsed_time: Optional[str] = None + remaining_time: Optional[str] = None + trained_tokens: Optional[int] = None + loss: Optional[Union[str, int, float]] = None + eval_loss: Optional[Union[str, int, float]] = None + acc: Optional[Union[str, int, float]] = None + eval_acc: Optional[Union[str, int, float]] = None + learning_rate: Optional[Union[str, int, float]] = None + + +class JobEvent(BaseModel): + object: Optional[str] = None + id: Optional[str] = None + type: Optional[str] = None + created_at: Optional[int] = None + level: Optional[str] = None + message: Optional[str] = None + data: Optional[Metric] = None + + +class FineTuningJobEvent(BaseModel): + object: Optional[str] = None + data: List[JobEvent] + has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py new file mode 100644 index 000000000..c661f7cdd --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from typing import Union + +from typing_extensions import Literal, TypedDict + +__all__ = ["Hyperparameters"] + + +class Hyperparameters(TypedDict, total=False): + batch_size: Union[Literal["auto"], int] + + learning_rate_multiplier: Union[Literal["auto"], float] + + n_epochs: Union[Literal["auto"], int] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py new file mode 100644 index 000000000..681942c84 --- /dev/null +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import Optional, List + +from pydantic import BaseModel + +__all__ = ["GeneratedImage", "ImagesResponded"] + + +class GeneratedImage(BaseModel): + b64_json: Optional[str] = None + url: Optional[str] = None + revised_prompt: Optional[str] = None + + +class ImagesResponded(BaseModel): + created: int + data: List[GeneratedImage] diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py index e5a3d0ad1..a04e607a5 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py @@ -3,7 +3,8 @@ from typing import Generator import pytest from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage +from core.model_runtime.entities.message_entities import (AssistantPromptMessage, SystemPromptMessage, + UserPromptMessage, PromptMessageTool) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.zhipuai.llm.llm import ZhipuAILargeLanguageModel @@ -102,3 +103,48 @@ def test_get_num_tokens(): ) assert num_tokens == 14 + +def test_get_tools_num_tokens(): + model = ZhipuAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model='tools', + credentials={ + 'api_key': os.environ.get('ZHIPUAI_API_KEY') + }, + tools=[ + PromptMessageTool( + name='get_current_weather', + description='Get the current weather in a given location', + parameters={ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": [ + "c", + "f" + ] + } + }, + "required": [ + "location" + ] + } + ) + ], + prompt_messages=[ + SystemPromptMessage( + content='You are a helpful AI assistant.', + ), + UserPromptMessage( + content='Hello World!' + ) + ] + ) + + assert num_tokens == 108 \ No newline at end of file diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py index 30453eafb..e8589350f 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py @@ -42,7 +42,7 @@ def test_invoke_model(): assert isinstance(result, TextEmbeddingResult) assert len(result.embeddings) == 2 - assert result.usage.total_tokens == 2 + assert result.usage.total_tokens > 0 def test_get_num_tokens(): diff --git a/web/app/components/app/chat/answer/index.tsx b/web/app/components/app/chat/answer/index.tsx index 05ea6f9d8..b8242bd21 100644 --- a/web/app/components/app/chat/answer/index.tsx +++ b/web/app/components/app/chat/answer/index.tsx @@ -229,7 +229,7 @@ const Answer: FC = ({ )} diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index 60ee53b38..1f4c4c418 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -43,7 +43,7 @@ import { fetchDatasets } from '@/service/datasets' import { useProviderContext } from '@/context/provider-context' import { AgentStrategy, AppType, ModelModeType, RETRIEVE_TYPE, Resolution, TransferMethod } from '@/types/app' import { PromptMode } from '@/models/debug' -import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config' +import { ANNOTATION_DEFAULT, DEFAULT_AGENT_SETTING, DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG, supportFunctionCallModels } from '@/config' import SelectDataSet from '@/app/components/app/configuration/dataset-config/select-dataset' import I18n from '@/context/i18n' import { useModalContext } from '@/context/modal-context' @@ -163,8 +163,7 @@ const Configuration: FC = () => { doSetModelConfig(newModelConfig) } const isOpenAI = modelConfig.provider === 'openai' - const isFunctionCall = isOpenAI && modelConfig.mode === ModelModeType.chat - + const isFunctionCall = (isOpenAI && modelConfig.mode === ModelModeType.chat) || supportFunctionCallModels.includes(modelConfig.model_id) const [collectionList, setCollectionList] = useState([]) useEffect(() => { diff --git a/web/config/index.ts b/web/config/index.ts index 35cd3e565..197af0a9a 100644 --- a/web/config/index.ts +++ b/web/config/index.ts @@ -160,6 +160,8 @@ export const DEFAULT_AGENT_SETTING = { tools: [], } +export const supportFunctionCallModels = ['glm-3-turbo', 'glm-4'] + export const DEFAULT_AGENT_PROMPT = { chat: `Respond to the human as helpfully and accurately as possible.