mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-05 04:38:37 +08:00
1063 lines
40 KiB
Python
1063 lines
40 KiB
Python
import datetime
|
|
import json
|
|
import logging
|
|
from collections import defaultdict
|
|
from collections.abc import Iterator
|
|
from json import JSONDecodeError
|
|
from typing import Optional
|
|
|
|
from pydantic import BaseModel, ConfigDict
|
|
|
|
from constants import HIDDEN_VALUE
|
|
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
|
from core.entities.provider_entities import (
|
|
CustomConfiguration,
|
|
ModelSettings,
|
|
SystemConfiguration,
|
|
SystemConfigurationStatus,
|
|
)
|
|
from core.helper import encrypter
|
|
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
|
from core.model_runtime.entities.model_entities import FetchFrom, ModelType
|
|
from core.model_runtime.entities.provider_entities import (
|
|
ConfigurateMethod,
|
|
CredentialFormSchema,
|
|
FormType,
|
|
ProviderEntity,
|
|
)
|
|
from core.model_runtime.model_providers import model_provider_factory
|
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
|
from extensions.ext_database import db
|
|
from models.provider import (
|
|
LoadBalancingModelConfig,
|
|
Provider,
|
|
ProviderModel,
|
|
ProviderModelSetting,
|
|
ProviderType,
|
|
TenantPreferredModelProvider,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
original_provider_configurate_methods = {}
|
|
|
|
|
|
class ProviderConfiguration(BaseModel):
|
|
"""
|
|
Model class for provider configuration.
|
|
"""
|
|
|
|
tenant_id: str
|
|
provider: ProviderEntity
|
|
preferred_provider_type: ProviderType
|
|
using_provider_type: ProviderType
|
|
system_configuration: SystemConfiguration
|
|
custom_configuration: CustomConfiguration
|
|
model_settings: list[ModelSettings]
|
|
|
|
# pydantic configs
|
|
model_config = ConfigDict(protected_namespaces=())
|
|
|
|
def __init__(self, **data):
|
|
super().__init__(**data)
|
|
|
|
if self.provider.provider not in original_provider_configurate_methods:
|
|
original_provider_configurate_methods[self.provider.provider] = []
|
|
for configurate_method in self.provider.configurate_methods:
|
|
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
|
|
|
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
|
if (
|
|
any(
|
|
len(quota_configuration.restrict_models) > 0
|
|
for quota_configuration in self.system_configuration.quota_configurations
|
|
)
|
|
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
|
|
):
|
|
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
|
|
|
|
def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]:
|
|
"""
|
|
Get current credentials.
|
|
|
|
:param model_type: model type
|
|
:param model: model name
|
|
:return:
|
|
"""
|
|
if self.model_settings:
|
|
# check if model is disabled by admin
|
|
for model_setting in self.model_settings:
|
|
if model_setting.model_type == model_type and model_setting.model == model:
|
|
if not model_setting.enabled:
|
|
raise ValueError(f"Model {model} is disabled.")
|
|
|
|
if self.using_provider_type == ProviderType.SYSTEM:
|
|
restrict_models = []
|
|
for quota_configuration in self.system_configuration.quota_configurations:
|
|
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
|
|
continue
|
|
|
|
restrict_models = quota_configuration.restrict_models
|
|
|
|
copy_credentials = self.system_configuration.credentials.copy()
|
|
if restrict_models:
|
|
for restrict_model in restrict_models:
|
|
if (
|
|
restrict_model.model_type == model_type
|
|
and restrict_model.model == model
|
|
and restrict_model.base_model_name
|
|
):
|
|
copy_credentials["base_model_name"] = restrict_model.base_model_name
|
|
|
|
return copy_credentials
|
|
else:
|
|
credentials = None
|
|
if self.custom_configuration.models:
|
|
for model_configuration in self.custom_configuration.models:
|
|
if model_configuration.model_type == model_type and model_configuration.model == model:
|
|
credentials = model_configuration.credentials
|
|
break
|
|
|
|
if not credentials and self.custom_configuration.provider:
|
|
credentials = self.custom_configuration.provider.credentials
|
|
|
|
return credentials
|
|
|
|
def get_system_configuration_status(self) -> SystemConfigurationStatus:
|
|
"""
|
|
Get system configuration status.
|
|
:return:
|
|
"""
|
|
if self.system_configuration.enabled is False:
|
|
return SystemConfigurationStatus.UNSUPPORTED
|
|
|
|
current_quota_type = self.system_configuration.current_quota_type
|
|
current_quota_configuration = next(
|
|
(q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None
|
|
)
|
|
|
|
return (
|
|
SystemConfigurationStatus.ACTIVE
|
|
if current_quota_configuration.is_valid
|
|
else SystemConfigurationStatus.QUOTA_EXCEEDED
|
|
)
|
|
|
|
def is_custom_configuration_available(self) -> bool:
|
|
"""
|
|
Check custom configuration available.
|
|
:return:
|
|
"""
|
|
return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0
|
|
|
|
def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
|
"""
|
|
Get custom credentials.
|
|
|
|
:param obfuscated: obfuscated secret data in credentials
|
|
:return:
|
|
"""
|
|
if self.custom_configuration.provider is None:
|
|
return None
|
|
|
|
credentials = self.custom_configuration.provider.credentials
|
|
if not obfuscated:
|
|
return credentials
|
|
|
|
# Obfuscate credentials
|
|
return self.obfuscated_credentials(
|
|
credentials=credentials,
|
|
credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas
|
|
if self.provider.provider_credential_schema
|
|
else [],
|
|
)
|
|
|
|
def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]:
|
|
"""
|
|
Validate custom credentials.
|
|
:param credentials: provider credentials
|
|
:return:
|
|
"""
|
|
# get provider
|
|
provider_record = (
|
|
db.session.query(Provider)
|
|
.filter(
|
|
Provider.tenant_id == self.tenant_id,
|
|
Provider.provider_name == self.provider.provider,
|
|
Provider.provider_type == ProviderType.CUSTOM.value,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
# Get provider credential secret variables
|
|
provider_credential_secret_variables = self.extract_secret_variables(
|
|
self.provider.provider_credential_schema.credential_form_schemas
|
|
if self.provider.provider_credential_schema
|
|
else []
|
|
)
|
|
|
|
if provider_record:
|
|
try:
|
|
# fix origin data
|
|
if provider_record.encrypted_config:
|
|
if not provider_record.encrypted_config.startswith("{"):
|
|
original_credentials = {"openai_api_key": provider_record.encrypted_config}
|
|
else:
|
|
original_credentials = json.loads(provider_record.encrypted_config)
|
|
else:
|
|
original_credentials = {}
|
|
except JSONDecodeError:
|
|
original_credentials = {}
|
|
|
|
# encrypt credentials
|
|
for key, value in credentials.items():
|
|
if key in provider_credential_secret_variables:
|
|
# if send [__HIDDEN__] in secret input, it will be same as original value
|
|
if value == HIDDEN_VALUE and key in original_credentials:
|
|
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
|
|
|
credentials = model_provider_factory.provider_credentials_validate(
|
|
provider=self.provider.provider, credentials=credentials
|
|
)
|
|
|
|
for key, value in credentials.items():
|
|
if key in provider_credential_secret_variables:
|
|
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
|
|
|
return provider_record, credentials
|
|
|
|
def add_or_update_custom_credentials(self, credentials: dict) -> None:
|
|
"""
|
|
Add or update custom provider credentials.
|
|
:param credentials:
|
|
:return:
|
|
"""
|
|
# validate custom provider config
|
|
provider_record, credentials = self.custom_credentials_validate(credentials)
|
|
|
|
# save provider
|
|
# Note: Do not switch the preferred provider, which allows users to use quotas first
|
|
if provider_record:
|
|
provider_record.encrypted_config = json.dumps(credentials)
|
|
provider_record.is_valid = True
|
|
provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
db.session.commit()
|
|
else:
|
|
provider_record = Provider(
|
|
tenant_id=self.tenant_id,
|
|
provider_name=self.provider.provider,
|
|
provider_type=ProviderType.CUSTOM.value,
|
|
encrypted_config=json.dumps(credentials),
|
|
is_valid=True,
|
|
)
|
|
db.session.add(provider_record)
|
|
db.session.commit()
|
|
|
|
provider_model_credentials_cache = ProviderCredentialsCache(
|
|
tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER
|
|
)
|
|
|
|
provider_model_credentials_cache.delete()
|
|
|
|
self.switch_preferred_provider_type(ProviderType.CUSTOM)
|
|
|
|
def delete_custom_credentials(self) -> None:
|
|
"""
|
|
Delete custom provider credentials.
|
|
:return:
|
|
"""
|
|
# get provider
|
|
provider_record = (
|
|
db.session.query(Provider)
|
|
.filter(
|
|
Provider.tenant_id == self.tenant_id,
|
|
Provider.provider_name == self.provider.provider,
|
|
Provider.provider_type == ProviderType.CUSTOM.value,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
# delete provider
|
|
if provider_record:
|
|
self.switch_preferred_provider_type(ProviderType.SYSTEM)
|
|
|
|
db.session.delete(provider_record)
|
|
db.session.commit()
|
|
|
|
provider_model_credentials_cache = ProviderCredentialsCache(
|
|
tenant_id=self.tenant_id,
|
|
identity_id=provider_record.id,
|
|
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
|
)
|
|
|
|
provider_model_credentials_cache.delete()
|
|
|
|
def get_custom_model_credentials(
|
|
self, model_type: ModelType, model: str, obfuscated: bool = False
|
|
) -> Optional[dict]:
|
|
"""
|
|
Get custom model credentials.
|
|
|
|
:param model_type: model type
|
|
:param model: model name
|
|
:param obfuscated: obfuscated secret data in credentials
|
|
:return:
|
|
"""
|
|
if not self.custom_configuration.models:
|
|
return None
|
|
|
|
for model_configuration in self.custom_configuration.models:
|
|
if model_configuration.model_type == model_type and model_configuration.model == model:
|
|
credentials = model_configuration.credentials
|
|
if not obfuscated:
|
|
return credentials
|
|
|
|
# Obfuscate credentials
|
|
return self.obfuscated_credentials(
|
|
credentials=credentials,
|
|
credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas
|
|
if self.provider.model_credential_schema
|
|
else [],
|
|
)
|
|
|
|
return None
|
|
|
|
def custom_model_credentials_validate(
|
|
self, model_type: ModelType, model: str, credentials: dict
|
|
) -> tuple[ProviderModel, dict]:
|
|
"""
|
|
Validate custom model credentials.
|
|
|
|
:param model_type: model type
|
|
:param model: model name
|
|
:param credentials: model credentials
|
|
:return:
|
|
"""
|
|
# get provider model
|
|
provider_model_record = (
|
|
db.session.query(ProviderModel)
|
|
.filter(
|
|
ProviderModel.tenant_id == self.tenant_id,
|
|
ProviderModel.provider_name == self.provider.provider,
|
|
ProviderModel.model_name == model,
|
|
ProviderModel.model_type == model_type.to_origin_model_type(),
|
|
)
|
|
.first()
|
|
)
|
|
|
|
# Get provider credential secret variables
|
|
provider_credential_secret_variables = self.extract_secret_variables(
|
|
self.provider.model_credential_schema.credential_form_schemas
|
|
if self.provider.model_credential_schema
|
|
else []
|
|
)
|
|
|
|
if provider_model_record:
|
|
try:
|
|
original_credentials = (
|
|
json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {}
|
|
)
|
|
except JSONDecodeError:
|
|
original_credentials = {}
|
|
|
|
# decrypt credentials
|
|
for key, value in credentials.items():
|
|
if key in provider_credential_secret_variables:
|
|
# if send [__HIDDEN__] in secret input, it will be same as original value
|
|
if value == HIDDEN_VALUE and key in original_credentials:
|
|
credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key])
|
|
|
|
credentials = model_provider_factory.model_credentials_validate(
|
|
provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials
|
|
)
|
|
|
|
for key, value in credentials.items():
|
|
if key in provider_credential_secret_variables:
|
|
credentials[key] = encrypter.encrypt_token(self.tenant_id, value)
|
|
|
|
return provider_model_record, credentials
|
|
|
|
def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None:
|
|
"""
|
|
Add or update custom model credentials.
|
|
|
|
:param model_type: model type
|
|
:param model: model name
|
|
:param credentials: model credentials
|
|
:return:
|
|
"""
|
|
# validate custom model config
|
|
provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials)
|
|
|
|
# save provider model
|
|
# Note: Do not switch the preferred provider, which allows users to use quotas first
|
|
if provider_model_record:
|
|
provider_model_record.encrypted_config = json.dumps(credentials)
|
|
provider_model_record.is_valid = True
|
|
provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
db.session.commit()
|
|
else:
|
|
provider_model_record = ProviderModel(
|
|
tenant_id=self.tenant_id,
|
|
provider_name=self.provider.provider,
|
|
model_name=model,
|
|
model_type=model_type.to_origin_model_type(),
|
|
encrypted_config=json.dumps(credentials),
|
|
is_valid=True,
|
|
)
|
|
db.session.add(provider_model_record)
|
|
db.session.commit()
|
|
|
|
provider_model_credentials_cache = ProviderCredentialsCache(
|
|
tenant_id=self.tenant_id,
|
|
identity_id=provider_model_record.id,
|
|
cache_type=ProviderCredentialsCacheType.MODEL,
|
|
)
|
|
|
|
provider_model_credentials_cache.delete()
|
|
|
|
def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
|
|
"""
|
|
Delete custom model credentials.
|
|
:param model_type: model type
|
|
:param model: model name
|
|
:return:
|
|
"""
|
|
# get provider model
|
|
provider_model_record = (
|
|
db.session.query(ProviderModel)
|
|
.filter(
|
|
ProviderModel.tenant_id == self.tenant_id,
|
|
ProviderModel.provider_name == self.provider.provider,
|
|
ProviderModel.model_name == model,
|
|
ProviderModel.model_type == model_type.to_origin_model_type(),
|
|
)
|
|
.first()
|
|
)
|
|
|
|
# delete provider model
|
|
if provider_model_record:
|
|
db.session.delete(provider_model_record)
|
|
db.session.commit()
|
|
|
|
provider_model_credentials_cache = ProviderCredentialsCache(
|
|
tenant_id=self.tenant_id,
|
|
identity_id=provider_model_record.id,
|
|
cache_type=ProviderCredentialsCacheType.MODEL,
|
|
)
|
|
|
|
provider_model_credentials_cache.delete()
|
|
|
|
def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
|
"""
|
|
Enable model.
|
|
:param model_type: model type
|
|
:param model: model name
|
|
:return:
|
|
"""
|
|
model_setting = (
|
|
db.session.query(ProviderModelSetting)
|
|
.filter(
|
|
ProviderModelSetting.tenant_id == self.tenant_id,
|
|
ProviderModelSetting.provider_name == self.provider.provider,
|
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
ProviderModelSetting.model_name == model,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if model_setting:
|
|
model_setting.enabled = True
|
|
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
db.session.commit()
|
|
else:
|
|
model_setting = ProviderModelSetting(
|
|
tenant_id=self.tenant_id,
|
|
provider_name=self.provider.provider,
|
|
model_type=model_type.to_origin_model_type(),
|
|
model_name=model,
|
|
enabled=True,
|
|
)
|
|
db.session.add(model_setting)
|
|
db.session.commit()
|
|
|
|
return model_setting
|
|
|
|
def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
|
"""
|
|
Disable model.
|
|
:param model_type: model type
|
|
:param model: model name
|
|
:return:
|
|
"""
|
|
model_setting = (
|
|
db.session.query(ProviderModelSetting)
|
|
.filter(
|
|
ProviderModelSetting.tenant_id == self.tenant_id,
|
|
ProviderModelSetting.provider_name == self.provider.provider,
|
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
ProviderModelSetting.model_name == model,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if model_setting:
|
|
model_setting.enabled = False
|
|
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
db.session.commit()
|
|
else:
|
|
model_setting = ProviderModelSetting(
|
|
tenant_id=self.tenant_id,
|
|
provider_name=self.provider.provider,
|
|
model_type=model_type.to_origin_model_type(),
|
|
model_name=model,
|
|
enabled=False,
|
|
)
|
|
db.session.add(model_setting)
|
|
db.session.commit()
|
|
|
|
return model_setting
|
|
|
|
def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]:
|
|
"""
|
|
Get provider model setting.
|
|
:param model_type: model type
|
|
:param model: model name
|
|
:return:
|
|
"""
|
|
return (
|
|
db.session.query(ProviderModelSetting)
|
|
.filter(
|
|
ProviderModelSetting.tenant_id == self.tenant_id,
|
|
ProviderModelSetting.provider_name == self.provider.provider,
|
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
ProviderModelSetting.model_name == model,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
|
"""
|
|
Enable model load balancing.
|
|
:param model_type: model type
|
|
:param model: model name
|
|
:return:
|
|
"""
|
|
load_balancing_config_count = (
|
|
db.session.query(LoadBalancingModelConfig)
|
|
.filter(
|
|
LoadBalancingModelConfig.tenant_id == self.tenant_id,
|
|
LoadBalancingModelConfig.provider_name == self.provider.provider,
|
|
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
|
LoadBalancingModelConfig.model_name == model,
|
|
)
|
|
.count()
|
|
)
|
|
|
|
if load_balancing_config_count <= 1:
|
|
raise ValueError("Model load balancing configuration must be more than 1.")
|
|
|
|
model_setting = (
|
|
db.session.query(ProviderModelSetting)
|
|
.filter(
|
|
ProviderModelSetting.tenant_id == self.tenant_id,
|
|
ProviderModelSetting.provider_name == self.provider.provider,
|
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
ProviderModelSetting.model_name == model,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if model_setting:
|
|
model_setting.load_balancing_enabled = True
|
|
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
db.session.commit()
|
|
else:
|
|
model_setting = ProviderModelSetting(
|
|
tenant_id=self.tenant_id,
|
|
provider_name=self.provider.provider,
|
|
model_type=model_type.to_origin_model_type(),
|
|
model_name=model,
|
|
load_balancing_enabled=True,
|
|
)
|
|
db.session.add(model_setting)
|
|
db.session.commit()
|
|
|
|
return model_setting
|
|
|
|
def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting:
|
|
"""
|
|
Disable model load balancing.
|
|
:param model_type: model type
|
|
:param model: model name
|
|
:return:
|
|
"""
|
|
model_setting = (
|
|
db.session.query(ProviderModelSetting)
|
|
.filter(
|
|
ProviderModelSetting.tenant_id == self.tenant_id,
|
|
ProviderModelSetting.provider_name == self.provider.provider,
|
|
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
|
|
ProviderModelSetting.model_name == model,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if model_setting:
|
|
model_setting.load_balancing_enabled = False
|
|
model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
|
db.session.commit()
|
|
else:
|
|
model_setting = ProviderModelSetting(
|
|
tenant_id=self.tenant_id,
|
|
provider_name=self.provider.provider,
|
|
model_type=model_type.to_origin_model_type(),
|
|
model_name=model,
|
|
load_balancing_enabled=False,
|
|
)
|
|
db.session.add(model_setting)
|
|
db.session.commit()
|
|
|
|
return model_setting
|
|
|
|
def get_provider_instance(self) -> ModelProvider:
|
|
"""
|
|
Get provider instance.
|
|
:return:
|
|
"""
|
|
return model_provider_factory.get_provider_instance(self.provider.provider)
|
|
|
|
def get_model_type_instance(self, model_type: ModelType) -> AIModel:
|
|
"""
|
|
Get current model type instance.
|
|
|
|
:param model_type: model type
|
|
:return:
|
|
"""
|
|
# Get provider instance
|
|
provider_instance = self.get_provider_instance()
|
|
|
|
# Get model instance of LLM
|
|
return provider_instance.get_model_instance(model_type)
|
|
|
|
def switch_preferred_provider_type(self, provider_type: ProviderType) -> None:
|
|
"""
|
|
Switch preferred provider type.
|
|
:param provider_type:
|
|
:return:
|
|
"""
|
|
if provider_type == self.preferred_provider_type:
|
|
return
|
|
|
|
if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled:
|
|
return
|
|
|
|
# get preferred provider
|
|
preferred_model_provider = (
|
|
db.session.query(TenantPreferredModelProvider)
|
|
.filter(
|
|
TenantPreferredModelProvider.tenant_id == self.tenant_id,
|
|
TenantPreferredModelProvider.provider_name == self.provider.provider,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if preferred_model_provider:
|
|
preferred_model_provider.preferred_provider_type = provider_type.value
|
|
else:
|
|
preferred_model_provider = TenantPreferredModelProvider(
|
|
tenant_id=self.tenant_id,
|
|
provider_name=self.provider.provider,
|
|
preferred_provider_type=provider_type.value,
|
|
)
|
|
db.session.add(preferred_model_provider)
|
|
|
|
db.session.commit()
|
|
|
|
def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]:
|
|
"""
|
|
Extract secret input form variables.
|
|
|
|
:param credential_form_schemas:
|
|
:return:
|
|
"""
|
|
secret_input_form_variables = []
|
|
for credential_form_schema in credential_form_schemas:
|
|
if credential_form_schema.type == FormType.SECRET_INPUT:
|
|
secret_input_form_variables.append(credential_form_schema.variable)
|
|
|
|
return secret_input_form_variables
|
|
|
|
def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict:
|
|
"""
|
|
Obfuscated credentials.
|
|
|
|
:param credentials: credentials
|
|
:param credential_form_schemas: credential form schemas
|
|
:return:
|
|
"""
|
|
# Get provider credential secret variables
|
|
credential_secret_variables = self.extract_secret_variables(credential_form_schemas)
|
|
|
|
# Obfuscate provider credentials
|
|
copy_credentials = credentials.copy()
|
|
for key, value in copy_credentials.items():
|
|
if key in credential_secret_variables:
|
|
copy_credentials[key] = encrypter.obfuscated_token(value)
|
|
|
|
return copy_credentials
|
|
|
|
def get_provider_model(
|
|
self, model_type: ModelType, model: str, only_active: bool = False
|
|
) -> Optional[ModelWithProviderEntity]:
|
|
"""
|
|
Get provider model.
|
|
:param model_type: model type
|
|
:param model: model name
|
|
:param only_active: return active model only
|
|
:return:
|
|
"""
|
|
provider_models = self.get_provider_models(model_type, only_active)
|
|
|
|
for provider_model in provider_models:
|
|
if provider_model.model == model:
|
|
return provider_model
|
|
|
|
return None
|
|
|
|
def get_provider_models(
|
|
self, model_type: Optional[ModelType] = None, only_active: bool = False
|
|
) -> list[ModelWithProviderEntity]:
|
|
"""
|
|
Get provider models.
|
|
:param model_type: model type
|
|
:param only_active: only active models
|
|
:return:
|
|
"""
|
|
provider_instance = self.get_provider_instance()
|
|
|
|
model_types = []
|
|
if model_type:
|
|
model_types.append(model_type)
|
|
else:
|
|
model_types = provider_instance.get_provider_schema().supported_model_types
|
|
|
|
# Group model settings by model type and model
|
|
model_setting_map = defaultdict(dict)
|
|
for model_setting in self.model_settings:
|
|
model_setting_map[model_setting.model_type][model_setting.model] = model_setting
|
|
|
|
if self.using_provider_type == ProviderType.SYSTEM:
|
|
provider_models = self._get_system_provider_models(
|
|
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
|
|
)
|
|
else:
|
|
provider_models = self._get_custom_provider_models(
|
|
model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map
|
|
)
|
|
|
|
if only_active:
|
|
provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
|
|
|
|
# resort provider_models
|
|
return sorted(provider_models, key=lambda x: x.model_type.value)
|
|
|
|
def _get_system_provider_models(
|
|
self,
|
|
model_types: list[ModelType],
|
|
provider_instance: ModelProvider,
|
|
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
|
|
) -> list[ModelWithProviderEntity]:
|
|
"""
|
|
Get system provider models.
|
|
|
|
:param model_types: model types
|
|
:param provider_instance: provider instance
|
|
:param model_setting_map: model setting map
|
|
:return:
|
|
"""
|
|
provider_models = []
|
|
for model_type in model_types:
|
|
for m in provider_instance.models(model_type):
|
|
status = ModelStatus.ACTIVE
|
|
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
|
|
model_setting = model_setting_map[m.model_type][m.model]
|
|
if model_setting.enabled is False:
|
|
status = ModelStatus.DISABLED
|
|
|
|
provider_models.append(
|
|
ModelWithProviderEntity(
|
|
model=m.model,
|
|
label=m.label,
|
|
model_type=m.model_type,
|
|
features=m.features,
|
|
fetch_from=m.fetch_from,
|
|
model_properties=m.model_properties,
|
|
deprecated=m.deprecated,
|
|
provider=SimpleModelProviderEntity(self.provider),
|
|
status=status,
|
|
)
|
|
)
|
|
|
|
if self.provider.provider not in original_provider_configurate_methods:
|
|
original_provider_configurate_methods[self.provider.provider] = []
|
|
for configurate_method in provider_instance.get_provider_schema().configurate_methods:
|
|
original_provider_configurate_methods[self.provider.provider].append(configurate_method)
|
|
|
|
should_use_custom_model = False
|
|
if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]:
|
|
should_use_custom_model = True
|
|
|
|
for quota_configuration in self.system_configuration.quota_configurations:
|
|
if self.system_configuration.current_quota_type != quota_configuration.quota_type:
|
|
continue
|
|
|
|
restrict_models = quota_configuration.restrict_models
|
|
if len(restrict_models) == 0:
|
|
break
|
|
|
|
if should_use_custom_model:
|
|
if original_provider_configurate_methods[self.provider.provider] == [
|
|
ConfigurateMethod.CUSTOMIZABLE_MODEL
|
|
]:
|
|
# only customizable model
|
|
for restrict_model in restrict_models:
|
|
copy_credentials = self.system_configuration.credentials.copy()
|
|
if restrict_model.base_model_name:
|
|
copy_credentials["base_model_name"] = restrict_model.base_model_name
|
|
|
|
try:
|
|
custom_model_schema = provider_instance.get_model_instance(
|
|
restrict_model.model_type
|
|
).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials)
|
|
except Exception as ex:
|
|
logger.warning(f"get custom model schema failed, {ex}")
|
|
continue
|
|
|
|
if not custom_model_schema:
|
|
continue
|
|
|
|
if custom_model_schema.model_type not in model_types:
|
|
continue
|
|
|
|
status = ModelStatus.ACTIVE
|
|
if (
|
|
custom_model_schema.model_type in model_setting_map
|
|
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
|
|
):
|
|
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
|
|
if model_setting.enabled is False:
|
|
status = ModelStatus.DISABLED
|
|
|
|
provider_models.append(
|
|
ModelWithProviderEntity(
|
|
model=custom_model_schema.model,
|
|
label=custom_model_schema.label,
|
|
model_type=custom_model_schema.model_type,
|
|
features=custom_model_schema.features,
|
|
fetch_from=FetchFrom.PREDEFINED_MODEL,
|
|
model_properties=custom_model_schema.model_properties,
|
|
deprecated=custom_model_schema.deprecated,
|
|
provider=SimpleModelProviderEntity(self.provider),
|
|
status=status,
|
|
)
|
|
)
|
|
|
|
# if llm name not in restricted llm list, remove it
|
|
restrict_model_names = [rm.model for rm in restrict_models]
|
|
for m in provider_models:
|
|
if m.model_type == ModelType.LLM and m.model not in restrict_model_names:
|
|
m.status = ModelStatus.NO_PERMISSION
|
|
elif not quota_configuration.is_valid:
|
|
m.status = ModelStatus.QUOTA_EXCEEDED
|
|
|
|
return provider_models
|
|
|
|
def _get_custom_provider_models(
|
|
self,
|
|
model_types: list[ModelType],
|
|
provider_instance: ModelProvider,
|
|
model_setting_map: dict[ModelType, dict[str, ModelSettings]],
|
|
) -> list[ModelWithProviderEntity]:
|
|
"""
|
|
Get custom provider models.
|
|
|
|
:param model_types: model types
|
|
:param provider_instance: provider instance
|
|
:param model_setting_map: model setting map
|
|
:return:
|
|
"""
|
|
provider_models = []
|
|
|
|
credentials = None
|
|
if self.custom_configuration.provider:
|
|
credentials = self.custom_configuration.provider.credentials
|
|
|
|
for model_type in model_types:
|
|
if model_type not in self.provider.supported_model_types:
|
|
continue
|
|
|
|
models = provider_instance.models(model_type)
|
|
for m in models:
|
|
status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE
|
|
load_balancing_enabled = False
|
|
if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]:
|
|
model_setting = model_setting_map[m.model_type][m.model]
|
|
if model_setting.enabled is False:
|
|
status = ModelStatus.DISABLED
|
|
|
|
if len(model_setting.load_balancing_configs) > 1:
|
|
load_balancing_enabled = True
|
|
|
|
provider_models.append(
|
|
ModelWithProviderEntity(
|
|
model=m.model,
|
|
label=m.label,
|
|
model_type=m.model_type,
|
|
features=m.features,
|
|
fetch_from=m.fetch_from,
|
|
model_properties=m.model_properties,
|
|
deprecated=m.deprecated,
|
|
provider=SimpleModelProviderEntity(self.provider),
|
|
status=status,
|
|
load_balancing_enabled=load_balancing_enabled,
|
|
)
|
|
)
|
|
|
|
# custom models
|
|
for model_configuration in self.custom_configuration.models:
|
|
if model_configuration.model_type not in model_types:
|
|
continue
|
|
|
|
try:
|
|
custom_model_schema = provider_instance.get_model_instance(
|
|
model_configuration.model_type
|
|
).get_customizable_model_schema_from_credentials(
|
|
model_configuration.model, model_configuration.credentials
|
|
)
|
|
except Exception as ex:
|
|
logger.warning(f"get custom model schema failed, {ex}")
|
|
continue
|
|
|
|
if not custom_model_schema:
|
|
continue
|
|
|
|
status = ModelStatus.ACTIVE
|
|
load_balancing_enabled = False
|
|
if (
|
|
custom_model_schema.model_type in model_setting_map
|
|
and custom_model_schema.model in model_setting_map[custom_model_schema.model_type]
|
|
):
|
|
model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model]
|
|
if model_setting.enabled is False:
|
|
status = ModelStatus.DISABLED
|
|
|
|
if len(model_setting.load_balancing_configs) > 1:
|
|
load_balancing_enabled = True
|
|
|
|
provider_models.append(
|
|
ModelWithProviderEntity(
|
|
model=custom_model_schema.model,
|
|
label=custom_model_schema.label,
|
|
model_type=custom_model_schema.model_type,
|
|
features=custom_model_schema.features,
|
|
fetch_from=custom_model_schema.fetch_from,
|
|
model_properties=custom_model_schema.model_properties,
|
|
deprecated=custom_model_schema.deprecated,
|
|
provider=SimpleModelProviderEntity(self.provider),
|
|
status=status,
|
|
load_balancing_enabled=load_balancing_enabled,
|
|
)
|
|
)
|
|
|
|
return provider_models
|
|
|
|
|
|
class ProviderConfigurations(BaseModel):
|
|
"""
|
|
Model class for provider configuration dict.
|
|
"""
|
|
|
|
tenant_id: str
|
|
configurations: dict[str, ProviderConfiguration] = {}
|
|
|
|
def __init__(self, tenant_id: str):
|
|
super().__init__(tenant_id=tenant_id)
|
|
|
|
def get_models(
|
|
self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False
|
|
) -> list[ModelWithProviderEntity]:
|
|
"""
|
|
Get available models.
|
|
|
|
If preferred provider type is `system`:
|
|
Get the current **system mode** if provider supported,
|
|
if all system modes are not available (no quota), it is considered to be the **custom credential mode**.
|
|
If there is no model configured in custom mode, it is treated as no_configure.
|
|
system > custom > no_configure
|
|
|
|
If preferred provider type is `custom`:
|
|
If custom credentials are configured, it is treated as custom mode.
|
|
Otherwise, get the current **system mode** if supported,
|
|
If all system modes are not available (no quota), it is treated as no_configure.
|
|
custom > system > no_configure
|
|
|
|
If real mode is `system`, use system credentials to get models,
|
|
paid quotas > provider free quotas > system free quotas
|
|
include pre-defined models (exclude GPT-4, status marked as `no_permission`).
|
|
If real mode is `custom`, use workspace custom credentials to get models,
|
|
include pre-defined models, custom models(manual append).
|
|
If real mode is `no_configure`, only return pre-defined models from `model runtime`.
|
|
(model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`)
|
|
model status marked as `active` is available.
|
|
|
|
:param provider: provider name
|
|
:param model_type: model type
|
|
:param only_active: only active models
|
|
:return:
|
|
"""
|
|
all_models = []
|
|
for provider_configuration in self.values():
|
|
if provider and provider_configuration.provider.provider != provider:
|
|
continue
|
|
|
|
all_models.extend(provider_configuration.get_provider_models(model_type, only_active))
|
|
|
|
return all_models
|
|
|
|
def to_list(self) -> list[ProviderConfiguration]:
|
|
"""
|
|
Convert to list.
|
|
|
|
:return:
|
|
"""
|
|
return list(self.values())
|
|
|
|
def __getitem__(self, key):
|
|
return self.configurations[key]
|
|
|
|
def __setitem__(self, key, value):
|
|
self.configurations[key] = value
|
|
|
|
def __iter__(self):
|
|
return iter(self.configurations)
|
|
|
|
def values(self) -> Iterator[ProviderConfiguration]:
|
|
return self.configurations.values()
|
|
|
|
def get(self, key, default=None):
|
|
return self.configurations.get(key, default)
|
|
|
|
|
|
class ProviderModelBundle(BaseModel):
|
|
"""
|
|
Provider model bundle.
|
|
"""
|
|
|
|
configuration: ProviderConfiguration
|
|
provider_instance: ModelProvider
|
|
model_type_instance: AIModel
|
|
|
|
# pydantic configs
|
|
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|