feat(api): support wenxin text embedding (#7377)

This commit is contained in:
Chengyu Yan 2024-08-19 09:15:19 +08:00 committed by GitHub
parent a0a67873aa
commit bfd905602f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 553 additions and 228 deletions

View File

@ -0,0 +1,195 @@
from datetime import datetime, timedelta
from threading import Lock
from requests import post
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
BadRequestError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
)
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
baidu_access_tokens_lock = Lock()
class BaiduAccessToken:
api_key: str
access_token: str
expires: datetime
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.access_token = ''
self.expires = datetime.now() + timedelta(days=3)
@staticmethod
def _get_access_token(api_key: str, secret_key: str) -> str:
"""
request access token from Baidu
"""
try:
response = post(
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
headers={
'Content-Type': 'application/json',
'Accept': 'application/json'
},
)
except Exception as e:
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
resp = response.json()
if 'error' in resp:
if resp['error'] == 'invalid_client':
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
elif resp['error'] == 'unknown_error':
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
elif resp['error'] == 'invalid_request':
raise BadRequestError(f'Bad request: {resp["error_description"]}')
elif resp['error'] == 'rate_limit_exceeded':
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
else:
raise Exception(f'Unknown error: {resp["error_description"]}')
return resp['access_token']
@staticmethod
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
"""
LLM from Baidu requires access token to invoke the API.
however, we have api_key and secret_key, and access token is valid for 30 days.
so we can cache the access token for 3 days. (avoid memory leak)
it may be more efficient to use a ticker to refresh access token, but it will cause
more complexity, so we just refresh access tokens when get_access_token is called.
"""
# loop up cache, remove expired access token
baidu_access_tokens_lock.acquire()
now = datetime.now()
for key in list(baidu_access_tokens.keys()):
token = baidu_access_tokens[key]
if token.expires < now:
baidu_access_tokens.pop(key)
if api_key not in baidu_access_tokens:
# if access token not in cache, request it
token = BaiduAccessToken(api_key)
baidu_access_tokens[api_key] = token
# release it to enhance performance
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
baidu_access_tokens_lock.release()
# try to get access token
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
token.access_token = token_str
token.expires = now + timedelta(days=3)
return token
else:
# if access token in cache, return it
token = baidu_access_tokens[api_key]
baidu_access_tokens_lock.release()
return token
class _CommonWenxin:
api_bases = {
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1',
}
function_calling_supports = [
'ernie-bot',
'ernie-bot-8k',
'ernie-3.5-8k',
'ernie-3.5-8k-0205',
'ernie-3.5-8k-1222',
'ernie-3.5-4k-0205',
'ernie-3.5-128k',
'ernie-4.0-8k',
'ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview',
'yi_34b_chat'
]
api_key: str = ''
secret_key: str = ''
def __init__(self, api_key: str, secret_key: str):
self.api_key = api_key
self.secret_key = secret_key
@staticmethod
def _to_credential_kwargs(credentials: dict) -> dict:
credentials_kwargs = {
"api_key": credentials['api_key'],
"secret_key": credentials['secret_key']
}
return credentials_kwargs
def _handle_error(self, code: int, msg: str):
error_map = {
1: InternalServerError,
2: InternalServerError,
3: BadRequestError,
4: RateLimitReachedError,
6: InvalidAuthenticationError,
13: InvalidAPIKeyError,
14: InvalidAPIKeyError,
15: InvalidAPIKeyError,
17: RateLimitReachedError,
18: RateLimitReachedError,
19: RateLimitReachedError,
100: InvalidAPIKeyError,
111: InvalidAPIKeyError,
200: InternalServerError,
336000: InternalServerError,
336001: BadRequestError,
336002: BadRequestError,
336003: BadRequestError,
336004: InvalidAuthenticationError,
336005: InvalidAPIKeyError,
336006: BadRequestError,
336007: BadRequestError,
336008: BadRequestError,
336100: InternalServerError,
336101: BadRequestError,
336102: BadRequestError,
336103: BadRequestError,
336104: BadRequestError,
336105: BadRequestError,
336200: InternalServerError,
336303: BadRequestError,
337006: BadRequestError
}
if code in error_map:
raise error_map[code](msg)
else:
raise InternalServerError(f'Unknown error: {msg}')
def _get_access_token(self) -> str:
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
return token.access_token

View File

