feat: add LocalAI local embedding model support (#1021)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
takatost 2023-08-29 22:22:02 +08:00 committed by GitHub
parent b5953039de
commit 417c19577a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 1144 additions and 7 deletions

View File

@ -63,6 +63,9 @@ class ModelProviderFactory:
elif provider_name == 'openllm':
from core.model_providers.providers.openllm_provider import OpenLLMProvider
return OpenLLMProvider
elif provider_name == 'localai':
from core.model_providers.providers.localai_provider import LocalAIProvider
return LocalAIProvider
else:
raise NotImplementedError

View File

@ -0,0 +1,29 @@
from langchain.embeddings import LocalAIEmbeddings
from replicate.exceptions import ModelError, ReplicateError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
class LocalAIEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = LocalAIEmbeddings(
model=name,
openai_api_key="1",
openai_api_base=credentials['server_url'],
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"LocalAI embedding: {str(ex)}")
else:
return ex

View File

@ -0,0 +1,131 @@
import logging
from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult, get_buffer_string
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, LLMAuthorizationError
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
class LocalAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
model_kwargs: ModelKwargs,
streaming: bool = False,
callbacks: Callbacks = None):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
if credentials['completion_type'] == 'chat_completion':
self.model_mode = ModelMode.CHAT
else:
self.model_mode = ModelMode.COMPLETION
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
if self.model_mode == ModelMode.COMPLETION:
client = EnhanceOpenAI(
model_name=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
openai_api_key="1",
openai_api_base=self.credentials['server_url'] + '/v1',
**provider_model_kwargs
)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p')
}
client = EnhanceChatOpenAI(
model_name=self.name,
temperature=provider_model_kwargs.get('temperature'),
max_tokens=provider_model_kwargs.get('max_tokens'),
model_kwargs=extra_model_kwargs,
streaming=self.streaming,
callbacks=self.callbacks,
request_timeout=60,
openai_api_key="1",
openai_api_base=self.credentials['server_url'] + '/v1'
)
return client
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
if isinstance(prompts, str):
return self._client.get_num_tokens(prompts)
else:
return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
if self.model_mode == ModelMode.COMPLETION:
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
else:
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p')
}
self.client.temperature = provider_model_kwargs.get('temperature')
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
self.client.model_kwargs = extra_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, openai.error.InvalidRequestError):
logging.warning("Invalid request to LocalAI API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to LocalAI API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("LocalAI service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else:
return ex
@classmethod
def support_streaming(cls):
return True

View File

@ -0,0 +1,164 @@
import json
from typing import Type
from langchain.embeddings import LocalAIEmbeddings
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
from core.third_party.langchain.llms.open_ai import EnhanceOpenAI
from models.provider import ProviderType
class LocalAIProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'localai'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = LocalAIModel
elif model_type == ModelType.EMBEDDINGS:
model_class = LocalAIEmbedding
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=0.7),
top_p=KwargRule[float](min=0, max=1, default=1),
max_tokens=KwargRule[int](min=10, max=4097, default=16),
)
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
if 'server_url' not in credentials:
raise CredentialsValidateFailedError('LocalAI Server URL must be provided.')
try:
if model_type == ModelType.EMBEDDINGS:
model = LocalAIEmbeddings(
model=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url']
)
model.embed_query("ping")
else:
if ('completion_type' not in credentials
or credentials['completion_type'] not in ['completion', 'chat_completion']):
raise CredentialsValidateFailedError('LocalAI Completion Type must be provided.')
if credentials['completion_type'] == 'chat_completion':
model = EnhanceChatOpenAI(
model_name=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url'] + '/v1',
max_tokens=10,
request_timeout=60,
)
model([HumanMessage(content='ping')])
else:
model = EnhanceOpenAI(
model_name=model_name,
openai_api_key='1',
openai_api_base=credentials['server_url'] + '/v1',
max_tokens=10,
request_timeout=60,
)
model('ping')
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
credentials['server_url'] = encrypter.encrypt_token(tenant_id, credentials['server_url'])
return credentials
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
if self.provider.provider_type != ProviderType.CUSTOM.value:
raise NotImplementedError
provider_model = self._get_provider_model(model_name, model_type)
if not provider_model.encrypted_config:
return {
'server_url': None,
}
credentials = json.loads(provider_model.encrypted_config)
if credentials['server_url']:
credentials['server_url'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['server_url']
)
if obfuscated:
credentials['server_url'] = encrypter.obfuscated_token(credentials['server_url'])
return credentials
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
return
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
return {}
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
return {}

View File

@ -10,5 +10,6 @@
"replicate",
"huggingface_hub",
"xinference",
"openllm"
"openllm",
"localai"
]

View File

@ -0,0 +1,7 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "configurable"
}

View File

