chore(api/tests): apply ruff reformat #7590 (#7591)

Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Bowen Liang 2024-08-23 23:52:25 +08:00 committed by GitHub
parent 2da63654e5
commit b035c02f78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
155 changed files with 4279 additions and 5925 deletions

View File

@ -76,7 +76,6 @@ exclude = [
"migrations/**/*",
"services/**/*.py",
"tasks/**/*.py",
"tests/**/*.py",
]
[tool.pytest_env]

View File

@ -22,23 +22,20 @@ from anthropic.types import (
)
from anthropic.types.message_delta_event import Delta
MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
class MockAnthropicClass:
@staticmethod
def mocked_anthropic_chat_create_sync(model: str) -> Message:
return Message(
id='msg-123',
type='message',
role='assistant',
content=[ContentBlock(text='hello, I\'m a chatbot from anthropic', type='text')],
id="msg-123",
type="message",
role="assistant",
content=[ContentBlock(text="hello, I'm a chatbot from anthropic", type="text")],
model=model,
stop_reason='stop_sequence',
usage=Usage(
input_tokens=1,
output_tokens=1
)
stop_reason="stop_sequence",
usage=Usage(input_tokens=1, output_tokens=1),
)
@staticmethod
@ -46,52 +43,43 @@ class MockAnthropicClass:
full_response_text = "hello, I'm a chatbot from anthropic"
yield MessageStartEvent(
type='message_start',
type="message_start",
message=Message(
id='msg-123',
id="msg-123",
content=[],
role='assistant',
role="assistant",
model=model,
stop_reason=None,
type='message',
usage=Usage(
input_tokens=1,
output_tokens=1
)
)
type="message",
usage=Usage(input_tokens=1, output_tokens=1),
),
)
index = 0
for i in range(0, len(full_response_text)):
yield ContentBlockDeltaEvent(
type='content_block_delta',
delta=TextDelta(text=full_response_text[i], type='text_delta'),
index=index
type="content_block_delta", delta=TextDelta(text=full_response_text[i], type="text_delta"), index=index
)
index += 1
yield MessageDeltaEvent(
type='message_delta',
delta=Delta(
stop_reason='stop_sequence'
),
usage=MessageDeltaUsage(
output_tokens=1
)
type="message_delta", delta=Delta(stop_reason="stop_sequence"), usage=MessageDeltaUsage(output_tokens=1)
)
yield MessageStopEvent(type='message_stop')
yield MessageStopEvent(type="message_stop")
def mocked_anthropic(self: Messages, *,
max_tokens: int,
messages: Iterable[MessageParam],
model: str,
stream: Literal[True],
**kwargs: Any
) -> Union[Message, Stream[MessageStreamEvent]]:
def mocked_anthropic(
self: Messages,
*,
max_tokens: int,
messages: Iterable[MessageParam],
model: str,
stream: Literal[True],
**kwargs: Any,
) -> Union[Message, Stream[MessageStreamEvent]]:
if len(self._client.api_key) < 18:
raise anthropic.AuthenticationError('Invalid API key')
raise anthropic.AuthenticationError("Invalid API key")
if stream:
return MockAnthropicClass.mocked_anthropic_chat_create_stream(model=model)
@ -102,7 +90,7 @@ class MockAnthropicClass:
@pytest.fixture
def setup_anthropic_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(Messages, 'create', MockAnthropicClass.mocked_anthropic)
monkeypatch.setattr(Messages, "create", MockAnthropicClass.mocked_anthropic)
yield

View File

@ -12,63 +12,46 @@ from google.generativeai.client import _ClientManager, configure
from google.generativeai.types import GenerateContentResponse
from google.generativeai.types.generation_types import BaseGenerateContentResponse
current_api_key = ''
current_api_key = ""
class MockGoogleResponseClass:
_done = False
def __iter__(self):
full_response_text = 'it\'s google!'
full_response_text = "it's google!"
for i in range(0, len(full_response_text) + 1, 1):
if i == len(full_response_text):
self._done = True
yield GenerateContentResponse(
done=True,
iterator=None,
result=glm.GenerateContentResponse({
}),
chunks=[]
done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
)
else:
yield GenerateContentResponse(
done=False,
iterator=None,
result=glm.GenerateContentResponse({
}),
chunks=[]
done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
)
class MockGoogleResponseCandidateClass:
finish_reason = 'stop'
finish_reason = "stop"
@property
def content(self) -> gag_content.Content:
return gag_content.Content(
parts=[
gag_content.Part(text='it\'s google!')
]
)
return gag_content.Content(parts=[gag_content.Part(text="it's google!")])
class MockGoogleClass:
@staticmethod
def generate_content_sync() -> GenerateContentResponse:
return GenerateContentResponse(
done=True,
iterator=None,
result=glm.GenerateContentResponse({
}),
chunks=[]
)
return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[])
@staticmethod
def generate_content_stream() -> Generator[GenerateContentResponse, None, None]:
return MockGoogleResponseClass()
def generate_content(self: GenerativeModel,
def generate_content(
self: GenerativeModel,
contents: content_types.ContentsType,
*,
generation_config: generation_config_types.GenerationConfigType | None = None,
@ -79,21 +62,21 @@ class MockGoogleClass:
global current_api_key
if len(current_api_key) < 16:
raise Exception('Invalid API key')
raise Exception("Invalid API key")
if stream:
return MockGoogleClass.generate_content_stream()
return MockGoogleClass.generate_content_sync()
@property
def generative_response_text(self) -> str:
return 'it\'s google!'
return "it's google!"
@property
def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
return [MockGoogleResponseCandidateClass()]
def make_client(self: _ClientManager, name: str):
global current_api_key
@ -121,7 +104,8 @@ class MockGoogleClass:
if not self.default_metadata:
return client
@pytest.fixture
def setup_google_mock(request, monkeypatch: MonkeyPatch):
monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
@ -131,4 +115,4 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch):
yield
monkeypatch.undo()
monkeypatch.undo()

View File

@ -6,14 +6,15 @@ from huggingface_hub import InferenceClient
from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_huggingface_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(InferenceClient, "text_generation", MockHuggingfaceChatClass.text_generation)
yield
if MOCK:
monkeypatch.undo()
monkeypatch.undo()

View File

@ -22,10 +22,8 @@ class MockHuggingfaceChatClass:
details=Details(
finish_reason="length",
generated_tokens=6,
tokens=[
Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)
]
)
tokens=[Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)],
),
)
return response
@ -36,26 +34,23 @@ class MockHuggingfaceChatClass:
for i in range(0, len(full_text)):
response = TextGenerationStreamResponse(
token = Token(id=i, text=full_text[i], logprob=0.0, special=False),
token=Token(id=i, text=full_text[i], logprob=0.0, special=False),
)
response.generated_text = full_text[i]
response.details = StreamDetails(finish_reason='stop_sequence', generated_tokens=1)
response.details = StreamDetails(finish_reason="stop_sequence", generated_tokens=1)
yield response
def text_generation(self: InferenceClient, prompt: str, *,
stream: Literal[False] = ...,
model: Optional[str] = None,
**kwargs: Any
def text_generation(
self: InferenceClient, prompt: str, *, stream: Literal[False] = ..., model: Optional[str] = None, **kwargs: Any
) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]:
# check if key is valid
if not re.match(r'Bearer\shf\-[a-zA-Z0-9]{16,}', self.headers['authorization']):
raise BadRequestError('Invalid API key')
if not re.match(r"Bearer\shf\-[a-zA-Z0-9]{16,}", self.headers["authorization"]):
raise BadRequestError("Invalid API key")
if model is None:
raise BadRequestError('Invalid model')
raise BadRequestError("Invalid model")
if stream:
return MockHuggingfaceChatClass.generate_create_stream(model)
return MockHuggingfaceChatClass.generate_create_sync(model)

View File

@ -5,10 +5,10 @@ class MockTEIClass:
@staticmethod
def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter:
# During mock, we don't have a real server to query, so we just return a dummy value
if 'rerank' in model_name:
model_type = 'reranker'
if "rerank" in model_name:
model_type = "reranker"
else:
model_type = 'embedding'
model_type = "embedding"
return TeiModelExtraParameter(model_type=model_type, max_input_length=512, max_client_batch_size=1)
@ -17,16 +17,16 @@ class MockTEIClass:
# Use space as token separator, and split the text into tokens
tokenized_texts = []
for text in texts:
tokens = text.split(' ')
tokens = text.split(" ")
current_index = 0
tokenized_text = []
for idx, token in enumerate(tokens):
s_token = {
'id': idx,
'text': token,
'special': False,
'start': current_index,
'stop': current_index + len(token),
"id": idx,
"text": token,
"special": False,
"start": current_index,
"stop": current_index + len(token),
}
current_index += len(token) + 1
tokenized_text.append(s_token)
@ -55,18 +55,18 @@ class MockTEIClass:
embedding = [0.1] * 768
embeddings.append(
{
'object': 'embedding',
'embedding': embedding,
'index': idx,
"object": "embedding",
"embedding": embedding,
"index": idx,
}
)
return {
'object': 'list',
'data': embeddings,
'model': 'MODEL_NAME',
'usage': {
'prompt_tokens': sum(len(text.split(' ')) for text in texts),
'total_tokens': sum(len(text.split(' ')) for text in texts),
"object": "list",
"data": embeddings,
"model": "MODEL_NAME",
"usage": {
"prompt_tokens": sum(len(text.split(" ")) for text in texts),
"total_tokens": sum(len(text.split(" ")) for text in texts),
},
}
@ -83,9 +83,9 @@ class MockTEIClass:
for idx, text in enumerate(texts):
reranked_docs.append(
{
'index': idx,
'text': text,
'score': 0.9,
"index": idx,
"text": text,
"score": 0.9,
}
)
# For mock, only return the first document

View File

@ -21,13 +21,17 @@ from tests.integration_tests.model_runtime.__mock.openai_remote import MockModel
from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass
def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]]) -> Callable[[], None]:
def mock_openai(
monkeypatch: MonkeyPatch,
methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]],
) -> Callable[[], None]:
"""
mock openai module
mock openai module
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
"""
def unpatch() -> None:
monkeypatch.undo()
@ -52,15 +56,16 @@ def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "c
return unpatch
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_openai_mock(request, monkeypatch):
methods = request.param if hasattr(request, 'param') else []
methods = request.param if hasattr(request, "param") else []
if MOCK:
unpatch = mock_openai(monkeypatch, methods=methods)
yield
if MOCK:
unpatch()
unpatch()

View File

@ -43,62 +43,64 @@ class MockChatClass:
if not functions or len(functions) == 0:
return None
function: completion_create_params.Function = functions[0]
function_name = function['name']
function_description = function['description']
function_parameters = function['parameters']
function_parameters_type = function_parameters['type']
if function_parameters_type != 'object':
function_name = function["name"]
function_description = function["description"]
function_parameters = function["parameters"]
function_parameters_type = function_parameters["type"]
if function_parameters_type != "object":
return None
function_parameters_properties = function_parameters['properties']
function_parameters_required = function_parameters['required']
function_parameters_properties = function_parameters["properties"]
function_parameters_required = function_parameters["required"]
parameters = {}
for parameter_name, parameter in function_parameters_properties.items():
if parameter_name not in function_parameters_required:
continue
parameter_type = parameter['type']
if parameter_type == 'string':
if 'enum' in parameter:
if len(parameter['enum']) == 0:
parameter_type = parameter["type"]
if parameter_type == "string":
if "enum" in parameter:
if len(parameter["enum"]) == 0:
continue
parameters[parameter_name] = parameter['enum'][0]
parameters[parameter_name] = parameter["enum"][0]
else:
parameters[parameter_name] = 'kawaii'
elif parameter_type == 'integer':
parameters[parameter_name] = "kawaii"
elif parameter_type == "integer":
parameters[parameter_name] = 114514
elif parameter_type == 'number':
elif parameter_type == "number":
parameters[parameter_name] = 1919810.0
elif parameter_type == 'boolean':
elif parameter_type == "boolean":
parameters[parameter_name] = True
return FunctionCall(name=function_name, arguments=dumps(parameters))
@staticmethod
def generate_tool_calls(tools = NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
def generate_tool_calls(tools=NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
list_tool_calls = []
if not tools or len(tools) == 0:
return None
tool = tools[0]
if 'type' in tools and tools['type'] != 'function':
if "type" in tools and tools["type"] != "function":
return None
function = tool['function']
function = tool["function"]
function_call = MockChatClass.generate_function_call(functions=[function])
if function_call is None:
return None
list_tool_calls.append(ChatCompletionMessageToolCall(
id='sakurajima-mai',
function=Function(
name=function_call.name,
arguments=function_call.arguments,
),
type='function'
))
list_tool_calls.append(
ChatCompletionMessageToolCall(
id="sakurajima-mai",
function=Function(
name=function_call.name,
arguments=function_call.arguments,
),
type="function",
)
)
return list_tool_calls
@staticmethod
def mocked_openai_chat_create_sync(
model: str,
@ -111,30 +113,27 @@ class MockChatClass:
tool_calls = MockChatClass.generate_tool_calls(tools=tools)
return _ChatCompletion(
id='cmpl-3QJQa5jXJ5Z5X',
id="cmpl-3QJQa5jXJ5Z5X",
choices=[
_ChatCompletionChoice(
finish_reason='content_filter',
finish_reason="content_filter",
index=0,
message=ChatCompletionMessage(
content='elaina',
role='assistant',
function_call=function_call,
tool_calls=tool_calls
)
content="elaina", role="assistant", function_call=function_call, tool_calls=tool_calls
),
)
],
created=int(time()),
model=model,
object='chat.completion',
system_fingerprint='',
object="chat.completion",
system_fingerprint="",
usage=CompletionUsage(
prompt_tokens=2,
completion_tokens=1,
total_tokens=3,
)
),
)
@staticmethod
def mocked_openai_chat_create_stream(
model: str,
@ -150,36 +149,40 @@ class MockChatClass:
for i in range(0, len(full_text) + 1):
if i == len(full_text):
yield ChatCompletionChunk(
id='cmpl-3QJQa5jXJ5Z5X',
id="cmpl-3QJQa5jXJ5Z5X",
choices=[
Choice(
delta=ChoiceDelta(
content='',
content="",
function_call=ChoiceDeltaFunctionCall(
name=function_call.name,
arguments=function_call.arguments,
) if function_call else None,
role='assistant',
)
if function_call
else None,
role="assistant",
tool_calls=[
ChoiceDeltaToolCall(
index=0,
id='misaka-mikoto',
id="misaka-mikoto",
function=ChoiceDeltaToolCallFunction(
name=tool_calls[0].function.name,
arguments=tool_calls[0].function.arguments,
),
type='function'
type="function",
)
] if tool_calls and len(tool_calls) > 0 else None
]
if tool_calls and len(tool_calls) > 0
else None,
),
finish_reason='function_call',
finish_reason="function_call",
index=0,
)
],
created=int(time()),
model=model,
object='chat.completion.chunk',
system_fingerprint='',
object="chat.completion.chunk",
system_fingerprint="",
usage=CompletionUsage(
prompt_tokens=2,
completion_tokens=17,
@ -188,30 +191,45 @@ class MockChatClass:
)
else:
yield ChatCompletionChunk(
id='cmpl-3QJQa5jXJ5Z5X',
id="cmpl-3QJQa5jXJ5Z5X",
choices=[
Choice(
delta=ChoiceDelta(
content=full_text[i],
role='assistant',
role="assistant",
),
finish_reason='content_filter',
finish_reason="content_filter",
index=0,
)
],
created=int(time()),
model=model,
object='chat.completion.chunk',
system_fingerprint='',
object="chat.completion.chunk",
system_fingerprint="",
)
def chat_create(self: Completions, *,
def chat_create(
self: Completions,
*,
messages: list[ChatCompletionMessageParam],
model: Union[str,Literal[
"gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613"],
model: Union[
str,
Literal[
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-32k-0613",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
],
],
functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
@ -220,24 +238,32 @@ class MockChatClass:
**kwargs: Any,
):
openai_models = [
"gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613",
"gpt-4-1106-preview",
"gpt-4-vision-preview",
"gpt-4",
"gpt-4-0314",
"gpt-4-0613",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-32k-0613",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
]
azure_openai_models = [
"gpt35", "gpt-4v", "gpt-35-turbo"
]
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
raise InvokeAuthorizationError('Invalid base url')
azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"]
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
raise InvokeAuthorizationError("Invalid base url")
if model in openai_models + azure_openai_models:
if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
# so we only check if model is in openai_models
raise InvokeAuthorizationError('Invalid api key')
raise InvokeAuthorizationError("Invalid api key")
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
raise InvokeAuthorizationError('Invalid api key')
raise InvokeAuthorizationError("Invalid api key")
if stream:
return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools)
return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)
return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)

View File

