diff --git a/api/core/model_runtime/model_providers/wenxin/_common.py b/api/core/model_runtime/model_providers/wenxin/_common.py index ee9c34b6a..0230c78b7 100644 --- a/api/core/model_runtime/model_providers/wenxin/_common.py +++ b/api/core/model_runtime/model_providers/wenxin/_common.py @@ -118,6 +118,9 @@ class _CommonWenxin: 'ernie-4.0-turbo-8k-preview': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-4.0-turbo-8k-preview', 'yi_34b_chat': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat', 'embedding-v1': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1', + 'bge-large-en': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en', + 'bge-large-zh': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh', + 'tao-8k': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k', } function_calling_supports = [ diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-en.yaml b/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-en.yaml new file mode 100644 index 000000000..74fadb7f9 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-en.yaml @@ -0,0 +1,9 @@ +model: bge-large-en +model_type: text-embedding +model_properties: + context_size: 512 + max_chunks: 16 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-zh.yaml b/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-zh.yaml new file mode 100644 index 000000000..d4af27ec3 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/bge-large-zh.yaml @@ -0,0 +1,9 @@ +model: bge-large-zh +model_type: text-embedding +model_properties: + context_size: 512 + max_chunks: 16 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/tao-8k.yaml b/api/core/model_runtime/model_providers/wenxin/text_embedding/tao-8k.yaml new file mode 100644 index 000000000..e28f253eb --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/tao-8k.yaml @@ -0,0 +1,9 @@ +model: tao-8k +model_type: text-embedding +model_properties: + context_size: 8192 + max_chunks: 1 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py b/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py index 60e803622..d886226cf 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_embedding.py @@ -5,7 +5,7 @@ from core.model_runtime.entities.text_embedding_entities import TextEmbeddingRes from core.model_runtime.model_providers.wenxin.text_embedding.text_embedding import WenxinTextEmbeddingModel -def test_invoke_embedding_model(): +def test_invoke_embedding_v1(): sleep(3) model = WenxinTextEmbeddingModel() @@ -21,4 +21,61 @@ def test_invoke_embedding_model(): assert isinstance(response, TextEmbeddingResult) assert len(response.embeddings) == 3 - assert isinstance(response.embeddings[0], list) \ No newline at end of file + assert isinstance(response.embeddings[0], list) + + +def test_invoke_embedding_bge_large_en(): + sleep(3) + model = WenxinTextEmbeddingModel() + + response = model.invoke( + model='bge-large-en', + credentials={ + 'api_key': os.environ.get('WENXIN_API_KEY'), + 'secret_key': os.environ.get('WENXIN_SECRET_KEY') + }, + texts=['hello', '你好', 'xxxxx'], + user="abc-123" + ) + + assert isinstance(response, TextEmbeddingResult) + assert len(response.embeddings) == 3 + assert isinstance(response.embeddings[0], list) + + +def test_invoke_embedding_bge_large_zh(): + sleep(3) + model = WenxinTextEmbeddingModel() + + response = model.invoke( + model='bge-large-zh', + credentials={ + 'api_key': os.environ.get('WENXIN_API_KEY'), + 'secret_key': os.environ.get('WENXIN_SECRET_KEY') + }, + texts=['hello', '你好', 'xxxxx'], + user="abc-123" + ) + + assert isinstance(response, TextEmbeddingResult) + assert len(response.embeddings) == 3 + assert isinstance(response.embeddings[0], list) + + +def test_invoke_embedding_tao_8k(): + sleep(3) + model = WenxinTextEmbeddingModel() + + response = model.invoke( + model='tao-8k', + credentials={ + 'api_key': os.environ.get('WENXIN_API_KEY'), + 'secret_key': os.environ.get('WENXIN_SECRET_KEY') + }, + texts=['hello', '你好', 'xxxxx'], + user="abc-123" + ) + + assert isinstance(response, TextEmbeddingResult) + assert len(response.embeddings) == 3 + assert isinstance(response.embeddings[0], list)