@ -1,102 +1,17 @@
from collections.abc import Generator
from datetime import datetime, timedelta
from enum import Enum
from json import dumps, loads
from threading import Lock
from typing import Any, Union
from requests import Response, post
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
from core.model_runtime.model_providers.wenxin._common import _CommonWenxin
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
BadRequestError,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
)
# map api_key to access_token
baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {}
baidu_access_tokens_lock = Lock()
class BaiduAccessToken:
api_key: str
access_token: str
expires: datetime
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.access_token = ''
self.expires = datetime.now() + timedelta(days=3)
def _get_access_token(api_key: str, secret_key: str) -> str:
"""
request access token from Baidu
"""
try:
response = post(
url=f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}',
headers={
'Content-Type': 'application/json',
'Accept': 'application/json'
},
)
except Exception as e:
raise InvalidAuthenticationError(f'Failed to get access token from Baidu: {e}')
resp = response.json()
if 'error' in resp:
if resp['error'] == 'invalid_client':
raise InvalidAPIKeyError(f'Invalid API key or secret key: {resp["error_description"]}')
elif resp['error'] == 'unknown_error':
raise InternalServerError(f'Internal server error: {resp["error_description"]}')
elif resp['error'] == 'invalid_request':
raise BadRequestError(f'Bad request: {resp["error_description"]}')
elif resp['error'] == 'rate_limit_exceeded':
raise RateLimitReachedError(f'Rate limit reached: {resp["error_description"]}')
else:
raise Exception(f'Unknown error: {resp["error_description"]}')
return resp['access_token']
@staticmethod
def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken':
"""
LLM from Baidu requires access token to invoke the API.
however, we have api_key and secret_key, and access token is valid for 30 days.
so we can cache the access token for 3 days. (avoid memory leak)
it may be more efficient to use a ticker to refresh access token, but it will cause
more complexity, so we just refresh access tokens when get_access_token is called.
"""
# loop up cache, remove expired access token
baidu_access_tokens_lock.acquire()
now = datetime.now()
for key in list(baidu_access_tokens.keys()):
token = baidu_access_tokens[key]
if token.expires < now:
baidu_access_tokens.pop(key)
if api_key not in baidu_access_tokens:
# if access token not in cache, request it
token = BaiduAccessToken(api_key)
baidu_access_tokens[api_key] = token
# release it to enhance performance
# btw, _get_access_token will raise exception if failed, release lock here to avoid deadlock
baidu_access_tokens_lock.release()
# try to get access token
token_str = BaiduAccessToken._get_access_token(api_key, secret_key)
token.access_token = token_str
token.expires = now + timedelta(days=3)
return token
else:
# if access token in cache, return it
token = baidu_access_tokens[api_key]
baidu_access_tokens_lock.release()
return token
class ErnieMessage:
class Role(Enum):
@ -120,51 +35,7 @@ class ErnieMessage:
self.content = content
self.role = role
class ErnieBotModel:
api_bases = {
'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-bot-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-bot-turbo': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-3.5-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions',
'ernie-3.5-8k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205',
'ernie-3.5-8k-1222': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222',
'ernie-3.5-4k-0205': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205',
'ernie-3.5-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-128k',
'ernie-4.0-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-4.0-8k-latest': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro',
'ernie-speed-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed',
'ernie-speed-128k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-speed-128k',
'ernie-speed-appbuilder': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ai_apaas',
'ernie-lite-8k-0922': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant',
'ernie-lite-8k-0308': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-lite-8k',
'ernie-character-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-character-8k-0321': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k',
'ernie-4.0-turbo-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview',
'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat',
}
function_calling_supports = [
'ernie-bot',
'ernie-bot-8k',
'ernie-3.5-8k',
'ernie-3.5-8k-0205',
'ernie-3.5-8k-1222',
'ernie-3.5-4k-0205',
'ernie-3.5-128k',
'ernie-4.0-8k',
'ernie-4.0-turbo-8k',
'ernie-4.0-turbo-8k-preview',
'yi_34b_chat'
]
api_key: str = ''
secret_key: str = ''
def __init__(self, api_key: str, secret_key: str):
self.api_key = api_key
self.secret_key = secret_key
class ErnieBotModel(_CommonWenxin):
def generate(self, model: str, stream: bool, messages: list[ErnieMessage],
parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \
@ -199,51 +70,6 @@ class ErnieBotModel:
return self._handle_chat_stream_generate_response(resp)
return self._handle_chat_generate_response(resp)
def _handle_error(self, code: int, msg: str):
error_map = {
1: InternalServerError,
2: InternalServerError,
3: BadRequestError,
4: RateLimitReachedError,
6: InvalidAuthenticationError,
13: InvalidAPIKeyError,
14: InvalidAPIKeyError,
15: InvalidAPIKeyError,
17: RateLimitReachedError,
18: RateLimitReachedError,
19: RateLimitReachedError,
100: InvalidAPIKeyError,
111: InvalidAPIKeyError,
200: InternalServerError,
336000: InternalServerError,
336001: BadRequestError,
336002: BadRequestError,
336003: BadRequestError,
336004: InvalidAuthenticationError,
336005: InvalidAPIKeyError,
336006: BadRequestError,
336007: BadRequestError,
336008: BadRequestError,
336100: InternalServerError,
336101: BadRequestError,
336102: BadRequestError,
336103: BadRequestError,
336104: BadRequestError,
336105: BadRequestError,
336200: InternalServerError,
336303: BadRequestError,
337006: BadRequestError
}
if code in error_map:
raise error_map[code](msg)
else:
raise InternalServerError(f'Unknown error: {msg}')
def _get_access_token(self) -> str:
token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key)
return token.access_token
def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]:
return [ErnieMessage(message.content, message.role) for message in messages]