@ -17,9 +17,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
class MockCompletionsClass:
@staticmethod
def mocked_openai_completion_create_sync(
model: str
) -> CompletionMessage:
def mocked_openai_completion_create_sync(model: str) -> CompletionMessage:
return CompletionMessage(
id="cmpl-3QJQa5jXJ5Z5X",
object="text_completion",
@ -38,13 +36,11 @@ class MockCompletionsClass:
prompt_tokens=2,
completion_tokens=1,
total_tokens=3,
)
),
)
@staticmethod
def mocked_openai_completion_create_stream(
model: str
) -> Generator[CompletionMessage, None, None]:
def mocked_openai_completion_create_stream(model: str) -> Generator[CompletionMessage, None, None]:
full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
for i in range(0, len(full_text) + 1):
if i == len(full_text):
@ -76,46 +72,59 @@ class MockCompletionsClass:
model=model,
system_fingerprint="",
choices=[
CompletionChoice(
text=full_text[i],
index=0,
logprobs=None,
finish_reason="content_filter"
)
CompletionChoice(text=full_text[i], index=0, logprobs=None, finish_reason="content_filter")
],
)
def completion_create(self: Completions, *, model: Union[
str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct",
"text-davinci-003", "text-davinci-002", "text-davinci-001",
"code-davinci-002", "text-curie-001", "text-babbage-001",
"text-ada-001"],
def completion_create(
self: Completions,
*,
model: Union[
str,
Literal[
"babbage-002",
"davinci-002",
"gpt-3.5-turbo-instruct",
"text-davinci-003",
"text-davinci-002",
"text-davinci-001",
"code-davinci-002",
"text-curie-001",
"text-babbage-001",
"text-ada-001",
],
],
prompt: Union[str, list[str], list[int], list[list[int]], None],
stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
**kwargs: Any
**kwargs: Any,
):
openai_models = [
"babbage-002", "davinci-002", "gpt-3.5-turbo-instruct", "text-davinci-003", "text-davinci-002", "text-davinci-001",
"code-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001",
]
azure_openai_models = [
"gpt-35-turbo-instruct"
"babbage-002",
"davinci-002",
"gpt-3.5-turbo-instruct",
"text-davinci-003",
"text-davinci-002",
"text-davinci-001",
"code-davinci-002",
"text-curie-001",
"text-babbage-001",
"text-ada-001",
]
azure_openai_models = ["gpt-35-turbo-instruct"]
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
raise InvokeAuthorizationError('Invalid base url')
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
raise InvokeAuthorizationError("Invalid base url")
if model in openai_models + azure_openai_models:
if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
# sometime, provider use OpenAI compatible API will not have api key or have different api key format
# so we only check if model is in openai_models
raise InvokeAuthorizationError('Invalid api key')
raise InvokeAuthorizationError("Invalid api key")
if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
raise InvokeAuthorizationError('Invalid api key')
raise InvokeAuthorizationError("Invalid api key")
if not prompt:
raise BadRequestError('Invalid prompt')
raise BadRequestError("Invalid prompt")
if stream:
return MockCompletionsClass.mocked_openai_completion_create_stream(model=model)
return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)
return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)

File diff suppressed because one or more lines are too long

View File

@ -10,58 +10,92 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
class MockModerationClass:
def moderation_create(self: Moderations,*,
def moderation_create(
self: Moderations,
*,
input: Union[str, list[str]],
model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN,
**kwargs: Any
**kwargs: Any,
) -> ModerationCreateResponse:
if isinstance(input, str):
input = [input]
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
raise InvokeAuthorizationError('Invalid base url')
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
raise InvokeAuthorizationError("Invalid base url")
if len(self._client.api_key) < 18:
raise InvokeAuthorizationError('Invalid API key')
raise InvokeAuthorizationError("Invalid API key")
for text in input:
result = []
if 'kill' in text:
if "kill" in text:
moderation_categories = {
'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False,
'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False,
'sexual/minors': False, 'violence': False, 'violence/graphic': False
"harassment": False,
"harassment/threatening": False,
"hate": False,
"hate/threatening": False,
"self-harm": False,
"self-harm/instructions": False,
"self-harm/intent": False,
"sexual": False,
"sexual/minors": False,
"violence": False,
"violence/graphic": False,
}
moderation_categories_scores = {
'harassment': 1.0, 'harassment/threatening': 1.0, 'hate': 1.0, 'hate/threatening': 1.0,
'self-harm': 1.0, 'self-harm/instructions': 1.0, 'self-harm/intent': 1.0, 'sexual': 1.0,
'sexual/minors': 1.0, 'violence': 1.0, 'violence/graphic': 1.0
"harassment": 1.0,
"harassment/threatening": 1.0,
"hate": 1.0,
"hate/threatening": 1.0,
"self-harm": 1.0,
"self-harm/instructions": 1.0,
"self-harm/intent": 1.0,
"sexual": 1.0,
"sexual/minors": 1.0,
"violence": 1.0,
"violence/graphic": 1.0,
}
result.append(Moderation(
flagged=True,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores)
))
result.append(
Moderation(
flagged=True,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores),
)
)
else:
moderation_categories = {
'harassment': False, 'harassment/threatening': False, 'hate': False, 'hate/threatening': False,
'self-harm': False, 'self-harm/instructions': False, 'self-harm/intent': False, 'sexual': False,
'sexual/minors': False, 'violence': False, 'violence/graphic': False
"harassment": False,
"harassment/threatening": False,
"hate": False,
"hate/threatening": False,
"self-harm": False,
"self-harm/instructions": False,
"self-harm/intent": False,
"sexual": False,
"sexual/minors": False,
"violence": False,
"violence/graphic": False,
}
moderation_categories_scores = {
'harassment': 0.0, 'harassment/threatening': 0.0, 'hate': 0.0, 'hate/threatening': 0.0,
'self-harm': 0.0, 'self-harm/instructions': 0.0, 'self-harm/intent': 0.0, 'sexual': 0.0,
'sexual/minors': 0.0, 'violence': 0.0, 'violence/graphic': 0.0
"harassment": 0.0,
"harassment/threatening": 0.0,
"hate": 0.0,
"hate/threatening": 0.0,
"self-harm": 0.0,
"self-harm/instructions": 0.0,
"self-harm/intent": 0.0,
"sexual": 0.0,
"sexual/minors": 0.0,
"violence": 0.0,
"violence/graphic": 0.0,
}
result.append(Moderation(
flagged=False,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores)
))
result.append(
Moderation(
flagged=False,
categories=Categories(**moderation_categories),
category_scores=CategoryScores(**moderation_categories_scores),
)
)
return ModerationCreateResponse(
id='shiroii kuloko',
model=model,
results=result
)
return ModerationCreateResponse(id="shiroii kuloko", model=model, results=result)

View File

@ -6,17 +6,18 @@ from openai.types.model import Model
class MockModelClass:
"""
mock class for openai.models.Models
mock class for openai.models.Models
"""
def list(
self,
**kwargs,
) -> list[Model]:
return [
Model(
id='ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ',
id="ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ",
created=int(time()),
object='model',
owned_by='organization:org-123',
object="model",
owned_by="organization:org-123",
)
]
]

View File

@ -9,7 +9,8 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
class MockSpeech2TextClass:
def speech2text_create(self: Transcriptions,
def speech2text_create(
self: Transcriptions,
*,
file: FileTypes,
model: Union[str, Literal["whisper-1"]],
@ -17,14 +18,12 @@ class MockSpeech2TextClass:
prompt: str | NotGiven = NOT_GIVEN,
response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
**kwargs: Any
**kwargs: Any,
) -> Transcription:
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
raise InvokeAuthorizationError('Invalid base url')
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._client.base_url.__str__()):
raise InvokeAuthorizationError("Invalid base url")
if len(self._client.api_key) < 18:
raise InvokeAuthorizationError('Invalid API key')
return Transcription(
text='1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
)
raise InvokeAuthorizationError("Invalid API key")
return Transcription(text="1, 2, 3, 4, 5, 6, 7, 8, 9, 10")

View File

@ -19,40 +19,43 @@ from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
class MockXinferenceClass:
def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url):
raise RuntimeError('404 Not Found')
if 'generate' == model_uid:
def get_chat_model(
self: Client, model_uid: str
) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url):
raise RuntimeError("404 Not Found")
if "generate" == model_uid:
return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'chat' == model_uid:
if "chat" == model_uid:
return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'embedding' == model_uid:
if "embedding" == model_uid:
return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
if 'rerank' == model_uid:
if "rerank" == model_uid:
return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
raise RuntimeError('404 Not Found')
raise RuntimeError("404 Not Found")
def get(self: Session, url: str, **kwargs):
response = Response()
if 'v1/models/' in url:
if "v1/models/" in url:
# get model uid
model_uid = url.split('/')[-1] or ''
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
model_uid = url.split("/")[-1] or ""
if not re.match(
r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid
) and model_uid not in ["generate", "chat", "embedding", "rerank"]:
response.status_code = 404
response._content = b'{}'
response._content = b"{}"
return response
# check if url is valid
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url):
response.status_code = 404
response._content = b'{}'
response._content = b"{}"
return response
if model_uid in ['generate', 'chat']:
if model_uid in ["generate", "chat"]:
response.status_code = 200
response._content = b'''{
response._content = b"""{
"model_type": "LLM",
"address": "127.0.0.1:43877",
"accelerators": [
@ -75,12 +78,12 @@ class MockXinferenceClass:
"revision": null,
"context_length": 2048,
"replica": 1
}'''
}"""
return response
elif model_uid == 'embedding':
elif model_uid == "embedding":
response.status_code = 200
response._content = b'''{
response._content = b"""{
"model_type": "embedding",
"address": "127.0.0.1:43877",
"accelerators": [
@ -93,51 +96,48 @@ class MockXinferenceClass:
],
"revision": null,
"max_tokens": 512
}'''
}"""
return response
elif 'v1/cluster/auth' in url:
elif "v1/cluster/auth" in url:
response.status_code = 200
response._content = b'''{
response._content = b"""{
"auth": true
}'''
}"""
return response
def _check_cluster_authenticated(self):
self._cluster_authed = True
def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict:
def rerank(
self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool
) -> dict:
# check if self._model_uid is a valid uuid
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
self._model_uid != 'rerank':
raise RuntimeError('404 Not Found')
if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url):
raise RuntimeError('404 Not Found')
if (
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
and self._model_uid != "rerank"
):
raise RuntimeError("404 Not Found")
if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url):
raise RuntimeError("404 Not Found")
if top_n is None:
top_n = 1
return {
'results': [
{
'index': i,
'document': doc,
'relevance_score': 0.9
}
for i, doc in enumerate(documents[:top_n])
"results": [
{"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n])
]
}
def create_embedding(
self: RESTfulGenerateModelHandle,
input: Union[str, list[str]],
**kwargs
) -> dict:
def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict:
# check if self._model_uid is a valid uuid
if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
self._model_uid != 'embedding':
raise RuntimeError('404 Not Found')
if (
not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
and self._model_uid != "embedding"
):
raise RuntimeError("404 Not Found")
if isinstance(input, str):
input = [input]
@ -147,32 +147,27 @@ class MockXinferenceClass:
object="list",
model=self._model_uid,
data=[
EmbeddingData(
index=i,
object="embedding",
embedding=[1919.810 for _ in range(768)]
)
EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)])
for i in range(ipt_len)
],
usage=EmbeddingUsage(
prompt_tokens=ipt_len,
total_tokens=ipt_len
)
usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len),
)
return embedding
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
if MOCK:
monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model)
monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated)
monkeypatch.setattr(Session, "get", MockXinferenceClass.get)
monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding)
monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank)
yield
if MOCK:
monkeypatch.undo()
monkeypatch.undo()

View File

@ -10,79 +10,60 @@ from core.model_runtime.model_providers.anthropic.llm.llm import AnthropicLargeL
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
def test_validate_credentials(setup_anthropic_mock):
model = AnthropicLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='claude-instant-1.2',
credentials={
'anthropic_api_key': 'invalid_key'
}
)
model.validate_credentials(model="claude-instant-1.2", credentials={"anthropic_api_key": "invalid_key"})
model.validate_credentials(
model='claude-instant-1.2',
credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
}
model="claude-instant-1.2", credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")}
)
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
def test_invoke_model(setup_anthropic_mock):
model = AnthropicLargeLanguageModel()
response = model.invoke(
model='claude-instant-1.2',
model="claude-instant-1.2",
credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY'),
'anthropic_api_url': os.environ.get('ANTHROPIC_API_URL')
"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY"),
"anthropic_api_url": os.environ.get("ANTHROPIC_API_URL"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'top_p': 1.0,
'max_tokens': 10
},
stop=['How'],
model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens": 10},
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
def test_invoke_stream_model(setup_anthropic_mock):
model = AnthropicLargeLanguageModel()
response = model.invoke(
model='claude-instant-1.2',
credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
},
model="claude-instant-1.2",
credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'max_tokens': 100
},
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -98,18 +79,14 @@ def test_get_num_tokens():
model = AnthropicLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='claude-instant-1.2',
credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
},
model="claude-instant-1.2",
credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 18

View File

@ -7,17 +7,11 @@ from core.model_runtime.model_providers.anthropic.anthropic import AnthropicProv
from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock
@pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_anthropic_mock", [["none"]], indirect=True)
def test_validate_provider_credentials(setup_anthropic_mock):
provider = AnthropicProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
'anthropic_api_key': os.environ.get('ANTHROPIC_API_KEY')
}
)
provider.validate_provider_credentials(credentials={"anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY")})

File diff suppressed because one or more lines are too long

View File

@ -8,45 +8,43 @@ from core.model_runtime.model_providers.azure_openai.text_embedding.text_embeddi
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
def test_validate_credentials(setup_openai_mock):
model = AzureOpenAITextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='embedding',
model="embedding",
credentials={
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
'openai_api_key': 'invalid_key',
'base_model_name': 'text-embedding-ada-002'
}
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
"openai_api_key": "invalid_key",
"base_model_name": "text-embedding-ada-002",
},
)
model.validate_credentials(
model='embedding',
model="embedding",
credentials={
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'),
'base_model_name': 'text-embedding-ada-002'
}
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
"openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
"base_model_name": "text-embedding-ada-002",
},
)
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
def test_invoke_model(setup_openai_mock):
model = AzureOpenAITextEmbeddingModel()
result = model.invoke(
model='embedding',
model="embedding",
credentials={
'openai_api_base': os.environ.get('AZURE_OPENAI_API_BASE'),
'openai_api_key': os.environ.get('AZURE_OPENAI_API_KEY'),
'base_model_name': 'text-embedding-ada-002'
"openai_api_base": os.environ.get("AZURE_OPENAI_API_BASE"),
"openai_api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
"base_model_name": "text-embedding-ada-002",
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
@ -58,14 +56,7 @@ def test_get_num_tokens():
model = AzureOpenAITextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='embedding',
credentials={
'base_model_name': 'text-embedding-ada-002'
},
texts=[
"hello",
"world"
]
model="embedding", credentials={"base_model_name": "text-embedding-ada-002"}, texts=["hello", "world"]
)
assert num_tokens == 2

View File

@ -17,111 +17,99 @@ def test_predefined_models():
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
def test_validate_credentials_for_chat_model():
sleep(3)
model = BaichuanLarguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='baichuan2-turbo',
credentials={
'api_key': 'invalid_key',
'secret_key': 'invalid_key'
}
model="baichuan2-turbo", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"}
)
model.validate_credentials(
model='baichuan2-turbo',
model="baichuan2-turbo",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
}
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
)
def test_invoke_model():
sleep(3)
model = BaichuanLarguageModel()
response = model.invoke(
model='baichuan2-turbo',
model="baichuan2-turbo",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_model_with_system_message():
sleep(3)
model = BaichuanLarguageModel()
response = model.invoke(
model='baichuan2-turbo',
model="baichuan2-turbo",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
prompt_messages=[
SystemPromptMessage(
content='请记住你是Kasumi。'
),
UserPromptMessage(
content='现在告诉我你是谁?'
)
SystemPromptMessage(content="请记住你是Kasumi。"),
UserPromptMessage(content="现在告诉我你是谁?"),
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_model():
sleep(3)
model = BaichuanLarguageModel()
response = model.invoke(
model='baichuan2-turbo',
model="baichuan2-turbo",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -131,34 +119,31 @@ def test_invoke_stream_model():
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_invoke_with_search():
sleep(3)
model = BaichuanLarguageModel()
response = model.invoke(
model='baichuan2-turbo',
model="baichuan2-turbo",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'
)
],
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
'with_search_enhance': True,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
"with_search_enhance": True,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
total_message = ''
total_message = ""
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
@ -166,25 +151,22 @@ def test_invoke_with_search():
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
total_message += chunk.delta.message.content
assert '' not in total_message
assert "" not in total_message
def test_get_num_tokens():
sleep(3)
model = BaichuanLarguageModel()
response = model.get_num_tokens(
model='baichuan2-turbo',
model="baichuan2-turbo",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
'secret_key': os.environ.get('BAICHUAN_SECRET_KEY')
"api_key": os.environ.get("BAICHUAN_API_KEY"),
"secret_key": os.environ.get("BAICHUAN_SECRET_KEY"),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
tools=[]
prompt_messages=[UserPromptMessage(content="Hello World!")],
tools=[],
)
assert isinstance(response, int)
assert response == 9
assert response == 9

View File

@ -10,14 +10,6 @@ def test_validate_provider_credentials():
provider = BaichuanProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
'api_key': 'hahahaha'
}
)
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
provider.validate_provider_credentials(
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY')
}
)
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")})