@ -42,7 +42,8 @@ class EnhanceChatOpenAI(ChatOpenAI):
return {
**super()._default_params,
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,

View File

@ -1,7 +1,10 @@
import os
from typing import Dict, Any, Mapping, Optional, Union, Tuple
from typing import Dict, Any, Mapping, Optional, Union, Tuple, List, Iterator
from langchain import OpenAI
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.openai import completion_with_retry, _stream_response_to_generation_chunk
from langchain.schema.output import GenerationChunk
from pydantic import root_validator
@ -33,7 +36,8 @@ class EnhanceOpenAI(OpenAI):
def _invocation_params(self) -> Dict[str, Any]:
return {**super()._invocation_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
@ -43,8 +47,33 @@ class EnhanceOpenAI(OpenAI):
def _identifying_params(self) -> Mapping[str, Any]:
return {**super()._identifying_params, **{
"api_type": 'openai',
"api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_base": self.openai_api_base if self.openai_api_base
else os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
"api_version": None,
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}}
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutates params
for stream_resp in completion_with_retry(
self, prompt=prompt, run_manager=run_manager, **params
):
if 'text' in stream_resp["choices"][0]:
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info
else None,
)

View File

@ -39,4 +39,7 @@ XINFERENCE_SERVER_URL=
XINFERENCE_MODEL_UID=
# OpenLLM Credentials
OPENLLM_SERVER_URL=
OPENLLM_SERVER_URL=
# LocalAI Credentials
LOCALAI_SERVER_URL=

View File

@ -0,0 +1,61 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.localai_provider import LocalAIProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='localai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_embedding_model(mocker):
model_name = 'text-embedding-ada-002'
server_url = os.environ['LOCALAI_SERVER_URL']
model_provider = LocalAIProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='localai',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps({
'server_url': server_url,
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return LocalAIEmbedding(
model_provider=model_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)

View File

@ -0,0 +1,68 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.localai_provider import LocalAIProvider
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider(server_url):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='localai',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({}),
is_valid=True,
)
def get_mock_model(model_name, mocker):
model_kwargs = ModelKwargs(
max_tokens=10,
temperature=0
)
server_url = os.environ['LOCALAI_SERVER_URL']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='localai',
model_name=model_name,
model_type=ModelType.TEXT_GENERATION.value,
encrypted_config=json.dumps({'server_url': server_url, 'completion_type': 'completion'}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
openai_provider = LocalAIProvider(provider=get_mock_provider(server_url))
return LocalAIModel(
model_provider=openai_provider,
name=model_name,
model_kwargs=model_kwargs
)
def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
return encrypted_openai_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_num_tokens(mock_decrypt, mocker):
openai_model = get_mock_model('ggml-gpt4all-j', mocker)
rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
assert rst > 0
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
openai_model = get_mock_model('ggml-gpt4all-j', mocker)
rst = openai_model.run(
[PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')],
stop=['\nHuman:'],
)
assert len(rst.content) > 0

View File

@ -0,0 +1,116 @@
import pytest
from unittest.mock import patch, MagicMock
import json
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.base import CredentialsValidateFailedError
from core.model_providers.providers.localai_provider import LocalAIProvider
from models.provider import ProviderType, Provider, ProviderModel
PROVIDER_NAME = 'localai'
MODEL_PROVIDER_CLASS = LocalAIProvider
VALIDATE_CREDENTIAL = {
'server_url': 'http://127.0.0.1:8080/'
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.embeddings.localai.LocalAIEmbeddings.embed_query',
return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='username/test_model_name',
model_type=ModelType.EMBEDDINGS,
credentials=VALIDATE_CREDENTIAL.copy()
)
def test_is_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if server_url is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS,
credentials={}
)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_model_credentials(mock_encrypt, mocker):
server_url = 'http://127.0.0.1:8080/'
result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
tenant_id='tenant_id',
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS,
credentials=VALIDATE_CREDENTIAL.copy()
)
mock_encrypt.assert_called_with('tenant_id', server_url)
assert result['server_url'] == f'encrypted_{server_url}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_custom(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS
)
assert result['server_url'] == 'http://127.0.0.1:8080/'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=None,
is_valid=True,
)
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['server_url'] = 'encrypted_' + encrypted_credential['server_url']
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
encrypted_config=json.dumps(encrypted_credential)
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_model_credentials(
model_name='test_model_name',
model_type=ModelType.EMBEDDINGS,
obfuscated=True
)
middle_token = result['server_url'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['server_url']) - 8, 0)
assert all(char == '*' for char in middle_token)

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 76 KiB

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 73 KiB

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,14 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './Localai.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
export default Icon

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,14 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './LocalaiText.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
export default Icon

View File

@ -14,6 +14,8 @@ export { default as Huggingface } from './Huggingface'
export { default as IflytekSparkTextCn } from './IflytekSparkTextCn'
export { default as IflytekSparkText } from './IflytekSparkText'
export { default as IflytekSpark } from './IflytekSpark'
export { default as LocalaiText } from './LocalaiText'
export { default as Localai } from './Localai'
export { default as Microsoft } from './Microsoft'
export { default as OpenaiBlack } from './OpenaiBlack'
export { default as OpenaiBlue } from './OpenaiBlue'

View File

