mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-02 11:18:19 +08:00
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
2da63654e5
commit
b035c02f78
@ -76,7 +76,6 @@ exclude = [
|
||||
"migrations/**/*",
|
||||
"services/**/*.py",
|
||||
"tasks/**/*.py",
|
||||
"tests/**/*.py",
|
||||
]
|
||||
|
||||
[tool.pytest_env]
|
||||
|
@ -22,23 +22,20 @@ from anthropic.types import (
|
||||
)
|
||||
from anthropic.types.message_delta_event import Delta
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
|
||||
|
||||
|
||||
class MockAnthropicClass:
|
||||
@staticmethod
|
||||
def mocked_anthropic_chat_create_sync(model: str) -> Message:
|
||||
return Message(
|
||||
id='msg-123',
|
||||
type='message',
|
||||
role='assistant',
|
||||
content=[ContentBlock(text='hello, I\'m a chatbot from anthropic', type='text')],
|
||||
id="msg-123",
|
||||
type="message",
|
||||
role="assistant",
|
||||
content=[ContentBlock(text="hello, I'm a chatbot from anthropic", type="text")],
|
||||
model=model,
|
||||
stop_reason='stop_sequence',
|
||||
usage=Usage(
|
||||
input_tokens=1,
|
||||
output_tokens=1
|
||||
)
|
||||
stop_reason="stop_sequence",
|
||||
usage=Usage(input_tokens=1, output_tokens=1),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -46,52 +43,43 @@ class MockAnthropicClass:
|
||||
full_response_text = "hello, I'm a chatbot from anthropic"
|
||||
|
||||
yield MessageStartEvent(
|
||||
type='message_start',
|
||||
type="message_start",
|
||||
message=Message(
|
||||
id='msg-123',
|
||||
id="msg-123",
|
||||
content=[],
|
||||
role='assistant',
|
||||
role="assistant",
|
||||
model=model,
|
||||
stop_reason=None,
|
||||
type='message',
|
||||
usage=Usage(
|
||||
input_tokens=1,
|
||||
output_tokens=1
|
||||
)
|
||||
)
|
||||
type="message",
|
||||
usage=Usage(input_tokens=1, output_tokens=1),
|
||||
),
|
||||
)
|
||||
|
||||
index = 0
|
||||
for i in range(0, len(full_response_text)):
|
||||
yield ContentBlockDeltaEvent(
|
||||
type='content_block_delta',
|
||||
delta=TextDelta(text=full_response_text[i], type='text_delta'),
|
||||
index=index
|
||||
type="content_block_delta", delta=TextDelta(text=full_response_text[i], type="text_delta"), index=index
|
||||
)
|
||||
|
||||
index += 1
|
||||
|
||||
yield MessageDeltaEvent(
|
||||
type='message_delta',
|
||||
delta=Delta(
|
||||
stop_reason='stop_sequence'
|
||||
),
|
||||
usage=MessageDeltaUsage(
|
||||
output_tokens=1
|
||||
)
|
||||
type="message_delta", delta=Delta(stop_reason="stop_sequence"), usage=MessageDeltaUsage(output_tokens=1)
|
||||
)
|
||||
|
||||
yield MessageStopEvent(type='message_stop')
|
||||
yield MessageStopEvent(type="message_stop")
|
||||
|
||||
def mocked_anthropic(self: Messages, *,
|
||||
max_tokens: int,
|
||||
messages: Iterable[MessageParam],
|
||||
model: str,
|
||||
stream: Literal[True],
|
||||
**kwargs: Any
|
||||
) -> Union[Message, Stream[MessageStreamEvent]]:
|
||||
def mocked_anthropic(
|
||||
self: Messages,
|
||||
*,
|
||||
max_tokens: int,
|
||||
messages: Iterable[MessageParam],
|
||||
model: str,
|
||||
stream: Literal[True],
|
||||
**kwargs: Any,
|
||||
) -> Union[Message, Stream[MessageStreamEvent]]:
|
||||
if len(self._client.api_key) < 18:
|
||||
raise anthropic.AuthenticationError('Invalid API key')
|
||||
raise anthropic.AuthenticationError("Invalid API key")
|
||||
|
||||
if stream:
|
||||
return MockAnthropicClass.mocked_anthropic_chat_create_stream(model=model)
|
||||
@ -102,7 +90,7 @@ class MockAnthropicClass:
|
||||
@pytest.fixture
|
||||
def setup_anthropic_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Messages, 'create', MockAnthropicClass.mocked_anthropic)
|
||||
monkeypatch.setattr(Messages, "create", MockAnthropicClass.mocked_anthropic)
|
||||
|
||||
yield
|
||||
|
||||
|
@ -12,63 +12,46 @@ from google.generativeai.client import _ClientManager, configure
|
||||
from google.generativeai.types import GenerateContentResponse
|
||||
from google.generativeai.types.generation_types import BaseGenerateContentResponse
|
||||
|
||||
current_api_key = ''
|
||||
current_api_key = ""
|
||||
|
||||
|
||||
class MockGoogleResponseClass:
|
||||
_done = False
|
||||
|
||||
def __iter__(self):
|
||||
full_response_text = 'it\'s google!'
|
||||
full_response_text = "it's google!"
|
||||
|
||||
for i in range(0, len(full_response_text) + 1, 1):
|
||||
if i == len(full_response_text):
|
||||
self._done = True
|
||||
yield GenerateContentResponse(
|
||||
done=True,
|
||||
iterator=None,
|
||||
result=glm.GenerateContentResponse({
|
||||
|
||||
}),
|
||||
chunks=[]
|
||||
done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
|
||||
)
|
||||
else:
|
||||
yield GenerateContentResponse(
|
||||
done=False,
|
||||
iterator=None,
|
||||
result=glm.GenerateContentResponse({
|
||||
|
||||
}),
|
||||
chunks=[]
|
||||
done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
|
||||
)
|
||||
|
||||
|
||||
class MockGoogleResponseCandidateClass:
|
||||
finish_reason = 'stop'
|
||||
finish_reason = "stop"
|
||||
|
||||
@property
|
||||
def content(self) -> gag_content.Content:
|
||||
return gag_content.Content(
|
||||
parts=[
|
||||
gag_content.Part(text='it\'s google!')
|
||||
]
|
||||
)
|
||||
return gag_content.Content(parts=[gag_content.Part(text="it's google!")])
|
||||
|
||||
|
||||
class MockGoogleClass:
|
||||
@staticmethod
|
||||
def generate_content_sync() -> GenerateContentResponse:
|
||||
return GenerateContentResponse(
|
||||
done=True,
|
||||
iterator=None,
|
||||
result=glm.GenerateContentResponse({
|
||||
|
||||
}),
|
||||
chunks=[]
|
||||
)
|
||||
return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[])
|
||||
|
||||
@staticmethod
|
||||
def generate_content_stream() -> Generator[GenerateContentResponse, None, None]:
|
||||
return MockGoogleResponseClass()
|
||||
|
||||
def generate_content(self: GenerativeModel,
|
||||
def generate_content(
|
||||
self: GenerativeModel,
|
||||
contents: content_types.ContentsType,
|
||||
*,
|
||||
generation_config: generation_config_types.GenerationConfigType | None = None,
|
||||
@ -79,21 +62,21 @@ class MockGoogleClass:
|
||||
global current_api_key
|
||||
|
||||
if len(current_api_key) < 16:
|
||||
raise Exception('Invalid API key')
|
||||
raise Exception("Invalid API key")
|
||||
|
||||
if stream:
|
||||
return MockGoogleClass.generate_content_stream()
|
||||
|
||||
|
||||
return MockGoogleClass.generate_content_sync()
|
||||
|
||||
|
||||
@property
|
||||
def generative_response_text(self) -> str:
|
||||
return 'it\'s google!'
|
||||
|
||||
return "it's google!"
|
||||
|
||||
@property
|
||||
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
|
||||
return [MockGoogleResponseCandidateClass()]
|
||||
|
||||
|
||||
def make_client(self: _ClientManager, name: str):
|
||||
global current_api_key
|
||||
|
||||
@ -121,7 +104,8 @@ class MockGoogleClass:
|
||||
|
||||
if not self.default_metadata:
|
||||
return client
|
||||
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_google_mock(request, monkeypatch: MonkeyPatch):
|
||||
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
|
||||
@ -131,4 +115,4 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch):
|
||||
|
||||
yield
|
||||
|
||||
monkeypatch.undo()
|
||||
monkeypatch.undo()
|
||||
|
@ -6,14 +6,15 @@ from huggingface_hub import InferenceClient
|
||||
|
||||
from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_huggingface_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(InferenceClient, "text_generation", MockHuggingfaceChatClass.text_generation)
|
||||
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
monkeypatch.undo()
|
||||
|
@ -22,10 +22,8 @@ class MockHuggingfaceChatClass:
|
||||
details=Details(
|
||||
finish_reason="length",
|
||||
generated_tokens=6,
|
||||
tokens=[
|
||||
Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)
|
||||
]
|
||||
)
|
||||
tokens=[Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)],
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
@ -36,26 +34,23 @@ class MockHuggingfaceChatClass:
|
||||
|
||||
for i in range(0, len(full_text)):
|
||||
response = TextGenerationStreamResponse(
|
||||
token = Token(id=i, text=full_text[i], logprob=0.0, special=False),
|
||||
token=Token(id=i, text=full_text[i], logprob=0.0, special=False),
|
||||
)
|
||||
response.generated_text = full_text[i]
|
||||
response.details = StreamDetails(finish_reason='stop_sequence', generated_tokens=1)
|
||||
response.details = StreamDetails(finish_reason="stop_sequence", generated_tokens=1)
|
||||
|
||||
yield response
|
||||
|
||||
def text_generation(self: InferenceClient, prompt: str, *,
|
||||
stream: Literal[False] = ...,
|
||||
model: Optional[str] = None,
|
||||
**kwargs: Any
|
||||
def text_generation(
|
||||
self: InferenceClient, prompt: str, *, stream: Literal[False] = ..., model: Optional[str] = None, **kwargs: Any
|
||||
) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]:
|
||||
# check if key is valid
|
||||
if not re.match(r'Bearer\shf\-[a-zA-Z0-9]{16,}', self.headers['authorization']):
|
||||
raise BadRequestError('Invalid API key')
|
||||
|
||||
if not re.match(r"Bearer\shf\-[a-zA-Z0-9]{16,}", self.headers["authorization"]):
|
||||
raise BadRequestError("Invalid API key")
|
||||
|
||||
if model is None:
|
||||
raise BadRequestError('Invalid model')
|
||||
|
||||
raise BadRequestError("Invalid model")
|
||||
|
||||
if stream:
|
||||
return MockHuggingfaceChatClass.generate_create_stream(model)
|
||||
return MockHuggingfaceChatClass.generate_create_sync(model)
|
||||
|
||||
|
@ -5,10 +5,10 @@ class MockTEIClass:
|
||||
@staticmethod
|
||||
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
|
||||
# During mock, we don't have a real server to query, so we just return a dummy value
|
||||
if 'rerank' in model_name:
|
||||
model_type = 'reranker'
|
||||
if "rerank" in model_name:
|
||||
model_type = "reranker"
|
||||
else:
|
||||
model_type = 'embedding'
|
||||
model_type = "embedding"
|
||||
|
||||
return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
|
||||
|
||||
@ -17,16 +17,16 @@ class MockTEIClass:
|
||||
# Use space as token separator, and split the text into tokens
|
||||
tokenized_texts = []
|
||||
for text in texts:
|
||||
tokens = text.split(' ')
|
||||
tokens = text.split(" ")
|
||||
current_index = 0
|
||||
tokenized_text = []
|
||||
for idx, token in enumerate(tokens):
|
||||
s_token = {
|
||||
'id': idx,
|
||||
'text': token,
|
||||
'special': False,
|
||||
'start': current_index,
|
||||
'stop': current_index + len(token),
|
||||
"id": idx,
|
||||
"text": token,
|
||||
"special": False,
|
||||
"start": current_index,
|
||||
"stop": current_index + len(token),
|
||||
}
|
||||
current_index += len(token) + 1
|
||||
tokenized_text.append(s_token)
|
||||
@ -55,18 +55,18 @@ class MockTEIClass:
|
||||
embedding = [0.1] * 768
|
||||
embeddings.append(
|
||||
{
|
||||
'object': 'embedding',
|
||||
'embedding': embedding,
|
||||
'index': idx,
|
||||
"object": "embedding",
|
||||
"embedding": embedding,
|
||||
"index": idx,
|
||||
}
|
||||
)
|
||||
return {
|
||||
'object': 'list',
|
||||
'data': embeddings,
|
||||
'model': 'MODEL_NAME',
|
||||
'usage': {
|
||||
'prompt_tokens': sum(len(text.split(' ')) for text in texts),
|
||||
'total_tokens': sum(len(text.split(' ')) for text in texts),
|
||||
"object": "list",
|
||||
"data": embeddings,
|
||||
"model": "MODEL_NAME",
|
||||
"usage": {
|
||||
"prompt_tokens": sum(len(text.split(" ")) for text in texts),
|
||||
"total_tokens": sum(len(text.split(" ")) for text in texts),
|
||||
},
|
||||
}
|
||||
|
||||
@ -83,9 +83,9 @@ class MockTEIClass:
|
||||
for idx, text in enumerate(texts):
|
||||
reranked_docs.append(
|
||||
{
|
||||
'index': idx,
|
||||
'text': text,
|
||||
'score': 0.9,
|
||||
"index": idx,
|
||||
"text": text,
|
||||
"score": 0.9,
|
||||
}
|
||||
)
|
||||
# For mock, only return the first document
|
||||
|
@ -21,13 +21,17 @@ from tests.integration_tests.model_runtime.__mock.openai_remote import MockModel
|
||||
from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass
|
||||
|
||||
|
||||
def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]]) -> Callable[[], None]:
|
||||
def mock_openai(
|
||||
monkeypatch: MonkeyPatch,
|
||||
methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]],
|
||||
) -> Callable[[], None]:
|
||||
"""
|
||||
mock openai module
|
||||
mock openai module
|
||||
|
||||
:param monkeypatch: pytest monkeypatch fixture
|
||||
:return: unpatch function
|
||||
:param monkeypatch: pytest monkeypatch fixture
|
||||
:return: unpatch function
|
||||
"""
|
||||
|
||||
def unpatch() -> None:
|
||||
monkeypatch.undo()
|
||||
|
||||
@ -52,15 +56,16 @@ def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "c
|
||||
return unpatch
|
||||
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_openai_mock(request, monkeypatch):
|
||||
methods = request.param if hasattr(request, 'param') else []
|
||||
methods = request.param if hasattr(request, "param") else []
|
||||
if MOCK:
|
||||
unpatch = mock_openai(monkeypatch, methods=methods)
|
||||
|
||||
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
unpatch()
|
||||
unpatch()
|
||||
|
@ -43,62 +43,64 @@ class MockChatClass:
|
||||
if not functions or len(functions) == 0:
|
||||
return None
|
||||
function: completion_create_params.Function = functions[0]
|
||||
function_name = function['name']
|
||||
function_description = function['description']
|
||||
function_parameters = function['parameters']
|
||||
function_parameters_type = function_parameters['type']
|
||||
if function_parameters_type != 'object':
|
||||
function_name = function["name"]
|
||||
function_description = function["description"]
|
||||
function_parameters = function["parameters"]
|
||||
function_parameters_type = function_parameters["type"]
|
||||
if function_parameters_type != "object":
|
||||
return None
|
||||
function_parameters_properties = function_parameters['properties']
|
||||
function_parameters_required = function_parameters['required']
|
||||
function_parameters_properties = function_parameters["properties"]
|
||||
function_parameters_required = function_parameters["required"]
|
||||
parameters = {}
|
||||
for parameter_name, parameter in function_parameters_properties.items():
|
||||
if parameter_name not in function_parameters_required:
|
||||
continue
|
||||
parameter_type = parameter['type']
|
||||
if parameter_type == 'string':
|
||||
if 'enum' in parameter:
|
||||
if len(parameter['enum']) == 0:
|
||||
parameter_type = parameter["type"]
|
||||
if parameter_type == "string":
|
||||
if "enum" in parameter:
|
||||
if len(parameter["enum"]) == 0:
|
||||
continue
|
||||
parameters[parameter_name] = parameter['enum'][0]
|
||||
parameters[parameter_name] = parameter["enum"][0]
|
||||
else:
|
||||
parameters[parameter_name] = 'kawaii'
|
||||
elif parameter_type == 'integer':
|
||||
parameters[parameter_name] = "kawaii"
|
||||
elif parameter_type == "integer":
|
||||
parameters[parameter_name] = 114514
|
||||
elif parameter_type == 'number':
|
||||
elif parameter_type == "number":
|
||||
parameters[parameter_name] = 1919810.0
|
||||
elif parameter_type == 'boolean':
|
||||
elif parameter_type == "boolean":
|
||||
parameters[parameter_name] = True
|
||||
|
||||
return FunctionCall(name=function_name, arguments=dumps(parameters))
|
||||
|
||||
|
||||
@staticmethod
|
||||
def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
|
||||
def generate_tool_calls(tools=NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
|
||||
list_tool_calls = []
|
||||
if not tools or len(tools) == 0:
|
||||
return None
|
||||
tool = tools[0]
|
||||
|
||||
if 'type' in tools and tools['type'] != 'function':
|
||||
if "type" in tools and tools["type"] != "function":
|
||||
return None
|
||||
|
||||
function = tool['function']
|
||||
function = tool["function"]
|
||||
|
||||
function_call = MockChatClass.generate_function_call(functions=[function])
|
||||
if function_call is None:
|
||||
return None
|
||||
|
||||
list_tool_calls.append(ChatCompletionMessageToolCall(
|
||||
id='sakurajima-mai',
|
||||
function=Function(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
),
|
||||
type='function'
|
||||
))
|
||||
|
||||
list_tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id="sakurajima-mai",
|
||||
function=Function(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
)
|
||||
|
||||
return list_tool_calls
|
||||
|
||||
|
||||
@staticmethod
|
||||
def mocked_openai_chat_create_sync(
|
||||
model: str,
|
||||
@ -111,30 +113,27 @@ class MockChatClass:
|
||||
tool_calls = MockChatClass.generate_tool_calls(tools=tools)
|
||||
|
||||
return _ChatCompletion(
|
||||
id='cmpl-3QJQa5jXJ5Z5X',
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
choices=[
|
||||
_ChatCompletionChoice(
|
||||
finish_reason='content_filter',
|
||||
finish_reason="content_filter",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(
|
||||
content='elaina',
|
||||
role='assistant',
|
||||
function_call=function_call,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
content="elaina", role="assistant", function_call=function_call, tool_calls=tool_calls
|
||||
),
|
||||
)
|
||||
],
|
||||
created=int(time()),
|
||||
model=model,
|
||||
object='chat.completion',
|
||||
system_fingerprint='',
|
||||
object="chat.completion",
|
||||
system_fingerprint="",
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=1,
|
||||
total_tokens=3,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def mocked_openai_chat_create_stream(
|
||||
model: str,
|
||||
@ -150,36 +149,40 @@ class MockChatClass:
|
||||
for i in range(0, len(full_text) + 1):
|
||||
if i == len(full_text):
|
||||
yield ChatCompletionChunk(
|
||||
id='cmpl-3QJQa5jXJ5Z5X',
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
choices=[
|
||||
Choice(
|
||||
delta=ChoiceDelta(
|
||||
content='',
|
||||
content="",
|
||||
function_call=ChoiceDeltaFunctionCall(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
) if function_call else None,
|
||||
role='assistant',
|
||||
)
|
||||
if function_call
|
||||
else None,
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChoiceDeltaToolCall(
|
||||
index=0,
|
||||
id='misaka-mikoto',
|
||||
id="misaka-mikoto",
|
||||
function=ChoiceDeltaToolCallFunction(
|
||||
name=tool_calls[0].function.name,
|
||||
arguments=tool_calls[0].function.arguments,
|
||||
),
|
||||
type='function'
|
||||
type="function",
|
||||
)
|
||||
] if tool_calls and len(tool_calls) > 0 else None
|
||||
]
|
||||
if tool_calls and len(tool_calls) > 0
|
||||
else None,
|
||||
),
|
||||
finish_reason='function_call',
|
||||
finish_reason="function_call",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=int(time()),
|
||||
model=model,
|
||||
object='chat.completion.chunk',
|
||||
system_fingerprint='',
|
||||
object="chat.completion.chunk",
|
||||
system_fingerprint="",
|
||||
usage=CompletionUsage(
|
||||
prompt_tokens=2,
|
||||
completion_tokens=17,
|
||||
@ -188,30 +191,45 @@ class MockChatClass:
|
||||
)
|
||||
else:
|
||||
yield ChatCompletionChunk(
|
||||
id='cmpl-3QJQa5jXJ5Z5X',
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
choices=[
|
||||
Choice(
|
||||
delta=ChoiceDelta(
|
||||
content=full_text[i],
|
||||
role='assistant',
|
||||
role="assistant",
|
||||
),
|
||||
finish_reason='content_filter',
|
||||
finish_reason="content_filter",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
created=int(time()),
|
||||
model=model,
|
||||
object='chat.completion.chunk',
|
||||
system_fingerprint='',
|
||||
object="chat.completion.chunk",
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
def chat_create(self: Completions, *,
|
||||
def chat_create(
|
||||
self: Completions,
|
||||
*,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model: Union[str,Literal[
|
||||
"gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
|
||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613"],
|
||||
model: Union[
|
||||
str,
|
||||
Literal[
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
],
|
||||
],
|
||||
functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
|
||||
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
|
||||
@ -220,24 +238,32 @@ class MockChatClass:
|
||||
**kwargs: Any,
|
||||
):
|
||||
openai_models = [
|
||||
"gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
|
||||
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
"gpt-4",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-32k-0613",
|
||||
"gpt-3.5-turbo-1106",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-3.5-turbo-16k",
|
||||
"gpt-3.5-turbo-0301",
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
]
|
||||
azure_openai_models = [
|
||||
"gpt35", "gpt-4v", "gpt-35-turbo"
|
||||
]
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError('Invalid base url')
|
||||
azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"]
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError("Invalid base url")
|
||||
if model in openai_models + azure_openai_models:
|
||||
if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
|
||||
if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
|
||||
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
|
||||
# so we only check if model is in openai_models
|
||||
raise InvokeAuthorizationError('Invalid api key')
|
||||
raise InvokeAuthorizationError("Invalid api key")
|
||||
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
|
||||
raise InvokeAuthorizationError('Invalid api key')
|
||||
raise InvokeAuthorizationError("Invalid api key")
|
||||
if stream:
|
||||
return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools)
|
||||
|
||||
return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)
|
||||
|
||||
return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)
|
||||
|
@ -17,9 +17,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
class MockCompletionsClass:
|
||||
@staticmethod
|
||||
def mocked_openai_completion_create_sync(
|
||||
model: str
|
||||
) -> CompletionMessage:
|
||||
def mocked_openai_completion_create_sync(model: str) -> CompletionMessage:
|
||||
return CompletionMessage(
|
||||
id="cmpl-3QJQa5jXJ5Z5X",
|
||||
object="text_completion",
|
||||
@ -38,13 +36,11 @@ class MockCompletionsClass:
|
||||
prompt_tokens=2,
|
||||
completion_tokens=1,
|
||||
total_tokens=3,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def mocked_openai_completion_create_stream(
|
||||
model: str
|
||||
) -> Generator[CompletionMessage, None, None]:
|
||||
def mocked_openai_completion_create_stream(model: str) -> Generator[CompletionMessage, None, None]:
|
||||
full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
|
||||
for i in range(0, len(full_text) + 1):
|
||||
if i == len(full_text):
|
||||
@ -76,46 +72,59 @@ class MockCompletionsClass:
|
||||
model=model,
|
||||
system_fingerprint="",
|
||||
choices=[
|
||||
CompletionChoice(
|
||||
text=full_text[i],
|
||||
index=0,
|
||||
logprobs=None,
|
||||
finish_reason="content_filter"
|
||||
)
|
||||
CompletionChoice(text=full_text[i], index=0, logprobs=None, finish_reason="content_filter")
|
||||
],
|
||||
)
|
||||
|
||||
def completion_create(self: Completions, *, model: Union[
|
||||
str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct",
|
||||
"text-davinci-003", "text-davinci-002", "text-davinci-001",
|
||||
"code-davinci-002", "text-curie-001", "text-babbage-001",
|
||||
"text-ada-001"],
|
||||
def completion_create(
|
||||
self: Completions,
|
||||
*,
|
||||
model: Union[
|
||||
str,
|
||||
Literal[
|
||||
"babbage-002",
|
||||
"davinci-002",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"text-davinci-003",
|
||||
"text-davinci-002",
|
||||
"text-davinci-001",
|
||||
"code-davinci-002",
|
||||
"text-curie-001",
|
||||
"text-babbage-001",
|
||||
"text-ada-001",
|
||||
],
|
||||
],
|
||||
prompt: Union[str, list[str], list[int], list[list[int]], None],
|
||||
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
):
|
||||
openai_models = [
|
||||
"babbage-002", "davinci-002", "gpt-3.5-turbo-instruct", "text-davinci-003", "text-davinci-002", "text-davinci-001",
|
||||
"code-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001",
|
||||
]
|
||||
azure_openai_models = [
|
||||
"gpt-35-turbo-instruct"
|
||||
"babbage-002",
|
||||
"davinci-002",
|
||||
"gpt-3.5-turbo-instruct",
|
||||
"text-davinci-003",
|
||||
"text-davinci-002",
|
||||
"text-davinci-001",
|
||||
"code-davinci-002",
|
||||
"text-curie-001",
|
||||
"text-babbage-001",
|
||||
"text-ada-001",
|
||||
]
|
||||
azure_openai_models = ["gpt-35-turbo-instruct"]
|
||||
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError('Invalid base url')
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError("Invalid base url")
|
||||
if model in openai_models + azure_openai_models:
|
||||
if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
|
||||
if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
|
||||
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
|
||||
# so we only check if model is in openai_models
|
||||
raise InvokeAuthorizationError('Invalid api key')
|
||||
raise InvokeAuthorizationError("Invalid api key")
|
||||
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
|
||||
raise InvokeAuthorizationError('Invalid api key')
|
||||
|
||||
raise InvokeAuthorizationError("Invalid api key")
|
||||
|
||||
if not prompt:
|
||||
raise BadRequestError('Invalid prompt')
|
||||
raise BadRequestError("Invalid prompt")
|
||||
if stream:
|
||||
return MockCompletionsClass.mocked_openai_completion_create_stream(model=model)
|
||||
|
||||
return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)
|
||||
|
||||
return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)
|
||||
|
File diff suppressed because one or more lines are too long
@ -10,58 +10,92 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class MockModerationClass:
|
||||
def moderation_create(self: Moderations,*,
|
||||
def moderation_create(
|
||||
self: Moderations,
|
||||
*,
|
||||
input: Union[str, list[str]],
|
||||
model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> ModerationCreateResponse:
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError('Invalid base url')
|
||||
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError("Invalid base url")
|
||||
|
||||
if len(self._client.api_key) < 18:
|
||||
raise InvokeAuthorizationError('Invalid API key')
|
||||
raise InvokeAuthorizationError("Invalid API key")
|
||||
|
||||
for text in input:
|
||||
result = []
|
||||
if 'kill' in text:
|
||||
if "kill" in text:
|
||||
moderation_categories = {
|
||||
'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False,
|
||||
'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False,
|
||||
'sexual/minors': False, 'violence': False, 'violence/graphic': False
|
||||
"harassment": False,
|
||||
"harassment/threatening": False,
|
||||
"hate": False,
|
||||
"hate/threatening": False,
|
||||
"self-harm": False,
|
||||
"self-harm/instructions": False,
|
||||
"self-harm/intent": False,
|
||||
"sexual": False,
|
||||
"sexual/minors": False,
|
||||
"violence": False,
|
||||
"violence/graphic": False,
|
||||
}
|
||||
moderation_categories_scores = {
|
||||
'harassment': 1.0, 'harassment/threatening': 1.0, 'hate': 1.0, 'hate/threatening': 1.0,
|
||||
'self-harm': 1.0, 'self-harm/instructions': 1.0, 'self-harm/intent': 1.0, 'sexual': 1.0,
|
||||
'sexual/minors': 1.0, 'violence': 1.0, 'violence/graphic': 1.0
|
||||
"harassment": 1.0,
|
||||
"harassment/threatening": 1.0,
|
||||
"hate": 1.0,
|
||||
"hate/threatening": 1.0,
|
||||
"self-harm": 1.0,
|
||||
"self-harm/instructions": 1.0,
|
||||
"self-harm/intent": 1.0,
|
||||
"sexual": 1.0,
|
||||
"sexual/minors": 1.0,
|
||||
"violence": 1.0,
|
||||
"violence/graphic": 1.0,
|
||||
}
|
||||
|
||||
result.append(Moderation(
|
||||
flagged=True,
|
||||
categories=Categories(**moderation_categories),
|
||||
category_scores=CategoryScores(**moderation_categories_scores)
|
||||
))
|
||||
result.append(
|
||||
Moderation(
|
||||
flagged=True,
|
||||
categories=Categories(**moderation_categories),
|
||||
category_scores=CategoryScores(**moderation_categories_scores),
|
||||
)
|
||||
)
|
||||
else:
|
||||
moderation_categories = {
|
||||
'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False,
|
||||
'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False,
|
||||
'sexual/minors': False, 'violence': False, 'violence/graphic': False
|
||||
"harassment": False,
|
||||
"harassment/threatening": False,
|
||||
"hate": False,
|
||||
"hate/threatening": False,
|
||||
"self-harm": False,
|
||||
"self-harm/instructions": False,
|
||||
"self-harm/intent": False,
|
||||
"sexual": False,
|
||||
"sexual/minors": False,
|
||||
"violence": False,
|
||||
"violence/graphic": False,
|
||||
}
|
||||
moderation_categories_scores = {
|
||||
'harassment': 0.0, 'harassment/threatening': 0.0, 'hate': 0.0, 'hate/threatening': 0.0,
|
||||
'self-harm': 0.0, 'self-harm/instructions': 0.0, 'self-harm/intent': 0.0, 'sexual': 0.0,
|
||||
'sexual/minors': 0.0, 'violence': 0.0, 'violence/graphic': 0.0
|
||||
"harassment": 0.0,
|
||||
"harassment/threatening": 0.0,
|
||||
"hate": 0.0,
|
||||
"hate/threatening": 0.0,
|
||||
"self-harm": 0.0,
|
||||
"self-harm/instructions": 0.0,
|
||||
"self-harm/intent": 0.0,
|
||||
"sexual": 0.0,
|
||||
"sexual/minors": 0.0,
|
||||
"violence": 0.0,
|
||||
"violence/graphic": 0.0,
|
||||
}
|
||||
result.append(Moderation(
|
||||
flagged=False,
|
||||
categories=Categories(**moderation_categories),
|
||||
category_scores=CategoryScores(**moderation_categories_scores)
|
||||
))
|
||||
result.append(
|
||||
Moderation(
|
||||
flagged=False,
|
||||
categories=Categories(**moderation_categories),
|
||||
category_scores=CategoryScores(**moderation_categories_scores),
|
||||
)
|
||||
)
|
||||
|
||||
return ModerationCreateResponse(
|
||||
id='shiroii kuloko',
|
||||
model=model,
|
||||
results=result
|
||||
)
|
||||
return ModerationCreateResponse(id="shiroii kuloko", model=model, results=result)
|
||||
|
@ -6,17 +6,18 @@ from openai.types.model import Model
|
||||
|
||||
class MockModelClass:
|
||||
"""
|
||||
mock class for openai.models.Models
|
||||
mock class for openai.models.Models
|
||||
"""
|
||||
|
||||
def list(
|
||||
self,
|
||||
**kwargs,
|
||||
) -> list[Model]:
|
||||
return [
|
||||
Model(
|
||||
id='ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ',
|
||||
id="ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ",
|
||||
created=int(time()),
|
||||
object='model',
|
||||
owned_by='organization:org-123',
|
||||
object="model",
|
||||
owned_by="organization:org-123",
|
||||
)
|
||||
]
|
||||
]
|
||||
|
@ -9,7 +9,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
|
||||
|
||||
class MockSpeech2TextClass:
|
||||
def speech2text_create(self: Transcriptions,
|
||||
def speech2text_create(
|
||||
self: Transcriptions,
|
||||
*,
|
||||
file: FileTypes,
|
||||
model: Union[str, Literal["whisper-1"]],
|
||||
@ -17,14 +18,12 @@ class MockSpeech2TextClass:
|
||||
prompt: str | NotGiven = NOT_GIVEN,
|
||||
response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN,
|
||||
temperature: float | NotGiven = NOT_GIVEN,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> Transcription:
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError('Invalid base url')
|
||||
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
|
||||
raise InvokeAuthorizationError("Invalid base url")
|
||||
|
||||
if len(self._client.api_key) < 18:
|
||||
raise InvokeAuthorizationError('Invalid API key')
|
||||
|
||||
return Transcription(
|
||||
text='1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
|
||||
)
|
||||
raise InvokeAuthorizationError("Invalid API key")
|
||||
|
||||
return Transcription(text="1, 2, 3, 4, 5, 6, 7, 8, 9, 10")
|
||||
|
@ -19,40 +19,43 @@ from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
|
||||
|
||||
|
||||
class MockXinferenceClass:
|
||||
def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
|
||||
if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url):
|
||||
raise RuntimeError('404 Not Found')
|
||||
|
||||
if 'generate' == model_uid:
|
||||
def get_chat_model(
|
||||
self: Client, model_uid: str
|
||||
) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
|
||||
if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if "generate" == model_uid:
|
||||
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if 'chat' == model_uid:
|
||||
if "chat" == model_uid:
|
||||
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if 'embedding' == model_uid:
|
||||
if "embedding" == model_uid:
|
||||
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
if 'rerank' == model_uid:
|
||||
if "rerank" == model_uid:
|
||||
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
|
||||
raise RuntimeError('404 Not Found')
|
||||
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
def get(self: Session, url: str, **kwargs):
|
||||
response = Response()
|
||||
if 'v1/models/' in url:
|
||||
if "v1/models/" in url:
|
||||
# get model uid
|
||||
model_uid = url.split('/')[-1] or ''
|
||||
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
|
||||
model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
|
||||
model_uid = url.split("/")[-1] or ""
|
||||
if not re.match(
|
||||
r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid
|
||||
) and model_uid not in ["generate", "chat", "embedding", "rerank"]:
|
||||
response.status_code = 404
|
||||
response._content = b'{}'
|
||||
response._content = b"{}"
|
||||
return response
|
||||
|
||||
# check if url is valid
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url):
|
||||
response.status_code = 404
|
||||
response._content = b'{}'
|
||||
response._content = b"{}"
|
||||
return response
|
||||
|
||||
if model_uid in ['generate', 'chat']:
|
||||
|
||||
if model_uid in ["generate", "chat"]:
|
||||
response.status_code = 200
|
||||
response._content = b'''{
|
||||
response._content = b"""{
|
||||
"model_type": "LLM",
|
||||
"address": "127.0.0.1:43877",
|
||||
"accelerators": [
|
||||
@ -75,12 +78,12 @@ class MockXinferenceClass:
|
||||
"revision": null,
|
||||
"context_length": 2048,
|
||||
"replica": 1
|
||||
}'''
|
||||
}"""
|
||||
return response
|
||||
|
||||
elif model_uid == 'embedding':
|
||||
|
||||
elif model_uid == "embedding":
|
||||
response.status_code = 200
|
||||
response._content = b'''{
|
||||
response._content = b"""{
|
||||
"model_type": "embedding",
|
||||
"address": "127.0.0.1:43877",
|
||||
"accelerators": [
|
||||
@ -93,51 +96,48 @@ class MockXinferenceClass:
|
||||
],
|
||||
"revision": null,
|
||||
"max_tokens": 512
|
||||
}'''
|
||||
}"""
|
||||
return response
|
||||
|
||||
elif 'v1/cluster/auth' in url:
|
||||
|
||||
elif "v1/cluster/auth" in url:
|
||||
response.status_code = 200
|
||||
response._content = b'''{
|
||||
response._content = b"""{
|
||||
"auth": true
|
||||
}'''
|
||||
}"""
|
||||
return response
|
||||
|
||||
|
||||
def _check_cluster_authenticated(self):
|
||||
self._cluster_authed = True
|
||||
|
||||
def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict:
|
||||
|
||||
def rerank(
|
||||
self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool
|
||||
) -> dict:
|
||||
# check if self._model_uid is a valid uuid
|
||||
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
|
||||
self._model_uid != 'rerank':
|
||||
raise RuntimeError('404 Not Found')
|
||||
|
||||
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url):
|
||||
raise RuntimeError('404 Not Found')
|
||||
if (
|
||||
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
|
||||
and self._model_uid != "rerank"
|
||||
):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if top_n is None:
|
||||
top_n = 1
|
||||
|
||||
return {
|
||||
'results': [
|
||||
{
|
||||
'index': i,
|
||||
'document': doc,
|
||||
'relevance_score': 0.9
|
||||
}
|
||||
for i, doc in enumerate(documents[:top_n])
|
||||
"results": [
|
||||
{"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n])
|
||||
]
|
||||
}
|
||||
|
||||
def create_embedding(
|
||||
self: RESTfulGenerateModelHandle,
|
||||
input: Union[str, list[str]],
|
||||
**kwargs
|
||||
) -> dict:
|
||||
|
||||
def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict:
|
||||
# check if self._model_uid is a valid uuid
|
||||
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
|
||||
self._model_uid != 'embedding':
|
||||
raise RuntimeError('404 Not Found')
|
||||
if (
|
||||
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
|
||||
and self._model_uid != "embedding"
|
||||
):
|
||||
raise RuntimeError("404 Not Found")
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
@ -147,32 +147,27 @@ class MockXinferenceClass:
|
||||
object="list",
|
||||
model=self._model_uid,
|
||||
data=[
|
||||
EmbeddingData(
|
||||
index=i,
|
||||
object="embedding",
|
||||
embedding=[1919.810 for _ in range(768)]
|
||||
)
|
||||
EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)])
|
||||
for i in range(ipt_len)
|
||||
],
|
||||
usage=EmbeddingUsage(
|
||||
prompt_tokens=ipt_len,
|
||||
total_tokens=ipt_len
|
||||
)
|
||||
usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len),
|
||||
)
|
||||
|
||||
return embedding
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
|
||||
monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
|
||||
monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
|
||||
monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
|
||||
monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
|
||||
monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model)
|
||||
monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated)
|
||||
monkeypatch.setattr(Session, "get", MockXinferenceClass.get)
|
||||
monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding)
|
||||
monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank)
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
monkeypatch.undo()
|
||||
|
@ -10,79 +10,60 @@ from core.model_runtime.model_providers.anthropic.llm.llm import AnthropicLargeL
|
||||
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
|
||||
def test_validate_credentials(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="claude-instant-1.2", credentials={"anthropic_api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
}
|
||||
model="claude-instant-1.2", credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
|
||||
def test_invoke_model(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='claude-instant-1.2',
|
||||
model="claude-instant-1.2",
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'),
|
||||
'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL')
|
||||
"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY"),
|
||||
"anthropic_api_url": os.environ.get("ANTHROPIC_API_URL"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'top_p': 1.0,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['How'],
|
||||
model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
|
||||
def test_invoke_stream_model(setup_anthropic_mock):
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
},
|
||||
model="claude-instant-1.2",
|
||||
credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -98,18 +79,14 @@ def test_get_num_tokens():
|
||||
model = AnthropicLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='claude-instant-1.2',
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
},
|
||||
model="claude-instant-1.2",
|
||||
credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 18
|
||||
|
@ -7,17 +7,11 @@ from core.model_runtime.model_providers.anthropic.anthropic import AnthropicProv
|
||||
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
|
||||
def test_validate_provider_credentials(setup_anthropic_mock):
|
||||
provider = AnthropicProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")})
|
||||
|
File diff suppressed because one or more lines are too long
@ -8,45 +8,43 @@ from core.model_runtime.model_providers.azure_openai.text_embedding.text_embeddi
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
|
||||
def test_validate_credentials(setup_openai_mock):
|
||||
model = AzureOpenAITextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='embedding',
|
||||
model="embedding",
|
||||
credentials={
|
||||
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
|
||||
'openai_api_key': 'invalid_key',
|
||||
'base_model_name': 'text-embedding-ada-002'
|
||||
}
|
||||
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
|
||||
"openai_api_key": "invalid_key",
|
||||
"base_model_name": "text-embedding-ada-002",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='embedding',
|
||||
model="embedding",
|
||||
credentials={
|
||||
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
|
||||
'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'),
|
||||
'base_model_name': 'text-embedding-ada-002'
|
||||
}
|
||||
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
|
||||
"openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
|
||||
"base_model_name": "text-embedding-ada-002",
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
|
||||
def test_invoke_model(setup_openai_mock):
|
||||
model = AzureOpenAITextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='embedding',
|
||||
model="embedding",
|
||||
credentials={
|
||||
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
|
||||
'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'),
|
||||
'base_model_name': 'text-embedding-ada-002'
|
||||
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
|
||||
"openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
|
||||
"base_model_name": "text-embedding-ada-002",
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
@ -58,14 +56,7 @@ def test_get_num_tokens():
|
||||
model = AzureOpenAITextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='embedding',
|
||||
credentials={
|
||||
'base_model_name': 'text-embedding-ada-002'
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
model="embedding", credentials={"base_model_name": "text-embedding-ada-002"}, texts=["hello", "world"]
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
@ -17,111 +17,99 @@ def test_predefined_models():
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='baichuan2-turbo',
|
||||
credentials={
|
||||
'api_key': 'invalid_key',
|
||||
'secret_key': 'invalid_key'
|
||||
}
|
||||
model="baichuan2-turbo", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='baichuan2-turbo',
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
}
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='baichuan2-turbo',
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_model_with_system_message():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='baichuan2-turbo',
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='请记住你是Kasumi。'
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='现在告诉我你是谁?'
|
||||
)
|
||||
SystemPromptMessage(content="请记住你是Kasumi。"),
|
||||
UserPromptMessage(content="现在告诉我你是谁?"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='baichuan2-turbo',
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -131,34 +119,31 @@ def test_invoke_stream_model():
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_invoke_with_search():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='baichuan2-turbo',
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='北京今天的天气怎么样'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
'with_search_enhance': True,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
"with_search_enhance": True,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
total_message = ''
|
||||
total_message = ""
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
@ -166,25 +151,22 @@ def test_invoke_with_search():
|
||||
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
|
||||
total_message += chunk.delta.message.content
|
||||
|
||||
assert '不' not in total_message
|
||||
assert "不" not in total_message
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
sleep(3)
|
||||
model = BaichuanLarguageModel()
|
||||
|
||||
response = model.get_num_tokens(
|
||||
model='baichuan2-turbo',
|
||||
model="baichuan2-turbo",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
tools=[]
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
tools=[],
|
||||
)
|
||||
|
||||
assert isinstance(response, int)
|
||||
assert response == 9
|
||||
assert response == 9
|
||||
|
@ -10,14 +10,6 @@ def test_validate_provider_credentials():
|
||||
provider = BaichuanProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': 'hahahaha'
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY')
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")})
|
||||
|
@ -11,18 +11,10 @@ def test_validate_credentials():
|
||||
model = BaichuanTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='baichuan-text-embedding',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="baichuan-text-embedding", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='baichuan-text-embedding',
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY')
|
||||
}
|
||||
model="baichuan-text-embedding", credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")}
|
||||
)
|
||||
|
||||
|
||||
@ -30,44 +22,40 @@ def test_invoke_model():
|
||||
model = BaichuanTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='baichuan-text-embedding',
|
||||
model="baichuan-text-embedding",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 6
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = BaichuanTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='baichuan-text-embedding',
|
||||
model="baichuan-text-embedding",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
||||
|
||||
def test_max_chunks():
|
||||
model = BaichuanTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='baichuan-text-embedding',
|
||||
model="baichuan-text-embedding",
|
||||
credentials={
|
||||
'api_key': os.environ.get('BAICHUAN_API_KEY'),
|
||||
"api_key": os.environ.get("BAICHUAN_API_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
@ -92,8 +80,8 @@ def test_max_chunks():
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 22
|
||||
assert len(result.embeddings) == 22
|
||||
|
@ -13,77 +13,63 @@ def test_validate_credentials():
|
||||
model = BedrockLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='meta.llama2-13b-chat-v1',
|
||||
credentials={
|
||||
'anthropic_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="meta.llama2-13b-chat-v1", credentials={"anthropic_api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='meta.llama2-13b-chat-v1',
|
||||
model="meta.llama2-13b-chat-v1",
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||
}
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = BedrockLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='meta.llama2-13b-chat-v1',
|
||||
model="meta.llama2-13b-chat-v1",
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'top_p': 1.0,
|
||||
'max_tokens_to_sample': 10
|
||||
},
|
||||
stop=['How'],
|
||||
model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens_to_sample": 10},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = BedrockLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='meta.llama2-13b-chat-v1',
|
||||
model="meta.llama2-13b-chat-v1",
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens_to_sample': 100
|
||||
},
|
||||
model_parameters={"temperature": 0.0, "max_tokens_to_sample": 100},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -100,20 +86,18 @@ def test_get_num_tokens():
|
||||
model = BedrockLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='meta.llama2-13b-chat-v1',
|
||||
credentials = {
|
||||
model="meta.llama2-13b-chat-v1",
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 18
|
||||
|
@ -10,14 +10,12 @@ def test_validate_provider_credentials():
|
||||
provider = BedrockProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
}
|
||||
)
|
||||
|
@ -23,79 +23,64 @@ def test_predefined_models():
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_validate_credentials_for_chat_model(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="chatglm2-6b", credentials={"api_base": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="chatglm2-6b", credentials={"api_base": os.environ.get("CHATGLM_API_BASE")})
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_model(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
model="chatglm2-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_stream_model(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
model="chatglm2-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -105,56 +90,45 @@ def test_invoke_stream_model(setup_openai_mock):
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_stream_model_with_functions(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chatglm3-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
model="chatglm3-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。'
|
||||
content="你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。"
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='波士顿天气如何?'
|
||||
)
|
||||
UserPromptMessage(content="波士顿天气如何?"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
user='abc-123',
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=True,
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_current_weather',
|
||||
description='Get the current weather in a given location',
|
||||
name="get_current_weather",
|
||||
description="Get the current weather in a given location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
|
||||
|
||||
call: LLMResultChunk = None
|
||||
chunks = []
|
||||
|
||||
@ -170,122 +144,87 @@ def test_invoke_stream_model_with_functions(setup_openai_mock):
|
||||
break
|
||||
|
||||
assert call is not None
|
||||
assert call.delta.message.tool_calls[0].function.name == 'get_current_weather'
|
||||
assert call.delta.message.tool_calls[0].function.name == "get_current_weather"
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_model_with_functions(setup_openai_mock):
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chatglm3-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='What is the weather like in San Francisco?'
|
||||
)
|
||||
],
|
||||
model="chatglm3-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[UserPromptMessage(content="What is the weather like in San Francisco?")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
user='abc-123',
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False,
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_current_weather',
|
||||
description='Get the current weather in a given location',
|
||||
name="get_current_weather",
|
||||
description="Get the current weather in a given location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
assert response.message.tool_calls[0].function.name == 'get_current_weather'
|
||||
assert response.message.tool_calls[0].function.name == "get_current_weather"
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = ChatGLMLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
model="chatglm2-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_current_weather',
|
||||
description='Get the current weather in a given location',
|
||||
name="get_current_weather",
|
||||
description="Get the current weather in a given location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 77
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='chatglm2-6b',
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
},
|
||||
model="chatglm2-6b",
|
||||
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 21
|
||||
assert num_tokens == 21
|
||||
|
@ -7,19 +7,11 @@ from core.model_runtime.model_providers.chatglm.chatglm import ChatGLMProvider
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_validate_provider_credentials(setup_openai_mock):
|
||||
provider = ChatGLMProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_base': 'hahahaha'
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_base": "hahahaha"})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_base': os.environ.get('CHATGLM_API_BASE')
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_base": os.environ.get("CHATGLM_API_BASE")})
|
||||
|
@ -13,87 +13,49 @@ def test_validate_credentials_for_chat_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='command-light-chat',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="command-light-chat", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='command-light-chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="command-light-chat", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
|
||||
|
||||
|
||||
def test_validate_credentials_for_completion_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='command-light',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="command-light", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='command-light',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="command-light", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
|
||||
|
||||
|
||||
def test_invoke_completion_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
credentials = {
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
credentials = {"api_key": os.environ.get("COHERE_API_KEY")}
|
||||
|
||||
result = model.invoke(
|
||||
model='command-light',
|
||||
model="command-light",
|
||||
credentials=credentials,
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 1
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 1},
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert len(result.message.content) > 0
|
||||
assert model._num_tokens_from_string('command-light', credentials, result.message.content) == 1
|
||||
assert model._num_tokens_from_string("command-light", credentials, result.message.content) == 1
|
||||
|
||||
|
||||
def test_invoke_stream_completion_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='command-light',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
model="command-light",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
@ -109,28 +71,24 @@ def test_invoke_chat_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='command-light-chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
model="command-light-chat",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'p': 0.99,
|
||||
'presence_penalty': 0.0,
|
||||
'frequency_penalty': 0.0,
|
||||
'max_tokens': 10
|
||||
"temperature": 0.0,
|
||||
"p": 0.99,
|
||||
"presence_penalty": 0.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"max_tokens": 10,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
@ -141,24 +99,17 @@ def test_invoke_stream_chat_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='command-light-chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
model="command-light-chat",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
@ -177,32 +128,22 @@ def test_get_num_tokens():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='command-light',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
model="command-light",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert num_tokens == 3
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='command-light-chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
model="command-light-chat",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 15
|
||||
@ -213,25 +154,17 @@ def test_fine_tuned_model():
|
||||
|
||||
# test invoke
|
||||
result = model.invoke(
|
||||
model='85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY'),
|
||||
'mode': 'completion'
|
||||
},
|
||||
model="85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "completion"},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
@ -242,25 +175,17 @@ def test_fine_tuned_chat_model():
|
||||
|
||||
# test invoke
|
||||
result = model.invoke(
|
||||
model='94f2d55a-4c79-4c00-bde4-23962e74b170-ft',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY'),
|
||||
'mode': 'chat'
|
||||
},
|
||||
model="94f2d55a-4c79-4c00-bde4-23962e74b170-ft",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "chat"},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
|
@ -10,12 +10,6 @@ def test_validate_provider_credentials():
|
||||
provider = CohereProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("COHERE_API_KEY")})
|
||||
|
@ -11,29 +11,17 @@ def test_validate_credentials():
|
||||
model = CohereRerankModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='rerank-english-v2.0',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='rerank-english-v2.0',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = CohereRerankModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='rerank-english-v2.0',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
model="rerank-english-v2.0",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||
@ -41,9 +29,9 @@ def test_invoke_model():
|
||||
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) "
|
||||
"is the capital of the United States. It is a federal district. The President of the USA and many major "
|
||||
"national government offices are in the territory. This makes it the political center of the United "
|
||||
"States of America."
|
||||
"States of America.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
score_threshold=0.8,
|
||||
)
|
||||
|
||||
assert isinstance(result, RerankResult)
|
||||
|
@ -11,18 +11,10 @@ def test_validate_credentials():
|
||||
model = CohereTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='embed-multilingual-v3.0',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="embed-multilingual-v3.0", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='embed-multilingual-v3.0',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
model="embed-multilingual-v3.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")}
|
||||
)
|
||||
|
||||
|
||||
@ -30,17 +22,10 @@ def test_invoke_model():
|
||||
model = CohereTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='embed-multilingual-v3.0',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world",
|
||||
" ".join(["long_text"] * 100),
|
||||
" ".join(["another_long_text"] * 100)
|
||||
],
|
||||
user="abc-123"
|
||||
model="embed-multilingual-v3.0",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
@ -52,14 +37,9 @@ def test_get_num_tokens():
|
||||
model = CohereTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='embed-multilingual-v3.0',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
model="embed-multilingual-v3.0",
|
||||
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 3
|
||||
|
File diff suppressed because one or more lines are too long
@ -7,17 +7,11 @@ from core.model_runtime.model_providers.google.google import GoogleProvider
|
||||
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
|
||||
def test_validate_provider_credentials(setup_google_mock):
|
||||
provider = GoogleProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'google_api_key': os.environ.get('GOOGLE_API_KEY')
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")})
|
||||
|
@ -10,87 +10,75 @@ from core.model_runtime.model_providers.huggingface_hub.llm.llm import Huggingfa
|
||||
from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_hosted_inference_api_validate_credentials(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='HuggingFaceH4/zephyr-7b-beta',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': 'invalid_key'
|
||||
}
|
||||
model="HuggingFaceH4/zephyr-7b-beta",
|
||||
credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
|
||||
)
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='fake-model',
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': 'invalid_key'
|
||||
}
|
||||
model="fake-model",
|
||||
credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='HuggingFaceH4/zephyr-7b-beta',
|
||||
model="HuggingFaceH4/zephyr-7b-beta",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
|
||||
}
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_hosted_inference_api_invoke_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='HuggingFaceH4/zephyr-7b-beta',
|
||||
model="HuggingFaceH4/zephyr-7b-beta",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='HuggingFaceH4/zephyr-7b-beta',
|
||||
model="HuggingFaceH4/zephyr-7b-beta",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -101,86 +89,81 @@ def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='openchat/openchat_3.5',
|
||||
model="openchat/openchat_3.5",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': 'invalid_key',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text-generation'
|
||||
}
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": "invalid_key",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text-generation",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='openchat/openchat_3.5',
|
||||
model="openchat/openchat_3.5",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text-generation'
|
||||
}
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text-generation",
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='openchat/openchat_3.5',
|
||||
model="openchat/openchat_3.5",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text-generation'
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text-generation",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='openchat/openchat_3.5',
|
||||
model="openchat/openchat_3.5",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text-generation'
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text-generation",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -191,86 +174,81 @@ def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingfa
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='google/mt5-base',
|
||||
model="google/mt5-base",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': 'invalid_key',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text2text-generation'
|
||||
}
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": "invalid_key",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text2text-generation",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='google/mt5-base',
|
||||
model="google/mt5-base",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text2text-generation'
|
||||
}
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text2text-generation",
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='google/mt5-base',
|
||||
model="google/mt5-base",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text2text-generation'
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text2text-generation",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
|
||||
def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock):
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='google/mt5-base',
|
||||
model="google/mt5-base",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text2text-generation'
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text2text-generation",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -286,18 +264,14 @@ def test_get_num_tokens():
|
||||
model = HuggingfaceHubLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='google/mt5-base',
|
||||
model="google/mt5-base",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
|
||||
'task_type': 'text2text-generation'
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
|
||||
"task_type": "text2text-generation",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert num_tokens == 7
|
||||
|
@ -14,19 +14,19 @@ def test_hosted_inference_api_validate_credentials():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='facebook/bart-base',
|
||||
model="facebook/bart-base",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': 'invalid_key',
|
||||
}
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": "invalid_key",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='facebook/bart-base',
|
||||
model="facebook/bart-base",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
}
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -34,15 +34,12 @@ def test_hosted_inference_api_invoke_model():
|
||||
model = HuggingfaceHubTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='facebook/bart-base',
|
||||
model="facebook/bart-base",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'hosted_inference_api',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
"huggingfacehub_api_type": "hosted_inference_api",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
@ -55,25 +52,25 @@ def test_inference_endpoints_validate_credentials():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='all-MiniLM-L6-v2',
|
||||
model="all-MiniLM-L6-v2",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': 'invalid_key',
|
||||
'huggingface_namespace': 'Dify-AI',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
|
||||
'task_type': 'feature-extraction'
|
||||
}
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": "invalid_key",
|
||||
"huggingface_namespace": "Dify-AI",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
|
||||
"task_type": "feature-extraction",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='all-MiniLM-L6-v2',
|
||||
model="all-MiniLM-L6-v2",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingface_namespace': 'Dify-AI',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
|
||||
'task_type': 'feature-extraction'
|
||||
}
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingface_namespace": "Dify-AI",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
|
||||
"task_type": "feature-extraction",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -81,18 +78,15 @@ def test_inference_endpoints_invoke_model():
|
||||
model = HuggingfaceHubTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='all-MiniLM-L6-v2',
|
||||
model="all-MiniLM-L6-v2",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingface_namespace': 'Dify-AI',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
|
||||
'task_type': 'feature-extraction'
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingface_namespace": "Dify-AI",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
|
||||
"task_type": "feature-extraction",
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
@ -104,18 +98,15 @@ def test_get_num_tokens():
|
||||
model = HuggingfaceHubTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='all-MiniLM-L6-v2',
|
||||
model="all-MiniLM-L6-v2",
|
||||
credentials={
|
||||
'huggingfacehub_api_type': 'inference_endpoints',
|
||||
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
|
||||
'huggingface_namespace': 'Dify-AI',
|
||||
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
|
||||
'task_type': 'feature-extraction'
|
||||
"huggingfacehub_api_type": "inference_endpoints",
|
||||
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
|
||||
"huggingface_namespace": "Dify-AI",
|
||||
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
|
||||
"task_type": "feature-extraction",
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
@ -10,61 +10,59 @@ from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embe
|
||||
)
|
||||
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
|
||||
monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank)
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
|
||||
def test_validate_credentials(setup_tei_mock):
|
||||
model = HuggingfaceTeiTextEmbeddingModel()
|
||||
# model name is only used in mock
|
||||
model_name = 'embedding'
|
||||
model_name = "embedding"
|
||||
|
||||
if MOCK:
|
||||
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
|
||||
# So we dont need to check model type here. Only check in mock
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='reranker',
|
||||
model="reranker",
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
|
||||
}
|
||||
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model=model_name,
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
|
||||
}
|
||||
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
|
||||
def test_invoke_model(setup_tei_mock):
|
||||
model = HuggingfaceTeiTextEmbeddingModel()
|
||||
model_name = 'embedding'
|
||||
model_name = "embedding"
|
||||
|
||||
result = model.invoke(
|
||||
model=model_name,
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
|
||||
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
|
@ -11,63 +11,65 @@ from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import (
|
||||
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
|
||||
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
|
||||
|
||||
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
|
||||
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
|
||||
monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
|
||||
monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings)
|
||||
monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank)
|
||||
yield
|
||||
|
||||
if MOCK:
|
||||
monkeypatch.undo()
|
||||
|
||||
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
|
||||
def test_validate_credentials(setup_tei_mock):
|
||||
model = HuggingfaceTeiRerankModel()
|
||||
# model name is only used in mock
|
||||
model_name = 'reranker'
|
||||
model_name = "reranker"
|
||||
|
||||
if MOCK:
|
||||
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
|
||||
# So we dont need to check model type here. Only check in mock
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='embedding',
|
||||
model="embedding",
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
|
||||
}
|
||||
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model=model_name,
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
|
||||
}
|
||||
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
|
||||
def test_invoke_model(setup_tei_mock):
|
||||
model = HuggingfaceTeiRerankModel()
|
||||
# model name is only used in mock
|
||||
model_name = 'reranker'
|
||||
model_name = "reranker"
|
||||
|
||||
result = model.invoke(
|
||||
model=model_name,
|
||||
credentials={
|
||||
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
|
||||
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
|
||||
},
|
||||
query="Who is Kasumi?",
|
||||
docs=[
|
||||
"Kasumi is a girl's name of Japanese origin meaning \"mist\".",
|
||||
'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
|
||||
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
|
||||
"and she leads a team named PopiParty."
|
||||
"and she leads a team named PopiParty.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
score_threshold=0.8,
|
||||
)
|
||||
|
||||
assert isinstance(result, RerankResult)
|
||||
|
@ -14,19 +14,15 @@ def test_validate_credentials():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='hunyuan-standard',
|
||||
credentials={
|
||||
'secret_id': 'invalid_key',
|
||||
'secret_key': 'invalid_key'
|
||||
}
|
||||
model="hunyuan-standard", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='hunyuan-standard',
|
||||
model="hunyuan-standard",
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
}
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -34,23 +30,16 @@ def test_invoke_model():
|
||||
model = HunyuanLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='hunyuan-standard',
|
||||
model="hunyuan-standard",
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hi'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['How'],
|
||||
prompt_messages=[UserPromptMessage(content="Hi")],
|
||||
model_parameters={"temperature": 0.5, "max_tokens": 10},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -61,23 +50,15 @@ def test_invoke_stream_model():
|
||||
model = HunyuanLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='hunyuan-standard',
|
||||
model="hunyuan-standard",
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hi'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 100,
|
||||
'seed': 1234
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hi")],
|
||||
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -93,19 +74,17 @@ def test_get_num_tokens():
|
||||
model = HunyuanLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='hunyuan-standard',
|
||||
model="hunyuan-standard",
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 14
|
||||
|
@ -10,16 +10,11 @@ def test_validate_provider_credentials():
|
||||
provider = HunyuanProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'secret_id': 'invalid_key',
|
||||
'secret_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
}
|
||||
)
|
||||
|
@ -12,19 +12,15 @@ def test_validate_credentials():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='hunyuan-embedding',
|
||||
credentials={
|
||||
'secret_id': 'invalid_key',
|
||||
'secret_key': 'invalid_key'
|
||||
}
|
||||
model="hunyuan-embedding", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='hunyuan-embedding',
|
||||
model="hunyuan-embedding",
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
}
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -32,47 +28,43 @@ def test_invoke_model():
|
||||
model = HunyuanTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='hunyuan-embedding',
|
||||
model="hunyuan-embedding",
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 6
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = HunyuanTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='hunyuan-embedding',
|
||||
model="hunyuan-embedding",
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
||||
|
||||
def test_max_chunks():
|
||||
model = HunyuanTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='hunyuan-embedding',
|
||||
model="hunyuan-embedding",
|
||||
credentials={
|
||||
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
|
||||
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
|
||||
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
|
||||
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
@ -97,8 +89,8 @@ def test_max_chunks():
|
||||
"world",
|
||||
"hello",
|
||||
"world",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 22
|
||||
assert len(result.embeddings) == 22
|
||||
|
@ -10,14 +10,6 @@ def test_validate_provider_credentials():
|
||||
provider = JinaProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': 'hahahaha'
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('JINA_API_KEY')
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("JINA_API_KEY")})
|
||||
|
@ -11,18 +11,10 @@ def test_validate_credentials():
|
||||
model = JinaTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='jina-embeddings-v2-base-en',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="jina-embeddings-v2-base-en", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='jina-embeddings-v2-base-en',
|
||||
credentials={
|
||||
'api_key': os.environ.get('JINA_API_KEY')
|
||||
}
|
||||
model="jina-embeddings-v2-base-en", credentials={"api_key": os.environ.get("JINA_API_KEY")}
|
||||
)
|
||||
|
||||
|
||||
@ -30,15 +22,12 @@ def test_invoke_model():
|
||||
model = JinaTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='jina-embeddings-v2-base-en',
|
||||
model="jina-embeddings-v2-base-en",
|
||||
credentials={
|
||||
'api_key': os.environ.get('JINA_API_KEY'),
|
||||
"api_key": os.environ.get("JINA_API_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
@ -50,14 +39,11 @@ def test_get_num_tokens():
|
||||
model = JinaTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='jina-embeddings-v2-base-en',
|
||||
model="jina-embeddings-v2-base-en",
|
||||
credentials={
|
||||
'api_key': os.environ.get('JINA_API_KEY'),
|
||||
"api_key": os.environ.get("JINA_API_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 6
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""
|
||||
LocalAI Embedding Interface is temporarily unavailable due to
|
||||
we could not find a way to test it for now.
|
||||
"""
|
||||
LocalAI Embedding Interface is temporarily unavailable due to
|
||||
we could not find a way to test it for now.
|
||||
"""
|
||||
|
@ -21,99 +21,78 @@ def test_validate_credentials_for_chat_model():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='chinese-llama-2-7b',
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
'server_url': 'hahahaha',
|
||||
'completion_type': 'completion',
|
||||
}
|
||||
"server_url": "hahahaha",
|
||||
"completion_type": "completion",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='chinese-llama-2-7b',
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'completion',
|
||||
}
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "completion",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_completion_model():
|
||||
model = LocalAILanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chinese-llama-2-7b',
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'completion',
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='ping'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'max_tokens': 10
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "completion",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="ping")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=[],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_chat_model():
|
||||
model = LocalAILanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chinese-llama-2-7b',
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'chat_completion',
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='ping'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'max_tokens': 10
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "chat_completion",
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="ping")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=[],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_stream_completion_model():
|
||||
model = LocalAILanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chinese-llama-2-7b',
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'completion',
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "completion",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['you'],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -123,28 +102,21 @@ def test_invoke_stream_completion_model():
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_invoke_stream_chat_model():
|
||||
model = LocalAILanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chinese-llama-2-7b',
|
||||
model="chinese-llama-2-7b",
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'chat_completion',
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "chat_completion",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['you'],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -154,64 +126,48 @@ def test_invoke_stream_chat_model():
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = LocalAILanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='????',
|
||||
model="????",
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'chat_completion',
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "chat_completion",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_current_weather',
|
||||
description='Get the current weather in a given location',
|
||||
name="get_current_weather",
|
||||
description="Get the current weather in a given location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 77
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='????',
|
||||
model="????",
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'chat_completion',
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "chat_completion",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
|
@ -12,30 +12,29 @@ def test_validate_credentials_for_chat_model():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='bge-reranker-v2-m3',
|
||||
model="bge-reranker-v2-m3",
|
||||
credentials={
|
||||
'server_url': 'hahahaha',
|
||||
'completion_type': 'completion',
|
||||
}
|
||||
"server_url": "hahahaha",
|
||||
"completion_type": "completion",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='bge-reranker-base',
|
||||
model="bge-reranker-base",
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
|
||||
'completion_type': 'completion',
|
||||
}
|
||||
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
|
||||
"completion_type": "completion",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_rerank_model():
|
||||
model = LocalaiRerankModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='bge-reranker-base',
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL')
|
||||
},
|
||||
query='Organic skincare products for sensitive skin',
|
||||
model="bge-reranker-base",
|
||||
credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")},
|
||||
query="Organic skincare products for sensitive skin",
|
||||
docs=[
|
||||
"Eco-friendly kitchenware for modern homes",
|
||||
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||
@ -45,43 +44,38 @@ def test_invoke_rerank_model():
|
||||
"Sustainable gardening tools and compost solutions",
|
||||
"Sensitive skin-friendly facial cleansers and toners",
|
||||
"Organic food wraps and storage solutions",
|
||||
"Yoga mats made from recycled materials"
|
||||
"Yoga mats made from recycled materials",
|
||||
],
|
||||
top_n=3,
|
||||
score_threshold=0.75,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, RerankResult)
|
||||
assert len(response.docs) == 3
|
||||
|
||||
|
||||
def test__invoke():
|
||||
model = LocalaiRerankModel()
|
||||
|
||||
# Test case 1: Empty docs
|
||||
result = model._invoke(
|
||||
model='bge-reranker-base',
|
||||
credentials={
|
||||
'server_url': 'https://example.com',
|
||||
'api_key': '1234567890'
|
||||
},
|
||||
query='Organic skincare products for sensitive skin',
|
||||
model="bge-reranker-base",
|
||||
credentials={"server_url": "https://example.com", "api_key": "1234567890"},
|
||||
query="Organic skincare products for sensitive skin",
|
||||
docs=[],
|
||||
top_n=3,
|
||||
score_threshold=0.75,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 0
|
||||
|
||||
# Test case 2: Valid invocation
|
||||
result = model._invoke(
|
||||
model='bge-reranker-base',
|
||||
credentials={
|
||||
'server_url': 'https://example.com',
|
||||
'api_key': '1234567890'
|
||||
},
|
||||
query='Organic skincare products for sensitive skin',
|
||||
model="bge-reranker-base",
|
||||
credentials={"server_url": "https://example.com", "api_key": "1234567890"},
|
||||
query="Organic skincare products for sensitive skin",
|
||||
docs=[
|
||||
"Eco-friendly kitchenware for modern homes",
|
||||
"Biodegradable cleaning supplies for eco-conscious consumers",
|
||||
@ -91,12 +85,12 @@ def test__invoke():
|
||||
"Sustainable gardening tools and compost solutions",
|
||||
"Sensitive skin-friendly facial cleansers and toners",
|
||||
"Organic food wraps and storage solutions",
|
||||
"Yoga mats made from recycled materials"
|
||||
"Yoga mats made from recycled materials",
|
||||
],
|
||||
top_n=3,
|
||||
score_threshold=0.75,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
assert isinstance(result, RerankResult)
|
||||
assert len(result.docs) == 3
|
||||
assert all(isinstance(doc, RerankDocument) for doc in result.docs)
|
||||
assert all(isinstance(doc, RerankDocument) for doc in result.docs)
|
||||
|
@ -10,19 +10,9 @@ def test_validate_credentials():
|
||||
model = LocalAISpeech2text()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='whisper-1',
|
||||
credentials={
|
||||
'server_url': 'invalid_url'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="whisper-1", credentials={"server_url": "invalid_url"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='whisper-1',
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL')
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="whisper-1", credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")})
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
@ -32,23 +22,21 @@ def test_invoke_model():
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
|
||||
|
||||
# Construct the path to the audio file
|
||||
audio_file_path = os.path.join(assets_dir, 'audio.mp3')
|
||||
audio_file_path = os.path.join(assets_dir, "audio.mp3")
|
||||
|
||||
# Open the file and get the file object
|
||||
with open(audio_file_path, 'rb') as audio_file:
|
||||
with open(audio_file_path, "rb") as audio_file:
|
||||
file = audio_file
|
||||
|
||||
result = model.invoke(
|
||||
model='whisper-1',
|
||||
credentials={
|
||||
'server_url': os.environ.get('LOCALAI_SERVER_URL')
|
||||
},
|
||||
model="whisper-1",
|
||||
credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")},
|
||||
file=file,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
|
||||
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"
|
||||
|
@ -12,54 +12,47 @@ def test_validate_credentials():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='embo-01',
|
||||
credentials={
|
||||
'minimax_api_key': 'invalid_key',
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
}
|
||||
model="embo-01",
|
||||
credentials={"minimax_api_key": "invalid_key", "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID")},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='embo-01',
|
||||
model="embo-01",
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
}
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = MinimaxTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='embo-01',
|
||||
model="embo-01",
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens == 16
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = MinimaxTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='embo-01',
|
||||
model="embo-01",
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
@ -17,79 +17,70 @@ def test_predefined_models():
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='abab5.5-chat',
|
||||
credentials={
|
||||
'minimax_api_key': 'invalid_key',
|
||||
'minimax_group_id': 'invalid_key'
|
||||
}
|
||||
model="abab5.5-chat", credentials={"minimax_api_key": "invalid_key", "minimax_group_id": "invalid_key"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='abab5.5-chat',
|
||||
model="abab5.5-chat",
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
}
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='abab5-chat',
|
||||
model="abab5-chat",
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='abab5.5-chat',
|
||||
model="abab5.5-chat",
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -99,34 +90,31 @@ def test_invoke_stream_model():
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_invoke_with_search():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='abab5.5-chat',
|
||||
model="abab5.5-chat",
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='北京今天的天气怎么样'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
'plugin_web_search': True,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
"plugin_web_search": True,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
total_message = ''
|
||||
total_message = ""
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
@ -134,25 +122,22 @@ def test_invoke_with_search():
|
||||
total_message += chunk.delta.message.content
|
||||
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
|
||||
|
||||
assert '参考资料' in total_message
|
||||
assert "参考资料" in total_message
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
sleep(3)
|
||||
model = MinimaxLargeLanguageModel()
|
||||
|
||||
response = model.get_num_tokens(
|
||||
model='abab5.5-chat',
|
||||
model="abab5.5-chat",
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
tools=[]
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
tools=[],
|
||||
)
|
||||
|
||||
assert isinstance(response, int)
|
||||
assert response == 30
|
||||
assert response == 30
|
||||
|
@ -12,14 +12,14 @@ def test_validate_provider_credentials():
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'minimax_api_key': 'hahahaha',
|
||||
'minimax_group_id': '123',
|
||||
"minimax_api_key": "hahahaha",
|
||||
"minimax_group_id": "123",
|
||||
}
|
||||
)
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
|
||||
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID'),
|
||||
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
|
||||
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
|
||||
}
|
||||
)
|
||||
|
@ -19,19 +19,12 @@ def test_validate_credentials():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='meta-llama/llama-3-8b-instruct',
|
||||
credentials={
|
||||
'api_key': 'invalid_key',
|
||||
'mode': 'chat'
|
||||
}
|
||||
model="meta-llama/llama-3-8b-instruct", credentials={"api_key": "invalid_key", "mode": "chat"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='meta-llama/llama-3-8b-instruct',
|
||||
credentials={
|
||||
'api_key': os.environ.get('NOVITA_API_KEY'),
|
||||
'mode': 'chat'
|
||||
}
|
||||
model="meta-llama/llama-3-8b-instruct",
|
||||
credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"},
|
||||
)
|
||||
|
||||
|
||||
@ -39,27 +32,22 @@ def test_invoke_model():
|
||||
model = NovitaLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='meta-llama/llama-3-8b-instruct',
|
||||
credentials={
|
||||
'api_key': os.environ.get('NOVITA_API_KEY'),
|
||||
'mode': 'completion'
|
||||
},
|
||||
model="meta-llama/llama-3-8b-instruct",
|
||||
credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "completion"},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
UserPromptMessage(content="Who are you?"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_p': 0.5,
|
||||
'max_tokens': 10,
|
||||
"temperature": 1.0,
|
||||
"top_p": 0.5,
|
||||
"max_tokens": 10,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="novita"
|
||||
user="novita",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -70,27 +58,17 @@ def test_invoke_stream_model():
|
||||
model = NovitaLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='meta-llama/llama-3-8b-instruct',
|
||||
credentials={
|
||||
'api_key': os.environ.get('NOVITA_API_KEY'),
|
||||
'mode': 'chat'
|
||||
},
|
||||
model="meta-llama/llama-3-8b-instruct",
|
||||
credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
UserPromptMessage(content="Who are you?"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
'max_tokens': 100
|
||||
},
|
||||
model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="novita"
|
||||
user="novita",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -105,18 +83,16 @@ def test_get_num_tokens():
|
||||
model = NovitaLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='meta-llama/llama-3-8b-instruct',
|
||||
model="meta-llama/llama-3-8b-instruct",
|
||||
credentials={
|
||||
'api_key': os.environ.get('NOVITA_API_KEY'),
|
||||
"api_key": os.environ.get("NOVITA_API_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
|
@ -10,12 +10,10 @@ def test_validate_provider_credentials():
|
||||
provider = NovitaProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('NOVITA_API_KEY'),
|
||||
"api_key": os.environ.get("NOVITA_API_KEY"),
|
||||
}
|
||||
)
|
||||
|
File diff suppressed because one or more lines are too long
@ -12,21 +12,21 @@ def test_validate_credentials():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='mistral:text',
|
||||
model="mistral:text",
|
||||
credentials={
|
||||
'base_url': 'http://localhost:21434',
|
||||
'mode': 'chat',
|
||||
'context_size': 4096,
|
||||
}
|
||||
"base_url": "http://localhost:21434",
|
||||
"mode": "chat",
|
||||
"context_size": 4096,
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='mistral:text',
|
||||
model="mistral:text",
|
||||
credentials={
|
||||
'base_url': os.environ.get('OLLAMA_BASE_URL'),
|
||||
'mode': 'chat',
|
||||
'context_size': 4096,
|
||||
}
|
||||
"base_url": os.environ.get("OLLAMA_BASE_URL"),
|
||||
"mode": "chat",
|
||||
"context_size": 4096,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -34,17 +34,14 @@ def test_invoke_model():
|
||||
model = OllamaEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='mistral:text',
|
||||
model="mistral:text",
|
||||
credentials={
|
||||
'base_url': os.environ.get('OLLAMA_BASE_URL'),
|
||||
'mode': 'chat',
|
||||
'context_size': 4096,
|
||||
"base_url": os.environ.get("OLLAMA_BASE_URL"),
|
||||
"mode": "chat",
|
||||
"context_size": 4096,
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
@ -56,16 +53,13 @@ def test_get_num_tokens():
|
||||
model = OllamaEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='mistral:text',
|
||||
model="mistral:text",
|
||||
credentials={
|
||||
'base_url': os.environ.get('OLLAMA_BASE_URL'),
|
||||
'mode': 'chat',
|
||||
'context_size': 4096,
|
||||
"base_url": os.environ.get("OLLAMA_BASE_URL"),
|
||||
"mode": "chat",
|
||||
"context_size": 4096,
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
File diff suppressed because one or more lines are too long
@ -7,48 +7,37 @@ from core.model_runtime.model_providers.openai.moderation.moderation import Open
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True)
|
||||
def test_validate_credentials(setup_openai_mock):
|
||||
model = OpenAIModerationModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='text-moderation-stable',
|
||||
credentials={
|
||||
'openai_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="text-moderation-stable", credentials={"openai_api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='text-moderation-stable',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
}
|
||||
model="text-moderation-stable", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True)
|
||||
def test_invoke_model(setup_openai_mock):
|
||||
model = OpenAIModerationModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='text-moderation-stable',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
},
|
||||
model="text-moderation-stable",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
text="hello",
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, bool)
|
||||
assert result is False
|
||||
|
||||
result = model.invoke(
|
||||
model='text-moderation-stable',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
},
|
||||
model="text-moderation-stable",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
text="i will kill you",
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, bool)
|
||||
|
@ -7,17 +7,11 @@ from core.model_runtime.model_providers.openai.openai import OpenAIProvider
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_validate_provider_credentials(setup_openai_mock):
|
||||
provider = OpenAIProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")})
|
||||
|
@ -7,26 +7,17 @@ from core.model_runtime.model_providers.openai.speech2text.speech2text import Op
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["speech2text"]], indirect=True)
|
||||
def test_validate_credentials(setup_openai_mock):
|
||||
model = OpenAISpeech2TextModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='whisper-1',
|
||||
credentials={
|
||||
'openai_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="whisper-1", credentials={"openai_api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='whisper-1',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="whisper-1", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")})
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["speech2text"]], indirect=True)
|
||||
def test_invoke_model(setup_openai_mock):
|
||||
model = OpenAISpeech2TextModel()
|
||||
|
||||
@ -34,23 +25,21 @@ def test_invoke_model(setup_openai_mock):
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Get assets directory
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
|
||||
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
|
||||
|
||||
# Construct the path to the audio file
|
||||
audio_file_path = os.path.join(assets_dir, 'audio.mp3')
|
||||
audio_file_path = os.path.join(assets_dir, "audio.mp3")
|
||||
|
||||
# Open the file and get the file object
|
||||
with open(audio_file_path, 'rb') as audio_file:
|
||||
with open(audio_file_path, "rb") as audio_file:
|
||||
file = audio_file
|
||||
|
||||
result = model.invoke(
|
||||
model='whisper-1',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
},
|
||||
model="whisper-1",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
file=file,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
|
||||
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"
|
||||
|
@ -8,42 +8,27 @@ from core.model_runtime.model_providers.openai.text_embedding.text_embedding imp
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
|
||||
def test_validate_credentials(setup_openai_mock):
|
||||
model = OpenAITextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='text-embedding-ada-002',
|
||||
credentials={
|
||||
'openai_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="text-embedding-ada-002", credentials={"openai_api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='text-embedding-ada-002',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
}
|
||||
model="text-embedding-ada-002", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
|
||||
def test_invoke_model(setup_openai_mock):
|
||||
model = OpenAITextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='text-embedding-ada-002',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'openai_api_base': 'https://api.openai.com'
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world",
|
||||
" ".join(["long_text"] * 100),
|
||||
" ".join(["another_long_text"] * 100)
|
||||
],
|
||||
user="abc-123"
|
||||
model="text-embedding-ada-002",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"},
|
||||
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
@ -55,15 +40,9 @@ def test_get_num_tokens():
|
||||
model = OpenAITextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='text-embedding-ada-002',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'openai_api_base': 'https://api.openai.com'
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
model="text-embedding-ada-002",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"},
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
@ -23,21 +23,17 @@ def test_validate_credentials():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': 'invalid_key',
|
||||
'endpoint_url': 'https://api.together.xyz/v1/',
|
||||
'mode': 'chat'
|
||||
}
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
credentials={"api_key": "invalid_key", "endpoint_url": "https://api.together.xyz/v1/", "mode": "chat"},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'endpoint_url': 'https://api.together.xyz/v1/',
|
||||
'mode': 'chat'
|
||||
}
|
||||
"api_key": os.environ.get("TOGETHER_API_KEY"),
|
||||
"endpoint_url": "https://api.together.xyz/v1/",
|
||||
"mode": "chat",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -45,28 +41,26 @@ def test_invoke_model():
|
||||
model = OAIAPICompatLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'endpoint_url': 'https://api.together.xyz/v1/',
|
||||
'mode': 'completion'
|
||||
"api_key": os.environ.get("TOGETHER_API_KEY"),
|
||||
"endpoint_url": "https://api.together.xyz/v1/",
|
||||
"mode": "completion",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
UserPromptMessage(content="Who are you?"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -77,29 +71,27 @@ def test_invoke_stream_model():
|
||||
model = OAIAPICompatLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'endpoint_url': 'https://api.together.xyz/v1/',
|
||||
'mode': 'chat',
|
||||
'stream_mode_delimiter': '\\n\\n'
|
||||
"api_key": os.environ.get("TOGETHER_API_KEY"),
|
||||
"endpoint_url": "https://api.together.xyz/v1/",
|
||||
"mode": "chat",
|
||||
"stream_mode_delimiter": "\\n\\n",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
UserPromptMessage(content="Who are you?"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -114,28 +106,26 @@ def test_invoke_stream_model_without_delimiter():
|
||||
model = OAIAPICompatLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'endpoint_url': 'https://api.together.xyz/v1/',
|
||||
'mode': 'chat'
|
||||
"api_key": os.environ.get("TOGETHER_API_KEY"),
|
||||
"endpoint_url": "https://api.together.xyz/v1/",
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
UserPromptMessage(content="Who are you?"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -151,51 +141,37 @@ def test_invoke_chat_model_with_tools():
|
||||
model = OAIAPICompatLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='gpt-3.5-turbo',
|
||||
model="gpt-3.5-turbo",
|
||||
credentials={
|
||||
'api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'endpoint_url': 'https://api.openai.com/v1/',
|
||||
'mode': 'chat'
|
||||
"api_key": os.environ.get("OPENAI_API_KEY"),
|
||||
"endpoint_url": "https://api.openai.com/v1/",
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content="what's the weather today in London?",
|
||||
)
|
||||
),
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_weather',
|
||||
description='Determine weather in my location',
|
||||
name="get_weather",
|
||||
description="Determine weather in my location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"celsius",
|
||||
"fahrenheit"
|
||||
]
|
||||
}
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 1024
|
||||
},
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 1024},
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
@ -207,19 +183,14 @@ def test_get_num_tokens():
|
||||
model = OAIAPICompatLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'endpoint_url': 'https://api.openai.com/v1/'
|
||||
},
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
|
@ -14,18 +14,12 @@ def test_validate_credentials():
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="whisper-1",
|
||||
credentials={
|
||||
"api_key": "invalid_key",
|
||||
"endpoint_url": "https://api.openai.com/v1/"
|
||||
},
|
||||
credentials={"api_key": "invalid_key", "endpoint_url": "https://api.openai.com/v1/"},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="whisper-1",
|
||||
credentials={
|
||||
"api_key": os.environ.get("OPENAI_API_KEY"),
|
||||
"endpoint_url": "https://api.openai.com/v1/"
|
||||
},
|
||||
credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"},
|
||||
)
|
||||
|
||||
|
||||
@ -47,13 +41,10 @@ def test_invoke_model():
|
||||
|
||||
result = model.invoke(
|
||||
model="whisper-1",
|
||||
credentials={
|
||||
"api_key": os.environ.get("OPENAI_API_KEY"),
|
||||
"endpoint_url": "https://api.openai.com/v1/"
|
||||
},
|
||||
credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"},
|
||||
file=file,
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
|
||||
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"
|
||||
|
@ -12,27 +12,23 @@ from core.model_runtime.model_providers.openai_api_compatible.text_embedding.tex
|
||||
Using OpenAI's API as testing endpoint
|
||||
"""
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = OAICompatEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='text-embedding-ada-002',
|
||||
credentials={
|
||||
'api_key': 'invalid_key',
|
||||
'endpoint_url': 'https://api.openai.com/v1/',
|
||||
'context_size': 8184
|
||||
|
||||
}
|
||||
model="text-embedding-ada-002",
|
||||
credentials={"api_key": "invalid_key", "endpoint_url": "https://api.openai.com/v1/", "context_size": 8184},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='text-embedding-ada-002',
|
||||
model="text-embedding-ada-002",
|
||||
credentials={
|
||||
'api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'endpoint_url': 'https://api.openai.com/v1/',
|
||||
'context_size': 8184
|
||||
}
|
||||
"api_key": os.environ.get("OPENAI_API_KEY"),
|
||||
"endpoint_url": "https://api.openai.com/v1/",
|
||||
"context_size": 8184,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -40,19 +36,14 @@ def test_invoke_model():
|
||||
model = OAICompatEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='text-embedding-ada-002',
|
||||
model="text-embedding-ada-002",
|
||||
credentials={
|
||||
'api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'endpoint_url': 'https://api.openai.com/v1/',
|
||||
'context_size': 8184
|
||||
"api_key": os.environ.get("OPENAI_API_KEY"),
|
||||
"endpoint_url": "https://api.openai.com/v1/",
|
||||
"context_size": 8184,
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world",
|
||||
" ".join(["long_text"] * 100),
|
||||
" ".join(["another_long_text"] * 100)
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
@ -64,16 +55,13 @@ def test_get_num_tokens():
|
||||
model = OAICompatEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='text-embedding-ada-002',
|
||||
model="text-embedding-ada-002",
|
||||
credentials={
|
||||
'api_key': os.environ.get('OPENAI_API_KEY'),
|
||||
'endpoint_url': 'https://api.openai.com/v1/embeddings',
|
||||
'context_size': 8184
|
||||
"api_key": os.environ.get("OPENAI_API_KEY"),
|
||||
"endpoint_url": "https://api.openai.com/v1/embeddings",
|
||||
"context_size": 8184,
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
assert num_tokens == 2
|
||||
|
@ -12,17 +12,17 @@ def test_validate_credentials():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'server_url': 'ww' + os.environ.get('OPENLLM_SERVER_URL'),
|
||||
}
|
||||
"server_url": "ww" + os.environ.get("OPENLLM_SERVER_URL"),
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
|
||||
}
|
||||
"server_url": os.environ.get("OPENLLM_SERVER_URL"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -30,33 +30,28 @@ def test_invoke_model():
|
||||
model = OpenLLMTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
|
||||
"server_url": os.environ.get("OPENLLM_SERVER_URL"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = OpenLLMTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
|
||||
"server_url": os.environ.get("OPENLLM_SERVER_URL"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
@ -14,67 +14,61 @@ def test_validate_credentials_for_chat_model():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'server_url': 'invalid_key',
|
||||
}
|
||||
"server_url": "invalid_key",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
|
||||
}
|
||||
"server_url": os.environ.get("OPENLLM_SERVER_URL"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = OpenLLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
|
||||
"server_url": os.environ.get("OPENLLM_SERVER_URL"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = OpenLLMLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
|
||||
"server_url": os.environ.get("OPENLLM_SERVER_URL"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -84,21 +78,18 @@ def test_invoke_stream_model():
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = OpenLLMLargeLanguageModel()
|
||||
|
||||
response = model.get_num_tokens(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
|
||||
"server_url": os.environ.get("OPENLLM_SERVER_URL"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
tools=[]
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
tools=[],
|
||||
)
|
||||
|
||||
assert isinstance(response, int)
|
||||
assert response == 3
|
||||
assert response == 3
|
||||
|
@ -19,19 +19,12 @@ def test_validate_credentials():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='mistralai/mixtral-8x7b-instruct',
|
||||
credentials={
|
||||
'api_key': 'invalid_key',
|
||||
'mode': 'chat'
|
||||
}
|
||||
model="mistralai/mixtral-8x7b-instruct", credentials={"api_key": "invalid_key", "mode": "chat"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='mistralai/mixtral-8x7b-instruct',
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'mode': 'chat'
|
||||
}
|
||||
model="mistralai/mixtral-8x7b-instruct",
|
||||
credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"},
|
||||
)
|
||||
|
||||
|
||||
@ -39,27 +32,22 @@ def test_invoke_model():
|
||||
model = OpenRouterLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='mistralai/mixtral-8x7b-instruct',
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'mode': 'completion'
|
||||
},
|
||||
model="mistralai/mixtral-8x7b-instruct",
|
||||
credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "completion"},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
UserPromptMessage(content="Who are you?"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -70,27 +58,22 @@ def test_invoke_stream_model():
|
||||
model = OpenRouterLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='mistralai/mixtral-8x7b-instruct',
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'mode': 'chat'
|
||||
},
|
||||
model="mistralai/mixtral-8x7b-instruct",
|
||||
credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
UserPromptMessage(content="Who are you?"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -105,18 +88,16 @@ def test_get_num_tokens():
|
||||
model = OpenRouterLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='mistralai/mixtral-8x7b-instruct',
|
||||
model="mistralai/mixtral-8x7b-instruct",
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
"api_key": os.environ.get("TOGETHER_API_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
|
@ -14,19 +14,19 @@ def test_validate_credentials():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='meta/llama-2-13b-chat',
|
||||
model="meta/llama-2-13b-chat",
|
||||
credentials={
|
||||
'replicate_api_token': 'invalid_key',
|
||||
'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
|
||||
}
|
||||
"replicate_api_token": "invalid_key",
|
||||
"model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='meta/llama-2-13b-chat',
|
||||
model="meta/llama-2-13b-chat",
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
|
||||
}
|
||||
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
|
||||
"model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -34,27 +34,25 @@ def test_invoke_model():
|
||||
model = ReplicateLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='meta/llama-2-13b-chat',
|
||||
model="meta/llama-2-13b-chat",
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
|
||||
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
|
||||
"model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
UserPromptMessage(content="Who are you?"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -65,27 +63,25 @@ def test_invoke_stream_model():
|
||||
model = ReplicateLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='mistralai/mixtral-8x7b-instruct-v0.1',
|
||||
model="mistralai/mixtral-8x7b-instruct-v0.1",
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e'
|
||||
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
|
||||
"model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
UserPromptMessage(content="Who are you?"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -100,19 +96,17 @@ def test_get_num_tokens():
|
||||
model = ReplicateLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='',
|
||||
model="",
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e'
|
||||
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
|
||||
"model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 14
|
||||
|
@ -12,19 +12,19 @@ def test_validate_credentials_one():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='replicate/all-mpnet-base-v2',
|
||||
model="replicate/all-mpnet-base-v2",
|
||||
credentials={
|
||||
'replicate_api_token': 'invalid_key',
|
||||
'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
|
||||
}
|
||||
"replicate_api_token": "invalid_key",
|
||||
"model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='replicate/all-mpnet-base-v2',
|
||||
model="replicate/all-mpnet-base-v2",
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
|
||||
}
|
||||
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
|
||||
"model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -33,19 +33,19 @@ def test_validate_credentials_two():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='nateraw/bge-large-en-v1.5',
|
||||
model="nateraw/bge-large-en-v1.5",
|
||||
credentials={
|
||||
'replicate_api_token': 'invalid_key',
|
||||
'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
|
||||
}
|
||||
"replicate_api_token": "invalid_key",
|
||||
"model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='nateraw/bge-large-en-v1.5',
|
||||
model="nateraw/bge-large-en-v1.5",
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
|
||||
}
|
||||
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
|
||||
"model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -53,16 +53,13 @@ def test_invoke_model_one():
|
||||
model = ReplicateEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='nateraw/bge-large-en-v1.5',
|
||||
model="nateraw/bge-large-en-v1.5",
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
|
||||
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
|
||||
"model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1",
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
@ -74,16 +71,13 @@ def test_invoke_model_two():
|
||||
model = ReplicateEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='andreasjansson/clip-features',
|
||||
model="andreasjansson/clip-features",
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': '75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a'
|
||||
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
|
||||
"model_version": "75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a",
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
@ -95,16 +89,13 @@ def test_invoke_model_three():
|
||||
model = ReplicateEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='replicate/all-mpnet-base-v2',
|
||||
model="replicate/all-mpnet-base-v2",
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
|
||||
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
|
||||
"model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305",
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
@ -116,16 +107,13 @@ def test_invoke_model_four():
|
||||
model = ReplicateEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='nateraw/jina-embeddings-v2-base-en',
|
||||
model="nateraw/jina-embeddings-v2-base-en",
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e'
|
||||
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
|
||||
"model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e",
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
@ -137,15 +125,12 @@ def test_get_num_tokens():
|
||||
model = ReplicateEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='nateraw/jina-embeddings-v2-base-en',
|
||||
model="nateraw/jina-embeddings-v2-base-en",
|
||||
credentials={
|
||||
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
|
||||
'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e'
|
||||
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
|
||||
"model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e",
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
@ -10,10 +10,6 @@ def test_validate_provider_credentials():
|
||||
provider = SageMakerProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
@ -12,11 +12,11 @@ def test_validate_credentials():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='bge-m3-rerank-v2',
|
||||
model="bge-m3-rerank-v2",
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
@ -25,7 +25,7 @@ def test_validate_credentials():
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
score_threshold=0.8,
|
||||
)
|
||||
|
||||
|
||||
@ -33,11 +33,11 @@ def test_invoke_model():
|
||||
model = SageMakerRerankModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='bge-m3-rerank-v2',
|
||||
model="bge-m3-rerank-v2",
|
||||
credentials={
|
||||
"aws_region": os.getenv("AWS_REGION"),
|
||||
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
|
||||
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||
},
|
||||
query="What is the capital of the United States?",
|
||||
docs=[
|
||||
@ -46,7 +46,7 @@ def test_invoke_model():
|
||||
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
score_threshold=0.8,
|
||||
)
|
||||
|
||||
assert isinstance(result, RerankResult)
|
||||
|
@ -11,45 +11,23 @@ def test_validate_credentials():
|
||||
model = SageMakerEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='bge-m3',
|
||||
credentials={
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="bge-m3", credentials={})
|
||||
|
||||
model.validate_credentials(
|
||||
model='bge-m3-embedding',
|
||||
credentials={
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="bge-m3-embedding", credentials={})
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = SageMakerEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='bge-m3-embedding',
|
||||
credentials={
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
result = model.invoke(model="bge-m3-embedding", credentials={}, texts=["hello", "world"], user="abc-123")
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = SageMakerEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='bge-m3-embedding',
|
||||
credentials={
|
||||
},
|
||||
texts=[
|
||||
]
|
||||
)
|
||||
num_tokens = model.get_num_tokens(model="bge-m3-embedding", credentials={}, texts=[])
|
||||
|
||||
assert num_tokens == 0
|
||||
|
@ -13,41 +13,22 @@ def test_validate_credentials():
|
||||
model = SiliconflowLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='deepseek-ai/DeepSeek-V2-Chat',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='deepseek-ai/DeepSeek-V2-Chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('API_KEY')
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials={"api_key": os.environ.get("API_KEY")})
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = SiliconflowLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='deepseek-ai/DeepSeek-V2-Chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['How'],
|
||||
model="deepseek-ai/DeepSeek-V2-Chat",
|
||||
credentials={"api_key": os.environ.get("API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={"temperature": 0.5, "max_tokens": 10},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -58,22 +39,12 @@ def test_invoke_stream_model():
|
||||
model = SiliconflowLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='deepseek-ai/DeepSeek-V2-Chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 100,
|
||||
'seed': 1234
|
||||
},
|
||||
model="deepseek-ai/DeepSeek-V2-Chat",
|
||||
credentials={"api_key": os.environ.get("API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -89,18 +60,14 @@ def test_get_num_tokens():
|
||||
model = SiliconflowLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='deepseek-ai/DeepSeek-V2-Chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('API_KEY')
|
||||
},
|
||||
model="deepseek-ai/DeepSeek-V2-Chat",
|
||||
credentials={"api_key": os.environ.get("API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 12
|
||||
|
@ -10,12 +10,6 @@ def test_validate_provider_credentials():
|
||||
provider = SiliconflowProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('API_KEY')
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("API_KEY")})
|
||||
|
@ -13,9 +13,7 @@ def test_validate_credentials():
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="BAAI/bge-reranker-v2-m3",
|
||||
credentials={
|
||||
"api_key": "invalid_key"
|
||||
},
|
||||
credentials={"api_key": "invalid_key"},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
@ -30,17 +28,17 @@ def test_invoke_model():
|
||||
model = SiliconflowRerankModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='BAAI/bge-reranker-v2-m3',
|
||||
model="BAAI/bge-reranker-v2-m3",
|
||||
credentials={
|
||||
"api_key": os.environ.get("API_KEY"),
|
||||
},
|
||||
query="Who is Kasumi?",
|
||||
docs=[
|
||||
"Kasumi is a girl's name of Japanese origin meaning \"mist\".",
|
||||
'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
|
||||
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
|
||||
"and she leads a team named PopiParty."
|
||||
"and she leads a team named PopiParty.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
score_threshold=0.8,
|
||||
)
|
||||
|
||||
assert isinstance(result, RerankResult)
|
||||
|
@ -12,16 +12,12 @@ def test_validate_credentials():
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="iic/SenseVoiceSmall",
|
||||
credentials={
|
||||
"api_key": "invalid_key"
|
||||
},
|
||||
credentials={"api_key": "invalid_key"},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="iic/SenseVoiceSmall",
|
||||
credentials={
|
||||
"api_key": os.environ.get("API_KEY")
|
||||
},
|
||||
credentials={"api_key": os.environ.get("API_KEY")},
|
||||
)
|
||||
|
||||
|
||||
@ -42,12 +38,8 @@ def test_invoke_model():
|
||||
file = audio_file
|
||||
|
||||
result = model.invoke(
|
||||
model="iic/SenseVoiceSmall",
|
||||
credentials={
|
||||
"api_key": os.environ.get("API_KEY")
|
||||
},
|
||||
file=file
|
||||
model="iic/SenseVoiceSmall", credentials={"api_key": os.environ.get("API_KEY")}, file=file
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == '1,2,3,4,5,6,7,8,9,10.'
|
||||
assert result == "1,2,3,4,5,6,7,8,9,10."
|
||||
|
@ -15,9 +15,7 @@ def test_validate_credentials():
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model="BAAI/bge-large-zh-v1.5",
|
||||
credentials={
|
||||
"api_key": "invalid_key"
|
||||
},
|
||||
credentials={"api_key": "invalid_key"},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
|
@ -13,20 +13,15 @@ def test_validate_credentials():
|
||||
model = SparkLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='spark-1.5',
|
||||
credentials={
|
||||
'app_id': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="spark-1.5", credentials={"app_id": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='spark-1.5',
|
||||
model="spark-1.5",
|
||||
credentials={
|
||||
'app_id': os.environ.get('SPARK_APP_ID'),
|
||||
'api_secret': os.environ.get('SPARK_API_SECRET'),
|
||||
'api_key': os.environ.get('SPARK_API_KEY')
|
||||
}
|
||||
"app_id": os.environ.get("SPARK_APP_ID"),
|
||||
"api_secret": os.environ.get("SPARK_API_SECRET"),
|
||||
"api_key": os.environ.get("SPARK_API_KEY"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -34,24 +29,17 @@ def test_invoke_model():
|
||||
model = SparkLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='spark-1.5',
|
||||
model="spark-1.5",
|
||||
credentials={
|
||||
'app_id': os.environ.get('SPARK_APP_ID'),
|
||||
'api_secret': os.environ.get('SPARK_API_SECRET'),
|
||||
'api_key': os.environ.get('SPARK_API_KEY')
|
||||
"app_id": os.environ.get("SPARK_APP_ID"),
|
||||
"api_secret": os.environ.get("SPARK_API_SECRET"),
|
||||
"api_key": os.environ.get("SPARK_API_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['How'],
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={"temperature": 0.5, "max_tokens": 10},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -62,23 +50,16 @@ def test_invoke_stream_model():
|
||||
model = SparkLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='spark-1.5',
|
||||
model="spark-1.5",
|
||||
credentials={
|
||||
'app_id': os.environ.get('SPARK_APP_ID'),
|
||||
'api_secret': os.environ.get('SPARK_API_SECRET'),
|
||||
'api_key': os.environ.get('SPARK_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 100
|
||||
"app_id": os.environ.get("SPARK_APP_ID"),
|
||||
"api_secret": os.environ.get("SPARK_API_SECRET"),
|
||||
"api_key": os.environ.get("SPARK_API_KEY"),
|
||||
},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.5, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -94,20 +75,18 @@ def test_get_num_tokens():
|
||||
model = SparkLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='spark-1.5',
|
||||
model="spark-1.5",
|
||||
credentials={
|
||||
'app_id': os.environ.get('SPARK_APP_ID'),
|
||||
'api_secret': os.environ.get('SPARK_API_SECRET'),
|
||||
'api_key': os.environ.get('SPARK_API_KEY')
|
||||
"app_id": os.environ.get("SPARK_APP_ID"),
|
||||
"api_secret": os.environ.get("SPARK_API_SECRET"),
|
||||
"api_key": os.environ.get("SPARK_API_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 14
|
||||
|
@ -10,14 +10,12 @@ def test_validate_provider_credentials():
|
||||
provider = SparkProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'app_id': os.environ.get('SPARK_APP_ID'),
|
||||
'api_secret': os.environ.get('SPARK_API_SECRET'),
|
||||
'api_key': os.environ.get('SPARK_API_KEY')
|
||||
"app_id": os.environ.get("SPARK_APP_ID"),
|
||||
"api_secret": os.environ.get("SPARK_API_SECRET"),
|
||||
"api_key": os.environ.get("SPARK_API_KEY"),
|
||||
}
|
||||
)
|
||||
|
@ -21,40 +21,22 @@ def test_validate_credentials():
|
||||
model = StepfunLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='step-1-8k',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="step-1-8k", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")})
|
||||
|
||||
model.validate_credentials(
|
||||
model='step-1-8k',
|
||||
credentials={
|
||||
'api_key': os.environ.get('STEPFUN_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
def test_invoke_model():
|
||||
model = StepfunLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='step-1-8k',
|
||||
credentials={
|
||||
'api_key': os.environ.get('STEPFUN_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.9,
|
||||
'top_p': 0.7
|
||||
},
|
||||
stop=['Hi'],
|
||||
model="step-1-8k",
|
||||
credentials={"api_key": os.environ.get("STEPFUN_API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.9, "top_p": 0.7},
|
||||
stop=["Hi"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -65,24 +47,17 @@ def test_invoke_stream_model():
|
||||
model = StepfunLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='step-1-8k',
|
||||
credentials={
|
||||
'api_key': os.environ.get('STEPFUN_API_KEY')
|
||||
},
|
||||
model="step-1-8k",
|
||||
credentials={"api_key": os.environ.get("STEPFUN_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.9,
|
||||
'top_p': 0.7
|
||||
},
|
||||
model_parameters={"temperature": 0.9, "top_p": 0.7},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -98,10 +73,7 @@ def test_get_customizable_model_schema():
|
||||
model = StepfunLargeLanguageModel()
|
||||
|
||||
schema = model.get_customizable_model_schema(
|
||||
model='step-1-8k',
|
||||
credentials={
|
||||
'api_key': os.environ.get('STEPFUN_API_KEY')
|
||||
}
|
||||
model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}
|
||||
)
|
||||
assert isinstance(schema, AIModelEntity)
|
||||
|
||||
@ -110,67 +82,44 @@ def test_invoke_chat_model_with_tools():
|
||||
model = StepfunLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='step-1-8k',
|
||||
credentials={
|
||||
'api_key': os.environ.get('STEPFUN_API_KEY')
|
||||
},
|
||||
model="step-1-8k",
|
||||
credentials={"api_key": os.environ.get("STEPFUN_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content="what's the weather today in Shanghai?",
|
||||
)
|
||||
),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.9,
|
||||
'max_tokens': 100
|
||||
},
|
||||
model_parameters={"temperature": 0.9, "max_tokens": 100},
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_weather',
|
||||
description='Determine weather in my location',
|
||||
name="get_weather",
|
||||
description="Determine weather in my location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
PromptMessageTool(
|
||||
name='get_stock_price',
|
||||
description='Get the current stock price',
|
||||
name="get_stock_price",
|
||||
description="Get the current stock price",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": "The stock symbol"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"symbol"
|
||||
]
|
||||
}
|
||||
)
|
||||
"properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
|
||||
"required": ["symbol"],
|
||||
},
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert isinstance(result.message, AssistantPromptMessage)
|
||||
assert len(result.message.tool_calls) > 0
|
||||
assert len(result.message.tool_calls) > 0
|
||||
|
@ -24,13 +24,8 @@ def test_get_models():
|
||||
providers = factory.get_models(
|
||||
model_type=ModelType.LLM,
|
||||
provider_configs=[
|
||||
ProviderConfig(
|
||||
provider='openai',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
}
|
||||
)
|
||||
]
|
||||
ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")})
|
||||
],
|
||||
)
|
||||
|
||||
logger.debug(providers)
|
||||
@ -44,29 +39,21 @@ def test_get_models():
|
||||
assert provider_model.model_type == ModelType.LLM
|
||||
|
||||
providers = factory.get_models(
|
||||
provider='openai',
|
||||
provider="openai",
|
||||
provider_configs=[
|
||||
ProviderConfig(
|
||||
provider='openai',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
}
|
||||
)
|
||||
]
|
||||
ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")})
|
||||
],
|
||||
)
|
||||
|
||||
assert len(providers) == 1
|
||||
assert isinstance(providers[0], SimpleProviderEntity)
|
||||
assert providers[0].provider == 'openai'
|
||||
assert providers[0].provider == "openai"
|
||||
|
||||
|
||||
def test_provider_credentials_validate():
|
||||
factory = ModelProviderFactory()
|
||||
factory.provider_credentials_validate(
|
||||
provider='openai',
|
||||
credentials={
|
||||
'openai_api_key': os.environ.get('OPENAI_API_KEY')
|
||||
}
|
||||
provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}
|
||||
)
|
||||
|
||||
|
||||
@ -79,4 +66,4 @@ def test__get_model_provider_map():
|
||||
logger.debug(model_provider.provider_instance)
|
||||
|
||||
assert len(model_providers) >= 1
|
||||
assert isinstance(model_providers['openai'], ModelProviderExtension)
|
||||
assert isinstance(model_providers["openai"], ModelProviderExtension)
|
||||
|
@ -19,76 +19,61 @@ def test_validate_credentials():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': 'invalid_key',
|
||||
'mode': 'chat'
|
||||
}
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={"api_key": "invalid_key", "mode": "chat"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'mode': 'chat'
|
||||
}
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = TogetherAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'mode': 'completion'
|
||||
},
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "completion"},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
UserPromptMessage(content="Who are you?"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
model = TogetherAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
'mode': 'chat'
|
||||
},
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
UserPromptMessage(content="Who are you?"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 1.0,
|
||||
'top_k': 2,
|
||||
'top_p': 0.5,
|
||||
"temperature": 1.0,
|
||||
"top_k": 2,
|
||||
"top_p": 0.5,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -98,22 +83,21 @@ def test_invoke_stream_model():
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = TogetherAILargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
|
||||
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
credentials={
|
||||
'api_key': os.environ.get('TOGETHER_API_KEY'),
|
||||
"api_key": os.environ.get("TOGETHER_API_KEY"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
|
@ -13,18 +13,10 @@ def test_validate_credentials():
|
||||
model = TongyiLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='qwen-turbo',
|
||||
credentials={
|
||||
'dashscope_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="qwen-turbo", credentials={"dashscope_api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='qwen-turbo',
|
||||
credentials={
|
||||
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
|
||||
}
|
||||
model="qwen-turbo", credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}
|
||||
)
|
||||
|
||||
|
||||
@ -32,22 +24,13 @@ def test_invoke_model():
|
||||
model = TongyiLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='qwen-turbo',
|
||||
credentials={
|
||||
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['How'],
|
||||
model="qwen-turbo",
|
||||
credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={"temperature": 0.5, "max_tokens": 10},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -58,22 +41,12 @@ def test_invoke_stream_model():
|
||||
model = TongyiLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='qwen-turbo',
|
||||
credentials={
|
||||
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 100,
|
||||
'seed': 1234
|
||||
},
|
||||
model="qwen-turbo",
|
||||
credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -89,18 +62,14 @@ def test_get_num_tokens():
|
||||
model = TongyiLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='qwen-turbo',
|
||||
credentials={
|
||||
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
|
||||
},
|
||||
model="qwen-turbo",
|
||||
credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 12
|
||||
|
@ -10,12 +10,8 @@ def test_validate_provider_credentials():
|
||||
provider = TongyiProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
|
||||
}
|
||||
credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}
|
||||
)
|
||||
|
@ -39,21 +39,17 @@ def invoke_model_with_json_response(model_name="qwen-max-0403"):
|
||||
|
||||
response = model.invoke(
|
||||
model=model_name,
|
||||
credentials={
|
||||
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
|
||||
},
|
||||
credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='output json data with format `{"data": "test", "code": 200, "msg": "success"}'
|
||||
)
|
||||
UserPromptMessage(content='output json data with format `{"data": "test", "code": 200, "msg": "success"}')
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 50,
|
||||
'response_format': 'JSON',
|
||||
"temperature": 0.5,
|
||||
"max_tokens": 50,
|
||||
"response_format": "JSON",
|
||||
},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
print("=====================================")
|
||||
print(response)
|
||||
@ -81,4 +77,4 @@ def is_json(s):
|
||||
json.loads(s)
|
||||
except ValueError:
|
||||
return False
|
||||
return True
|
||||
return True
|
||||
|
@ -26,151 +26,113 @@ def test_predefined_models():
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_validate_credentials_for_chat_model(setup_openai_mock):
|
||||
model = UpstageLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
# model name to gpt-3.5-turbo because of mocking
|
||||
model.validate_credentials(
|
||||
model='gpt-3.5-turbo',
|
||||
credentials={
|
||||
'upstage_api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="gpt-3.5-turbo", credentials={"upstage_api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='solar-1-mini-chat',
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||
}
|
||||
model="solar-1-mini-chat", credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_chat_model(setup_openai_mock):
|
||||
model = UpstageLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='solar-1-mini-chat',
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||
},
|
||||
model="solar-1-mini-chat",
|
||||
credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'top_p': 1.0,
|
||||
'presence_penalty': 0.0,
|
||||
'frequency_penalty': 0.0,
|
||||
'max_tokens': 10
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"presence_penalty": 0.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"max_tokens": 10,
|
||||
},
|
||||
stop=['How'],
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert len(result.message.content) > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_chat_model_with_tools(setup_openai_mock):
|
||||
model = UpstageLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='solar-1-mini-chat',
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||
},
|
||||
model="solar-1-mini-chat",
|
||||
credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content="what's the weather today in London?",
|
||||
)
|
||||
),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_weather',
|
||||
description='Determine weather in my location',
|
||||
name="get_weather",
|
||||
description="Determine weather in my location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
PromptMessageTool(
|
||||
name='get_stock_price',
|
||||
description='Get the current stock price',
|
||||
name="get_stock_price",
|
||||
description="Get the current stock price",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"symbol": {
|
||||
"type": "string",
|
||||
"description": "The stock symbol"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"symbol"
|
||||
]
|
||||
}
|
||||
)
|
||||
"properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
|
||||
"required": ["symbol"],
|
||||
},
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert isinstance(result.message, AssistantPromptMessage)
|
||||
assert len(result.message.tool_calls) > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_stream_chat_model(setup_openai_mock):
|
||||
model = UpstageLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='solar-1-mini-chat',
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||
},
|
||||
model="solar-1-mini-chat",
|
||||
credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
@ -189,57 +151,36 @@ def test_get_num_tokens():
|
||||
model = UpstageLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='solar-1-mini-chat',
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
model="solar-1-mini-chat",
|
||||
credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert num_tokens == 13
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='solar-1-mini-chat',
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||
},
|
||||
model="solar-1-mini-chat",
|
||||
credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_weather',
|
||||
description='Determine weather in my location',
|
||||
name="get_weather",
|
||||
description="Determine weather in my location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 106
|
||||
|
@ -7,17 +7,11 @@ from core.model_runtime.model_providers.upstage.upstage import UpstageProvider
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_validate_provider_credentials(setup_openai_mock):
|
||||
provider = UpstageProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")})
|
||||
|
@ -8,41 +8,31 @@ from core.model_runtime.model_providers.upstage.text_embedding.text_embedding im
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
|
||||
def test_validate_credentials(setup_openai_mock):
|
||||
model = UpstageTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='solar-embedding-1-large-passage',
|
||||
credentials={
|
||||
'upstage_api_key': 'invalid_key'
|
||||
}
|
||||
model="solar-embedding-1-large-passage", credentials={"upstage_api_key": "invalid_key"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='solar-embedding-1-large-passage',
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
|
||||
}
|
||||
model="solar-embedding-1-large-passage", credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
|
||||
def test_invoke_model(setup_openai_mock):
|
||||
model = UpstageTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='solar-embedding-1-large-passage',
|
||||
model="solar-embedding-1-large-passage",
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'),
|
||||
"upstage_api_key": os.environ.get("UPSTAGE_API_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world",
|
||||
" ".join(["long_text"] * 100),
|
||||
" ".join(["another_long_text"] * 100)
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
@ -54,14 +44,11 @@ def test_get_num_tokens():
|
||||
model = UpstageTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='solar-embedding-1-large-passage',
|
||||
model="solar-embedding-1-large-passage",
|
||||
credentials={
|
||||
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'),
|
||||
"upstage_api_key": os.environ.get("UPSTAGE_API_KEY"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 5
|
||||
|
@ -14,26 +14,26 @@ def test_validate_credentials():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
|
||||
'volc_region': 'cn-beijing',
|
||||
'volc_access_key_id': 'INVALID',
|
||||
'volc_secret_access_key': 'INVALID',
|
||||
'endpoint_id': 'INVALID',
|
||||
'base_model_name': 'Doubao-embedding',
|
||||
}
|
||||
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
|
||||
"volc_region": "cn-beijing",
|
||||
"volc_access_key_id": "INVALID",
|
||||
"volc_secret_access_key": "INVALID",
|
||||
"endpoint_id": "INVALID",
|
||||
"base_model_name": "Doubao-embedding",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
|
||||
'volc_region': 'cn-beijing',
|
||||
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
|
||||
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
|
||||
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
|
||||
'base_model_name': 'Doubao-embedding',
|
||||
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
|
||||
"volc_region": "cn-beijing",
|
||||
"volc_access_key_id": os.environ.get("VOLC_API_KEY"),
|
||||
"volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
|
||||
"endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"),
|
||||
"base_model_name": "Doubao-embedding",
|
||||
},
|
||||
)
|
||||
|
||||
@ -42,20 +42,17 @@ def test_invoke_model():
|
||||
model = VolcengineMaaSTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
|
||||
'volc_region': 'cn-beijing',
|
||||
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
|
||||
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
|
||||
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
|
||||
'base_model_name': 'Doubao-embedding',
|
||||
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
|
||||
"volc_region": "cn-beijing",
|
||||
"volc_access_key_id": os.environ.get("VOLC_API_KEY"),
|
||||
"volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
|
||||
"endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"),
|
||||
"base_model_name": "Doubao-embedding",
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
@ -67,19 +64,16 @@ def test_get_num_tokens():
|
||||
model = VolcengineMaaSTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
|
||||
'volc_region': 'cn-beijing',
|
||||
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
|
||||
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
|
||||
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
|
||||
'base_model_name': 'Doubao-embedding',
|
||||
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
|
||||
"volc_region": "cn-beijing",
|
||||
"volc_access_key_id": os.environ.get("VOLC_API_KEY"),
|
||||
"volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
|
||||
"endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"),
|
||||
"base_model_name": "Doubao-embedding",
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
@ -14,25 +14,25 @@ def test_validate_credentials_for_chat_model():
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
|
||||
'volc_region': 'cn-beijing',
|
||||
'volc_access_key_id': 'INVALID',
|
||||
'volc_secret_access_key': 'INVALID',
|
||||
'endpoint_id': 'INVALID',
|
||||
}
|
||||
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
|
||||
"volc_region": "cn-beijing",
|
||||
"volc_access_key_id": "INVALID",
|
||||
"volc_secret_access_key": "INVALID",
|
||||
"endpoint_id": "INVALID",
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
|
||||
'volc_region': 'cn-beijing',
|
||||
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
|
||||
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
|
||||
'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
|
||||
}
|
||||
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
|
||||
"volc_region": "cn-beijing",
|
||||
"volc_access_key_id": os.environ.get("VOLC_API_KEY"),
|
||||
"volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
|
||||
"endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -40,28 +40,24 @@ def test_invoke_model():
|
||||
model = VolcengineMaaSLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
|
||||
'volc_region': 'cn-beijing',
|
||||
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
|
||||
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
|
||||
'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
|
||||
'base_model_name': 'Skylark2-pro-4k',
|
||||
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
|
||||
"volc_region": "cn-beijing",
|
||||
"volc_access_key_id": os.environ.get("VOLC_API_KEY"),
|
||||
"volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
|
||||
"endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"),
|
||||
"base_model_name": "Skylark2-pro-4k",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -73,28 +69,24 @@ def test_invoke_stream_model():
|
||||
model = VolcengineMaaSLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
|
||||
'volc_region': 'cn-beijing',
|
||||
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
|
||||
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
|
||||
'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
|
||||
'base_model_name': 'Skylark2-pro-4k',
|
||||
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
|
||||
"volc_region": "cn-beijing",
|
||||
"volc_access_key_id": os.environ.get("VOLC_API_KEY"),
|
||||
"volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
|
||||
"endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"),
|
||||
"base_model_name": "Skylark2-pro-4k",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"top_k": 1,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -102,29 +94,24 @@ def test_invoke_stream_model():
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(
|
||||
chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = VolcengineMaaSLargeLanguageModel()
|
||||
|
||||
response = model.get_num_tokens(
|
||||
model='NOT IMPORTANT',
|
||||
model="NOT IMPORTANT",
|
||||
credentials={
|
||||
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
|
||||
'volc_region': 'cn-beijing',
|
||||
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
|
||||
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
|
||||
'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
|
||||
'base_model_name': 'Skylark2-pro-4k',
|
||||
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
|
||||
"volc_region": "cn-beijing",
|
||||
"volc_access_key_id": os.environ.get("VOLC_API_KEY"),
|
||||
"volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
|
||||
"endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"),
|
||||
"base_model_name": "Skylark2-pro-4k",
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
tools=[]
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
tools=[],
|
||||
)
|
||||
|
||||
assert isinstance(response, int)
|
||||
|
@ -10,13 +10,10 @@ def test_invoke_embedding_v1():
|
||||
model = WenxinTextEmbeddingModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='embedding-v1',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
texts=['hello', '你好', 'xxxxx'],
|
||||
user="abc-123"
|
||||
model="embedding-v1",
|
||||
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
|
||||
texts=["hello", "你好", "xxxxx"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, TextEmbeddingResult)
|
||||
@ -29,13 +26,10 @@ def test_invoke_embedding_bge_large_en():
|
||||
model = WenxinTextEmbeddingModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='bge-large-en',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
texts=['hello', '你好', 'xxxxx'],
|
||||
user="abc-123"
|
||||
model="bge-large-en",
|
||||
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
|
||||
texts=["hello", "你好", "xxxxx"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, TextEmbeddingResult)
|
||||
@ -48,13 +42,10 @@ def test_invoke_embedding_bge_large_zh():
|
||||
model = WenxinTextEmbeddingModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='bge-large-zh',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
texts=['hello', '你好', 'xxxxx'],
|
||||
user="abc-123"
|
||||
model="bge-large-zh",
|
||||
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
|
||||
texts=["hello", "你好", "xxxxx"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, TextEmbeddingResult)
|
||||
@ -67,13 +58,10 @@ def test_invoke_embedding_tao_8k():
|
||||
model = WenxinTextEmbeddingModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='tao-8k',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
texts=['hello', '你好', 'xxxxx'],
|
||||
user="abc-123"
|
||||
model="tao-8k",
|
||||
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
|
||||
texts=["hello", "你好", "xxxxx"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, TextEmbeddingResult)
|
||||
|
@ -17,161 +17,125 @@ def test_predefined_models():
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
sleep(3)
|
||||
model = ErnieBotLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='ernie-bot',
|
||||
credentials={
|
||||
'api_key': 'invalid_key',
|
||||
'secret_key': 'invalid_key'
|
||||
}
|
||||
model="ernie-bot", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='ernie-bot',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
}
|
||||
model="ernie-bot",
|
||||
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model_ernie_bot():
|
||||
sleep(3)
|
||||
model = ErnieBotLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ernie-bot',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model="ernie-bot",
|
||||
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_model_ernie_bot_turbo():
|
||||
sleep(3)
|
||||
model = ErnieBotLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ernie-bot-turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model="ernie-bot-turbo",
|
||||
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_model_ernie_8k():
|
||||
sleep(3)
|
||||
model = ErnieBotLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ernie-bot-8k',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model="ernie-bot-8k",
|
||||
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_model_ernie_bot_4():
|
||||
sleep(3)
|
||||
model = ErnieBotLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ernie-bot-4',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model="ernie-bot-4",
|
||||
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_invoke_stream_model():
|
||||
sleep(3)
|
||||
model = ErnieBotLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ernie-3.5-8k',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model="ernie-3.5-8k",
|
||||
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -181,63 +145,48 @@ def test_invoke_stream_model():
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_invoke_model_with_system():
|
||||
sleep(3)
|
||||
model = ErnieBotLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ernie-bot',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='你是Kasumi'
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='你是谁?'
|
||||
)
|
||||
],
|
||||
model="ernie-bot",
|
||||
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
|
||||
prompt_messages=[SystemPromptMessage(content="你是Kasumi"), UserPromptMessage(content="你是谁?")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert 'kasumi' in response.message.content.lower()
|
||||
assert "kasumi" in response.message.content.lower()
|
||||
|
||||
|
||||
def test_invoke_with_search():
|
||||
sleep(3)
|
||||
model = ErnieBotLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ernie-bot',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='北京今天的天气怎么样'
|
||||
)
|
||||
],
|
||||
model="ernie-bot",
|
||||
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
'disable_search': True,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"disable_search": True,
|
||||
},
|
||||
stop=[],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
total_message = ''
|
||||
total_message = ""
|
||||
for chunk in response:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
@ -247,25 +196,19 @@ def test_invoke_with_search():
|
||||
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
|
||||
|
||||
# there should be 对不起、我不能、不支持……
|
||||
assert ('不' in total_message or '抱歉' in total_message or '无法' in total_message)
|
||||
assert "不" in total_message or "抱歉" in total_message or "无法" in total_message
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
sleep(3)
|
||||
model = ErnieBotLargeLanguageModel()
|
||||
|
||||
response = model.get_num_tokens(
|
||||
model='ernie-bot',
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
tools=[]
|
||||
model="ernie-bot",
|
||||
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
tools=[],
|
||||
)
|
||||
|
||||
assert isinstance(response, int)
|
||||
assert response == 10
|
||||
assert response == 10
|
||||
|
@ -10,16 +10,8 @@ def test_validate_provider_credentials():
|
||||
provider = WenxinProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': 'hahahaha',
|
||||
'secret_key': 'hahahaha'
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_key": "hahahaha", "secret_key": "hahahaha"})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('WENXIN_API_KEY'),
|
||||
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
|
||||
}
|
||||
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}
|
||||
)
|
||||
|
@ -8,61 +8,57 @@ from core.model_runtime.model_providers.xinference.text_embedding.text_embedding
|
||||
from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True)
|
||||
def test_validate_credentials(setup_xinference_mock):
|
||||
model = XinferenceTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='bge-base-en',
|
||||
model="bge-base-en",
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': 'www ' + os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
|
||||
}
|
||||
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
|
||||
"model_uid": "www " + os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"),
|
||||
},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='bge-base-en',
|
||||
model="bge-base-en",
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
|
||||
}
|
||||
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
|
||||
"model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"),
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True)
|
||||
def test_invoke_model(setup_xinference_mock):
|
||||
model = XinferenceTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='bge-base-en',
|
||||
model="bge-base-en",
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
|
||||
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
|
||||
"model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 2
|
||||
assert result.usage.total_tokens > 0
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = XinferenceTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='bge-base-en',
|
||||
model="bge-base-en",
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
|
||||
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
|
||||
"model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"),
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
texts=["hello", "world"],
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
@ -20,92 +20,84 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_moc
|
||||
from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True)
|
||||
def test_validate_credentials_for_chat_model(setup_openai_mock, setup_xinference_mock):
|
||||
model = XinferenceAILargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='ChatGLM3',
|
||||
model="ChatGLM3",
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': 'www ' + os.environ.get('XINFERENCE_CHAT_MODEL_UID')
|
||||
}
|
||||
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
|
||||
"model_uid": "www " + os.environ.get("XINFERENCE_CHAT_MODEL_UID"),
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='aaaaa',
|
||||
credentials={
|
||||
'server_url': '',
|
||||
'model_uid': ''
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="aaaaa", credentials={"server_url": "", "model_uid": ""})
|
||||
|
||||
model.validate_credentials(
|
||||
model='ChatGLM3',
|
||||
model="ChatGLM3",
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
|
||||
}
|
||||
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
|
||||
"model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"),
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True)
|
||||
def test_invoke_chat_model(setup_openai_mock, setup_xinference_mock):
|
||||
model = XinferenceAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ChatGLM3',
|
||||
model="ChatGLM3",
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
|
||||
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
|
||||
"model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True)
|
||||
def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
|
||||
model = XinferenceAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='ChatGLM3',
|
||||
model="ChatGLM3",
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
|
||||
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
|
||||
"model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -114,6 +106,8 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
"""
|
||||
Funtion calling of xinference does not support stream mode currently
|
||||
"""
|
||||
@ -168,7 +162,7 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
|
||||
# )
|
||||
|
||||
# assert isinstance(response, Generator)
|
||||
|
||||
|
||||
# call: LLMResultChunk = None
|
||||
# chunks = []
|
||||
|
||||
@ -241,86 +235,75 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
|
||||
# assert response.usage.total_tokens > 0
|
||||
# assert response.message.tool_calls[0].function.name == 'get_current_weather'
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True)
|
||||
def test_validate_credentials_for_generation_model(setup_openai_mock, setup_xinference_mock):
|
||||
model = XinferenceAILargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='alapaca',
|
||||
model="alapaca",
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': 'www ' + os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
|
||||
}
|
||||
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
|
||||
"model_uid": "www " + os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='alapaca',
|
||||
credentials={
|
||||
'server_url': '',
|
||||
'model_uid': ''
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="alapaca", credentials={"server_url": "", "model_uid": ""})
|
||||
|
||||
model.validate_credentials(
|
||||
model='alapaca',
|
||||
model="alapaca",
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
|
||||
}
|
||||
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
|
||||
"model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True)
|
||||
def test_invoke_generation_model(setup_openai_mock, setup_xinference_mock):
|
||||
model = XinferenceAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='alapaca',
|
||||
model="alapaca",
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
|
||||
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
|
||||
"model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='the United States is'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="the United States is")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
user="abc-123",
|
||||
stream=False
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.message.content) > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True)
|
||||
def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock):
|
||||
model = XinferenceAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='alapaca',
|
||||
model="alapaca",
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
|
||||
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
|
||||
"model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='the United States is'
|
||||
)
|
||||
],
|
||||
prompt_messages=[UserPromptMessage(content="the United States is")],
|
||||
model_parameters={
|
||||
'temperature': 0.7,
|
||||
'top_p': 1.0,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
stop=['you'],
|
||||
stop=["you"],
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -330,68 +313,54 @@ def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = XinferenceAILargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='ChatGLM3',
|
||||
model="ChatGLM3",
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
|
||||
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
|
||||
"model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_current_weather',
|
||||
description='Get the current weather in a given location',
|
||||
name="get_current_weather",
|
||||
description="Get the current weather in a given location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 77
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='ChatGLM3',
|
||||
model="ChatGLM3",
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
|
||||
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
|
||||
"model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(num_tokens, int)
|
||||
assert num_tokens == 21
|
||||
assert num_tokens == 21
|
||||
|
@ -8,44 +8,42 @@ from core.model_runtime.model_providers.xinference.rerank.rerank import Xinferen
|
||||
from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
|
||||
@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True)
|
||||
def test_validate_credentials(setup_xinference_mock):
|
||||
model = XinferenceRerankModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='bge-reranker-base',
|
||||
credentials={
|
||||
'server_url': 'awdawdaw',
|
||||
'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
|
||||
}
|
||||
model="bge-reranker-base",
|
||||
credentials={"server_url": "awdawdaw", "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID")},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='bge-reranker-base',
|
||||
model="bge-reranker-base",
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
|
||||
}
|
||||
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
|
||||
"model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"),
|
||||
},
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
|
||||
|
||||
@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True)
|
||||
def test_invoke_model(setup_xinference_mock):
|
||||
model = XinferenceRerankModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='bge-reranker-base',
|
||||
model="bge-reranker-base",
|
||||
credentials={
|
||||
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
|
||||
'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
|
||||
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
|
||||
"model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"),
|
||||
},
|
||||
query="Who is Kasumi?",
|
||||
docs=[
|
||||
"Kasumi is a girl's name of Japanese origin meaning \"mist\".",
|
||||
'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
|
||||
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
|
||||
"and she leads a team named PopiParty."
|
||||
"and she leads a team named PopiParty.",
|
||||
],
|
||||
score_threshold=0.8
|
||||
score_threshold=0.8,
|
||||
)
|
||||
|
||||
assert isinstance(result, RerankResult)
|
||||
|
@ -13,41 +13,22 @@ def test_validate_credentials():
|
||||
model = ZhinaoLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='360gpt2-pro',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="360gpt2-pro", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='360gpt2-pro',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHINAO_API_KEY')
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="360gpt2-pro", credentials={"api_key": os.environ.get("ZHINAO_API_KEY")})
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = ZhinaoLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='360gpt2-pro',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHINAO_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['How'],
|
||||
model="360gpt2-pro",
|
||||
credentials={"api_key": os.environ.get("ZHINAO_API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={"temperature": 0.5, "max_tokens": 10},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -58,22 +39,12 @@ def test_invoke_stream_model():
|
||||
model = ZhinaoLargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='360gpt2-pro',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHINAO_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.5,
|
||||
'max_tokens': 100,
|
||||
'seed': 1234
|
||||
},
|
||||
model="360gpt2-pro",
|
||||
credentials={"api_key": os.environ.get("ZHINAO_API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -89,18 +60,14 @@ def test_get_num_tokens():
|
||||
model = ZhinaoLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='360gpt2-pro',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHINAO_API_KEY')
|
||||
},
|
||||
model="360gpt2-pro",
|
||||
credentials={"api_key": os.environ.get("ZHINAO_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 21
|
||||
|
@ -10,12 +10,6 @@ def test_validate_provider_credentials():
|
||||
provider = ZhinaoProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHINAO_API_KEY')
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("ZHINAO_API_KEY")})
|
||||
|
@ -18,41 +18,22 @@ def test_validate_credentials():
|
||||
model = ZhipuAILargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='chatglm_turbo',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="chatglm_turbo", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='chatglm_turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="chatglm_turbo", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")})
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = ZhipuAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chatglm_turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Who are you?'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.9,
|
||||
'top_p': 0.7
|
||||
},
|
||||
stop=['How'],
|
||||
model="chatglm_turbo",
|
||||
credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Who are you?")],
|
||||
model_parameters={"temperature": 0.9, "top_p": 0.7},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
@ -63,21 +44,12 @@ def test_invoke_stream_model():
|
||||
model = ZhipuAILargeLanguageModel()
|
||||
|
||||
response = model.invoke(
|
||||
model='chatglm_turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.9,
|
||||
'top_p': 0.7
|
||||
},
|
||||
model="chatglm_turbo",
|
||||
credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
model_parameters={"temperature": 0.9, "top_p": 0.7},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(response, Generator)
|
||||
@ -93,63 +65,45 @@ def test_get_num_tokens():
|
||||
model = ZhipuAILargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='chatglm_turbo',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
},
|
||||
model="chatglm_turbo",
|
||||
credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 14
|
||||
|
||||
|
||||
def test_get_tools_num_tokens():
|
||||
model = ZhipuAILargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='tools',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
},
|
||||
model="tools",
|
||||
credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")},
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name='get_current_weather',
|
||||
description='Get the current weather in a given location',
|
||||
name="get_current_weather",
|
||||
description="Get the current weather in a given location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"c",
|
||||
"f"
|
||||
]
|
||||
}
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": [
|
||||
"location"
|
||||
]
|
||||
}
|
||||
"required": ["location"],
|
||||
},
|
||||
)
|
||||
],
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 88
|
||||
|
@ -10,12 +10,6 @@ def test_validate_provider_credentials():
|
||||
provider = ZhipuaiProvider()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
provider.validate_provider_credentials(
|
||||
credentials={}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={})
|
||||
|
||||
provider.validate_provider_credentials(
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
}
|
||||
)
|
||||
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")})
|
||||
|
@ -11,34 +11,19 @@ def test_validate_credentials():
|
||||
model = ZhipuAITextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='text_embedding',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="text_embedding", credentials={"api_key": "invalid_key"})
|
||||
|
||||
model.validate_credentials(
|
||||
model='text_embedding',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
}
|
||||
)
|
||||
model.validate_credentials(model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")})
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = ZhipuAITextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='text_embedding',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
],
|
||||
user="abc-123"
|
||||
model="text_embedding",
|
||||
credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")},
|
||||
texts=["hello", "world"],
|
||||
user="abc-123",
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
@ -50,14 +35,7 @@ def test_get_num_tokens():
|
||||
model = ZhipuAITextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='text_embedding',
|
||||
credentials={
|
||||
'api_key': os.environ.get('ZHIPUAI_API_KEY')
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, texts=["hello", "world"]
|
||||
)
|
||||
|
||||
assert num_tokens == 2
|
||||
|
@ -7,20 +7,17 @@ from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
|
||||
class MockedHttp:
|
||||
def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'],
|
||||
url: str, **kwargs) -> httpx.Response:
|
||||
def httpx_request(
|
||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Mocked httpx.request
|
||||
"""
|
||||
request = httpx.Request(
|
||||
method,
|
||||
url,
|
||||
params=kwargs.get('params'),
|
||||
headers=kwargs.get('headers'),
|
||||
cookies=kwargs.get('cookies')
|
||||
method, url, params=kwargs.get("params"), headers=kwargs.get("headers"), cookies=kwargs.get("cookies")
|
||||
)
|
||||
data = kwargs.get('data', None)
|
||||
resp = json.dumps(data).encode('utf-8') if data else b'OK'
|
||||
data = kwargs.get("data", None)
|
||||
resp = json.dumps(data).encode("utf-8") if data else b"OK"
|
||||
response = httpx.Response(
|
||||
status_code=200,
|
||||
request=request,
|
||||
|
@ -10,6 +10,7 @@ todos_data = {
|
||||
"user1": ["Go for a run", "Read a book"],
|
||||
}
|
||||
|
||||
|
||||
class TodosResource(Resource):
|
||||
def get(self, username):
|
||||
todos = todos_data.get(username, [])
|
||||
@ -32,7 +33,8 @@ class TodosResource(Resource):
|
||||
|
||||
return {"error": "Invalid todo index"}, 400
|
||||
|
||||
api.add_resource(TodosResource, '/todos/<string:username>')
|
||||
|
||||
if __name__ == '__main__':
|
||||
api.add_resource(TodosResource, "/todos/<string:username>")
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(port=5003, debug=True)
|
||||
|
@ -3,37 +3,40 @@ from core.tools.tool.tool import Tool
|
||||
from tests.integration_tests.tools.__mock.http import setup_http_mock
|
||||
|
||||
tool_bundle = {
|
||||
'server_url': 'http://www.example.com/{path_param}',
|
||||
'method': 'post',
|
||||
'author': '',
|
||||
'openapi': {'parameters': [{'in': 'path', 'name': 'path_param'},
|
||||
{'in': 'query', 'name': 'query_param'},
|
||||
{'in': 'cookie', 'name': 'cookie_param'},
|
||||
{'in': 'header', 'name': 'header_param'},
|
||||
],
|
||||
'requestBody': {
|
||||
'content': {'application/json': {'schema': {'properties': {'body_param': {'type': 'string'}}}}}}
|
||||
},
|
||||
'parameters': []
|
||||
"server_url": "http://www.example.com/{path_param}",
|
||||
"method": "post",
|
||||
"author": "",
|
||||
"openapi": {
|
||||
"parameters": [
|
||||
{"in": "path", "name": "path_param"},
|
||||
{"in": "query", "name": "query_param"},
|
||||
{"in": "cookie", "name": "cookie_param"},
|
||||
{"in": "header", "name": "header_param"},
|
||||
],
|
||||
"requestBody": {
|
||||
"content": {"application/json": {"schema": {"properties": {"body_param": {"type": "string"}}}}}
|
||||
},
|
||||
},
|
||||
"parameters": [],
|
||||
}
|
||||
parameters = {
|
||||
'path_param': 'p_param',
|
||||
'query_param': 'q_param',
|
||||
'cookie_param': 'c_param',
|
||||
'header_param': 'h_param',
|
||||
'body_param': 'b_param',
|
||||
"path_param": "p_param",
|
||||
"query_param": "q_param",
|
||||
"cookie_param": "c_param",
|
||||
"header_param": "h_param",
|
||||
"body_param": "b_param",
|
||||
}
|
||||
|
||||
|
||||
def test_api_tool(setup_http_mock):
|
||||
tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={'auth_type': 'none'}))
|
||||
tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={"auth_type": "none"}))
|
||||
headers = tool.assembling_request(parameters)
|
||||
response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert '/p_param' == response.request.url.path
|
||||
assert b'query_param=q_param' == response.request.url.query
|
||||
assert 'h_param' == response.request.headers.get('header_param')
|
||||
assert 'application/json' == response.request.headers.get('content-type')
|
||||
assert 'cookie_param=c_param' == response.request.headers.get('cookie')
|
||||
assert 'b_param' in response.content.decode()
|
||||
assert "/p_param" == response.request.url.path
|
||||
assert b"query_param=q_param" == response.request.url.query
|
||||
assert "h_param" == response.request.headers.get("header_param")
|
||||
assert "application/json" == response.request.headers.get("content-type")
|
||||
assert "cookie_param=c_param" == response.request.headers.get("cookie")
|
||||
assert "b_param" in response.content.decode()
|
||||
|
@ -7,16 +7,17 @@ provider_names = [provider.identity.name for provider in provider_generator]
|
||||
ToolManager.clear_builtin_providers_cache()
|
||||
provider_generator = ToolManager.list_builtin_providers()
|
||||
|
||||
@pytest.mark.parametrize('name', provider_names)
|
||||
|
||||
@pytest.mark.parametrize("name", provider_names)
|
||||
def test_tool_providers(benchmark, name):
|
||||
"""
|
||||
Test that all tool providers can be loaded
|
||||
"""
|
||||
|
||||
|
||||
def test(generator):
|
||||
try:
|
||||
return next(generator)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1)
|
||||
|
||||
benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user