View File

@ -11,18 +11,10 @@ def test_validate_credentials():
model = BaichuanTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='baichuan-text-embedding',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(model="baichuan-text-embedding", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model='baichuan-text-embedding',
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY')
}
model="baichuan-text-embedding", credentials={"api_key": os.environ.get("BAICHUAN_API_KEY")}
)
@ -30,44 +22,40 @@ def test_invoke_model():
model = BaichuanTextEmbeddingModel()
result = model.invoke(
model='baichuan-text-embedding',
model="baichuan-text-embedding",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
"api_key": os.environ.get("BAICHUAN_API_KEY"),
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 6
def test_get_num_tokens():
model = BaichuanTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='baichuan-text-embedding',
model="baichuan-text-embedding",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
"api_key": os.environ.get("BAICHUAN_API_KEY"),
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert num_tokens == 2
def test_max_chunks():
model = BaichuanTextEmbeddingModel()
result = model.invoke(
model='baichuan-text-embedding',
model="baichuan-text-embedding",
credentials={
'api_key': os.environ.get('BAICHUAN_API_KEY'),
"api_key": os.environ.get("BAICHUAN_API_KEY"),
},
texts=[
"hello",
@ -92,8 +80,8 @@ def test_max_chunks():
"world",
"hello",
"world",
]
],
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 22
assert len(result.embeddings) == 22

View File

@ -13,77 +13,63 @@ def test_validate_credentials():
model = BedrockLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='meta.llama2-13b-chat-v1',
credentials={
'anthropic_api_key': 'invalid_key'
}
)
model.validate_credentials(model="meta.llama2-13b-chat-v1", credentials={"anthropic_api_key": "invalid_key"})
model.validate_credentials(
model='meta.llama2-13b-chat-v1',
model="meta.llama2-13b-chat-v1",
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
}
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
},
)
def test_invoke_model():
model = BedrockLargeLanguageModel()
response = model.invoke(
model='meta.llama2-13b-chat-v1',
model="meta.llama2-13b-chat-v1",
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'top_p': 1.0,
'max_tokens_to_sample': 10
},
stop=['How'],
model_parameters={"temperature": 0.0, "top_p": 1.0, "max_tokens_to_sample": 10},
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
def test_invoke_stream_model():
model = BedrockLargeLanguageModel()
response = model.invoke(
model='meta.llama2-13b-chat-v1',
model="meta.llama2-13b-chat-v1",
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'max_tokens_to_sample': 100
},
model_parameters={"temperature": 0.0, "max_tokens_to_sample": 100},
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -100,20 +86,18 @@ def test_get_num_tokens():
model = BedrockLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='meta.llama2-13b-chat-v1',
credentials = {
model="meta.llama2-13b-chat-v1",
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
},
messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 18

View File

@ -10,14 +10,12 @@ def test_validate_provider_credentials():
provider = BedrockProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
}
)

View File

@ -23,79 +23,64 @@ def test_predefined_models():
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_validate_credentials_for_chat_model(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='chatglm2-6b',
credentials={
'api_base': 'invalid_key'
}
)
model.validate_credentials(model="chatglm2-6b", credentials={"api_base": "invalid_key"})
model.validate_credentials(
model='chatglm2-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
}
)
model.validate_credentials(model="chatglm2-6b", credentials={"api_base": os.environ.get("CHATGLM_API_BASE")})
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_model(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
response = model.invoke(
model='chatglm2-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
model="chatglm2-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_stream_model(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
response = model.invoke(
model='chatglm2-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
model="chatglm2-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -105,56 +90,45 @@ def test_invoke_stream_model(setup_openai_mock):
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_stream_model_with_functions(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
response = model.invoke(
model='chatglm3-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
model="chatglm3-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content='你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。'
content="你是一个天气机器人,你不知道今天的天气怎么样,你需要通过调用一个函数来获取天气信息。"
),
UserPromptMessage(
content='波士顿天气如何?'
)
UserPromptMessage(content="波士顿天气如何?"),
],
model_parameters={
'temperature': 0,
'top_p': 1.0,
"temperature": 0,
"top_p": 1.0,
},
stop=['you'],
user='abc-123',
stop=["you"],
user="abc-123",
stream=True,
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": [
"location"
]
}
"required": ["location"],
},
)
]
],
)
assert isinstance(response, Generator)
call: LLMResultChunk = None
chunks = []
@ -170,122 +144,87 @@ def test_invoke_stream_model_with_functions(setup_openai_mock):
break
assert call is not None
assert call.delta.message.tool_calls[0].function.name == 'get_current_weather'
assert call.delta.message.tool_calls[0].function.name == "get_current_weather"
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_model_with_functions(setup_openai_mock):
model = ChatGLMLargeLanguageModel()
response = model.invoke(
model='chatglm3-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
prompt_messages=[
UserPromptMessage(
content='What is the weather like in San Francisco?'
)
],
model="chatglm3-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[UserPromptMessage(content="What is the weather like in San Francisco?")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
user='abc-123',
stop=["you"],
user="abc-123",
stream=False,
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"c",
"f"
]
}
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": [
"location"
]
}
"required": ["location"],
},
)
]
],
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
assert response.message.tool_calls[0].function.name == 'get_current_weather'
assert response.message.tool_calls[0].function.name == "get_current_weather"
def test_get_num_tokens():
model = ChatGLMLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='chatglm2-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
model="chatglm2-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"c",
"f"
]
}
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": [
"location"
]
}
"required": ["location"],
},
)
]
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 77
num_tokens = model.get_num_tokens(
model='chatglm2-6b',
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
},
model="chatglm2-6b",
credentials={"api_base": os.environ.get("CHATGLM_API_BASE")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 21
assert num_tokens == 21

View File

@ -7,19 +7,11 @@ from core.model_runtime.model_providers.chatglm.chatglm import ChatGLMProvider
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_validate_provider_credentials(setup_openai_mock):
provider = ChatGLMProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
'api_base': 'hahahaha'
}
)
provider.validate_provider_credentials(credentials={"api_base": "hahahaha"})
provider.validate_provider_credentials(
credentials={
'api_base': os.environ.get('CHATGLM_API_BASE')
}
)
provider.validate_provider_credentials(credentials={"api_base": os.environ.get("CHATGLM_API_BASE")})

View File

@ -13,87 +13,49 @@ def test_validate_credentials_for_chat_model():
model = CohereLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='command-light-chat',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(model="command-light-chat", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model='command-light-chat',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
}
)
model.validate_credentials(model="command-light-chat", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
def test_validate_credentials_for_completion_model():
model = CohereLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='command-light',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(model="command-light", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model='command-light',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
}
)
model.validate_credentials(model="command-light", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
def test_invoke_completion_model():
model = CohereLargeLanguageModel()
credentials = {
'api_key': os.environ.get('COHERE_API_KEY')
}
credentials = {"api_key": os.environ.get("COHERE_API_KEY")}
result = model.invoke(
model='command-light',
model="command-light",
credentials=credentials,
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.0,
'max_tokens': 1
},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.0, "max_tokens": 1},
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0
assert model._num_tokens_from_string('command-light', credentials, result.message.content) == 1
assert model._num_tokens_from_string("command-light", credentials, result.message.content) == 1
def test_invoke_stream_completion_model():
model = CohereLargeLanguageModel()
result = model.invoke(
model='command-light',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.0,
'max_tokens': 100
},
model="command-light",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(result, Generator)
@ -109,28 +71,24 @@ def test_invoke_chat_model():
model = CohereLargeLanguageModel()
result = model.invoke(
model='command-light-chat',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
model="command-light-chat",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'p': 0.99,
'presence_penalty': 0.0,
'frequency_penalty': 0.0,
'max_tokens': 10
"temperature": 0.0,
"p": 0.99,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"max_tokens": 10,
},
stop=['How'],
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(result, LLMResult)
@ -141,24 +99,17 @@ def test_invoke_stream_chat_model():
model = CohereLargeLanguageModel()
result = model.invoke(
model='command-light-chat',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
model="command-light-chat",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'max_tokens': 100
},
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(result, Generator)
@ -177,32 +128,22 @@ def test_get_num_tokens():
model = CohereLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='command-light',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
]
model="command-light",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
)
assert num_tokens == 3
num_tokens = model.get_num_tokens(
model='command-light-chat',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
model="command-light-chat",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 15
@ -213,25 +154,17 @@ def test_fine_tuned_model():
# test invoke
result = model.invoke(
model='85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft',
credentials={
'api_key': os.environ.get('COHERE_API_KEY'),
'mode': 'completion'
},
model="85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft",
credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "completion"},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'max_tokens': 100
},
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(result, LLMResult)
@ -242,25 +175,17 @@ def test_fine_tuned_chat_model():
# test invoke
result = model.invoke(
model='94f2d55a-4c79-4c00-bde4-23962e74b170-ft',
credentials={
'api_key': os.environ.get('COHERE_API_KEY'),
'mode': 'chat'
},
model="94f2d55a-4c79-4c00-bde4-23962e74b170-ft",
credentials={"api_key": os.environ.get("COHERE_API_KEY"), "mode": "chat"},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'max_tokens': 100
},
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(result, LLMResult)

View File

@ -10,12 +10,6 @@ def test_validate_provider_credentials():
provider = CohereProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
}
)
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("COHERE_API_KEY")})

View File

@ -11,29 +11,17 @@ def test_validate_credentials():
model = CohereRerankModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='rerank-english-v2.0',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model='rerank-english-v2.0',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
}
)
model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
def test_invoke_model():
model = CohereRerankModel()
result = model.invoke(
model='rerank-english-v2.0',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
model="rerank-english-v2.0",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
@ -41,9 +29,9 @@ def test_invoke_model():
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) "
"is the capital of the United States. It is a federal district. The President of the USA and many major "
"national government offices are in the territory. This makes it the political center of the United "
"States of America."
"States of America.",
],
score_threshold=0.8
score_threshold=0.8,
)
assert isinstance(result, RerankResult)

View File

@ -11,18 +11,10 @@ def test_validate_credentials():
model = CohereTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='embed-multilingual-v3.0',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(model="embed-multilingual-v3.0", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model='embed-multilingual-v3.0',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
}
model="embed-multilingual-v3.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")}
)
@ -30,17 +22,10 @@ def test_invoke_model():
model = CohereTextEmbeddingModel()
result = model.invoke(
model='embed-multilingual-v3.0',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
texts=[
"hello",
"world",
" ".join(["long_text"] * 100),
" ".join(["another_long_text"] * 100)
],
user="abc-123"
model="embed-multilingual-v3.0",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
@ -52,14 +37,9 @@ def test_get_num_tokens():
model = CohereTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='embed-multilingual-v3.0',
credentials={
'api_key': os.environ.get('COHERE_API_KEY')
},
texts=[
"hello",
"world"
]
model="embed-multilingual-v3.0",
credentials={"api_key": os.environ.get("COHERE_API_KEY")},
texts=["hello", "world"],
)
assert num_tokens == 3

File diff suppressed because one or more lines are too long

View File

@ -7,17 +7,11 @@ from core.model_runtime.model_providers.google.google import GoogleProvider
from tests.integration_tests.model_runtime.__mock.google import setup_google_mock
@pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_google_mock", [["none"]], indirect=True)
def test_validate_provider_credentials(setup_google_mock):
provider = GoogleProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
'google_api_key': os.environ.get('GOOGLE_API_KEY')
}
)
provider.validate_provider_credentials(credentials={"google_api_key": os.environ.get("GOOGLE_API_KEY")})

View File

@ -10,87 +10,75 @@ from core.model_runtime.model_providers.huggingface_hub.llm.llm import Huggingfa
from tests.integration_tests.model_runtime.__mock.huggingface import setup_huggingface_mock
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_hosted_inference_api_validate_credentials(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='HuggingFaceH4/zephyr-7b-beta',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': 'invalid_key'
}
model="HuggingFaceH4/zephyr-7b-beta",
credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
)
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='fake-model',
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': 'invalid_key'
}
model="fake-model",
credentials={"huggingfacehub_api_type": "hosted_inference_api", "huggingfacehub_api_token": "invalid_key"},
)
model.validate_credentials(
model='HuggingFaceH4/zephyr-7b-beta',
model="HuggingFaceH4/zephyr-7b-beta",
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
}
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
},
)
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_hosted_inference_api_invoke_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='HuggingFaceH4/zephyr-7b-beta',
model="HuggingFaceH4/zephyr-7b-beta",
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='HuggingFaceH4/zephyr-7b-beta',
model="HuggingFaceH4/zephyr-7b-beta",
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY')
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -101,86 +89,81 @@ def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock):
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='openchat/openchat_3.5',
model="openchat/openchat_3.5",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'invalid_key',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
}
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": "invalid_key",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
"task_type": "text-generation",
},
)
model.validate_credentials(
model='openchat/openchat_3.5',
model="openchat/openchat_3.5",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
}
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
"task_type": "text-generation",
},
)
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='openchat/openchat_3.5',
model="openchat/openchat_3.5",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
"task_type": "text-generation",
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='openchat/openchat_3.5',
model="openchat/openchat_3.5",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text-generation'
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT_GEN_ENDPOINT_URL"),
"task_type": "text-generation",
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -191,86 +174,81 @@ def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingfa
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='google/mt5-base',
model="google/mt5-base",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'invalid_key',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
}
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": "invalid_key",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
"task_type": "text2text-generation",
},
)
model.validate_credentials(
model='google/mt5-base',
model="google/mt5-base",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
}
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
"task_type": "text2text-generation",
},
)
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='google/mt5-base',
model="google/mt5-base",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
"task_type": "text2text-generation",
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
@pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_huggingface_mock", [["none"]], indirect=True)
def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock):
model = HuggingfaceHubLargeLanguageModel()
response = model.invoke(
model='google/mt5-base',
model="google/mt5-base",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
"task_type": "text2text-generation",
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -286,18 +264,14 @@ def test_get_num_tokens():
model = HuggingfaceHubLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='google/mt5-base',
model="google/mt5-base",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL'),
'task_type': 'text2text-generation'
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL"),
"task_type": "text2text-generation",
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
]
prompt_messages=[UserPromptMessage(content="Hello World!")],
)
assert num_tokens == 7

View File

