Feat/tool secret parameter (#2760)

This commit is contained in:
Yeuoly 2024-03-08 20:31:13 +08:00 committed by GitHub
parent bbc0d330a9
commit ce58f0607b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 490 additions and 117 deletions

View File

@ -27,7 +27,9 @@ from fields.app_fields import (
from libs.login import login_required
from models.model import App, AppModelConfig, Site
from services.app_model_config_service import AppModelConfigService
from core.tools.utils.configuration import ToolParameterConfigurationManager
from core.tools.tool_manager import ToolManager
from core.entities.application_entities import AgentToolEntity
def _get_app(app_id, tenant_id):
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
@ -236,7 +238,39 @@ class AppApi(Resource):
def get(self, app_id):
"""Get app detail"""
app_id = str(app_id)
app = _get_app(app_id, current_user.current_tenant_id)
app: App = _get_app(app_id, current_user.current_tenant_id)
# get original app model config
model_config: AppModelConfig = app.app_model_config
agent_mode = model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
masked_parameter = {}
# override tool parameters
tool['tool_parameters'] = masked_parameter
# override agent mode
model_config.agent_mode = json.dumps(agent_mode)
return app

View File

@ -1,3 +1,4 @@
import json
from flask import request
from flask_login import current_user
@ -7,6 +8,9 @@ from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.entities.application_entities import AgentToolEntity
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from libs.login import login_required
@ -38,6 +42,82 @@ class ModelConfigResource(Resource):
)
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
# get original app model config
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
AppModelConfig.id == app.app_model_config_id
).first()
agent_mode = original_app_model_config.agent_mode_dict
# decrypt agent tool parameters if it's secret-input
parameter_map = {}
masked_parameter_map = {}
tool_map = {}
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
# get decrypted parameters
if agent_tool_entity.tool_parameters:
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
masked_parameter = manager.mask_tool_parameters(parameters or {})
else:
parameters = {}
masked_parameter = {}
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
masked_parameter_map[key] = masked_parameter
parameter_map[key] = parameters
tool_map[key] = tool_runtime
# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get('tools') or []:
agent_tool_entity = AgentToolEntity(**tool)
# get tool
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
if key in tool_map:
tool_runtime = tool_map[key]
else:
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
tool_runtime=tool_runtime,
provider_name=agent_tool_entity.provider_id,
provider_type=agent_tool_entity.provider_type,
)
manager.delete_tool_parameters_cache()
# override parameters if it equals to masked parameters
if agent_tool_entity.tool_parameters:
if key not in masked_parameter_map:
continue
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
agent_tool_entity.tool_parameters = parameter_map[key]
# encrypt parameters
if agent_tool_entity.tool_parameters:
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
# update app model config
new_app_model_config.agent_mode = json.dumps(agent_mode)
db.session.add(new_app_model_config)
db.session.flush()

View File

