mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-02 11:18:19 +08:00
feat: Add fireworks custom llm intergration (#9333)
This commit is contained in:
parent
a8134a49c4
commit
aba70207ab
@ -18,6 +18,7 @@ supported_model_types:
|
||||
- text-embedding
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
- customizable-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: fireworks_api_key
|
||||
@ -28,3 +29,75 @@ provider_credential_schema:
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model URL
|
||||
zh_Hans: 模型URL
|
||||
placeholder:
|
||||
en_US: Enter your Model URL
|
||||
zh_Hans: 输入模型URL
|
||||
credential_form_schemas:
|
||||
- variable: model_label_zh_Hanns
|
||||
label:
|
||||
zh_Hans: 模型中文名称
|
||||
en_US: The zh_Hans of Model
|
||||
required: true
|
||||
type: text-input
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型中文名称
|
||||
en_US: Enter your zh_Hans of Model
|
||||
- variable: model_label_en_US
|
||||
label:
|
||||
zh_Hans: 模型英文名称
|
||||
en_US: The en_US of Model
|
||||
required: true
|
||||
type: text-input
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型英文名称
|
||||
en_US: Enter your en_US of Model
|
||||
- variable: fireworks_api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: context_size
|
||||
label:
|
||||
zh_Hans: 模型上下文长度
|
||||
en_US: Model context size
|
||||
required: true
|
||||
type: text-input
|
||||
default: '4096'
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的模型上下文长度
|
||||
en_US: Enter your Model context size
|
||||
- variable: max_tokens
|
||||
label:
|
||||
zh_Hans: 最大 token 上限
|
||||
en_US: Upper bound for max tokens
|
||||
default: '4096'
|
||||
type: text-input
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
- variable: function_calling_type
|
||||
label:
|
||||
en_US: Function calling
|
||||
type: select
|
||||
required: false
|
||||
default: no_call
|
||||
options:
|
||||
- value: no_call
|
||||
label:
|
||||
en_US: Not Support
|
||||
zh_Hans: 不支持
|
||||
- value: function_call
|
||||
label:
|
||||
en_US: Support
|
||||
zh_Hans: 支持
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
|
@ -8,7 +8,8 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, Cho
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
@ -20,6 +21,15 @@ from core.model_runtime.entities.message_entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
AIModelEntity,
|
||||
FetchFrom,
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
ParameterRule,
|
||||
ParameterType,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
|
||||
@ -608,3 +618,50 @@ class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel):
|
||||
num_tokens += self._get_num_tokens_by_gpt2(required_field)
|
||||
|
||||
return num_tokens
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
return AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=credentials.get("model_label_en_US", model),
|
||||
zh_Hans=credentials.get("model_label_zh_Hanns", model),
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
||||
if credentials.get("function_calling_type") == "function_call"
|
||||
else [],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
|
||||
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name="temperature",
|
||||
use_template="temperature",
|
||||
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
ParameterRule(
|
||||
name="max_tokens",
|
||||
use_template="max_tokens",
|
||||
default=512,
|
||||
min=1,
|
||||
max=int(credentials.get("max_tokens", 4096)),
|
||||
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
|
||||
type=ParameterType.INT,
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_p",
|
||||
use_template="top_p",
|
||||
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
ParameterRule(
|
||||
name="top_k",
|
||||
use_template="top_k",
|
||||
label=I18nObject(en_US="Top K", zh_Hans="Top K"),
|
||||
type=ParameterType.FLOAT,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user