@ -14,19 +14,19 @@ def test_hosted_inference_api_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='facebook/bart-base',
model="facebook/bart-base",
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': 'invalid_key',
}
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": "invalid_key",
},
)
model.validate_credentials(
model='facebook/bart-base',
model="facebook/bart-base",
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
}
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
},
)
@ -34,15 +34,12 @@ def test_hosted_inference_api_invoke_model():
model = HuggingfaceHubTextEmbeddingModel()
result = model.invoke(
model='facebook/bart-base',
model="facebook/bart-base",
credentials={
'huggingfacehub_api_type': 'hosted_inference_api',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
"huggingfacehub_api_type": "hosted_inference_api",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert isinstance(result, TextEmbeddingResult)
@ -55,25 +52,25 @@ def test_inference_endpoints_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='all-MiniLM-L6-v2',
model="all-MiniLM-L6-v2",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': 'invalid_key',
'huggingface_namespace': 'Dify-AI',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
'task_type': 'feature-extraction'
}
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": "invalid_key",
"huggingface_namespace": "Dify-AI",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
"task_type": "feature-extraction",
},
)
model.validate_credentials(
model='all-MiniLM-L6-v2',
model="all-MiniLM-L6-v2",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingface_namespace': 'Dify-AI',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
'task_type': 'feature-extraction'
}
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingface_namespace": "Dify-AI",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
"task_type": "feature-extraction",
},
)
@ -81,18 +78,15 @@ def test_inference_endpoints_invoke_model():
model = HuggingfaceHubTextEmbeddingModel()
result = model.invoke(
model='all-MiniLM-L6-v2',
model="all-MiniLM-L6-v2",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingface_namespace': 'Dify-AI',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
'task_type': 'feature-extraction'
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingface_namespace": "Dify-AI",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
"task_type": "feature-extraction",
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert isinstance(result, TextEmbeddingResult)
@ -104,18 +98,15 @@ def test_get_num_tokens():
model = HuggingfaceHubTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='all-MiniLM-L6-v2',
model="all-MiniLM-L6-v2",
credentials={
'huggingfacehub_api_type': 'inference_endpoints',
'huggingfacehub_api_token': os.environ.get('HUGGINGFACE_API_KEY'),
'huggingface_namespace': 'Dify-AI',
'huggingfacehub_endpoint_url': os.environ.get('HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL'),
'task_type': 'feature-extraction'
"huggingfacehub_api_type": "inference_endpoints",
"huggingfacehub_api_token": os.environ.get("HUGGINGFACE_API_KEY"),
"huggingface_namespace": "Dify-AI",
"huggingfacehub_endpoint_url": os.environ.get("HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL"),
"task_type": "feature-extraction",
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert num_tokens == 2

View File

@ -10,61 +10,59 @@ from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embe
)
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
if MOCK:
monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter)
monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize)
monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings)
monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank)
yield
if MOCK:
monkeypatch.undo()
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
def test_validate_credentials(setup_tei_mock):
model = HuggingfaceTeiTextEmbeddingModel()
# model name is only used in mock
model_name = 'embedding'
model_name = "embedding"
if MOCK:
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
# So we dont need to check model type here. Only check in mock
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='reranker',
model="reranker",
credentials={
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
}
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
},
)
model.validate_credentials(
model=model_name,
credentials={
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
}
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
},
)
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
def test_invoke_model(setup_tei_mock):
model = HuggingfaceTeiTextEmbeddingModel()
model_name = 'embedding'
model_name = "embedding"
result = model.invoke(
model=model_name,
credentials={
'server_url': os.environ.get('TEI_EMBEDDING_SERVER_URL', ""),
"server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""),
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)

View File

@ -11,63 +11,65 @@ from core.model_runtime.model_providers.huggingface_tei.rerank.rerank import (
from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelper
from tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClass
MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):
if MOCK:
monkeypatch.setattr(TeiHelper, 'get_tei_extra_parameter', MockTEIClass.get_tei_extra_parameter)
monkeypatch.setattr(TeiHelper, 'invoke_tokenize', MockTEIClass.invoke_tokenize)
monkeypatch.setattr(TeiHelper, 'invoke_embeddings', MockTEIClass.invoke_embeddings)
monkeypatch.setattr(TeiHelper, 'invoke_rerank', MockTEIClass.invoke_rerank)
monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter)
monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize)
monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings)
monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank)
yield
if MOCK:
monkeypatch.undo()
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
def test_validate_credentials(setup_tei_mock):
model = HuggingfaceTeiRerankModel()
# model name is only used in mock
model_name = 'reranker'
model_name = "reranker"
if MOCK:
# TEI Provider will check model type by API endpoint, at real server, the model type is correct.
# So we dont need to check model type here. Only check in mock
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='embedding',
model="embedding",
credentials={
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
}
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
},
)
model.validate_credentials(
model=model_name,
credentials={
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
}
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
},
)
@pytest.mark.parametrize('setup_tei_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)
def test_invoke_model(setup_tei_mock):
model = HuggingfaceTeiRerankModel()
# model name is only used in mock
model_name = 'reranker'
model_name = "reranker"
result = model.invoke(
model=model_name,
credentials={
'server_url': os.environ.get('TEI_RERANK_SERVER_URL'),
"server_url": os.environ.get("TEI_RERANK_SERVER_URL"),
},
query="Who is Kasumi?",
docs=[
"Kasumi is a girl's name of Japanese origin meaning \"mist\".",
'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
"and she leads a team named PopiParty."
"and she leads a team named PopiParty.",
],
score_threshold=0.8
score_threshold=0.8,
)
assert isinstance(result, RerankResult)

View File

@ -14,19 +14,15 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='hunyuan-standard',
credentials={
'secret_id': 'invalid_key',
'secret_key': 'invalid_key'
}
model="hunyuan-standard", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}
)
model.validate_credentials(
model='hunyuan-standard',
model="hunyuan-standard",
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
}
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
)
@ -34,23 +30,16 @@ def test_invoke_model():
model = HunyuanLargeLanguageModel()
response = model.invoke(
model='hunyuan-standard',
model="hunyuan-standard",
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
prompt_messages=[
UserPromptMessage(
content='Hi'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 10
},
stop=['How'],
prompt_messages=[UserPromptMessage(content="Hi")],
model_parameters={"temperature": 0.5, "max_tokens": 10},
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
@ -61,23 +50,15 @@ def test_invoke_stream_model():
model = HunyuanLargeLanguageModel()
response = model.invoke(
model='hunyuan-standard',
model="hunyuan-standard",
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hi'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 100,
'seed': 1234
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
prompt_messages=[UserPromptMessage(content="Hi")],
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -93,19 +74,17 @@ def test_get_num_tokens():
model = HunyuanLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='hunyuan-standard',
model="hunyuan-standard",
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 14

View File

@ -10,16 +10,11 @@ def test_validate_provider_credentials():
provider = HunyuanProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
'secret_id': 'invalid_key',
'secret_key': 'invalid_key'
}
)
provider.validate_provider_credentials(credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"})
provider.validate_provider_credentials(
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
}
)

View File

@ -12,19 +12,15 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='hunyuan-embedding',
credentials={
'secret_id': 'invalid_key',
'secret_key': 'invalid_key'
}
model="hunyuan-embedding", credentials={"secret_id": "invalid_key", "secret_key": "invalid_key"}
)
model.validate_credentials(
model='hunyuan-embedding',
model="hunyuan-embedding",
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
}
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
)
@ -32,47 +28,43 @@ def test_invoke_model():
model = HunyuanTextEmbeddingModel()
result = model.invoke(
model='hunyuan-embedding',
model="hunyuan-embedding",
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 6
def test_get_num_tokens():
model = HunyuanTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='hunyuan-embedding',
model="hunyuan-embedding",
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert num_tokens == 2
def test_max_chunks():
model = HunyuanTextEmbeddingModel()
result = model.invoke(
model='hunyuan-embedding',
model="hunyuan-embedding",
credentials={
'secret_id': os.environ.get('HUNYUAN_SECRET_ID'),
'secret_key': os.environ.get('HUNYUAN_SECRET_KEY')
"secret_id": os.environ.get("HUNYUAN_SECRET_ID"),
"secret_key": os.environ.get("HUNYUAN_SECRET_KEY"),
},
texts=[
"hello",
@ -97,8 +89,8 @@ def test_max_chunks():
"world",
"hello",
"world",
]
],
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 22
assert len(result.embeddings) == 22

View File

@ -10,14 +10,6 @@ def test_validate_provider_credentials():
provider = JinaProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
'api_key': 'hahahaha'
}
)
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
provider.validate_provider_credentials(
credentials={
'api_key': os.environ.get('JINA_API_KEY')
}
)
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("JINA_API_KEY")})

View File

