mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-11-30 18:27:53 +08:00
parent
0b4902bdc2
commit
e0da0744b5
@ -8,7 +8,12 @@ from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMMode,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
@ -40,7 +45,9 @@ from core.model_runtime.errors.invoke import (
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.large_language_model import (
|
||||
LargeLanguageModel,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -50,11 +57,17 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
Model class for Ollama large language model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
@ -75,11 +88,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
user=user,
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
def get_num_tokens(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
@ -100,10 +118,12 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
if isinstance(first_prompt_message.content, str):
|
||||
text = first_prompt_message.content
|
||||
else:
|
||||
text = ''
|
||||
text = ""
|
||||
for message_content in first_prompt_message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
message_content = cast(
|
||||
TextPromptMessageContent, message_content
|
||||
)
|
||||
text = message_content.data
|
||||
break
|
||||
return self._get_num_tokens_by_gpt2(text)
|
||||
@ -121,19 +141,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[UserPromptMessage(content="ping")],
|
||||
model_parameters={
|
||||
'num_predict': 5
|
||||
},
|
||||
stream=False
|
||||
model_parameters={"num_predict": 5},
|
||||
stream=False,
|
||||
)
|
||||
except InvokeError as ex:
|
||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {ex.description}')
|
||||
raise CredentialsValidateFailedError(
|
||||
f"An error occurred during credentials validation: {ex.description}"
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
|
||||
raise CredentialsValidateFailedError(
|
||||
f"An error occurred during credentials validation: {str(ex)}"
|
||||
)
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
def _generate(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke llm completion model
|
||||
|
||||
@ -146,76 +175,93 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
endpoint_url = credentials['base_url']
|
||||
if not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
endpoint_url = credentials["base_url"]
|
||||
if not endpoint_url.endswith("/"):
|
||||
endpoint_url += "/"
|
||||
|
||||
# prepare the payload for a simple ping to the model
|
||||
data = {
|
||||
'model': model,
|
||||
'stream': stream
|
||||
}
|
||||
data = {"model": model, "stream": stream}
|
||||
|
||||
if 'format' in model_parameters:
|
||||
data['format'] = model_parameters['format']
|
||||
del model_parameters['format']
|
||||
if "format" in model_parameters:
|
||||
data["format"] = model_parameters["format"]
|
||||
del model_parameters["format"]
|
||||
|
||||
data['options'] = model_parameters or {}
|
||||
if "keep_alive" in model_parameters:
|
||||
data["keep_alive"] = model_parameters["keep_alive"]
|
||||
del model_parameters["keep_alive"]
|
||||
|
||||
data["options"] = model_parameters or {}
|
||||
|
||||
if stop:
|
||||
data['stop'] = "\n".join(stop)
|
||||
data["stop"] = "\n".join(stop)
|
||||
|
||||
completion_type = LLMMode.value_of(credentials['mode'])
|
||||
completion_type = LLMMode.value_of(credentials["mode"])
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
endpoint_url = urljoin(endpoint_url, 'api/chat')
|
||||
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
endpoint_url = urljoin(endpoint_url, "api/chat")
|
||||
data["messages"] = [
|
||||
self._convert_prompt_message_to_dict(m) for m in prompt_messages
|
||||
]
|
||||
else:
|
||||
endpoint_url = urljoin(endpoint_url, 'api/generate')
|
||||
endpoint_url = urljoin(endpoint_url, "api/generate")
|
||||
first_prompt_message = prompt_messages[0]
|
||||
if isinstance(first_prompt_message, UserPromptMessage):
|
||||
first_prompt_message = cast(UserPromptMessage, first_prompt_message)
|
||||
if isinstance(first_prompt_message.content, str):
|
||||
data['prompt'] = first_prompt_message.content
|
||||
data["prompt"] = first_prompt_message.content
|
||||
else:
|
||||
text = ''
|
||||
text = ""
|
||||
images = []
|
||||
for message_content in first_prompt_message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
message_content = cast(
|
||||
TextPromptMessageContent, message_content
|
||||
)
|
||||
text = message_content.data
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
|
||||
message_content = cast(
|
||||
ImagePromptMessageContent, message_content
|
||||
)
|
||||
image_data = re.sub(
|
||||
r"^data:image\/[a-zA-Z]+;base64,",
|
||||
"",
|
||||
message_content.data,
|
||||
)
|
||||
images.append(image_data)
|
||||
|
||||
data['prompt'] = text
|
||||
data['images'] = images
|
||||
data["prompt"] = text
|
||||
data["images"] = images
|
||||
|
||||
# send a post request to validate the credentials
|
||||
response = requests.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=(10, 300),
|
||||
stream=stream
|
||||
endpoint_url, headers=headers, json=data, timeout=(10, 300), stream=stream
|
||||
)
|
||||
|
||||
response.encoding = "utf-8"
|
||||
if response.status_code != 200:
|
||||
raise InvokeError(f"API request failed with status code {response.status_code}: {response.text}")
|
||||
raise InvokeError(
|
||||
f"API request failed with status code {response.status_code}: {response.text}"
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages)
|
||||
return self._handle_generate_stream_response(
|
||||
model, credentials, completion_type, response, prompt_messages
|
||||
)
|
||||
|
||||
return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages)
|
||||
return self._handle_generate_response(
|
||||
model, credentials, completion_type, response, prompt_messages
|
||||
)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, completion_type: LLMMode,
|
||||
response: requests.Response, prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
def _handle_generate_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
completion_type: LLMMode,
|
||||
response: requests.Response,
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> LLMResult:
|
||||
"""
|
||||
Handle llm completion response
|
||||
|
||||
@ -229,14 +275,14 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
response_json = response.json()
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
message = response_json.get('message', {})
|
||||
response_content = message.get('content', '')
|
||||
message = response_json.get("message", {})
|
||||
response_content = message.get("content", "")
|
||||
else:
|
||||
response_content = response_json['response']
|
||||
response_content = response_json["response"]
|
||||
|
||||
assistant_message = AssistantPromptMessage(content=response_content)
|
||||
|
||||
if 'prompt_eval_count' in response_json and 'eval_count' in response_json:
|
||||
if "prompt_eval_count" in response_json and "eval_count" in response_json:
|
||||
# transform usage
|
||||
prompt_tokens = response_json["prompt_eval_count"]
|
||||
completion_tokens = response_json["eval_count"]
|
||||
@ -246,7 +292,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
completion_tokens = self._get_num_tokens_by_gpt2(assistant_message.content)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
usage = self._calc_response_usage(
|
||||
model, credentials, prompt_tokens, completion_tokens
|
||||
)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
@ -258,8 +306,14 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, completion_type: LLMMode,
|
||||
response: requests.Response, prompt_messages: list[PromptMessage]) -> Generator:
|
||||
def _handle_generate_stream_response(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
completion_type: LLMMode,
|
||||
response: requests.Response,
|
||||
prompt_messages: list[PromptMessage],
|
||||
) -> Generator:
|
||||
"""
|
||||
Handle llm completion stream response
|
||||
|
||||
@ -270,17 +324,20 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator result
|
||||
"""
|
||||
full_text = ''
|
||||
full_text = ""
|
||||
chunk_index = 0
|
||||
|
||||
def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, finish_reason: str) \
|
||||
-> LLMResultChunk:
|
||||
def create_final_llm_result_chunk(
|
||||
index: int, message: AssistantPromptMessage, finish_reason: str
|
||||
) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
|
||||
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
usage = self._calc_response_usage(
|
||||
model, credentials, prompt_tokens, completion_tokens
|
||||
)
|
||||
|
||||
return LLMResultChunk(
|
||||
model=model,
|
||||
@ -289,11 +346,11 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
index=index,
|
||||
message=message,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter='\n'):
|
||||
for chunk in response.iter_lines(decode_unicode=True, delimiter="\n"):
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
@ -304,7 +361,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered."
|
||||
finish_reason="Non-JSON encountered.",
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
@ -314,55 +371,57 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
if not chunk_json:
|
||||
continue
|
||||
|
||||
if 'message' not in chunk_json:
|
||||
text = ''
|
||||
if "message" not in chunk_json:
|
||||
text = ""
|
||||
else:
|
||||
text = chunk_json.get('message').get('content', '')
|
||||
text = chunk_json.get("message").get("content", "")
|
||||
else:
|
||||
if not chunk_json:
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
text = chunk_json['response']
|
||||
text = chunk_json["response"]
|
||||
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=text
|
||||
)
|
||||
assistant_prompt_message = AssistantPromptMessage(content=text)
|
||||
|
||||
full_text += text
|
||||
|
||||
if chunk_json['done']:
|
||||
if chunk_json["done"]:
|
||||
# calculate num tokens
|
||||
if 'prompt_eval_count' in chunk_json and 'eval_count' in chunk_json:
|
||||
if "prompt_eval_count" in chunk_json and "eval_count" in chunk_json:
|
||||
# transform usage
|
||||
prompt_tokens = chunk_json["prompt_eval_count"]
|
||||
completion_tokens = chunk_json["eval_count"]
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._get_num_tokens_by_gpt2(prompt_messages[0].content)
|
||||
prompt_tokens = self._get_num_tokens_by_gpt2(
|
||||
prompt_messages[0].content
|
||||
)
|
||||
completion_tokens = self._get_num_tokens_by_gpt2(full_text)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
usage = self._calc_response_usage(
|
||||
model, credentials, prompt_tokens, completion_tokens
|
||||
)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=chunk_json['model'],
|
||||
model=chunk_json["model"],
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason='stop',
|
||||
usage=usage
|
||||
)
|
||||
finish_reason="stop",
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=chunk_json['model'],
|
||||
model=chunk_json["model"],
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
chunk_index += 1
|
||||
@ -376,15 +435,21 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
text = ''
|
||||
text = ""
|
||||
images = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
message_content = cast(
|
||||
TextPromptMessageContent, message_content
|
||||
)
|
||||
text = message_content.data
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
image_data = re.sub(r'^data:image\/[a-zA-Z]+;base64,', '', message_content.data)
|
||||
message_content = cast(
|
||||
ImagePromptMessageContent, message_content
|
||||
)
|
||||
image_data = re.sub(
|
||||
r"^data:image\/[a-zA-Z]+;base64,", "", message_content.data
|
||||
)
|
||||
images.append(image_data)
|
||||
|
||||
message_dict = {"role": "user", "content": text, "images": images}
|
||||
@ -414,7 +479,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
return num_tokens
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
def get_customizable_model_schema(
|
||||
self, model: str, credentials: dict
|
||||
) -> AIModelEntity:
|
||||
"""
|
||||
Get customizable model schema.
|
||||
|
||||
@ -425,20 +492,19 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
extras = {}
|
||||
|
||||
if 'vision_support' in credentials and credentials['vision_support'] == 'true':
|
||||
extras['features'] = [ModelFeature.VISION]
|
||||
if "vision_support" in credentials and credentials["vision_support"] == "true":
|
||||
extras["features"] = [ModelFeature.VISION]
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
zh_Hans=model,
|
||||
en_US=model
|
||||
),
|
||||
label=I18nObject(zh_Hans=model, en_US=model),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.MODE: credentials.get('mode'),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', 4096)),
|
||||
ModelPropertyKey.MODE: credentials.get("mode"),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(
|
||||
credentials.get("context_size", 4096)
|
||||
),
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
@ -446,91 +512,111 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
use_template=DefaultParameterName.TEMPERATURE.value,
|
||||
label=I18nObject(en_US="Temperature"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="The temperature of the model. "
|
||||
"Increasing the temperature will make the model answer "
|
||||
"more creatively. (Default: 0.8)"),
|
||||
help=I18nObject(
|
||||
en_US="The temperature of the model. "
|
||||
"Increasing the temperature will make the model answer "
|
||||
"more creatively. (Default: 0.8)"
|
||||
),
|
||||
default=0.1,
|
||||
min=0,
|
||||
max=1
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name=DefaultParameterName.TOP_P.value,
|
||||
use_template=DefaultParameterName.TOP_P.value,
|
||||
label=I18nObject(en_US="Top P"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
|
||||
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
||||
"focused and conservative text. (Default: 0.9)"),
|
||||
help=I18nObject(
|
||||
en_US="Works together with top-k. A higher value (e.g., 0.95) will lead to "
|
||||
"more diverse text, while a lower value (e.g., 0.5) will generate more "
|
||||
"focused and conservative text. (Default: 0.9)"
|
||||
),
|
||||
default=0.9,
|
||||
min=0,
|
||||
max=1
|
||||
max=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_k",
|
||||
label=I18nObject(en_US="Top K"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Reduces the probability of generating nonsense. "
|
||||
"A higher value (e.g. 100) will give more diverse answers, "
|
||||
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"),
|
||||
help=I18nObject(
|
||||
en_US="Reduces the probability of generating nonsense. "
|
||||
"A higher value (e.g. 100) will give more diverse answers, "
|
||||
"while a lower value (e.g. 10) will be more conservative. (Default: 40)"
|
||||
),
|
||||
min=1,
|
||||
max=100
|
||||
max=100,
|
||||
),
|
||||
ParameterRule(
|
||||
name='repeat_penalty',
|
||||
name="repeat_penalty",
|
||||
label=I18nObject(en_US="Repeat Penalty"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="Sets how strongly to penalize repetitions. "
|
||||
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
|
||||
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"),
|
||||
help=I18nObject(
|
||||
en_US="Sets how strongly to penalize repetitions. "
|
||||
"A higher value (e.g., 1.5) will penalize repetitions more strongly, "
|
||||
"while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)"
|
||||
),
|
||||
min=-2,
|
||||
max=2
|
||||
max=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name='num_predict',
|
||||
use_template='max_tokens',
|
||||
name="num_predict",
|
||||
use_template="max_tokens",
|
||||
label=I18nObject(en_US="Num Predict"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Maximum number of tokens to predict when generating text. "
|
||||
"(Default: 128, -1 = infinite generation, -2 = fill context)"),
|
||||
default=512 if int(credentials.get('max_tokens', 4096)) >= 768 else 128,
|
||||
help=I18nObject(
|
||||
en_US="Maximum number of tokens to predict when generating text. "
|
||||
"(Default: 128, -1 = infinite generation, -2 = fill context)"
|
||||
),
|
||||
default=(
|
||||
512 if int(credentials.get("max_tokens", 4096)) >= 768 else 128
|
||||
),
|
||||
min=-2,
|
||||
max=int(credentials.get('max_tokens', 4096)),
|
||||
max=int(credentials.get("max_tokens", 4096)),
|
||||
),
|
||||
ParameterRule(
|
||||
name='mirostat',
|
||||
name="mirostat",
|
||||
label=I18nObject(en_US="Mirostat sampling"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Enable Mirostat sampling for controlling perplexity. "
|
||||
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"),
|
||||
help=I18nObject(
|
||||
en_US="Enable Mirostat sampling for controlling perplexity. "
|
||||
"(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"
|
||||
),
|
||||
min=0,
|
||||
max=2
|
||||
max=2,
|
||||
),
|
||||
ParameterRule(
|
||||
name='mirostat_eta',
|
||||
name="mirostat_eta",
|
||||
label=I18nObject(en_US="Mirostat Eta"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="Influences how quickly the algorithm responds to feedback from "
|
||||
"the generated text. A lower learning rate will result in slower adjustments, "
|
||||
"while a higher learning rate will make the algorithm more responsive. "
|
||||
"(Default: 0.1)"),
|
||||
precision=1
|
||||
help=I18nObject(
|
||||
en_US="Influences how quickly the algorithm responds to feedback from "
|
||||
"the generated text. A lower learning rate will result in slower adjustments, "
|
||||
"while a higher learning rate will make the algorithm more responsive. "
|
||||
"(Default: 0.1)"
|
||||
),
|
||||
precision=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='mirostat_tau',
|
||||
name="mirostat_tau",
|
||||
label=I18nObject(en_US="Mirostat Tau"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="Controls the balance between coherence and diversity of the output. "
|
||||
"A lower value will result in more focused and coherent text. (Default: 5.0)"),
|
||||
precision=1
|
||||
help=I18nObject(
|
||||
en_US="Controls the balance between coherence and diversity of the output. "
|
||||
"A lower value will result in more focused and coherent text. (Default: 5.0)"
|
||||
),
|
||||
precision=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='num_ctx',
|
||||
name="num_ctx",
|
||||
label=I18nObject(en_US="Size of context window"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Sets the size of the context window used to generate the next token. "
|
||||
"(Default: 2048)"),
|
||||
help=I18nObject(
|
||||
en_US="Sets the size of the context window used to generate the next token. "
|
||||
"(Default: 2048)"
|
||||
),
|
||||
default=2048,
|
||||
min=1
|
||||
min=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='num_gpu',
|
||||
@ -544,56 +630,77 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
default=1
|
||||
),
|
||||
ParameterRule(
|
||||
name='num_thread',
|
||||
name="num_thread",
|
||||
label=I18nObject(en_US="Num Thread"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Sets the number of threads to use during computation. "
|
||||
"By default, Ollama will detect this for optimal performance. "
|
||||
"It is recommended to set this value to the number of physical CPU cores "
|
||||
"your system has (as opposed to the logical number of cores)."),
|
||||
help=I18nObject(
|
||||
en_US="Sets the number of threads to use during computation. "
|
||||
"By default, Ollama will detect this for optimal performance. "
|
||||
"It is recommended to set this value to the number of physical CPU cores "
|
||||
"your system has (as opposed to the logical number of cores)."
|
||||
),
|
||||
min=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='repeat_last_n',
|
||||
name="repeat_last_n",
|
||||
label=I18nObject(en_US="Repeat last N"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Sets how far back for the model to look back to prevent repetition. "
|
||||
"(Default: 64, 0 = disabled, -1 = num_ctx)"),
|
||||
min=-1
|
||||
help=I18nObject(
|
||||
en_US="Sets how far back for the model to look back to prevent repetition. "
|
||||
"(Default: 64, 0 = disabled, -1 = num_ctx)"
|
||||
),
|
||||
min=-1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='tfs_z',
|
||||
name="tfs_z",
|
||||
label=I18nObject(en_US="TFS Z"),
|
||||
type=ParameterType.FLOAT,
|
||||
help=I18nObject(en_US="Tail free sampling is used to reduce the impact of less probable tokens "
|
||||
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
|
||||
"while a value of 1.0 disables this setting. (default: 1)"),
|
||||
precision=1
|
||||
help=I18nObject(
|
||||
en_US="Tail free sampling is used to reduce the impact of less probable tokens "
|
||||
"from the output. A higher value (e.g., 2.0) will reduce the impact more, "
|
||||
"while a value of 1.0 disables this setting. (default: 1)"
|
||||
),
|
||||
precision=1,
|
||||
),
|
||||
ParameterRule(
|
||||
name='seed',
|
||||
name="seed",
|
||||
label=I18nObject(en_US="Seed"),
|
||||
type=ParameterType.INT,
|
||||
help=I18nObject(en_US="Sets the random number seed to use for generation. Setting this to "
|
||||
"a specific number will make the model generate the same text for "
|
||||
"the same prompt. (Default: 0)"),
|
||||
help=I18nObject(
|
||||
en_US="Sets the random number seed to use for generation. Setting this to "
|
||||
"a specific number will make the model generate the same text for "
|
||||
"the same prompt. (Default: 0)"
|
||||
),
|
||||
),
|
||||
ParameterRule(
|
||||
name='format',
|
||||
name="keep_alive",
|
||||
label=I18nObject(en_US="Keep Alive"),
|
||||
type=ParameterType.STRING,
|
||||
help=I18nObject(
|
||||
en_US="Sets how long the model is kept in memory after generating a response. "
|
||||
"This must be a duration string with a unit (e.g., '10m' for 10 minutes or '24h' for 24 hours). "
|
||||
"A negative number keeps the model loaded indefinitely, and '0' unloads the model immediately after generating a response. "
|
||||
"Valid time units are 's','m','h'. (Default: 5m)"
|
||||
),
|
||||
),
|
||||
ParameterRule(
|
||||
name="format",
|
||||
label=I18nObject(en_US="Format"),
|
||||
type=ParameterType.STRING,
|
||||
help=I18nObject(en_US="the format to return a response in."
|
||||
" Currently the only accepted value is json."),
|
||||
options=['json'],
|
||||
)
|
||||
help=I18nObject(
|
||||
en_US="the format to return a response in."
|
||||
" Currently the only accepted value is json."
|
||||
),
|
||||
options=["json"],
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(credentials.get('input_price', 0)),
|
||||
output=Decimal(credentials.get('output_price', 0)),
|
||||
unit=Decimal(credentials.get('unit', 0)),
|
||||
currency=credentials.get('currency', "USD")
|
||||
input=Decimal(credentials.get("input_price", 0)),
|
||||
output=Decimal(credentials.get("output_price", 0)),
|
||||
unit=Decimal(credentials.get("unit", 0)),
|
||||
currency=credentials.get("currency", "USD"),
|
||||
),
|
||||
**extras
|
||||
**extras,
|
||||
)
|
||||
|
||||
return entity
|
||||
@ -621,10 +728,10 @@ class OllamaLargeLanguageModel(LargeLanguageModel):
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
requests.exceptions.ConnectionError, # Engine Overloaded
|
||||
requests.exceptions.HTTPError # Server Error
|
||||
requests.exceptions.HTTPError, # Server Error
|
||||
],
|
||||
InvokeConnectionError: [
|
||||
requests.exceptions.ConnectTimeout, # Timeout
|
||||
requests.exceptions.ReadTimeout # Timeout
|
||||
]
|
||||
requests.exceptions.ReadTimeout, # Timeout
|
||||
],
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user