@ -154,9 +154,9 @@ class BaseAssistantApplicationRunner(AppRunner):
"""
convert tool to prompt message tool
"""
tool_entity = ToolManager.get_tool_runtime(
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
tenant_id=self.application_generate_entity.tenant_id,
tool_entity = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id,
agent_tool=tool,
agent_callback=self.agent_callback
)
tool_entity.load_variables(self.variables_pool)
@ -171,33 +171,11 @@ class BaseAssistantApplicationRunner(AppRunner):
}
)
runtime_parameters = {}
parameters = tool_entity.parameters or []
user_parameters = tool_entity.get_runtime_parameters() or []
# override parameters
for parameter in user_parameters:
# check if parameter in tool parameters
found = False
for tool_parameter in parameters:
if tool_parameter.name == parameter.name:
found = True
break
if found:
# override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
else:
# add new parameter
parameters.append(parameter)
parameters = tool_entity.get_all_runtime_parameters()
for parameter in parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = 'string'
enum = []
if parameter.type == ToolParameter.ToolParameterType.STRING:
@ -213,59 +191,16 @@ class BaseAssistantApplicationRunner(AppRunner):
else:
raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# get tool parameter from form
tool_parameter_config = tool.tool_parameters.get(parameter.name)
if not tool_parameter_config:
# get default value
tool_parameter_config = parameter.default
if not tool_parameter_config and parameter.required:
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
if parameter.type == ToolParameter.ToolParameterType.SELECT:
# check if tool_parameter_config in options
options = list(map(lambda x: x.value, parameter.options))
if tool_parameter_config not in options:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
# convert tool parameter config to correct type
try:
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
# check if tool parameter is integer
if isinstance(tool_parameter_config, int):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, float):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, str):
if '.' in tool_parameter_config:
tool_parameter_config = float(tool_parameter_config)
else:
tool_parameter_config = int(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
tool_parameter_config = bool(tool_parameter_config)
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
tool_parameter_config = str(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType:
tool_parameter_config = str(tool_parameter_config)
except Exception as e:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
# save tool parameter to tool entity memory
runtime_parameters[parameter.name] = tool_parameter_config
elif parameter.form == ToolParameter.ToolParameterForm.LLM:
message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
message_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
if len(enum) > 0:
message_tool.parameters['properties'][parameter.name]['enum'] = enum
if len(enum) > 0:
message_tool.parameters['properties'][parameter.name]['enum'] = enum
if parameter.required:
message_tool.parameters['required'].append(parameter.name)
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
if parameter.required:
message_tool.parameters['required'].append(parameter.name)
return message_tool, tool_entity
@ -305,6 +240,9 @@ class BaseAssistantApplicationRunner(AppRunner):
tool_runtime_parameters = tool.get_runtime_parameters() or []
for parameter in tool_runtime_parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = 'string'
enum = []
if parameter.type == ToolParameter.ToolParameterType.STRING:
@ -320,18 +258,17 @@ class BaseAssistantApplicationRunner(AppRunner):
else:
raise ValueError(f"parameter type {parameter.type} is not supported")
if parameter.form == ToolParameter.ToolParameterForm.LLM:
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
prompt_tool.parameters['properties'][parameter.name] = {
"type": parameter_type,
"description": parameter.llm_description or '',
}
if len(enum) > 0:
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
if len(enum) > 0:
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
if parameter.required:
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
if parameter.required:
if parameter.name not in prompt_tool.parameters['required']:
prompt_tool.parameters['required'].append(parameter.name)
return prompt_tool

View File

@ -0,0 +1,54 @@
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ToolParameterCacheType(Enum):
PARAMETER = "tool_parameter"
class ToolParameterCache:
def __init__(self,
tenant_id: str,
provider: str,
tool_name: str,
cache_type: ToolParameterCacheType
):
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_tool_parameter = redis_client.get(self.cache_key)
if cached_tool_parameter:
try:
cached_tool_parameter = cached_tool_parameter.decode('utf-8')
cached_tool_parameter = json.loads(cached_tool_parameter)
except JSONDecodeError:
return None
return cached_tool_parameter
else:
return None
def set(self, parameters: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))
def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)

View File

@ -119,7 +119,7 @@ parameters: # Parameter list
- The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc.
- `parameters` Parameter list
- `name` Parameter name, unique, no duplication with other parameters
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select` four types, corresponding to string, number, boolean, drop-down box
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` four types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using `secret-input` type
- `required` Required or not
- In `llm` mode, if the parameter is required, the Agent is required to infer this parameter
- In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts

View File