@ -11,18 +11,10 @@ def test_validate_credentials():
model = JinaTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='jina-embeddings-v2-base-en',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(model="jina-embeddings-v2-base-en", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model='jina-embeddings-v2-base-en',
credentials={
'api_key': os.environ.get('JINA_API_KEY')
}
model="jina-embeddings-v2-base-en", credentials={"api_key": os.environ.get("JINA_API_KEY")}
)
@ -30,15 +22,12 @@ def test_invoke_model():
model = JinaTextEmbeddingModel()
result = model.invoke(
model='jina-embeddings-v2-base-en',
model="jina-embeddings-v2-base-en",
credentials={
'api_key': os.environ.get('JINA_API_KEY'),
"api_key": os.environ.get("JINA_API_KEY"),
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
@ -50,14 +39,11 @@ def test_get_num_tokens():
model = JinaTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='jina-embeddings-v2-base-en',
model="jina-embeddings-v2-base-en",
credentials={
'api_key': os.environ.get('JINA_API_KEY'),
"api_key": os.environ.get("JINA_API_KEY"),
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert num_tokens == 6

View File

@ -1,4 +1,4 @@
"""
LocalAI Embedding Interface is temporarily unavailable due to
we could not find a way to test it for now.
"""
LocalAI Embedding Interface is temporarily unavailable due to
we could not find a way to test it for now.
"""

View File

@ -21,99 +21,78 @@ def test_validate_credentials_for_chat_model():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='chinese-llama-2-7b',
model="chinese-llama-2-7b",
credentials={
'server_url': 'hahahaha',
'completion_type': 'completion',
}
"server_url": "hahahaha",
"completion_type": "completion",
},
)
model.validate_credentials(
model='chinese-llama-2-7b',
model="chinese-llama-2-7b",
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'completion',
}
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "completion",
},
)
def test_invoke_completion_model():
model = LocalAILanguageModel()
response = model.invoke(
model='chinese-llama-2-7b',
model="chinese-llama-2-7b",
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'completion',
},
prompt_messages=[
UserPromptMessage(
content='ping'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'max_tokens': 10
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "completion",
},
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=[],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_chat_model():
model = LocalAILanguageModel()
response = model.invoke(
model='chinese-llama-2-7b',
model="chinese-llama-2-7b",
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'chat_completion',
},
prompt_messages=[
UserPromptMessage(
content='ping'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'max_tokens': 10
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "chat_completion",
},
prompt_messages=[UserPromptMessage(content="ping")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=[],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_completion_model():
model = LocalAILanguageModel()
response = model.invoke(
model='chinese-llama-2-7b',
model="chinese-llama-2-7b",
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'completion',
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "completion",
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'max_tokens': 10
},
stop=['you'],
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -123,28 +102,21 @@ def test_invoke_stream_completion_model():
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_invoke_stream_chat_model():
model = LocalAILanguageModel()
response = model.invoke(
model='chinese-llama-2-7b',
model="chinese-llama-2-7b",
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'chat_completion',
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "chat_completion",
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'max_tokens': 10
},
stop=['you'],
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.7, "top_p": 1.0, "max_tokens": 10},
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -154,64 +126,48 @@ def test_invoke_stream_chat_model():
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = LocalAILanguageModel()
num_tokens = model.get_num_tokens(
model='????',
model="????",
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'chat_completion',
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "chat_completion",
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"c",
"f"
]
}
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": [
"location"
]
}
"required": ["location"],
},
)
]
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 77
num_tokens = model.get_num_tokens(
model='????',
model="????",
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'chat_completion',
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "chat_completion",
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
prompt_messages=[UserPromptMessage(content="Hello World!")],
)
assert isinstance(num_tokens, int)

View File

@ -12,30 +12,29 @@ def test_validate_credentials_for_chat_model():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='bge-reranker-v2-m3',
model="bge-reranker-v2-m3",
credentials={
'server_url': 'hahahaha',
'completion_type': 'completion',
}
"server_url": "hahahaha",
"completion_type": "completion",
},
)
model.validate_credentials(
model='bge-reranker-base',
model="bge-reranker-base",
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL'),
'completion_type': 'completion',
}
"server_url": os.environ.get("LOCALAI_SERVER_URL"),
"completion_type": "completion",
},
)
def test_invoke_rerank_model():
model = LocalaiRerankModel()
response = model.invoke(
model='bge-reranker-base',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL')
},
query='Organic skincare products for sensitive skin',
model="bge-reranker-base",
credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")},
query="Organic skincare products for sensitive skin",
docs=[
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
@ -45,43 +44,38 @@ def test_invoke_rerank_model():
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"Yoga mats made from recycled materials"
"Yoga mats made from recycled materials",
],
top_n=3,
score_threshold=0.75,
user="abc-123"
user="abc-123",
)
assert isinstance(response, RerankResult)
assert len(response.docs) == 3
def test__invoke():
model = LocalaiRerankModel()
# Test case 1: Empty docs
result = model._invoke(
model='bge-reranker-base',
credentials={
'server_url': 'https://example.com',
'api_key': '1234567890'
},
query='Organic skincare products for sensitive skin',
model="bge-reranker-base",
credentials={"server_url": "https://example.com", "api_key": "1234567890"},
query="Organic skincare products for sensitive skin",
docs=[],
top_n=3,
score_threshold=0.75,
user="abc-123"
user="abc-123",
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 0
# Test case 2: Valid invocation
result = model._invoke(
model='bge-reranker-base',
credentials={
'server_url': 'https://example.com',
'api_key': '1234567890'
},
query='Organic skincare products for sensitive skin',
model="bge-reranker-base",
credentials={"server_url": "https://example.com", "api_key": "1234567890"},
query="Organic skincare products for sensitive skin",
docs=[
"Eco-friendly kitchenware for modern homes",
"Biodegradable cleaning supplies for eco-conscious consumers",
@ -91,12 +85,12 @@ def test__invoke():
"Sustainable gardening tools and compost solutions",
"Sensitive skin-friendly facial cleansers and toners",
"Organic food wraps and storage solutions",
"Yoga mats made from recycled materials"
"Yoga mats made from recycled materials",
],
top_n=3,
score_threshold=0.75,
user="abc-123"
user="abc-123",
)
assert isinstance(result, RerankResult)
assert len(result.docs) == 3
assert all(isinstance(doc, RerankDocument) for doc in result.docs)
assert all(isinstance(doc, RerankDocument) for doc in result.docs)

View File

@ -10,19 +10,9 @@ def test_validate_credentials():
model = LocalAISpeech2text()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='whisper-1',
credentials={
'server_url': 'invalid_url'
}
)
model.validate_credentials(model="whisper-1", credentials={"server_url": "invalid_url"})
model.validate_credentials(
model='whisper-1',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL')
}
)
model.validate_credentials(model="whisper-1", credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")})
def test_invoke_model():
@ -32,23 +22,21 @@ def test_invoke_model():
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, 'audio.mp3')
audio_file_path = os.path.join(assets_dir, "audio.mp3")
# Open the file and get the file object
with open(audio_file_path, 'rb') as audio_file:
with open(audio_file_path, "rb") as audio_file:
file = audio_file
result = model.invoke(
model='whisper-1',
credentials={
'server_url': os.environ.get('LOCALAI_SERVER_URL')
},
model="whisper-1",
credentials={"server_url": os.environ.get("LOCALAI_SERVER_URL")},
file=file,
user="abc-123"
user="abc-123",
)
assert isinstance(result, str)
assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"

View File

@ -12,54 +12,47 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='embo-01',
credentials={
'minimax_api_key': 'invalid_key',
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
}
model="embo-01",
credentials={"minimax_api_key": "invalid_key", "minimax_group_id": os.environ.get("MINIMAX_GROUP_ID")},
)
model.validate_credentials(
model='embo-01',
model="embo-01",
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
}
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
)
def test_invoke_model():
model = MinimaxTextEmbeddingModel()
result = model.invoke(
model='embo-01',
model="embo-01",
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 16
def test_get_num_tokens():
model = MinimaxTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='embo-01',
model="embo-01",
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert num_tokens == 2

View File

@ -17,79 +17,70 @@ def test_predefined_models():
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
def test_validate_credentials_for_chat_model():
sleep(3)
model = MinimaxLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='abab5.5-chat',
credentials={
'minimax_api_key': 'invalid_key',
'minimax_group_id': 'invalid_key'
}
model="abab5.5-chat", credentials={"minimax_api_key": "invalid_key", "minimax_group_id": "invalid_key"}
)
model.validate_credentials(
model='abab5.5-chat',
model="abab5.5-chat",
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
}
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
)
def test_invoke_model():
sleep(3)
model = MinimaxLargeLanguageModel()
response = model.invoke(
model='abab5-chat',
model="abab5-chat",
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_model():
sleep(3)
model = MinimaxLargeLanguageModel()
response = model.invoke(
model='abab5.5-chat',
model="abab5.5-chat",
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -99,34 +90,31 @@ def test_invoke_stream_model():
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_invoke_with_search():
sleep(3)
model = MinimaxLargeLanguageModel()
response = model.invoke(
model='abab5.5-chat',
model="abab5.5-chat",
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'
)
],
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
'plugin_web_search': True,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
"plugin_web_search": True,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
total_message = ''
total_message = ""
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
@ -134,25 +122,22 @@ def test_invoke_with_search():
total_message += chunk.delta.message.content
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
assert '参考资料' in total_message
assert "参考资料" in total_message
def test_get_num_tokens():
sleep(3)
model = MinimaxLargeLanguageModel()
response = model.get_num_tokens(
model='abab5.5-chat',
model="abab5.5-chat",
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID')
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
tools=[]
prompt_messages=[UserPromptMessage(content="Hello World!")],
tools=[],
)
assert isinstance(response, int)
assert response == 30
assert response == 30

View File

@ -12,14 +12,14 @@ def test_validate_provider_credentials():
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
'minimax_api_key': 'hahahaha',
'minimax_group_id': '123',
"minimax_api_key": "hahahaha",
"minimax_group_id": "123",
}
)
provider.validate_provider_credentials(
credentials={
'minimax_api_key': os.environ.get('MINIMAX_API_KEY'),
'minimax_group_id': os.environ.get('MINIMAX_GROUP_ID'),
"minimax_api_key": os.environ.get("MINIMAX_API_KEY"),
"minimax_group_id": os.environ.get("MINIMAX_GROUP_ID"),
}
)

View File

@ -19,19 +19,12 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='meta-llama/llama-3-8b-instruct',
credentials={
'api_key': 'invalid_key',
'mode': 'chat'
}
model="meta-llama/llama-3-8b-instruct", credentials={"api_key": "invalid_key", "mode": "chat"}
)
model.validate_credentials(
model='meta-llama/llama-3-8b-instruct',
credentials={
'api_key': os.environ.get('NOVITA_API_KEY'),
'mode': 'chat'
}
model="meta-llama/llama-3-8b-instruct",
credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"},
)
@ -39,27 +32,22 @@ def test_invoke_model():
model = NovitaLargeLanguageModel()
response = model.invoke(
model='meta-llama/llama-3-8b-instruct',
credentials={
'api_key': os.environ.get('NOVITA_API_KEY'),
'mode': 'completion'
},
model="meta-llama/llama-3-8b-instruct",
credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "completion"},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Who are you?'
)
UserPromptMessage(content="Who are you?"),
],
model_parameters={
'temperature': 1.0,
'top_p': 0.5,
'max_tokens': 10,
"temperature": 1.0,
"top_p": 0.5,
"max_tokens": 10,
},
stop=['How'],
stop=["How"],
stream=False,
user="novita"
user="novita",
)
assert isinstance(response, LLMResult)
@ -70,27 +58,17 @@ def test_invoke_stream_model():
model = NovitaLargeLanguageModel()
response = model.invoke(
model='meta-llama/llama-3-8b-instruct',
credentials={
'api_key': os.environ.get('NOVITA_API_KEY'),
'mode': 'chat'
},
model="meta-llama/llama-3-8b-instruct",
credentials={"api_key": os.environ.get("NOVITA_API_KEY"), "mode": "chat"},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Who are you?'
)
UserPromptMessage(content="Who are you?"),
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
'max_tokens': 100
},
model_parameters={"temperature": 1.0, "top_k": 2, "top_p": 0.5, "max_tokens": 100},
stream=True,
user="novita"
user="novita",
)
assert isinstance(response, Generator)
@ -105,18 +83,16 @@ def test_get_num_tokens():
model = NovitaLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='meta-llama/llama-3-8b-instruct',
model="meta-llama/llama-3-8b-instruct",
credentials={
'api_key': os.environ.get('NOVITA_API_KEY'),
"api_key": os.environ.get("NOVITA_API_KEY"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert isinstance(num_tokens, int)

View File

@ -10,12 +10,10 @@ def test_validate_provider_credentials():
provider = NovitaProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
'api_key': os.environ.get('NOVITA_API_KEY'),
"api_key": os.environ.get("NOVITA_API_KEY"),
}
)

File diff suppressed because one or more lines are too long

View File

@ -12,21 +12,21 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='mistral:text',
model="mistral:text",
credentials={
'base_url': 'http://localhost:21434',
'mode': 'chat',
'context_size': 4096,
}
"base_url": "http://localhost:21434",
"mode": "chat",
"context_size": 4096,
},
)
model.validate_credentials(
model='mistral:text',
model="mistral:text",
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'chat',
'context_size': 4096,
}
"base_url": os.environ.get("OLLAMA_BASE_URL"),
"mode": "chat",
"context_size": 4096,
},
)
@ -34,17 +34,14 @@ def test_invoke_model():
model = OllamaEmbeddingModel()
result = model.invoke(
model='mistral:text',
model="mistral:text",
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'chat',
'context_size': 4096,
"base_url": os.environ.get("OLLAMA_BASE_URL"),
"mode": "chat",
"context_size": 4096,
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
@ -56,16 +53,13 @@ def test_get_num_tokens():
model = OllamaEmbeddingModel()
num_tokens = model.get_num_tokens(
model='mistral:text',
model="mistral:text",
credentials={
'base_url': os.environ.get('OLLAMA_BASE_URL'),
'mode': 'chat',
'context_size': 4096,
"base_url": os.environ.get("OLLAMA_BASE_URL"),
"mode": "chat",
"context_size": 4096,
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert num_tokens == 2

File diff suppressed because one or more lines are too long

View File

@ -7,48 +7,37 @@ from core.model_runtime.model_providers.openai.moderation.moderation import Open
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True)
def test_validate_credentials(setup_openai_mock):
model = OpenAIModerationModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='text-moderation-stable',
credentials={
'openai_api_key': 'invalid_key'
}
)
model.validate_credentials(model="text-moderation-stable", credentials={"openai_api_key": "invalid_key"})
model.validate_credentials(
model='text-moderation-stable',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
}
model="text-moderation-stable", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}
)
@pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["moderation"]], indirect=True)
def test_invoke_model(setup_openai_mock):
model = OpenAIModerationModel()
result = model.invoke(
model='text-moderation-stable',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
},
model="text-moderation-stable",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
text="hello",
user="abc-123"
user="abc-123",
)
assert isinstance(result, bool)
assert result is False
result = model.invoke(
model='text-moderation-stable',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
},
model="text-moderation-stable",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
text="i will kill you",
user="abc-123"
user="abc-123",
)
assert isinstance(result, bool)

View File

@ -7,17 +7,11 @@ from core.model_runtime.model_providers.openai.openai import OpenAIProvider
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_validate_provider_credentials(setup_openai_mock):
provider = OpenAIProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
}
)
provider.validate_provider_credentials(credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")})

View File

@ -7,26 +7,17 @@ from core.model_runtime.model_providers.openai.speech2text.speech2text import Op
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["speech2text"]], indirect=True)
def test_validate_credentials(setup_openai_mock):
model = OpenAISpeech2TextModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='whisper-1',
credentials={
'openai_api_key': 'invalid_key'
}
)
model.validate_credentials(model="whisper-1", credentials={"openai_api_key": "invalid_key"})
model.validate_credentials(
model='whisper-1',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
}
)
model.validate_credentials(model="whisper-1", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")})
@pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["speech2text"]], indirect=True)
def test_invoke_model(setup_openai_mock):
model = OpenAISpeech2TextModel()
@ -34,23 +25,21 @@ def test_invoke_model(setup_openai_mock):
current_dir = os.path.dirname(os.path.abspath(__file__))
# Get assets directory
assets_dir = os.path.join(os.path.dirname(current_dir), 'assets')
assets_dir = os.path.join(os.path.dirname(current_dir), "assets")
# Construct the path to the audio file
audio_file_path = os.path.join(assets_dir, 'audio.mp3')
audio_file_path = os.path.join(assets_dir, "audio.mp3")
# Open the file and get the file object
with open(audio_file_path, 'rb') as audio_file:
with open(audio_file_path, "rb") as audio_file:
file = audio_file
result = model.invoke(
model='whisper-1',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
},
model="whisper-1",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
file=file,
user="abc-123"
user="abc-123",
)
assert isinstance(result, str)
assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"

View File

@ -8,42 +8,27 @@ from core.model_runtime.model_providers.openai.text_embedding.text_embedding imp
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
def test_validate_credentials(setup_openai_mock):
model = OpenAITextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='text-embedding-ada-002',
credentials={
'openai_api_key': 'invalid_key'
}
)
model.validate_credentials(model="text-embedding-ada-002", credentials={"openai_api_key": "invalid_key"})
model.validate_credentials(
model='text-embedding-ada-002',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
}
model="text-embedding-ada-002", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}
)
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
def test_invoke_model(setup_openai_mock):
model = OpenAITextEmbeddingModel()
result = model.invoke(
model='text-embedding-ada-002',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY'),
'openai_api_base': 'https://api.openai.com'
},
texts=[
"hello",
"world",
" ".join(["long_text"] * 100),
" ".join(["another_long_text"] * 100)
],
user="abc-123"
model="text-embedding-ada-002",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"},
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
@ -55,15 +40,9 @@ def test_get_num_tokens():
model = OpenAITextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='text-embedding-ada-002',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY'),
'openai_api_base': 'https://api.openai.com'
},
texts=[
"hello",
"world"
]
model="text-embedding-ada-002",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"},
texts=["hello", "world"],
)
assert num_tokens == 2

View File

@ -23,21 +23,17 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': 'invalid_key',
'endpoint_url': 'https://api.together.xyz/v1/',
'mode': 'chat'
}
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
credentials={"api_key": "invalid_key", "endpoint_url": "https://api.together.xyz/v1/", "mode": "chat"},
)
model.validate_credentials(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'endpoint_url': 'https://api.together.xyz/v1/',
'mode': 'chat'
}
"api_key": os.environ.get("TOGETHER_API_KEY"),
"endpoint_url": "https://api.together.xyz/v1/",
"mode": "chat",
},
)
@ -45,28 +41,26 @@ def test_invoke_model():
model = OAIAPICompatLargeLanguageModel()
response = model.invoke(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'endpoint_url': 'https://api.together.xyz/v1/',
'mode': 'completion'
"api_key": os.environ.get("TOGETHER_API_KEY"),
"endpoint_url": "https://api.together.xyz/v1/",
"mode": "completion",
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Who are you?'
)
UserPromptMessage(content="Who are you?"),
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
@ -77,29 +71,27 @@ def test_invoke_stream_model():
model = OAIAPICompatLargeLanguageModel()
response = model.invoke(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'endpoint_url': 'https://api.together.xyz/v1/',
'mode': 'chat',
'stream_mode_delimiter': '\\n\\n'
"api_key": os.environ.get("TOGETHER_API_KEY"),
"endpoint_url": "https://api.together.xyz/v1/",
"mode": "chat",
"stream_mode_delimiter": "\\n\\n",
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Who are you?'
)
UserPromptMessage(content="Who are you?"),
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -114,28 +106,26 @@ def test_invoke_stream_model_without_delimiter():
model = OAIAPICompatLargeLanguageModel()
response = model.invoke(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'endpoint_url': 'https://api.together.xyz/v1/',
'mode': 'chat'
"api_key": os.environ.get("TOGETHER_API_KEY"),
"endpoint_url": "https://api.together.xyz/v1/",
"mode": "chat",
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Who are you?'
)
UserPromptMessage(content="Who are you?"),
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -151,51 +141,37 @@ def test_invoke_chat_model_with_tools():
model = OAIAPICompatLargeLanguageModel()
result = model.invoke(
model='gpt-3.5-turbo',
model="gpt-3.5-turbo",
credentials={
'api_key': os.environ.get('OPENAI_API_KEY'),
'endpoint_url': 'https://api.openai.com/v1/',
'mode': 'chat'
"api_key": os.environ.get("OPENAI_API_KEY"),
"endpoint_url": "https://api.openai.com/v1/",
"mode": "chat",
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content="what's the weather today in London?",
)
),
],
tools=[
PromptMessageTool(
name='get_weather',
description='Determine weather in my location',
name="get_weather",
description="Determine weather in my location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"celsius",
"fahrenheit"
]
}
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": [
"location"
]
}
"required": ["location"],
},
),
],
model_parameters={
'temperature': 0.0,
'max_tokens': 1024
},
model_parameters={"temperature": 0.0, "max_tokens": 1024},
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(result, LLMResult)
@ -207,19 +183,14 @@ def test_get_num_tokens():
model = OAIAPICompatLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': os.environ.get('OPENAI_API_KEY'),
'endpoint_url': 'https://api.openai.com/v1/'
},
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert isinstance(num_tokens, int)

View File

@ -14,18 +14,12 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="whisper-1",
credentials={
"api_key": "invalid_key",
"endpoint_url": "https://api.openai.com/v1/"
},
credentials={"api_key": "invalid_key", "endpoint_url": "https://api.openai.com/v1/"},
)
model.validate_credentials(
model="whisper-1",
credentials={
"api_key": os.environ.get("OPENAI_API_KEY"),
"endpoint_url": "https://api.openai.com/v1/"
},
credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"},
)
@ -47,13 +41,10 @@ def test_invoke_model():
result = model.invoke(
model="whisper-1",
credentials={
"api_key": os.environ.get("OPENAI_API_KEY"),
"endpoint_url": "https://api.openai.com/v1/"
},
credentials={"api_key": os.environ.get("OPENAI_API_KEY"), "endpoint_url": "https://api.openai.com/v1/"},
file=file,
user="abc-123",
)
assert isinstance(result, str)
assert result == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10'
assert result == "1, 2, 3, 4, 5, 6, 7, 8, 9, 10"

View File

@ -12,27 +12,23 @@ from core.model_runtime.model_providers.openai_api_compatible.text_embedding.tex
Using OpenAI's API as testing endpoint
"""
def test_validate_credentials():
model = OAICompatEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='text-embedding-ada-002',
credentials={
'api_key': 'invalid_key',
'endpoint_url': 'https://api.openai.com/v1/',
'context_size': 8184
}
model="text-embedding-ada-002",
credentials={"api_key": "invalid_key", "endpoint_url": "https://api.openai.com/v1/", "context_size": 8184},
)
model.validate_credentials(
model='text-embedding-ada-002',
model="text-embedding-ada-002",
credentials={
'api_key': os.environ.get('OPENAI_API_KEY'),
'endpoint_url': 'https://api.openai.com/v1/',
'context_size': 8184
}
"api_key": os.environ.get("OPENAI_API_KEY"),
"endpoint_url": "https://api.openai.com/v1/",
"context_size": 8184,
},
)
@ -40,19 +36,14 @@ def test_invoke_model():
model = OAICompatEmbeddingModel()
result = model.invoke(
model='text-embedding-ada-002',
model="text-embedding-ada-002",
credentials={
'api_key': os.environ.get('OPENAI_API_KEY'),
'endpoint_url': 'https://api.openai.com/v1/',
'context_size': 8184
"api_key": os.environ.get("OPENAI_API_KEY"),
"endpoint_url": "https://api.openai.com/v1/",
"context_size": 8184,
},
texts=[
"hello",
"world",
" ".join(["long_text"] * 100),
" ".join(["another_long_text"] * 100)
],
user="abc-123"
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
@ -64,16 +55,13 @@ def test_get_num_tokens():
model = OAICompatEmbeddingModel()
num_tokens = model.get_num_tokens(
model='text-embedding-ada-002',
model="text-embedding-ada-002",
credentials={
'api_key': os.environ.get('OPENAI_API_KEY'),
'endpoint_url': 'https://api.openai.com/v1/embeddings',
'context_size': 8184
"api_key": os.environ.get("OPENAI_API_KEY"),
"endpoint_url": "https://api.openai.com/v1/embeddings",
"context_size": 8184,
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert num_tokens == 2
assert num_tokens == 2

View File

@ -12,17 +12,17 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'server_url': 'ww' + os.environ.get('OPENLLM_SERVER_URL'),
}
"server_url": "ww" + os.environ.get("OPENLLM_SERVER_URL"),
},
)
model.validate_credentials(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
}
"server_url": os.environ.get("OPENLLM_SERVER_URL"),
},
)
@ -30,33 +30,28 @@ def test_invoke_model():
model = OpenLLMTextEmbeddingModel()
result = model.invoke(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
"server_url": os.environ.get("OPENLLM_SERVER_URL"),
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens > 0
def test_get_num_tokens():
model = OpenLLMTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
"server_url": os.environ.get("OPENLLM_SERVER_URL"),
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert num_tokens == 2

View File

@ -14,67 +14,61 @@ def test_validate_credentials_for_chat_model():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'server_url': 'invalid_key',
}
"server_url": "invalid_key",
},
)
model.validate_credentials(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
}
"server_url": os.environ.get("OPENLLM_SERVER_URL"),
},
)
def test_invoke_model():
model = OpenLLMLargeLanguageModel()
response = model.invoke(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
"server_url": os.environ.get("OPENLLM_SERVER_URL"),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_model():
model = OpenLLMLargeLanguageModel()
response = model.invoke(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
"server_url": os.environ.get("OPENLLM_SERVER_URL"),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -84,21 +78,18 @@ def test_invoke_stream_model():
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = OpenLLMLargeLanguageModel()
response = model.get_num_tokens(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'server_url': os.environ.get('OPENLLM_SERVER_URL'),
"server_url": os.environ.get("OPENLLM_SERVER_URL"),
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
tools=[]
prompt_messages=[UserPromptMessage(content="Hello World!")],
tools=[],
)
assert isinstance(response, int)
assert response == 3
assert response == 3

View File

@ -19,19 +19,12 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='mistralai/mixtral-8x7b-instruct',
credentials={
'api_key': 'invalid_key',
'mode': 'chat'
}
model="mistralai/mixtral-8x7b-instruct", credentials={"api_key": "invalid_key", "mode": "chat"}
)
model.validate_credentials(
model='mistralai/mixtral-8x7b-instruct',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'mode': 'chat'
}
model="mistralai/mixtral-8x7b-instruct",
credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"},
)
@ -39,27 +32,22 @@ def test_invoke_model():
model = OpenRouterLargeLanguageModel()
response = model.invoke(
model='mistralai/mixtral-8x7b-instruct',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'mode': 'completion'
},
model="mistralai/mixtral-8x7b-instruct",
credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "completion"},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Who are you?'
)
UserPromptMessage(content="Who are you?"),
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
@ -70,27 +58,22 @@ def test_invoke_stream_model():
model = OpenRouterLargeLanguageModel()
response = model.invoke(
model='mistralai/mixtral-8x7b-instruct',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'mode': 'chat'
},
model="mistralai/mixtral-8x7b-instruct",
credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Who are you?'
)
UserPromptMessage(content="Who are you?"),
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -105,18 +88,16 @@ def test_get_num_tokens():
model = OpenRouterLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='mistralai/mixtral-8x7b-instruct',
model="mistralai/mixtral-8x7b-instruct",
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
"api_key": os.environ.get("TOGETHER_API_KEY"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert isinstance(num_tokens, int)

View File

@ -14,19 +14,19 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='meta/llama-2-13b-chat',
model="meta/llama-2-13b-chat",
credentials={
'replicate_api_token': 'invalid_key',
'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
}
"replicate_api_token": "invalid_key",
"model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d",
},
)
model.validate_credentials(
model='meta/llama-2-13b-chat',
model="meta/llama-2-13b-chat",
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
}
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
"model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d",
},
)
@ -34,27 +34,25 @@ def test_invoke_model():
model = ReplicateLargeLanguageModel()
response = model.invoke(
model='meta/llama-2-13b-chat',
model="meta/llama-2-13b-chat",
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': 'f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d'
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
"model_version": "f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d",
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Who are you?'
)
UserPromptMessage(content="Who are you?"),
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
@ -65,27 +63,25 @@ def test_invoke_stream_model():
model = ReplicateLargeLanguageModel()
response = model.invoke(
model='mistralai/mixtral-8x7b-instruct-v0.1',
model="mistralai/mixtral-8x7b-instruct-v0.1",
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e'
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
"model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e",
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Who are you?'
)
UserPromptMessage(content="Who are you?"),
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -100,19 +96,17 @@ def test_get_num_tokens():
model = ReplicateLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='',
model="",
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': '2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e'
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
"model_version": "2b56576fcfbe32fa0526897d8385dd3fb3d36ba6fd0dbe033c72886b81ade93e",
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 14

