mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-02 11:18:19 +08:00
feat(api): support wenxin text embedding (#7377)
This commit is contained in:
parent
a0a67873aa
commit
bfd905602f
195
api/core/model_runtime/model_providers/wenxin/_common.py
Normal file
195
api/core/model_runtime/model_providers/wenxin/_common.py
Normal file
@ -0,0 +1,195 @@
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
|
||||
from requests import post
|
||||
|
||||
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
|
||||
BadRequestError,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
RateLimitReachedError,
|
||||
)
|
||||
|
||||
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
|
||||
baidu_access_tokens_lock = Lock()
|
||||
|
||||
|
||||
class BaiduAccessToken:
|
||||
api_key: str
|
||||
access_token: str
|
||||
expires: datetime
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.access_token = ''
|
||||
self.expires = datetime.now() + timedelta(days=3)
|
||||
|
||||
@staticmethod
|
||||
def _get_access_token(api_key: str, secret_key: str) -> str:
|
||||
"""
|
||||
request access token from Baidu
|
||||
"""
|
||||
try:
|
||||
response = post(
|
||||
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
|
||||
headers={
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
|
||||
|
||||
resp = response.json()
|
||||
if 'error' in resp:
|
||||
if resp['error'] == 'invalid_client':
|
||||
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
|
||||
elif resp['error'] == 'unknown_error':
|
||||
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
|
||||
elif resp['error'] == 'invalid_request':
|
||||
raise BadRequestError(f'Bad request: {resp["error_description"]}')
|
||||
elif resp['error'] == 'rate_limit_exceeded':
|
||||
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
|
||||
else:
|
||||
raise Exception(f'Unknown error: {resp["error_description"]}')
|
||||
|
||||
return resp['access_token']
|
||||
|
||||
@staticmethod
|
||||
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
|
||||
"""
|
||||
LLM from Baidu requires access token to invoke the API.
|
||||
however, we have api_key and secret_key, and access token is valid for 30 days.
|
||||
so we can cache the access token for 3 days. (avoid memory leak)
|
||||
|
||||
it may be more efficient to use a ticker to refresh access token, but it will cause
|
||||
more complexity, so we just refresh access tokens when get_access_token is called.
|
||||
"""
|
||||
|
||||
# loop up cache, remove expired access token
|
||||
baidu_access_tokens_lock.acquire()
|
||||
now = datetime.now()
|
||||
for key in list(baidu_access_tokens.keys()):
|
||||
token = baidu_access_tokens[key]
|
||||
if token.expires < now:
|
||||
baidu_access_tokens.pop(key)
|
||||
|
||||
if api_key not in baidu_access_tokens:
|
||||
# if access token not in cache, request it
|
||||
token = BaiduAccessToken(api_key)
|
||||
baidu_access_tokens[api_key] = token
|
||||
# release it to enhance performance
|
||||
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
|
||||
baidu_access_tokens_lock.release()
|
||||
# try to get access token
|
||||
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
|
||||
token.access_token = token_str
|
||||
token.expires = now + timedelta(days=3)
|
||||
return token
|
||||
else:
|
||||
# if access token in cache, return it
|
||||
token = baidu_access_tokens[api_key]
|
||||
baidu_access_tokens_lock.release()
|
||||
return token
|
||||
|
||||
|
||||
class _CommonWenxin:
|
||||
api_bases = {
|
||||
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
||||
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
|
||||
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
|
||||
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
|
||||
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
|
||||
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
||||
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
|
||||
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
|
||||
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
|
||||
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
|
||||
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
|
||||
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
|
||||
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
|
||||
'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1',
|
||||
}
|
||||
|
||||
function_calling_supports = [
|
||||
'ernie-bot',
|
||||
'ernie-bot-8k',
|
||||
'ernie-3.5-8k',
|
||||
'ernie-3.5-8k-0205',
|
||||
'ernie-3.5-8k-1222',
|
||||
'ernie-3.5-4k-0205',
|
||||
'ernie-3.5-128k',
|
||||
'ernie-4.0-8k',
|
||||
'ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-turbo-8k-preview',
|
||||
'yi_34b_chat'
|
||||
]
|
||||
|
||||
api_key: str = ''
|
||||
secret_key: str = ''
|
||||
|
||||
def __init__(self, api_key: str, secret_key: str):
|
||||
self.api_key = api_key
|
||||
self.secret_key = secret_key
|
||||
|
||||
@staticmethod
|
||||
def _to_credential_kwargs(credentials: dict) -> dict:
|
||||
credentials_kwargs = {
|
||||
"api_key": credentials['api_key'],
|
||||
"secret_key": credentials['secret_key']
|
||||
}
|
||||
return credentials_kwargs
|
||||
|
||||
def _handle_error(self, code: int, msg: str):
|
||||
error_map = {
|
||||
1: InternalServerError,
|
||||
2: InternalServerError,
|
||||
3: BadRequestError,
|
||||
4: RateLimitReachedError,
|
||||
6: InvalidAuthenticationError,
|
||||
13: InvalidAPIKeyError,
|
||||
14: InvalidAPIKeyError,
|
||||
15: InvalidAPIKeyError,
|
||||
17: RateLimitReachedError,
|
||||
18: RateLimitReachedError,
|
||||
19: RateLimitReachedError,
|
||||
100: InvalidAPIKeyError,
|
||||
111: InvalidAPIKeyError,
|
||||
200: InternalServerError,
|
||||
336000: InternalServerError,
|
||||
336001: BadRequestError,
|
||||
336002: BadRequestError,
|
||||
336003: BadRequestError,
|
||||
336004: InvalidAuthenticationError,
|
||||
336005: InvalidAPIKeyError,
|
||||
336006: BadRequestError,
|
||||
336007: BadRequestError,
|
||||
336008: BadRequestError,
|
||||
336100: InternalServerError,
|
||||
336101: BadRequestError,
|
||||
336102: BadRequestError,
|
||||
336103: BadRequestError,
|
||||
336104: BadRequestError,
|
||||
336105: BadRequestError,
|
||||
336200: InternalServerError,
|
||||
336303: BadRequestError,
|
||||
337006: BadRequestError
|
||||
}
|
||||
|
||||
if code in error_map:
|
||||
raise error_map[code](msg)
|
||||
else:
|
||||
raise InternalServerError(f'Unknown error: {msg}')
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
|
||||
return token.access_token
|
@ -1,102 +1,17 @@
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from json import dumps, loads
|
||||
from threading import Lock
|
||||
from typing import Any, Union
|
||||
|
||||
from requests import Response, post
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
|
||||
from core.model_runtime.model_providers.wenxin._common import _CommonWenxin
|
||||
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
|
||||
BadRequestError,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
RateLimitReachedError,
|
||||
)
|
||||
|
||||
# map api_key to access_token
|
||||
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
|
||||
baidu_access_tokens_lock = Lock()
|
||||
|
||||
class BaiduAccessToken:
|
||||
api_key: str
|
||||
access_token: str
|
||||
expires: datetime
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.access_token = ''
|
||||
self.expires = datetime.now() + timedelta(days=3)
|
||||
|
||||
def _get_access_token(api_key: str, secret_key: str) -> str:
|
||||
"""
|
||||
request access token from Baidu
|
||||
"""
|
||||
try:
|
||||
response = post(
|
||||
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
|
||||
headers={
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
|
||||
|
||||
resp = response.json()
|
||||
if 'error' in resp:
|
||||
if resp['error'] == 'invalid_client':
|
||||
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
|
||||
elif resp['error'] == 'unknown_error':
|
||||
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
|
||||
elif resp['error'] == 'invalid_request':
|
||||
raise BadRequestError(f'Bad request: {resp["error_description"]}')
|
||||
elif resp['error'] == 'rate_limit_exceeded':
|
||||
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
|
||||
else:
|
||||
raise Exception(f'Unknown error: {resp["error_description"]}')
|
||||
|
||||
return resp['access_token']
|
||||
|
||||
@staticmethod
|
||||
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
|
||||
"""
|
||||
LLM from Baidu requires access token to invoke the API.
|
||||
however, we have api_key and secret_key, and access token is valid for 30 days.
|
||||
so we can cache the access token for 3 days. (avoid memory leak)
|
||||
|
||||
it may be more efficient to use a ticker to refresh access token, but it will cause
|
||||
more complexity, so we just refresh access tokens when get_access_token is called.
|
||||
"""
|
||||
|
||||
# loop up cache, remove expired access token
|
||||
baidu_access_tokens_lock.acquire()
|
||||
now = datetime.now()
|
||||
for key in list(baidu_access_tokens.keys()):
|
||||
token = baidu_access_tokens[key]
|
||||
if token.expires < now:
|
||||
baidu_access_tokens.pop(key)
|
||||
|
||||
if api_key not in baidu_access_tokens:
|
||||
# if access token not in cache, request it
|
||||
token = BaiduAccessToken(api_key)
|
||||
baidu_access_tokens[api_key] = token
|
||||
# release it to enhance performance
|
||||
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
|
||||
baidu_access_tokens_lock.release()
|
||||
# try to get access token
|
||||
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
|
||||
token.access_token = token_str
|
||||
token.expires = now + timedelta(days=3)
|
||||
return token
|
||||
else:
|
||||
# if access token in cache, return it
|
||||
token = baidu_access_tokens[api_key]
|
||||
baidu_access_tokens_lock.release()
|
||||
return token
|
||||
|
||||
|
||||
class ErnieMessage:
|
||||
class Role(Enum):
|
||||
@ -120,51 +35,7 @@ class ErnieMessage:
|
||||
self.content = content
|
||||
self.role = role
|
||||
|
||||
class ErnieBotModel:
|
||||
api_bases = {
|
||||
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
||||
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
|
||||
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
|
||||
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
|
||||
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
|
||||
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
|
||||
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
|
||||
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
|
||||
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
|
||||
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
|
||||
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
|
||||
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
|
||||
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
|
||||
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
|
||||
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
|
||||
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
|
||||
}
|
||||
|
||||
function_calling_supports = [
|
||||
'ernie-bot',
|
||||
'ernie-bot-8k',
|
||||
'ernie-3.5-8k',
|
||||
'ernie-3.5-8k-0205',
|
||||
'ernie-3.5-8k-1222',
|
||||
'ernie-3.5-4k-0205',
|
||||
'ernie-3.5-128k',
|
||||
'ernie-4.0-8k',
|
||||
'ernie-4.0-turbo-8k',
|
||||
'ernie-4.0-turbo-8k-preview',
|
||||
'yi_34b_chat'
|
||||
]
|
||||
|
||||
api_key: str = ''
|
||||
secret_key: str = ''
|
||||
|
||||
def __init__(self, api_key: str, secret_key: str):
|
||||
self.api_key = api_key
|
||||
self.secret_key = secret_key
|
||||
class ErnieBotModel(_CommonWenxin):
|
||||
|
||||
def generate(self, model: str, stream: bool, messages: list[ErnieMessage],
|
||||
parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \
|
||||
@ -199,51 +70,6 @@ class ErnieBotModel:
|
||||
return self._handle_chat_stream_generate_response(resp)
|
||||
return self._handle_chat_generate_response(resp)
|
||||
|
||||
def _handle_error(self, code: int, msg: str):
|
||||
error_map = {
|
||||
1: InternalServerError,
|
||||
2: InternalServerError,
|
||||
3: BadRequestError,
|
||||
4: RateLimitReachedError,
|
||||
6: InvalidAuthenticationError,
|
||||
13: InvalidAPIKeyError,
|
||||
14: InvalidAPIKeyError,
|
||||
15: InvalidAPIKeyError,
|
||||
17: RateLimitReachedError,
|
||||
18: RateLimitReachedError,
|
||||
19: RateLimitReachedError,
|
||||
100: InvalidAPIKeyError,
|
||||
111: InvalidAPIKeyError,
|
||||
200: InternalServerError,
|
||||
336000: InternalServerError,
|
||||
336001: BadRequestError,
|
||||
336002: BadRequestError,
|
||||
336003: BadRequestError,
|
||||
336004: InvalidAuthenticationError,
|
||||
336005: InvalidAPIKeyError,
|
||||
336006: BadRequestError,
|
||||
336007: BadRequestError,
|
||||
336008: BadRequestError,
|
||||
336100: InternalServerError,
|
||||
336101: BadRequestError,
|
||||
336102: BadRequestError,
|
||||
336103: BadRequestError,
|
||||
336104: BadRequestError,
|
||||
336105: BadRequestError,
|
||||
336200: InternalServerError,
|
||||
336303: BadRequestError,
|
||||
337006: BadRequestError
|
||||
}
|
||||
|
||||
if code in error_map:
|
||||
raise error_map[code](msg)
|
||||
else:
|
||||
raise InternalServerError(f'Unknown error: {msg}')
|
||||
|
||||
def _get_access_token(self) -> str:
|
||||
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
|
||||
return token.access_token
|
||||
|
||||
def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
|
||||
return [ErnieMessage(message.content, message.role) for message in messages]
|
||||
|
||||
|
@ -1,17 +0,0 @@
|
||||
class InvalidAuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
class InvalidAPIKeyError(Exception):
|
||||
pass
|
||||
|
||||
class RateLimitReachedError(Exception):
|
||||
pass
|
||||
|
||||
class InsufficientAccountBalance(Exception):
|
||||
pass
|
||||
|
||||
class InternalServerError(Exception):
|
||||
pass
|
||||
|
||||
class BadRequestError(Exception):
|
||||
pass
|
@ -11,24 +11,13 @@ from core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage
|
||||
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
RateLimitReachedError,
|
||||
)
|
||||
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken
|
||||
from core.model_runtime.model_providers.wenxin.llm.ernie_bot import ErnieBotModel, ErnieMessage
|
||||
from core.model_runtime.model_providers.wenxin.wenxin_errors import invoke_error_mapping
|
||||
|
||||
ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
@ -140,7 +129,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||
api_key = credentials['api_key']
|
||||
secret_key = credentials['secret_key']
|
||||
try:
|
||||
BaiduAccessToken._get_access_token(api_key, secret_key)
|
||||
BaiduAccessToken.get_access_token(api_key, secret_key)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
|
||||
|
||||
@ -254,22 +243,4 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
RateLimitReachedError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
BadRequestError,
|
||||
KeyError
|
||||
]
|
||||
}
|
||||
return invoke_error_mapping()
|
||||
|
@ -0,0 +1,9 @@
|
||||
model: embedding-v1
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 384
|
||||
max_chunks: 16
|
||||
pricing:
|
||||
input: '0.0005'
|
||||
unit: '0.001'
|
||||
currency: RMB
|
@ -0,0 +1,184 @@
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Mapping
|
||||
from json import dumps
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from requests import Response, post
|
||||
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken, _CommonWenxin
|
||||
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
|
||||
BadRequestError,
|
||||
InternalServerError,
|
||||
invoke_error_mapping,
|
||||
)
|
||||
|
||||
|
||||
class TextEmbedding:
|
||||
@abstractmethod
|
||||
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WenxinTextEmbedding(_CommonWenxin, TextEmbedding):
|
||||
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||
access_token = self._get_access_token()
|
||||
url = f'{self.api_bases[model]}?access_token={access_token}'
|
||||
body = self._build_embed_request_body(model, texts, user)
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
resp = post(url, data=dumps(body), headers=headers)
|
||||
if resp.status_code != 200:
|
||||
raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}')
|
||||
return self._handle_embed_response(model, resp)
|
||||
|
||||
def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]:
|
||||
if len(texts) == 0:
|
||||
raise BadRequestError('The number of texts should not be zero.')
|
||||
body = {
|
||||
'input': texts,
|
||||
'user_id': user,
|
||||
}
|
||||
return body
|
||||
|
||||
def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int):
|
||||
data = response.json()
|
||||
if 'error_code' in data:
|
||||
code = data['error_code']
|
||||
msg = data['error_msg']
|
||||
# raise error
|
||||
self._handle_error(code, msg)
|
||||
|
||||
embeddings = [v['embedding'] for v in data['data']]
|
||||
_usage = data['usage']
|
||||
tokens = _usage['prompt_tokens']
|
||||
total_tokens = _usage['total_tokens']
|
||||
|
||||
return embeddings, tokens, total_tokens
|
||||
|
||||
|
||||
class WenxinTextEmbeddingModel(TextEmbeddingModel):
|
||||
def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding:
|
||||
return WenxinTextEmbedding(api_key, secret_key)
|
||||
|
||||
def _invoke(self, model: str, credentials: dict, texts: list[str],
|
||||
user: Optional[str] = None) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
|
||||
api_key = credentials['api_key']
|
||||
secret_key = credentials['secret_key']
|
||||
embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key)
|
||||
user = user if user else 'ErnieBotDefault'
|
||||
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
inputs = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
used_total_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0:cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(inputs), max_chunks)
|
||||
for i in _iter:
|
||||
embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents(
|
||||
model,
|
||||
inputs[i: i + max_chunks],
|
||||
user)
|
||||
used_tokens += _used_tokens
|
||||
used_total_tokens += _total_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
usage = self._calc_response_usage(model, credentials, used_tokens, used_total_tokens)
|
||||
return TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=batched_embeddings,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
total_num_tokens = 0
|
||||
for text in texts:
|
||||
total_num_tokens += self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
return total_num_tokens
|
||||
|
||||
def validate_credentials(self, model: str, credentials: Mapping) -> None:
|
||||
api_key = credentials['api_key']
|
||||
secret_key = credentials['secret_key']
|
||||
try:
|
||||
BaiduAccessToken.get_access_token(api_key, secret_key)
|
||||
except Exception as e:
|
||||
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
return invoke_error_mapping()
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=total_tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
@ -17,6 +17,7 @@ help:
|
||||
en_US: https://cloud.baidu.com/wenxin.html
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
|
@ -0,0 +1,57 @@
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
|
||||
|
||||
def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
RateLimitReachedError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
BadRequestError,
|
||||
KeyError
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class InvalidAuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
class InvalidAPIKeyError(Exception):
|
||||
pass
|
||||
|
||||
class RateLimitReachedError(Exception):
|
||||
pass
|
||||
|
||||
class InsufficientAccountBalance(Exception):
|
||||
pass
|
||||
|
||||
class InternalServerError(Exception):
|
||||
pass
|
||||
|
||||
class BadRequestError(Exception):
|
||||
pass
|
@ -0,0 +1,24 @@
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import WenxinTextEmbeddingModel
|
||||
|
||||
|
||||
def test_invoke_embedding_model():
|
||||
sleep(3)
|
||||
model = WenxinTextEmbeddingModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='embedding-v1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
texts=['hello', '你好', 'xxxxx'],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(response, TextEmbeddingResult)
|
||||
assert len(response.embeddings) == 3
|
||||
assert isinstance(response.embeddings[0], list)
|
0
api/tests/unit_tests/core/model_runtime/__init__.py
Normal file
0
api/tests/unit_tests/core/model_runtime/__init__.py
Normal file
@ -0,0 +1,75 @@
|
||||
import numpy as np
|
||||
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import (
|
||||
TextEmbedding,
|
||||
WenxinTextEmbeddingModel,
|
||||
)
|
||||
|
||||
|
||||
def test_max_chunks():
|
||||
class _MockTextEmbedding(TextEmbedding):
|
||||
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||
embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))]
|
||||
tokens = 0
|
||||
for text in texts:
|
||||
tokens += len(text)
|
||||
|
||||
return embeddings, tokens, tokens
|
||||
|
||||
def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding:
|
||||
return _MockTextEmbedding()
|
||||
|
||||
model = 'embedding-v1'
|
||||
credentials = {
|
||||
'api_key': 'xxxx',
|
||||
'secret_key': 'yyyy',
|
||||
}
|
||||
embedding_model = WenxinTextEmbeddingModel()
|
||||
context_size = embedding_model._get_context_size(model, credentials)
|
||||
max_chunks = embedding_model._get_max_chunks(model, credentials)
|
||||
embedding_model._create_text_embedding = _create_text_embedding
|
||||
|
||||
texts = ['0123456789' for i in range(0, max_chunks * 2)]
|
||||
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
|
||||
assert len(result.embeddings) == max_chunks * 2
|
||||
|
||||
|
||||
def test_context_size():
|
||||
def get_num_tokens_by_gpt2(text: str) -> int:
|
||||
return GPT2Tokenizer.get_num_tokens(text)
|
||||
|
||||
def mock_text(token_size: int) -> str:
|
||||
_text = "".join(['0' for i in range(token_size)])
|
||||
num_tokens = get_num_tokens_by_gpt2(_text)
|
||||
ratio = int(np.floor(len(_text) / num_tokens))
|
||||
m_text = "".join([_text for i in range(ratio)])
|
||||
return m_text
|
||||
|
||||
model = 'embedding-v1'
|
||||
credentials = {
|
||||
'api_key': 'xxxx',
|
||||
'secret_key': 'yyyy',
|
||||
}
|
||||
embedding_model = WenxinTextEmbeddingModel()
|
||||
context_size = embedding_model._get_context_size(model, credentials)
|
||||
|
||||
class _MockTextEmbedding(TextEmbedding):
|
||||
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
|
||||
embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))]
|
||||
tokens = 0
|
||||
for text in texts:
|
||||
tokens += get_num_tokens_by_gpt2(text)
|
||||
return embeddings, tokens, tokens
|
||||
|
||||
def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding:
|
||||
return _MockTextEmbedding()
|
||||
|
||||
embedding_model._create_text_embedding = _create_text_embedding
|
||||
text = mock_text(context_size * 2)
|
||||
assert get_num_tokens_by_gpt2(text) == context_size * 2
|
||||
|
||||
texts = [text]
|
||||
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
|
||||
assert result.usage.tokens == context_size
|
Loading…
Reference in New Issue
Block a user