@ -119,7 +119,7 @@ parameters: # 参数列表
- `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等
- `parameters` 参数列表
- `name` 参数名称,唯一,不允许和其他参数重名
- `type` 参数类型,目前支持`string`、`number`、`boolean`、`select`种类型,分别对应字符串、数字、布尔值、下拉框
- `type` 参数类型,目前支持`string`、`number`、`boolean`、`select`、`secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型
- `required` 是否必填
- 在`llm`模式下如果参数为必填则会要求Agent必须要推理出这个参数
- 在`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数

View File

@ -100,6 +100,7 @@ class ToolParameter(BaseModel):
NUMBER = "number"
BOOLEAN = "boolean"
SELECT = "select"
SECRET_INPUT = "secret-input"
class ToolParameterForm(Enum):
SCHEMA = "schema" # should be set while adding tool

View File

@ -23,6 +23,8 @@ class AIPPTGenerateTool(BuiltinTool):
_api_base_url = URL('https://co.aippt.cn/api')
_api_token_cache = {}
_api_token_cache_lock = Lock()
_style_cache = {}
_style_cache_lock = Lock()
_task = {}
_task_type_map = {
@ -390,20 +392,31 @@ class AIPPTGenerateTool(BuiltinTool):
).digest()
).decode('utf-8')
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
@classmethod
def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]:
"""
Get styles
:param credentials: the credentials
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
"""
# check cache
with cls._style_cache_lock:
# clear expired styles
now = time()
for key in list(cls._style_cache.keys()):
if cls._style_cache[key]['expire'] < now:
del cls._style_cache[key]
key = f'{credentials["aippt_access_key"]}#@#{user_id}'
if key in cls._style_cache:
return cls._style_cache[key]['colors'], cls._style_cache[key]['styles']
headers = {
'x-channel': '',
'x-api-key': self.runtime.credentials['aippt_access_key'],
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id)
'x-api-key': credentials['aippt_access_key'],
'x-token': cls._get_api_token(credentials=credentials, user_id=user_id)
}
response = get(
str(self._api_base_url / 'template_component' / 'suit' / 'select'),
str(cls._api_base_url / 'template_component' / 'suit' / 'select'),
headers=headers
)
@ -425,7 +438,26 @@ class AIPPTGenerateTool(BuiltinTool):
'name': item.get('title'),
} for item in response.get('data', {}).get('suit_style') or []]
with cls._style_cache_lock:
cls._style_cache[key] = {
'colors': colors,
'styles': styles,
'expire': now + 60 * 60
}
return colors, styles
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
"""
Get styles
:param credentials: the credentials
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
"""
if not self.runtime.credentials.get('aippt_access_key') or not self.runtime.credentials.get('aippt_secret_key'):
return [], []
return self._get_styles(credentials=self.runtime.credentials, user_id=user_id)
def _get_suit(self, style_id: int, colour_id: int) -> int:
"""

View File

@ -14,7 +14,7 @@ description:
llm: A tool for sending messages to a chat group on Wecom(企业微信) .
parameters:
- name: hook_key
type: string
type: secret-input
required: true
label:
en_US: Wecom Group bot webhook key

View File

@ -266,6 +266,40 @@ class Tool(BaseModel, ABC):
"""
return self.parameters
def get_all_runtime_parameters(self) -> list[ToolParameter]:
"""
get all runtime parameters
:return: all runtime parameters
"""
parameters = self.parameters or []
parameters = parameters.copy()
user_parameters = self.get_runtime_parameters() or []
user_parameters = user_parameters.copy()
# override parameters
for parameter in user_parameters:
# check if parameter in tool parameters
found = False
for tool_parameter in parameters:
if tool_parameter.name == parameter.name:
found = True
break
if found:
# override parameter
tool_parameter.type = parameter.type
tool_parameter.form = parameter.form
tool_parameter.required = parameter.required
tool_parameter.default = parameter.default
tool_parameter.options = parameter.options
tool_parameter.llm_description = parameter.llm_description
else:
# add new parameter
parameters.append(parameter)
return parameters
def is_tool_available(self) -> bool:
"""
check if the tool is available

View File

@ -6,11 +6,17 @@ from os import listdir, path
from typing import Any, Union
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.entities.application_entities import AgentToolEntity
from core.model_runtime.entities.message_entities import PromptMessage
from core.provider_manager import ProviderManager
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.constant import DEFAULT_PROVIDERS
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeMessage, ToolProviderCredentials
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolInvokeMessage,
ToolParameter,
ToolProviderCredentials,
)
from core.tools.entities.user_entities import UserToolProvider
from core.tools.errors import ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
@ -21,7 +27,12 @@ from core.tools.provider.model_tool_provider import ModelToolProviderController
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.api_tool import ApiTool
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.utils.configuration import ModelToolConfigurationManager, ToolConfiguration
from core.tools.tool.tool import Tool
from core.tools.utils.configuration import (
ModelToolConfigurationManager,
ToolConfigurationManager,
ToolParameterConfigurationManager,
)
from core.tools.utils.encoder import serialize_base_model_dict
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider
@ -172,7 +183,7 @@ class ToolManager:
# decrypt the credentials
credentials = builtin_provider.credentials
controller = ToolManager.get_builtin_provider(provider_name)
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
@ -189,7 +200,7 @@ class ToolManager:
api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name)
# decrypt the credentials
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=api_provider)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
@ -214,6 +225,71 @@ class ToolManager:
else:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@staticmethod
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool:
"""
get the agent tool runtime
"""
tool_entity = ToolManager.get_tool_runtime(
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name,
tenant_id=tenant_id,
agent_callback=agent_callback
)
runtime_parameters = {}
parameters = tool_entity.get_all_runtime_parameters()
for parameter in parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# get tool parameter from form
tool_parameter_config = agent_tool.tool_parameters.get(parameter.name)
if not tool_parameter_config:
# get default value
tool_parameter_config = parameter.default
if not tool_parameter_config and parameter.required:
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
if parameter.type == ToolParameter.ToolParameterType.SELECT:
# check if tool_parameter_config in options
options = list(map(lambda x: x.value, parameter.options))
if tool_parameter_config not in options:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
# convert tool parameter config to correct type
try:
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
# check if tool parameter is integer
if isinstance(tool_parameter_config, int):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, float):
tool_parameter_config = tool_parameter_config
elif isinstance(tool_parameter_config, str):
if '.' in tool_parameter_config:
tool_parameter_config = float(tool_parameter_config)
else:
tool_parameter_config = int(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
tool_parameter_config = bool(tool_parameter_config)
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
tool_parameter_config = str(tool_parameter_config)
elif parameter.type == ToolParameter.ToolParameterType:
tool_parameter_config = str(tool_parameter_config)
except Exception as e:
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
# save tool parameter to tool entity memory
runtime_parameters[parameter.name] = tool_parameter_config
# decrypt runtime parameters
encryption_manager = ToolParameterConfigurationManager(
tenant_id=tenant_id,
tool_runtime=tool_entity,
provider_name=agent_tool.provider_id,
provider_type=agent_tool.provider_type,
)
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
return tool_entity
@staticmethod
def get_builtin_provider_icon(provider: str) -> tuple[str, str]:
"""
@ -396,7 +472,7 @@ class ToolManager:
controller = ToolManager.get_builtin_provider(provider_name)
# init tool configuration
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
@ -463,7 +539,7 @@ class ToolManager:
)
# init tool configuration
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
@ -523,7 +599,7 @@ class ToolManager:
provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
)
# init tool configuration
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)

View File

@ -5,16 +5,19 @@ from pydantic import BaseModel
from yaml import FullLoader, load
from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.entities.tool_entities import (
ModelToolConfiguration,
ModelToolProviderConfiguration,
ToolParameter,
ToolProviderCredentials,
)
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool.tool import Tool
class ToolConfiguration(BaseModel):
class ToolConfigurationManager(BaseModel):
tenant_id: str
provider_controller: ToolProviderController
@ -101,6 +104,128 @@ class ToolConfiguration(BaseModel):
)
cache.delete()
class ToolParameterConfigurationManager(BaseModel):
"""
Tool parameter configuration manager
"""
tenant_id: str
tool_runtime: Tool
provider_name: str
provider_type: str
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
deep copy parameters
"""
return {key: value for key, value in parameters.items()}
def _merge_parameters(self) -> list[ToolParameter]:
"""
merge parameters
"""
# get tool parameters
tool_parameters = self.tool_runtime.parameters or []
# get tool runtime parameters
runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
# override parameters
current_parameters = tool_parameters.copy()
for runtime_parameter in runtime_parameters:
found = False
for index, parameter in enumerate(current_parameters):
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
current_parameters[index] = runtime_parameter
found = True
break
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
return current_parameters
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
mask tool parameters
return a deep copy of parameters with masked values
"""
parameters = self._deep_copy(parameters)
# override parameters
current_parameters = self._merge_parameters()
for parameter in current_parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
if parameter.name in parameters:
if len(parameters[parameter.name]) > 6:
parameters[parameter.name] = \
parameters[parameter.name][:2] + \
'*' * (len(parameters[parameter.name]) - 4) +\
parameters[parameter.name][-2:]
else:
parameters[parameter.name] = '*' * len(parameters[parameter.name])
return parameters
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
encrypt tool parameters with tenant id
return a deep copy of parameters with encrypted values
"""
# override parameters
current_parameters = self._merge_parameters()
for parameter in current_parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
if parameter.name in parameters:
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
parameters[parameter.name] = encrypted
return parameters
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
"""
decrypt tool parameters with tenant id
return a deep copy of parameters with decrypted values
"""
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER
)
cached_parameters = cache.get()
if cached_parameters:
return cached_parameters
# override parameters
current_parameters = self._merge_parameters()
has_secret_input = False
for parameter in current_parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
if parameter.name in parameters:
try:
has_secret_input = True
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
except:
pass
if has_secret_input:
cache.set(parameters)
return parameters
def delete_tool_parameters_cache(self):
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER
)
cache.delete()
class ModelToolConfigurationManager:
"""
Model as tool configuration

View File

@ -17,7 +17,7 @@ from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidatio
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolConfiguration
from core.tools.utils.configuration import ToolConfigurationManager
from core.tools.utils.encoder import serialize_base_model_array, serialize_base_model_dict
from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db
@ -77,7 +77,7 @@ class ToolManageService:
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
tools = provider_controller.get_tools()
tool_provider_configurations = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
# check if user has added the provider
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
@ -279,7 +279,7 @@ class ToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# encrypt credentials
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)
db_provider.credentials_str = json.dumps(encrypted_credentials)
@ -366,7 +366,7 @@ class ToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials:
raise ValueError(f'provider {provider_name} does not need credentials')
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
# get original credentials if exists
if provider is not None:
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
@ -450,7 +450,7 @@ class ToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
@ -490,7 +490,7 @@ class ToolManageService:
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name)
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' }
@ -632,7 +632,7 @@ class ToolManageService:
# decrypt credentials
if db_provider.id:
tool_configuration = ToolConfiguration(
tool_configuration = ToolConfigurationManager(
tenant_id=tenant_id,
provider_controller=provider_controller
)