View File

@ -12,19 +12,19 @@ def test_validate_credentials_one():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='replicate/all-mpnet-base-v2',
model="replicate/all-mpnet-base-v2",
credentials={
'replicate_api_token': 'invalid_key',
'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
}
"replicate_api_token": "invalid_key",
"model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305",
},
)
model.validate_credentials(
model='replicate/all-mpnet-base-v2',
model="replicate/all-mpnet-base-v2",
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
}
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
"model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305",
},
)
@ -33,19 +33,19 @@ def test_validate_credentials_two():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='nateraw/bge-large-en-v1.5',
model="nateraw/bge-large-en-v1.5",
credentials={
'replicate_api_token': 'invalid_key',
'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
}
"replicate_api_token": "invalid_key",
"model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1",
},
)
model.validate_credentials(
model='nateraw/bge-large-en-v1.5',
model="nateraw/bge-large-en-v1.5",
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
}
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
"model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1",
},
)
@ -53,16 +53,13 @@ def test_invoke_model_one():
model = ReplicateEmbeddingModel()
result = model.invoke(
model='nateraw/bge-large-en-v1.5',
model="nateraw/bge-large-en-v1.5",
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': '9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1'
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
"model_version": "9cf9f015a9cb9c61d1a2610659cdac4a4ca222f2d3707a68517b18c198a9add1",
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
@ -74,16 +71,13 @@ def test_invoke_model_two():
model = ReplicateEmbeddingModel()
result = model.invoke(
model='andreasjansson/clip-features',
model="andreasjansson/clip-features",
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': '75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a'
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
"model_version": "75b33f253f7714a281ad3e9b28f63e3232d583716ef6718f2e46641077ea040a",
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
@ -95,16 +89,13 @@ def test_invoke_model_three():
model = ReplicateEmbeddingModel()
result = model.invoke(
model='replicate/all-mpnet-base-v2',
model="replicate/all-mpnet-base-v2",
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305'
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
"model_version": "b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305",
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
@ -116,16 +107,13 @@ def test_invoke_model_four():
model = ReplicateEmbeddingModel()
result = model.invoke(
model='nateraw/jina-embeddings-v2-base-en',
model="nateraw/jina-embeddings-v2-base-en",
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e'
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
"model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e",
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
@ -137,15 +125,12 @@ def test_get_num_tokens():
model = ReplicateEmbeddingModel()
num_tokens = model.get_num_tokens(
model='nateraw/jina-embeddings-v2-base-en',
model="nateraw/jina-embeddings-v2-base-en",
credentials={
'replicate_api_token': os.environ.get('REPLICATE_API_KEY'),
'model_version': 'f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e'
"replicate_api_token": os.environ.get("REPLICATE_API_KEY"),
"model_version": "f8367a1c072ba2bc28af549d1faeacfe9b88b3f0e475add7a75091dac507f79e",
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert num_tokens == 2

View File

@ -10,10 +10,6 @@ def test_validate_provider_credentials():
provider = SageMakerProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})

View File

@ -12,11 +12,11 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='bge-m3-rerank-v2',
model="bge-m3-rerank-v2",
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
},
query="What is the capital of the United States?",
docs=[
@ -25,7 +25,7 @@ def test_validate_credentials():
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8
score_threshold=0.8,
)
@ -33,11 +33,11 @@ def test_invoke_model():
model = SageMakerRerankModel()
result = model.invoke(
model='bge-m3-rerank-v2',
model="bge-m3-rerank-v2",
credentials={
"aws_region": os.getenv("AWS_REGION"),
"aws_access_key": os.getenv("AWS_ACCESS_KEY"),
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
},
query="What is the capital of the United States?",
docs=[
@ -46,7 +46,7 @@ def test_invoke_model():
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8
score_threshold=0.8,
)
assert isinstance(result, RerankResult)

View File

@ -11,45 +11,23 @@ def test_validate_credentials():
model = SageMakerEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='bge-m3',
credentials={
}
)
model.validate_credentials(model="bge-m3", credentials={})
model.validate_credentials(
model='bge-m3-embedding',
credentials={
}
)
model.validate_credentials(model="bge-m3-embedding", credentials={})
def test_invoke_model():
model = SageMakerEmbeddingModel()
result = model.invoke(
model='bge-m3-embedding',
credentials={
},
texts=[
"hello",
"world"
],
user="abc-123"
)
result = model.invoke(model="bge-m3-embedding", credentials={}, texts=["hello", "world"], user="abc-123")
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
def test_get_num_tokens():
model = SageMakerEmbeddingModel()
num_tokens = model.get_num_tokens(
model='bge-m3-embedding',
credentials={
},
texts=[
]
)
num_tokens = model.get_num_tokens(model="bge-m3-embedding", credentials={}, texts=[])
assert num_tokens == 0

View File

@ -13,41 +13,22 @@ def test_validate_credentials():
model = SiliconflowLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='deepseek-ai/DeepSeek-V2-Chat',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model='deepseek-ai/DeepSeek-V2-Chat',
credentials={
'api_key': os.environ.get('API_KEY')
}
)
model.validate_credentials(model="deepseek-ai/DeepSeek-V2-Chat", credentials={"api_key": os.environ.get("API_KEY")})
def test_invoke_model():
model = SiliconflowLargeLanguageModel()
response = model.invoke(
model='deepseek-ai/DeepSeek-V2-Chat',
credentials={
'api_key': os.environ.get('API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 10
},
stop=['How'],
model="deepseek-ai/DeepSeek-V2-Chat",
credentials={"api_key": os.environ.get("API_KEY")},
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={"temperature": 0.5, "max_tokens": 10},
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
@ -58,22 +39,12 @@ def test_invoke_stream_model():
model = SiliconflowLargeLanguageModel()
response = model.invoke(
model='deepseek-ai/DeepSeek-V2-Chat',
credentials={
'api_key': os.environ.get('API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 100,
'seed': 1234
},
model="deepseek-ai/DeepSeek-V2-Chat",
credentials={"api_key": os.environ.get("API_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -89,18 +60,14 @@ def test_get_num_tokens():
model = SiliconflowLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='deepseek-ai/DeepSeek-V2-Chat',
credentials={
'api_key': os.environ.get('API_KEY')
},
model="deepseek-ai/DeepSeek-V2-Chat",
credentials={"api_key": os.environ.get("API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 12

View File

@ -10,12 +10,6 @@ def test_validate_provider_credentials():
provider = SiliconflowProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
'api_key': os.environ.get('API_KEY')
}
)
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("API_KEY")})

View File

@ -13,9 +13,7 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="BAAI/bge-reranker-v2-m3",
credentials={
"api_key": "invalid_key"
},
credentials={"api_key": "invalid_key"},
)
model.validate_credentials(
@ -30,17 +28,17 @@ def test_invoke_model():
model = SiliconflowRerankModel()
result = model.invoke(
model='BAAI/bge-reranker-v2-m3',
model="BAAI/bge-reranker-v2-m3",
credentials={
"api_key": os.environ.get("API_KEY"),
},
query="Who is Kasumi?",
docs=[
"Kasumi is a girl's name of Japanese origin meaning \"mist\".",
'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
"and she leads a team named PopiParty."
"and she leads a team named PopiParty.",
],
score_threshold=0.8
score_threshold=0.8,
)
assert isinstance(result, RerankResult)

View File

@ -12,16 +12,12 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="iic/SenseVoiceSmall",
credentials={
"api_key": "invalid_key"
},
credentials={"api_key": "invalid_key"},
)
model.validate_credentials(
model="iic/SenseVoiceSmall",
credentials={
"api_key": os.environ.get("API_KEY")
},
credentials={"api_key": os.environ.get("API_KEY")},
)
@ -42,12 +38,8 @@ def test_invoke_model():
file = audio_file
result = model.invoke(
model="iic/SenseVoiceSmall",
credentials={
"api_key": os.environ.get("API_KEY")
},
file=file
model="iic/SenseVoiceSmall", credentials={"api_key": os.environ.get("API_KEY")}, file=file
)
assert isinstance(result, str)
assert result == '1,2,3,4,5,6,7,8,9,10.'
assert result == "1,2,3,4,5,6,7,8,9,10."

View File

@ -15,9 +15,7 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="BAAI/bge-large-zh-v1.5",
credentials={
"api_key": "invalid_key"
},
credentials={"api_key": "invalid_key"},
)
model.validate_credentials(

View File

@ -13,20 +13,15 @@ def test_validate_credentials():
model = SparkLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='spark-1.5',
credentials={
'app_id': 'invalid_key'
}
)
model.validate_credentials(model="spark-1.5", credentials={"app_id": "invalid_key"})
model.validate_credentials(
model='spark-1.5',
model="spark-1.5",
credentials={
'app_id': os.environ.get('SPARK_APP_ID'),
'api_secret': os.environ.get('SPARK_API_SECRET'),
'api_key': os.environ.get('SPARK_API_KEY')
}
"app_id": os.environ.get("SPARK_APP_ID"),
"api_secret": os.environ.get("SPARK_API_SECRET"),
"api_key": os.environ.get("SPARK_API_KEY"),
},
)
@ -34,24 +29,17 @@ def test_invoke_model():
model = SparkLargeLanguageModel()
response = model.invoke(
model='spark-1.5',
model="spark-1.5",
credentials={
'app_id': os.environ.get('SPARK_APP_ID'),
'api_secret': os.environ.get('SPARK_API_SECRET'),
'api_key': os.environ.get('SPARK_API_KEY')
"app_id": os.environ.get("SPARK_APP_ID"),
"api_secret": os.environ.get("SPARK_API_SECRET"),
"api_key": os.environ.get("SPARK_API_KEY"),
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 10
},
stop=['How'],
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={"temperature": 0.5, "max_tokens": 10},
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
@ -62,23 +50,16 @@ def test_invoke_stream_model():
model = SparkLargeLanguageModel()
response = model.invoke(
model='spark-1.5',
model="spark-1.5",
credentials={
'app_id': os.environ.get('SPARK_APP_ID'),
'api_secret': os.environ.get('SPARK_API_SECRET'),
'api_key': os.environ.get('SPARK_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 100
"app_id": os.environ.get("SPARK_APP_ID"),
"api_secret": os.environ.get("SPARK_API_SECRET"),
"api_key": os.environ.get("SPARK_API_KEY"),
},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.5, "max_tokens": 100},
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -94,20 +75,18 @@ def test_get_num_tokens():
model = SparkLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='spark-1.5',
model="spark-1.5",
credentials={
'app_id': os.environ.get('SPARK_APP_ID'),
'api_secret': os.environ.get('SPARK_API_SECRET'),
'api_key': os.environ.get('SPARK_API_KEY')
"app_id": os.environ.get("SPARK_APP_ID"),
"api_secret": os.environ.get("SPARK_API_SECRET"),
"api_key": os.environ.get("SPARK_API_KEY"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 14

View File

@ -10,14 +10,12 @@ def test_validate_provider_credentials():
provider = SparkProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
'app_id': os.environ.get('SPARK_APP_ID'),
'api_secret': os.environ.get('SPARK_API_SECRET'),
'api_key': os.environ.get('SPARK_API_KEY')
"app_id": os.environ.get("SPARK_APP_ID"),
"api_secret": os.environ.get("SPARK_API_SECRET"),
"api_key": os.environ.get("SPARK_API_KEY"),
}
)

View File

@ -21,40 +21,22 @@ def test_validate_credentials():
model = StepfunLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='step-1-8k',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(model="step-1-8k", credentials={"api_key": "invalid_key"})
model.validate_credentials(model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")})
model.validate_credentials(
model='step-1-8k',
credentials={
'api_key': os.environ.get('STEPFUN_API_KEY')
}
)
def test_invoke_model():
model = StepfunLargeLanguageModel()
response = model.invoke(
model='step-1-8k',
credentials={
'api_key': os.environ.get('STEPFUN_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.9,
'top_p': 0.7
},
stop=['Hi'],
model="step-1-8k",
credentials={"api_key": os.environ.get("STEPFUN_API_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.9, "top_p": 0.7},
stop=["Hi"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
@ -65,24 +47,17 @@ def test_invoke_stream_model():
model = StepfunLargeLanguageModel()
response = model.invoke(
model='step-1-8k',
credentials={
'api_key': os.environ.get('STEPFUN_API_KEY')
},
model="step-1-8k",
credentials={"api_key": os.environ.get("STEPFUN_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.9,
'top_p': 0.7
},
model_parameters={"temperature": 0.9, "top_p": 0.7},
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -98,10 +73,7 @@ def test_get_customizable_model_schema():
model = StepfunLargeLanguageModel()
schema = model.get_customizable_model_schema(
model='step-1-8k',
credentials={
'api_key': os.environ.get('STEPFUN_API_KEY')
}
model="step-1-8k", credentials={"api_key": os.environ.get("STEPFUN_API_KEY")}
)
assert isinstance(schema, AIModelEntity)
@ -110,67 +82,44 @@ def test_invoke_chat_model_with_tools():
model = StepfunLargeLanguageModel()
result = model.invoke(
model='step-1-8k',
credentials={
'api_key': os.environ.get('STEPFUN_API_KEY')
},
model="step-1-8k",
credentials={"api_key": os.environ.get("STEPFUN_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content="what's the weather today in Shanghai?",
)
),
],
model_parameters={
'temperature': 0.9,
'max_tokens': 100
},
model_parameters={"temperature": 0.9, "max_tokens": 100},
tools=[
PromptMessageTool(
name='get_weather',
description='Determine weather in my location',
name="get_weather",
description="Determine weather in my location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"c",
"f"
]
}
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": [
"location"
]
}
"required": ["location"],
},
),
PromptMessageTool(
name='get_stock_price',
description='Get the current stock price',
name="get_stock_price",
description="Get the current stock price",
parameters={
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "The stock symbol"
}
},
"required": [
"symbol"
]
}
)
"properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
"required": ["symbol"],
},
),
],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(result, LLMResult)
assert isinstance(result.message, AssistantPromptMessage)
assert len(result.message.tool_calls) > 0
assert len(result.message.tool_calls) > 0

View File

@ -24,13 +24,8 @@ def test_get_models():
providers = factory.get_models(
model_type=ModelType.LLM,
provider_configs=[
ProviderConfig(
provider='openai',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
}
)
]
ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")})
],
)
logger.debug(providers)
@ -44,29 +39,21 @@ def test_get_models():
assert provider_model.model_type == ModelType.LLM
providers = factory.get_models(
provider='openai',
provider="openai",
provider_configs=[
ProviderConfig(
provider='openai',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
}
)
]
ProviderConfig(provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")})
],
)
assert len(providers) == 1
assert isinstance(providers[0], SimpleProviderEntity)
assert providers[0].provider == 'openai'
assert providers[0].provider == "openai"
def test_provider_credentials_validate():
factory = ModelProviderFactory()
factory.provider_credentials_validate(
provider='openai',
credentials={
'openai_api_key': os.environ.get('OPENAI_API_KEY')
}
provider="openai", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}
)
@ -79,4 +66,4 @@ def test__get_model_provider_map():
logger.debug(model_provider.provider_instance)
assert len(model_providers) >= 1
assert isinstance(model_providers['openai'], ModelProviderExtension)
assert isinstance(model_providers["openai"], ModelProviderExtension)