View File

@ -1,17 +0,0 @@
class InvalidAuthenticationError(Exception):
pass
class InvalidAPIKeyError(Exception):
pass
class RateLimitReachedError(Exception):
pass
class InsufficientAccountBalance(Exception):
pass
class InternalServerError(Exception):
pass
class BadRequestError(Exception):
pass

View File

@ -11,24 +11,13 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage
from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
BadRequestError,
InsufficientAccountBalance,
InternalServerError,
InvalidAPIKeyError,
InvalidAuthenticationError,
RateLimitReachedError,
)
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken
from core.model_runtime.model_providers.wenxin.llm.ernie_bot import ErnieBotModel, ErnieMessage
from core.model_runtime.model_providers.wenxin.wenxin_errors import invoke_error_mapping
ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
@ -140,7 +129,7 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
api_key = credentials['api_key']
secret_key = credentials['secret_key']
try:
BaiduAccessToken._get_access_token(api_key, secret_key)
BaiduAccessToken.get_access_token(api_key, secret_key)
except Exception as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
@ -254,22 +243,4 @@ class ErnieBotLargeLanguageModel(LargeLanguageModel):
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
],
InvokeServerUnavailableError: [
InternalServerError
],
InvokeRateLimitError: [
RateLimitReachedError
],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InvalidAPIKeyError,
],
InvokeBadRequestError: [
BadRequestError,
KeyError
]
}
return invoke_error_mapping()

View File

@ -0,0 +1,9 @@
model: embedding-v1
model_type: text-embedding
model_properties:
context_size: 384
max_chunks: 16
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,184 @@
import time
from abc import abstractmethod
from collections.abc import Mapping
from json import dumps
from typing import Any, Optional
import numpy as np
from requests import Response, post
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import InvokeError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.wenxin._common import BaiduAccessToken, _CommonWenxin
from core.model_runtime.model_providers.wenxin.wenxin_errors import (
BadRequestError,
InternalServerError,
invoke_error_mapping,
)
class TextEmbedding:
@abstractmethod
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
raise NotImplementedError
class WenxinTextEmbedding(_CommonWenxin, TextEmbedding):
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
access_token = self._get_access_token()
url = f'{self.api_bases[model]}?access_token={access_token}'
body = self._build_embed_request_body(model, texts, user)
headers = {
'Content-Type': 'application/json',
}
resp = post(url, data=dumps(body), headers=headers)
if resp.status_code != 200:
raise InternalServerError(f'Failed to invoke ernie bot: {resp.text}')
return self._handle_embed_response(model, resp)
def _build_embed_request_body(self, model: str, texts: list[str], user: str) -> dict[str, Any]:
if len(texts) == 0:
raise BadRequestError('The number of texts should not be zero.')
body = {
'input': texts,
'user_id': user,
}
return body
def _handle_embed_response(self, model: str, response: Response) -> (list[list[float]], int, int):
data = response.json()
if 'error_code' in data:
code = data['error_code']
msg = data['error_msg']
# raise error
self._handle_error(code, msg)
embeddings = [v['embedding'] for v in data['data']]
_usage = data['usage']
tokens = _usage['prompt_tokens']
total_tokens = _usage['total_tokens']
return embeddings, tokens, total_tokens
class WenxinTextEmbeddingModel(TextEmbeddingModel):
def _create_text_embedding(self, api_key: str, secret_key: str) -> TextEmbedding:
return WenxinTextEmbedding(api_key, secret_key)
def _invoke(self, model: str, credentials: dict, texts: list[str],
user: Optional[str] = None) -> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
api_key = credentials['api_key']
secret_key = credentials['secret_key']
embedding: TextEmbedding = self._create_text_embedding(api_key, secret_key)
user = user if user else 'ErnieBotDefault'
context_size = self._get_context_size(model, credentials)
max_chunks = self._get_max_chunks(model, credentials)
inputs = []
indices = []
used_tokens = 0
used_total_tokens = 0
for i, text in enumerate(texts):
# Here token count is only an approximation based on the GPT2 tokenizer
num_tokens = self._get_num_tokens_by_gpt2(text)
if num_tokens >= context_size:
cutoff = int(np.floor(len(text) * (context_size / num_tokens)))
# if num tokens is larger than context length, only use the start
inputs.append(text[0:cutoff])
else:
inputs.append(text)
indices += [i]
batched_embeddings = []
_iter = range(0, len(inputs), max_chunks)
for i in _iter:
embeddings_batch, _used_tokens, _total_used_tokens = embedding.embed_documents(
model,
inputs[i: i + max_chunks],
user)
used_tokens += _used_tokens
used_total_tokens += _total_used_tokens
batched_embeddings += embeddings_batch
usage = self._calc_response_usage(model, credentials, used_tokens, used_total_tokens)
return TextEmbeddingResult(
model=model,
embeddings=batched_embeddings,
usage=usage,
)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
if len(texts) == 0:
return 0
total_num_tokens = 0
for text in texts:
total_num_tokens += self._get_num_tokens_by_gpt2(text)
return total_num_tokens
def validate_credentials(self, model: str, credentials: Mapping) -> None:
api_key = credentials['api_key']
secret_key = credentials['secret_key']
try:
BaiduAccessToken.get_access_token(api_key, secret_key)
except Exception as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: {e}')
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return invoke_error_mapping()
def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: input tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=total_tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at
)
return usage