@ -10,6 +10,7 @@ import minimax from './minimax'
import chatglm from './chatglm'
import xinference from './xinference'
import openllm from './openllm'
import localai from './localai'
export default {
openai,
@ -24,4 +25,5 @@ export default {
chatglm,
xinference,
openllm,
localai,
}

View File

@ -0,0 +1,176 @@
import { ProviderEnum } from '../declarations'
import type { FormValue, ProviderConfig } from '../declarations'
import { Localai, LocalaiText } from '@/app/components/base/icons/src/public/llm'
const config: ProviderConfig = {
selector: {
name: {
'en': 'LocalAI',
'zh-Hans': 'LocalAI',
},
icon: <Localai className='w-full h-full' />,
},
item: {
key: ProviderEnum.localai,
titleIcon: {
'en': <LocalaiText className='h-6' />,
'zh-Hans': <LocalaiText className='h-6' />,
},
disable: {
tip: {
'en': 'Only supports the ',
'zh-Hans': '仅支持',
},
link: {
href: {
'en': 'https://docs.dify.ai/getting-started/install-self-hosted',
'zh-Hans': 'https://docs.dify.ai/v/zh-hans/getting-started/install-self-hosted',
},
label: {
'en': 'community open-source version',
'zh-Hans': '社区开源版本',
},
},
},
},
modal: {
key: ProviderEnum.localai,
title: {
'en': 'LocalAI',
'zh-Hans': 'LocalAI',
},
icon: <Localai className='h-6' />,
link: {
href: 'https://github.com/go-skynet/LocalAI',
label: {
'en': 'How to deploy LocalAI',
'zh-Hans': '如何部署 LocalAI',
},
},
defaultValue: {
model_type: 'text-generation',
completion_type: 'completion',
},
validateKeys: (v?: FormValue) => {
if (v?.model_type === 'text-generation') {
return [
'model_type',
'model_name',
'server_url',
'completion_type',
]
}
if (v?.model_type === 'embeddings') {
return [
'model_type',
'model_name',
'server_url',
]
}
return []
},
filterValue: (v?: FormValue) => {
let filteredKeys: string[] = []
if (v?.model_type === 'text-generation') {
filteredKeys = [
'model_type',
'model_name',
'server_url',
'completion_type',
]
}
if (v?.model_type === 'embeddings') {
filteredKeys = [
'model_type',
'model_name',
'server_url',
]
}
return filteredKeys.reduce((prev: FormValue, next: string) => {
prev[next] = v?.[next] || ''
return prev
}, {})
},
fields: [
{
type: 'radio',
key: 'model_type',
required: true,
label: {
'en': 'Model Type',
'zh-Hans': '模型类型',
},
options: [
{
key: 'text-generation',
label: {
'en': 'Text Generation',
'zh-Hans': '文本生成',
},
},
{
key: 'embeddings',
label: {
'en': 'Embeddings',
'zh-Hans': 'Embeddings',
},
},
],
},
{
type: 'text',
key: 'model_name',
required: true,
label: {
'en': 'Model Name',
'zh-Hans': '模型名称',
},
placeholder: {
'en': 'Enter your Model Name here',
'zh-Hans': '在此输入您的模型名称',
},
},
{
hidden: (value?: FormValue) => value?.model_type === 'embeddings',
type: 'radio',
key: 'completion_type',
required: true,
label: {
'en': 'Completion Type',
'zh-Hans': 'Completion Type',
},
options: [
{
key: 'completion',
label: {
'en': 'Completion',
'zh-Hans': 'Completion',
},
},
{
key: 'chat_completion',
label: {
'en': 'Chat Completion',
'zh-Hans': 'Chat Completion',
},
},
],
},
{
type: 'text',
key: 'server_url',
required: true,
label: {
'en': 'Server url',
'zh-Hans': 'Server url',
},
placeholder: {
'en': 'Enter your Server Url, eg: https://example.com/xxx',
'zh-Hans': '在此输入您的 Server Urlhttps://example.com/xxx',
},
},
],
},
}
export default config

View File

@ -41,6 +41,7 @@ export enum ProviderEnum {
'chatglm' = 'chatglm',
'xinference' = 'xinference',
'openllm' = 'openllm',
'localai' = 'localai',
}
export type ProviderConfigItem = {

View File

@ -99,6 +99,7 @@ const ModelPage = () => {
config.chatglm,
config.xinference,
config.openllm,
config.localai,
]
}

View File

@ -2,7 +2,7 @@ import { ValidatedStatus } from '../key-validator/declarations'
import { ProviderEnum } from './declarations'
import { validateModelProvider } from '@/service/common'
export const ConfigurableProviders = [ProviderEnum.azure_openai, ProviderEnum.replicate, ProviderEnum.huggingface_hub, ProviderEnum.xinference, ProviderEnum.openllm]
export const ConfigurableProviders = [ProviderEnum.azure_openai, ProviderEnum.replicate, ProviderEnum.huggingface_hub, ProviderEnum.xinference, ProviderEnum.openllm, ProviderEnum.localai]
export const validateModelProviderFn = async (providerName: ProviderEnum, v: any) => {
let body, url