View File

@ -19,76 +19,61 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': 'invalid_key',
'mode': 'chat'
}
model="mistralai/Mixtral-8x7B-Instruct-v0.1", credentials={"api_key": "invalid_key", "mode": "chat"}
)
model.validate_credentials(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'mode': 'chat'
}
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"},
)
def test_invoke_model():
model = TogetherAILargeLanguageModel()
response = model.invoke(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'mode': 'completion'
},
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "completion"},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Who are you?'
)
UserPromptMessage(content="Who are you?"),
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
def test_invoke_stream_model():
model = TogetherAILargeLanguageModel()
response = model.invoke(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
'mode': 'chat'
},
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
credentials={"api_key": os.environ.get("TOGETHER_API_KEY"), "mode": "chat"},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Who are you?'
)
UserPromptMessage(content="Who are you?"),
],
model_parameters={
'temperature': 1.0,
'top_k': 2,
'top_p': 0.5,
"temperature": 1.0,
"top_k": 2,
"top_p": 0.5,
},
stop=['How'],
stop=["How"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -98,22 +83,21 @@ def test_invoke_stream_model():
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
def test_get_num_tokens():
model = TogetherAILargeLanguageModel()
num_tokens = model.get_num_tokens(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
credentials={
'api_key': os.environ.get('TOGETHER_API_KEY'),
"api_key": os.environ.get("TOGETHER_API_KEY"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert isinstance(num_tokens, int)

View File

@ -13,18 +13,10 @@ def test_validate_credentials():
model = TongyiLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='qwen-turbo',
credentials={
'dashscope_api_key': 'invalid_key'
}
)
model.validate_credentials(model="qwen-turbo", credentials={"dashscope_api_key": "invalid_key"})
model.validate_credentials(
model='qwen-turbo',
credentials={
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
}
model="qwen-turbo", credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}
)
@ -32,22 +24,13 @@ def test_invoke_model():
model = TongyiLargeLanguageModel()
response = model.invoke(
model='qwen-turbo',
credentials={
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 10
},
stop=['How'],
model="qwen-turbo",
credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")},
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={"temperature": 0.5, "max_tokens": 10},
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
@ -58,22 +41,12 @@ def test_invoke_stream_model():
model = TongyiLargeLanguageModel()
response = model.invoke(
model='qwen-turbo',
credentials={
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 100,
'seed': 1234
},
model="qwen-turbo",
credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -89,18 +62,14 @@ def test_get_num_tokens():
model = TongyiLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='qwen-turbo',
credentials={
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
},
model="qwen-turbo",
credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 12

View File

@ -10,12 +10,8 @@ def test_validate_provider_credentials():
provider = TongyiProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
}
credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}
)

View File

@ -39,21 +39,17 @@ def invoke_model_with_json_response(model_name="qwen-max-0403"):
response = model.invoke(
model=model_name,
credentials={
'dashscope_api_key': os.environ.get('TONGYI_DASHSCOPE_API_KEY')
},
credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")},
prompt_messages=[
UserPromptMessage(
content='output json data with format `{"data": "test", "code": 200, "msg": "success"}'
)
UserPromptMessage(content='output json data with format `{"data": "test", "code": 200, "msg": "success"}')
],
model_parameters={
'temperature': 0.5,
'max_tokens': 50,
'response_format': 'JSON',
"temperature": 0.5,
"max_tokens": 50,
"response_format": "JSON",
},
stream=True,
user="abc-123"
user="abc-123",
)
print("=====================================")
print(response)
@ -81,4 +77,4 @@ def is_json(s):
json.loads(s)
except ValueError:
return False
return True
return True

View File

@ -26,151 +26,113 @@ def test_predefined_models():
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_validate_credentials_for_chat_model(setup_openai_mock):
model = UpstageLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
# model name to gpt-3.5-turbo because of mocking
model.validate_credentials(
model='gpt-3.5-turbo',
credentials={
'upstage_api_key': 'invalid_key'
}
)
model.validate_credentials(model="gpt-3.5-turbo", credentials={"upstage_api_key": "invalid_key"})
model.validate_credentials(
model='solar-1-mini-chat',
credentials={
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
}
model="solar-1-mini-chat", credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}
)
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_chat_model(setup_openai_mock):
model = UpstageLargeLanguageModel()
result = model.invoke(
model='solar-1-mini-chat',
credentials={
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
},
model="solar-1-mini-chat",
credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'top_p': 1.0,
'presence_penalty': 0.0,
'frequency_penalty': 0.0,
'max_tokens': 10
"temperature": 0.0,
"top_p": 1.0,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"max_tokens": 10,
},
stop=['How'],
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_chat_model_with_tools(setup_openai_mock):
model = UpstageLargeLanguageModel()
result = model.invoke(
model='solar-1-mini-chat',
credentials={
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
},
model="solar-1-mini-chat",
credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content="what's the weather today in London?",
)
),
],
model_parameters={
'temperature': 0.0,
'max_tokens': 100
},
model_parameters={"temperature": 0.0, "max_tokens": 100},
tools=[
PromptMessageTool(
name='get_weather',
description='Determine weather in my location',
name="get_weather",
description="Determine weather in my location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"c",
"f"
]
}
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": [
"location"
]
}
"required": ["location"],
},
),
PromptMessageTool(
name='get_stock_price',
description='Get the current stock price',
name="get_stock_price",
description="Get the current stock price",
parameters={
"type": "object",
"properties": {
"symbol": {
"type": "string",
"description": "The stock symbol"
}
},
"required": [
"symbol"
]
}
)
"properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
"required": ["symbol"],
},
),
],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(result, LLMResult)
assert isinstance(result.message, AssistantPromptMessage)
assert len(result.message.tool_calls) > 0
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_invoke_stream_chat_model(setup_openai_mock):
model = UpstageLargeLanguageModel()
result = model.invoke(
model='solar-1-mini-chat',
credentials={
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
},
model="solar-1-mini-chat",
credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.0,
'max_tokens': 100
},
model_parameters={"temperature": 0.0, "max_tokens": 100},
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(result, Generator)
@ -189,57 +151,36 @@ def test_get_num_tokens():
model = UpstageLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='solar-1-mini-chat',
credentials={
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
]
model="solar-1-mini-chat",
credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
)
assert num_tokens == 13
num_tokens = model.get_num_tokens(
model='solar-1-mini-chat',
credentials={
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
},
model="solar-1-mini-chat",
credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
tools=[
PromptMessageTool(
name='get_weather',
description='Determine weather in my location',
name="get_weather",
description="Determine weather in my location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"c",
"f"
]
}
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": [
"location"
]
}
"required": ["location"],
},
),
]
],
)
assert num_tokens == 106

View File

@ -7,17 +7,11 @@ from core.model_runtime.model_providers.upstage.upstage import UpstageProvider
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
def test_validate_provider_credentials(setup_openai_mock):
provider = UpstageProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
}
)
provider.validate_provider_credentials(credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")})

View File

@ -8,41 +8,31 @@ from core.model_runtime.model_providers.upstage.text_embedding.text_embedding im
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
def test_validate_credentials(setup_openai_mock):
model = UpstageTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='solar-embedding-1-large-passage',
credentials={
'upstage_api_key': 'invalid_key'
}
model="solar-embedding-1-large-passage", credentials={"upstage_api_key": "invalid_key"}
)
model.validate_credentials(
model='solar-embedding-1-large-passage',
credentials={
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY')
}
model="solar-embedding-1-large-passage", credentials={"upstage_api_key": os.environ.get("UPSTAGE_API_KEY")}
)
@pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
def test_invoke_model(setup_openai_mock):
model = UpstageTextEmbeddingModel()
result = model.invoke(
model='solar-embedding-1-large-passage',
model="solar-embedding-1-large-passage",
credentials={
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'),
"upstage_api_key": os.environ.get("UPSTAGE_API_KEY"),
},
texts=[
"hello",
"world",
" ".join(["long_text"] * 100),
" ".join(["another_long_text"] * 100)
],
user="abc-123"
texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
@ -54,14 +44,11 @@ def test_get_num_tokens():
model = UpstageTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='solar-embedding-1-large-passage',
model="solar-embedding-1-large-passage",
credentials={
'upstage_api_key': os.environ.get('UPSTAGE_API_KEY'),
"upstage_api_key": os.environ.get("UPSTAGE_API_KEY"),
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert num_tokens == 5

View File

@ -14,26 +14,26 @@ def test_validate_credentials():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
'volc_region': 'cn-beijing',
'volc_access_key_id': 'INVALID',
'volc_secret_access_key': 'INVALID',
'endpoint_id': 'INVALID',
'base_model_name': 'Doubao-embedding',
}
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
"volc_region": "cn-beijing",
"volc_access_key_id": "INVALID",
"volc_secret_access_key": "INVALID",
"endpoint_id": "INVALID",
"base_model_name": "Doubao-embedding",
},
)
model.validate_credentials(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
'volc_region': 'cn-beijing',
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
'base_model_name': 'Doubao-embedding',
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
"volc_region": "cn-beijing",
"volc_access_key_id": os.environ.get("VOLC_API_KEY"),
"volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
"endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"),
"base_model_name": "Doubao-embedding",
},
)
@ -42,20 +42,17 @@ def test_invoke_model():
model = VolcengineMaaSTextEmbeddingModel()
result = model.invoke(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
'volc_region': 'cn-beijing',
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
'base_model_name': 'Doubao-embedding',
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
"volc_region": "cn-beijing",
"volc_access_key_id": os.environ.get("VOLC_API_KEY"),
"volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
"endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"),
"base_model_name": "Doubao-embedding",
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
@ -67,19 +64,16 @@ def test_get_num_tokens():
model = VolcengineMaaSTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
'volc_region': 'cn-beijing',
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
'endpoint_id': os.environ.get('VOLC_EMBEDDING_ENDPOINT_ID'),
'base_model_name': 'Doubao-embedding',
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
"volc_region": "cn-beijing",
"volc_access_key_id": os.environ.get("VOLC_API_KEY"),
"volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
"endpoint_id": os.environ.get("VOLC_EMBEDDING_ENDPOINT_ID"),
"base_model_name": "Doubao-embedding",
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert num_tokens == 2

View File

@ -14,25 +14,25 @@ def test_validate_credentials_for_chat_model():
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
'volc_region': 'cn-beijing',
'volc_access_key_id': 'INVALID',
'volc_secret_access_key': 'INVALID',
'endpoint_id': 'INVALID',
}
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
"volc_region": "cn-beijing",
"volc_access_key_id": "INVALID",
"volc_secret_access_key": "INVALID",
"endpoint_id": "INVALID",
},
)
model.validate_credentials(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
'volc_region': 'cn-beijing',
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
}
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
"volc_region": "cn-beijing",
"volc_access_key_id": os.environ.get("VOLC_API_KEY"),
"volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
"endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"),
},
)
@ -40,28 +40,24 @@ def test_invoke_model():
model = VolcengineMaaSLargeLanguageModel()
response = model.invoke(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
'volc_region': 'cn-beijing',
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
'base_model_name': 'Skylark2-pro-4k',
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
"volc_region": "cn-beijing",
"volc_access_key_id": os.environ.get("VOLC_API_KEY"),
"volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
"endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"),
"base_model_name": "Skylark2-pro-4k",
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
@ -73,28 +69,24 @@ def test_invoke_stream_model():
model = VolcengineMaaSLargeLanguageModel()
response = model.invoke(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
'volc_region': 'cn-beijing',
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
'base_model_name': 'Skylark2-pro-4k',
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
"volc_region": "cn-beijing",
"volc_access_key_id": os.environ.get("VOLC_API_KEY"),
"volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
"endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"),
"base_model_name": "Skylark2-pro-4k",
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'top_k': 1,
"temperature": 0.7,
"top_p": 1.0,
"top_k": 1,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -102,29 +94,24 @@ def test_invoke_stream_model():
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(
chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = VolcengineMaaSLargeLanguageModel()
response = model.get_num_tokens(
model='NOT IMPORTANT',
model="NOT IMPORTANT",
credentials={
'api_endpoint_host': 'maas-api.ml-platform-cn-beijing.volces.com',
'volc_region': 'cn-beijing',
'volc_access_key_id': os.environ.get('VOLC_API_KEY'),
'volc_secret_access_key': os.environ.get('VOLC_SECRET_KEY'),
'endpoint_id': os.environ.get('VOLC_MODEL_ENDPOINT_ID'),
'base_model_name': 'Skylark2-pro-4k',
"api_endpoint_host": "maas-api.ml-platform-cn-beijing.volces.com",
"volc_region": "cn-beijing",
"volc_access_key_id": os.environ.get("VOLC_API_KEY"),
"volc_secret_access_key": os.environ.get("VOLC_SECRET_KEY"),
"endpoint_id": os.environ.get("VOLC_MODEL_ENDPOINT_ID"),
"base_model_name": "Skylark2-pro-4k",
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
tools=[]
prompt_messages=[UserPromptMessage(content="Hello World!")],
tools=[],
)
assert isinstance(response, int)

View File

@ -10,13 +10,10 @@ def test_invoke_embedding_v1():
model = WenxinTextEmbeddingModel()
response = model.invoke(
model='embedding-v1',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
texts=['hello', '你好', 'xxxxx'],
user="abc-123"
model="embedding-v1",
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
texts=["hello", "你好", "xxxxx"],
user="abc-123",
)
assert isinstance(response, TextEmbeddingResult)
@ -29,13 +26,10 @@ def test_invoke_embedding_bge_large_en():
model = WenxinTextEmbeddingModel()
response = model.invoke(
model='bge-large-en',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
texts=['hello', '你好', 'xxxxx'],
user="abc-123"
model="bge-large-en",
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
texts=["hello", "你好", "xxxxx"],
user="abc-123",
)
assert isinstance(response, TextEmbeddingResult)
@ -48,13 +42,10 @@ def test_invoke_embedding_bge_large_zh():
model = WenxinTextEmbeddingModel()
response = model.invoke(
model='bge-large-zh',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
texts=['hello', '你好', 'xxxxx'],
user="abc-123"
model="bge-large-zh",
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
texts=["hello", "你好", "xxxxx"],
user="abc-123",
)
assert isinstance(response, TextEmbeddingResult)
@ -67,13 +58,10 @@ def test_invoke_embedding_tao_8k():
model = WenxinTextEmbeddingModel()
response = model.invoke(
model='tao-8k',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
texts=['hello', '你好', 'xxxxx'],
user="abc-123"
model="tao-8k",
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
texts=["hello", "你好", "xxxxx"],
user="abc-123",
)
assert isinstance(response, TextEmbeddingResult)

View File

@ -17,161 +17,125 @@ def test_predefined_models():
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
def test_validate_credentials_for_chat_model():
sleep(3)
model = ErnieBotLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='ernie-bot',
credentials={
'api_key': 'invalid_key',
'secret_key': 'invalid_key'
}
model="ernie-bot", credentials={"api_key": "invalid_key", "secret_key": "invalid_key"}
)
model.validate_credentials(
model='ernie-bot',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
}
model="ernie-bot",
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
)
def test_invoke_model_ernie_bot():
sleep(3)
model = ErnieBotLargeLanguageModel()
response = model.invoke(
model='ernie-bot',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model="ernie-bot",
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_model_ernie_bot_turbo():
sleep(3)
model = ErnieBotLargeLanguageModel()
response = model.invoke(
model='ernie-bot-turbo',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model="ernie-bot-turbo",
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_model_ernie_8k():
sleep(3)
model = ErnieBotLargeLanguageModel()
response = model.invoke(
model='ernie-bot-8k',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model="ernie-bot-8k",
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_model_ernie_bot_4():
sleep(3)
model = ErnieBotLargeLanguageModel()
response = model.invoke(
model='ernie-bot-4',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model="ernie-bot-4",
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
def test_invoke_stream_model():
sleep(3)
model = ErnieBotLargeLanguageModel()
response = model.invoke(
model='ernie-3.5-8k',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model="ernie-3.5-8k",
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -181,63 +145,48 @@ def test_invoke_stream_model():
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_invoke_model_with_system():
sleep(3)
model = ErnieBotLargeLanguageModel()
response = model.invoke(
model='ernie-bot',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='你是Kasumi'
),
UserPromptMessage(
content='你是谁?'
)
],
model="ernie-bot",
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
prompt_messages=[SystemPromptMessage(content="你是Kasumi"), UserPromptMessage(content="你是谁?")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
assert 'kasumi' in response.message.content.lower()
assert "kasumi" in response.message.content.lower()
def test_invoke_with_search():
sleep(3)
model = ErnieBotLargeLanguageModel()
response = model.invoke(
model='ernie-bot',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='北京今天的天气怎么样'
)
],
model="ernie-bot",
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
prompt_messages=[UserPromptMessage(content="北京今天的天气怎么样")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
'disable_search': True,
"temperature": 0.7,
"top_p": 1.0,
"disable_search": True,
},
stop=[],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
total_message = ''
total_message = ""
for chunk in response:
assert isinstance(chunk, LLMResultChunk)
assert isinstance(chunk.delta, LLMResultChunkDelta)
@ -247,25 +196,19 @@ def test_invoke_with_search():
assert len(chunk.delta.message.content) > 0 if not chunk.delta.finish_reason else True
# there should be 对不起、我不能、不支持……
assert ('' in total_message or '抱歉' in total_message or '无法' in total_message)
assert "" in total_message or "抱歉" in total_message or "无法" in total_message
def test_get_num_tokens():
sleep(3)
model = ErnieBotLargeLanguageModel()
response = model.get_num_tokens(
model='ernie-bot',
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
tools=[]
model="ernie-bot",
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
tools=[],
)
assert isinstance(response, int)
assert response == 10
assert response == 10

View File

@ -10,16 +10,8 @@ def test_validate_provider_credentials():
provider = WenxinProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={
'api_key': 'hahahaha',
'secret_key': 'hahahaha'
}
)
provider.validate_provider_credentials(credentials={"api_key": "hahahaha", "secret_key": "hahahaha"})
provider.validate_provider_credentials(
credentials={
'api_key': os.environ.get('WENXIN_API_KEY'),
'secret_key': os.environ.get('WENXIN_SECRET_KEY')
}
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}
)

View File

@ -8,61 +8,57 @@ from core.model_runtime.model_providers.xinference.text_embedding.text_embedding
from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock
@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True)
def test_validate_credentials(setup_xinference_mock):
model = XinferenceTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='bge-base-en',
model="bge-base-en",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': 'www ' + os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
}
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": "www " + os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"),
},
)
model.validate_credentials(
model='bge-base-en',
model="bge-base-en",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
}
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"),
},
)
@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True)
def test_invoke_model(setup_xinference_mock):
model = XinferenceTextEmbeddingModel()
result = model.invoke(
model='bge-base-en',
model="bge-base-en",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"),
},
texts=[
"hello",
"world"
],
user="abc-123"
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
assert len(result.embeddings) == 2
assert result.usage.total_tokens > 0
def test_get_num_tokens():
model = XinferenceTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='bge-base-en',
model="bge-base-en",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_EMBEDDINGS_MODEL_UID"),
},
texts=[
"hello",
"world"
]
texts=["hello", "world"],
)
assert num_tokens == 2