View File

@ -17,6 +17,7 @@ help:
en_US: https://cloud.baidu.com/wenxin.html
supported_model_types:
- llm
- text-embedding
configurate_methods:
- predefined-model
provider_credential_schema:

View File

@ -0,0 +1,57 @@
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
],
InvokeServerUnavailableError: [
InternalServerError
],
InvokeRateLimitError: [
RateLimitReachedError
],
InvokeAuthorizationError: [
InvalidAuthenticationError,
InsufficientAccountBalance,
InvalidAPIKeyError,
],
InvokeBadRequestError: [
BadRequestError,
KeyError
]
}
class InvalidAuthenticationError(Exception):
pass
class InvalidAPIKeyError(Exception):
pass
class RateLimitReachedError(Exception):
pass
class InsufficientAccountBalance(Exception):
pass
class InternalServerError(Exception):
pass
class BadRequestError(Exception):
pass

View File

@ -0,0 +1,24 @@
import os
from time import sleep
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import WenxinTextEmbeddingModel
def test_invoke_embedding_model():
sleep(3)
model = WenxinTextEmbeddingModel()
response = model.invoke(
model='embedding-v1',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
texts=['hello', '你好', 'xxxxx'],
user="abc-123"
)
assert isinstance(response, TextEmbeddingResult)
assert len(response.embeddings) == 3
assert isinstance(response.embeddings[0], list)

View File

@ -0,0 +1,75 @@
import numpy as np
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import (
TextEmbedding,
WenxinTextEmbeddingModel,
)
def test_max_chunks():
class _MockTextEmbedding(TextEmbedding):
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))]
tokens = 0
for text in texts:
tokens += len(text)
return embeddings, tokens, tokens
def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding:
return _MockTextEmbedding()
model = 'embedding-v1'
credentials = {
'api_key': 'xxxx',
'secret_key': 'yyyy',
}
embedding_model = WenxinTextEmbeddingModel()
context_size = embedding_model._get_context_size(model, credentials)
max_chunks = embedding_model._get_max_chunks(model, credentials)
embedding_model._create_text_embedding = _create_text_embedding
texts = ['0123456789' for i in range(0, max_chunks * 2)]
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
assert len(result.embeddings) == max_chunks * 2
def test_context_size():
def get_num_tokens_by_gpt2(text: str) -> int:
return GPT2Tokenizer.get_num_tokens(text)
def mock_text(token_size: int) -> str:
_text = "".join(['0' for i in range(token_size)])
num_tokens = get_num_tokens_by_gpt2(_text)
ratio = int(np.floor(len(_text) / num_tokens))
m_text = "".join([_text for i in range(ratio)])
return m_text
model = 'embedding-v1'
credentials = {
'api_key': 'xxxx',
'secret_key': 'yyyy',
}
embedding_model = WenxinTextEmbeddingModel()
context_size = embedding_model._get_context_size(model, credentials)
class _MockTextEmbedding(TextEmbedding):
def embed_documents(self, model: str, texts: list[str], user: str) -> (list[list[float]], int, int):
embeddings = [[1.0, 2.0, 3.0] for i in range(len(texts))]
tokens = 0
for text in texts:
tokens += get_num_tokens_by_gpt2(text)
return embeddings, tokens, tokens
def _create_text_embedding(api_key: str, secret_key: str) -> TextEmbedding:
return _MockTextEmbedding()
embedding_model._create_text_embedding = _create_text_embedding
text = mock_text(context_size * 2)
assert get_num_tokens_by_gpt2(text) == context_size * 2
texts = [text]
result: TextEmbeddingResult = embedding_model.invoke(model, credentials, texts, 'test')
assert result.usage.tokens == context_size