From aba70207ab2245467bc29a60e6a2ae7663adc703 Mon Sep 17 00:00:00 2001 From: ice yao Date: Mon, 14 Oct 2024 22:50:31 +0800 Subject: [PATCH] feat: Add fireworks custom llm intergration (#9333) --- .../model_providers/fireworks/fireworks.yaml | 73 +++++++++++++++++++ .../model_providers/fireworks/llm/llm.py | 59 ++++++++++++++- 2 files changed, 131 insertions(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/fireworks/fireworks.yaml b/api/core/model_runtime/model_providers/fireworks/fireworks.yaml index cdb87a55e..ddbaa54eb 100644 --- a/api/core/model_runtime/model_providers/fireworks/fireworks.yaml +++ b/api/core/model_runtime/model_providers/fireworks/fireworks.yaml @@ -18,6 +18,7 @@ supported_model_types: - text-embedding configurate_methods: - predefined-model + - customizable-model provider_credential_schema: credential_form_schemas: - variable: fireworks_api_key @@ -28,3 +29,75 @@ provider_credential_schema: placeholder: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key +model_credential_schema: + model: + label: + en_US: Model URL + zh_Hans: 模型URL + placeholder: + en_US: Enter your Model URL + zh_Hans: 输入模型URL + credential_form_schemas: + - variable: model_label_zh_Hanns + label: + zh_Hans: 模型中文名称 + en_US: The zh_Hans of Model + required: true + type: text-input + placeholder: + zh_Hans: 在此输入您的模型中文名称 + en_US: Enter your zh_Hans of Model + - variable: model_label_en_US + label: + zh_Hans: 模型英文名称 + en_US: The en_US of Model + required: true + type: text-input + placeholder: + zh_Hans: 在此输入您的模型英文名称 + en_US: Enter your en_US of Model + - variable: fireworks_api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + default: '4096' + type: text-input + show_on: + - variable: __model_type + value: llm + - variable: function_calling_type + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: no_call + label: + en_US: Not Support + zh_Hans: 不支持 + - value: function_call + label: + en_US: Support + zh_Hans: 支持 + show_on: + - variable: __model_type + value: llm diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llm.py b/api/core/model_runtime/model_providers/fireworks/llm/llm.py index 24aad9c4d..ffe1ad5fc 100644 --- a/api/core/model_runtime/model_providers/fireworks/llm/llm.py +++ b/api/core/model_runtime/model_providers/fireworks/llm/llm.py @@ -8,7 +8,8 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, Cho from openai.types.chat.chat_completion_message import FunctionCall from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -20,6 +21,15 @@ from core.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.fireworks._common import _CommonFireworks @@ -608,3 +618,50 @@ class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel): num_tokens += self._get_num_tokens_by_gpt2(required_field) return num_tokens + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + return AIModelEntity( + model=model, + label=I18nObject( + en_US=credentials.get("model_label_en_US", model), + zh_Hans=credentials.get("model_label_zh_Hanns", model), + ), + model_type=ModelType.LLM, + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "function_call" + else [], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)), + ModelPropertyKey.MODE: LLMMode.CHAT.value, + }, + parameter_rules=[ + ParameterRule( + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="max_tokens", + use_template="max_tokens", + default=512, + min=1, + max=int(credentials.get("max_tokens", 4096)), + label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), + type=ParameterType.INT, + ), + ParameterRule( + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P", zh_Hans="Top P"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="top_k", + use_template="top_k", + label=I18nObject(en_US="Top K", zh_Hans="Top K"), + type=ParameterType.FLOAT, + ), + ], + )