View File

@ -20,92 +20,84 @@ from tests.integration_tests.model_runtime.__mock.openai import setup_openai_moc
from tests.integration_tests.model_runtime.__mock.xinference import setup_xinference_mock
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True)
def test_validate_credentials_for_chat_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='ChatGLM3',
model="ChatGLM3",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': 'www ' + os.environ.get('XINFERENCE_CHAT_MODEL_UID')
}
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": "www " + os.environ.get("XINFERENCE_CHAT_MODEL_UID"),
},
)
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='aaaaa',
credentials={
'server_url': '',
'model_uid': ''
}
)
model.validate_credentials(model="aaaaa", credentials={"server_url": "", "model_uid": ""})
model.validate_credentials(
model='ChatGLM3',
model="ChatGLM3",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
}
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"),
},
)
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True)
def test_invoke_chat_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
response = model.invoke(
model='ChatGLM3',
model="ChatGLM3",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["chat", "none"]], indirect=True)
def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
response = model.invoke(
model='ChatGLM3',
model="ChatGLM3",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_CHAT_MODEL_UID')
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_CHAT_MODEL_UID"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -114,6 +106,8 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
assert isinstance(chunk.delta, LLMResultChunkDelta)
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
"""
Funtion calling of xinference does not support stream mode currently
"""
@ -168,7 +162,7 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
# )
# assert isinstance(response, Generator)
# call: LLMResultChunk = None
# chunks = []
@ -241,86 +235,75 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock):
# assert response.usage.total_tokens > 0
# assert response.message.tool_calls[0].function.name == 'get_current_weather'
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True)
def test_validate_credentials_for_generation_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='alapaca',
model="alapaca",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': 'www ' + os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
}
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": "www " + os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
},
)
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='alapaca',
credentials={
'server_url': '',
'model_uid': ''
}
)
model.validate_credentials(model="alapaca", credentials={"server_url": "", "model_uid": ""})
model.validate_credentials(
model='alapaca',
model="alapaca",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
}
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
},
)
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True)
def test_invoke_generation_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
response = model.invoke(
model='alapaca',
model="alapaca",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
},
prompt_messages=[
UserPromptMessage(
content='the United States is'
)
],
prompt_messages=[UserPromptMessage(content="the United States is")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
user="abc-123",
stream=False
stream=False,
)
assert isinstance(response, LLMResult)
assert len(response.message.content) > 0
assert response.usage.total_tokens > 0
@pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True)
@pytest.mark.parametrize("setup_openai_mock, setup_xinference_mock", [["completion", "none"]], indirect=True)
def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock):
model = XinferenceAILargeLanguageModel()
response = model.invoke(
model='alapaca',
model="alapaca",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
},
prompt_messages=[
UserPromptMessage(
content='the United States is'
)
],
prompt_messages=[UserPromptMessage(content="the United States is")],
model_parameters={
'temperature': 0.7,
'top_p': 1.0,
"temperature": 0.7,
"top_p": 1.0,
},
stop=['you'],
stop=["you"],
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -330,68 +313,54 @@ def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock
assert isinstance(chunk.delta.message, AssistantPromptMessage)
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
def test_get_num_tokens():
model = XinferenceAILargeLanguageModel()
num_tokens = model.get_num_tokens(
model='ChatGLM3',
model="ChatGLM3",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"c",
"f"
]
}
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": [
"location"
]
}
"required": ["location"],
},
)
]
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 77
num_tokens = model.get_num_tokens(
model='ChatGLM3',
model="ChatGLM3",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_GENERATION_MODEL_UID')
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_GENERATION_MODEL_UID"),
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
UserPromptMessage(content="Hello World!"),
],
)
assert isinstance(num_tokens, int)
assert num_tokens == 21
assert num_tokens == 21

View File

@ -8,44 +8,42 @@ from core.model_runtime.model_providers.xinference.rerank.rerank import Xinferen
from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock
@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True)
def test_validate_credentials(setup_xinference_mock):
model = XinferenceRerankModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='bge-reranker-base',
credentials={
'server_url': 'awdawdaw',
'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
}
model="bge-reranker-base",
credentials={"server_url": "awdawdaw", "model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID")},
)
model.validate_credentials(
model='bge-reranker-base',
model="bge-reranker-base",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
}
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"),
},
)
@pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
@pytest.mark.parametrize("setup_xinference_mock", [["none"]], indirect=True)
def test_invoke_model(setup_xinference_mock):
model = XinferenceRerankModel()
result = model.invoke(
model='bge-reranker-base',
model="bge-reranker-base",
credentials={
'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
"server_url": os.environ.get("XINFERENCE_SERVER_URL"),
"model_uid": os.environ.get("XINFERENCE_RERANK_MODEL_UID"),
},
query="Who is Kasumi?",
docs=[
"Kasumi is a girl's name of Japanese origin meaning \"mist\".",
'Kasumi is a girl\'s name of Japanese origin meaning "mist".',
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
"and she leads a team named PopiParty."
"and she leads a team named PopiParty.",
],
score_threshold=0.8
score_threshold=0.8,
)
assert isinstance(result, RerankResult)

View File

@ -13,41 +13,22 @@ def test_validate_credentials():
model = ZhinaoLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='360gpt2-pro',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(model="360gpt2-pro", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model='360gpt2-pro',
credentials={
'api_key': os.environ.get('ZHINAO_API_KEY')
}
)
model.validate_credentials(model="360gpt2-pro", credentials={"api_key": os.environ.get("ZHINAO_API_KEY")})
def test_invoke_model():
model = ZhinaoLargeLanguageModel()
response = model.invoke(
model='360gpt2-pro',
credentials={
'api_key': os.environ.get('ZHINAO_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 10
},
stop=['How'],
model="360gpt2-pro",
credentials={"api_key": os.environ.get("ZHINAO_API_KEY")},
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={"temperature": 0.5, "max_tokens": 10},
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
@ -58,22 +39,12 @@ def test_invoke_stream_model():
model = ZhinaoLargeLanguageModel()
response = model.invoke(
model='360gpt2-pro',
credentials={
'api_key': os.environ.get('ZHINAO_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.5,
'max_tokens': 100,
'seed': 1234
},
model="360gpt2-pro",
credentials={"api_key": os.environ.get("ZHINAO_API_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.5, "max_tokens": 100, "seed": 1234},
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -89,18 +60,14 @@ def test_get_num_tokens():
model = ZhinaoLargeLanguageModel()
num_tokens = model.get_num_tokens(
model='360gpt2-pro',
credentials={
'api_key': os.environ.get('ZHINAO_API_KEY')
},
model="360gpt2-pro",
credentials={"api_key": os.environ.get("ZHINAO_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 21

View File

@ -10,12 +10,6 @@ def test_validate_provider_credentials():
provider = ZhinaoProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
'api_key': os.environ.get('ZHINAO_API_KEY')
}
)
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("ZHINAO_API_KEY")})

View File

@ -18,41 +18,22 @@ def test_validate_credentials():
model = ZhipuAILargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='chatglm_turbo',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(model="chatglm_turbo", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model='chatglm_turbo',
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
}
)
model.validate_credentials(model="chatglm_turbo", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")})
def test_invoke_model():
model = ZhipuAILargeLanguageModel()
response = model.invoke(
model='chatglm_turbo',
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Who are you?'
)
],
model_parameters={
'temperature': 0.9,
'top_p': 0.7
},
stop=['How'],
model="chatglm_turbo",
credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")},
prompt_messages=[UserPromptMessage(content="Who are you?")],
model_parameters={"temperature": 0.9, "top_p": 0.7},
stop=["How"],
stream=False,
user="abc-123"
user="abc-123",
)
assert isinstance(response, LLMResult)
@ -63,21 +44,12 @@ def test_invoke_stream_model():
model = ZhipuAILargeLanguageModel()
response = model.invoke(
model='chatglm_turbo',
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
},
prompt_messages=[
UserPromptMessage(
content='Hello World!'
)
],
model_parameters={
'temperature': 0.9,
'top_p': 0.7
},
model="chatglm_turbo",
credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")},
prompt_messages=[UserPromptMessage(content="Hello World!")],
model_parameters={"temperature": 0.9, "top_p": 0.7},
stream=True,
user="abc-123"
user="abc-123",
)
assert isinstance(response, Generator)
@ -93,63 +65,45 @@ def test_get_num_tokens():
model = ZhipuAILargeLanguageModel()
num_tokens = model.get_num_tokens(
model='chatglm_turbo',
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
},
model="chatglm_turbo",
credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 14
def test_get_tools_num_tokens():
model = ZhipuAILargeLanguageModel()
num_tokens = model.get_num_tokens(
model='tools',
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
},
model="tools",
credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")},
tools=[
PromptMessageTool(
name='get_current_weather',
description='Get the current weather in a given location',
name="get_current_weather",
description="Get the current weather in a given location",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": [
"c",
"f"
]
}
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": [
"location"
]
}
"required": ["location"],
},
)
],
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.',
content="You are a helpful AI assistant.",
),
UserPromptMessage(
content='Hello World!'
)
]
UserPromptMessage(content="Hello World!"),
],
)
assert num_tokens == 88

View File

@ -10,12 +10,6 @@ def test_validate_provider_credentials():
provider = ZhipuaiProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(
credentials={}
)
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
}
)
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")})

View File

@ -11,34 +11,19 @@ def test_validate_credentials():
model = ZhipuAITextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model='text_embedding',
credentials={
'api_key': 'invalid_key'
}
)
model.validate_credentials(model="text_embedding", credentials={"api_key": "invalid_key"})
model.validate_credentials(
model='text_embedding',
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
}
)
model.validate_credentials(model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")})
def test_invoke_model():
model = ZhipuAITextEmbeddingModel()
result = model.invoke(
model='text_embedding',
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
},
texts=[
"hello",
"world"
],
user="abc-123"
model="text_embedding",
credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")},
texts=["hello", "world"],
user="abc-123",
)
assert isinstance(result, TextEmbeddingResult)
@ -50,14 +35,7 @@ def test_get_num_tokens():
model = ZhipuAITextEmbeddingModel()
num_tokens = model.get_num_tokens(
model='text_embedding',
credentials={
'api_key': os.environ.get('ZHIPUAI_API_KEY')
},
texts=[
"hello",
"world"
]
model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, texts=["hello", "world"]
)
assert num_tokens == 2

View File

@ -7,20 +7,17 @@ from _pytest.monkeypatch import MonkeyPatch
class MockedHttp:
def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'],
url: str, **kwargs) -> httpx.Response:
def httpx_request(
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
) -> httpx.Response:
"""
Mocked httpx.request
"""
request = httpx.Request(
method,
url,
params=kwargs.get('params'),
headers=kwargs.get('headers'),
cookies=kwargs.get('cookies')
method, url, params=kwargs.get("params"), headers=kwargs.get("headers"), cookies=kwargs.get("cookies")
)
data = kwargs.get('data', None)
resp = json.dumps(data).encode('utf-8') if data else b'OK'
data = kwargs.get("data", None)
resp = json.dumps(data).encode("utf-8") if data else b"OK"
response = httpx.Response(
status_code=200,
request=request,

View File

@ -10,6 +10,7 @@ todos_data = {
"user1": ["Go for a run", "Read a book"],
}
class TodosResource(Resource):
def get(self, username):
todos = todos_data.get(username, [])
@ -32,7 +33,8 @@ class TodosResource(Resource):
return {"error": "Invalid todo index"}, 400
api.add_resource(TodosResource, '/todos/<string:username>')
if __name__ == '__main__':
api.add_resource(TodosResource, "/todos/<string:username>")
if __name__ == "__main__":
app.run(port=5003, debug=True)

View File

@ -3,37 +3,40 @@ from core.tools.tool.tool import Tool
from tests.integration_tests.tools.__mock.http import setup_http_mock
tool_bundle = {
'server_url': 'http://www.example.com/{path_param}',
'method': 'post',
'author': '',
'openapi': {'parameters': [{'in': 'path', 'name': 'path_param'},
{'in': 'query', 'name': 'query_param'},
{'in': 'cookie', 'name': 'cookie_param'},
{'in': 'header', 'name': 'header_param'},
],
'requestBody': {
'content': {'application/json': {'schema': {'properties': {'body_param': {'type': 'string'}}}}}}
},
'parameters': []
"server_url": "http://www.example.com/{path_param}",
"method": "post",
"author": "",
"openapi": {
"parameters": [
{"in": "path", "name": "path_param"},
{"in": "query", "name": "query_param"},
{"in": "cookie", "name": "cookie_param"},
{"in": "header", "name": "header_param"},
],
"requestBody": {
"content": {"application/json": {"schema": {"properties": {"body_param": {"type": "string"}}}}}
},
},
"parameters": [],
}
parameters = {
'path_param': 'p_param',
'query_param': 'q_param',
'cookie_param': 'c_param',
'header_param': 'h_param',
'body_param': 'b_param',
"path_param": "p_param",
"query_param": "q_param",
"cookie_param": "c_param",
"header_param": "h_param",
"body_param": "b_param",
}
def test_api_tool(setup_http_mock):
tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={'auth_type': 'none'}))
tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={"auth_type": "none"}))
headers = tool.assembling_request(parameters)
response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters)
assert response.status_code == 200
assert '/p_param' == response.request.url.path
assert b'query_param=q_param' == response.request.url.query
assert 'h_param' == response.request.headers.get('header_param')
assert 'application/json' == response.request.headers.get('content-type')
assert 'cookie_param=c_param' == response.request.headers.get('cookie')
assert 'b_param' in response.content.decode()
assert "/p_param" == response.request.url.path
assert b"query_param=q_param" == response.request.url.query
assert "h_param" == response.request.headers.get("header_param")
assert "application/json" == response.request.headers.get("content-type")
assert "cookie_param=c_param" == response.request.headers.get("cookie")
assert "b_param" in response.content.decode()

View File

@ -7,16 +7,17 @@ provider_names = [provider.identity.name for provider in provider_generator]
ToolManager.clear_builtin_providers_cache()
provider_generator = ToolManager.list_builtin_providers()
@pytest.mark.parametrize('name', provider_names)
@pytest.mark.parametrize("name", provider_names)
def test_tool_providers(benchmark, name):
"""
Test that all tool providers can be loaded
"""
def test(generator):
try:
return next(generator)
except StopIteration:
return None
benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1)
benchmark.pedantic(test, args=(provider_generator,), iterations=1, rounds=1)

Some files were not shown because too many files have changed in this diff Show More