mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-11-29 17:58:19 +08:00
Feat/tool secret parameter (#2760)
This commit is contained in:
parent
bbc0d330a9
commit
ce58f0607b
@ -27,7 +27,9 @@ from fields.app_fields import (
|
||||
from libs.login import login_required
|
||||
from models.model import App, AppModelConfig, Site
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.entities.application_entities import AgentToolEntity
|
||||
|
||||
def _get_app(app_id, tenant_id):
|
||||
app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first()
|
||||
@ -236,7 +238,39 @@ class AppApi(Resource):
|
||||
def get(self, app_id):
|
||||
"""Get app detail"""
|
||||
app_id = str(app_id)
|
||||
app = _get_app(app_id, current_user.current_tenant_id)
|
||||
app: App = _get_app(app_id, current_user.current_tenant_id)
|
||||
|
||||
# get original app model config
|
||||
model_config: AppModelConfig = app.app_model_config
|
||||
agent_mode = model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
# get tool
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
|
||||
# get decrypted parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
masked_parameter = manager.mask_tool_parameters(parameters or {})
|
||||
else:
|
||||
masked_parameter = {}
|
||||
|
||||
# override tool parameters
|
||||
tool['tool_parameters'] = masked_parameter
|
||||
|
||||
# override agent mode
|
||||
model_config.agent_mode = json.dumps(agent_mode)
|
||||
|
||||
return app
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
@ -7,6 +8,9 @@ from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.entities.application_entities import AgentToolEntity
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolParameterConfigurationManager
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from libs.login import login_required
|
||||
@ -38,6 +42,82 @@ class ModelConfigResource(Resource):
|
||||
)
|
||||
new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration)
|
||||
|
||||
# get original app model config
|
||||
original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == app.app_model_config_id
|
||||
).first()
|
||||
agent_mode = original_app_model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
parameter_map = {}
|
||||
masked_parameter_map = {}
|
||||
tool_map = {}
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
# get tool
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
|
||||
# get decrypted parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
masked_parameter = manager.mask_tool_parameters(parameters or {})
|
||||
else:
|
||||
parameters = {}
|
||||
masked_parameter = {}
|
||||
|
||||
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
|
||||
masked_parameter_map[key] = masked_parameter
|
||||
parameter_map[key] = parameters
|
||||
tool_map[key] = tool_runtime
|
||||
|
||||
# encrypt agent tool parameters if it's secret-input
|
||||
agent_mode = new_app_model_config.agent_mode_dict
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
|
||||
# get tool
|
||||
key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}'
|
||||
if key in tool_map:
|
||||
tool_runtime = tool_map[key]
|
||||
else:
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
)
|
||||
manager.delete_tool_parameters_cache()
|
||||
|
||||
# override parameters if it equals to masked parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
if key not in masked_parameter_map:
|
||||
continue
|
||||
|
||||
if agent_tool_entity.tool_parameters == masked_parameter_map[key]:
|
||||
agent_tool_entity.tool_parameters = parameter_map[key]
|
||||
|
||||
# encrypt parameters
|
||||
if agent_tool_entity.tool_parameters:
|
||||
tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {})
|
||||
|
||||
# update app model config
|
||||
new_app_model_config.agent_mode = json.dumps(agent_mode)
|
||||
|
||||
db.session.add(new_app_model_config)
|
||||
db.session.flush()
|
||||
|
||||
|
@ -154,9 +154,9 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
"""
|
||||
convert tool to prompt message tool
|
||||
"""
|
||||
tool_entity = ToolManager.get_tool_runtime(
|
||||
provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name,
|
||||
tenant_id=self.application_generate_entity.tenant_id,
|
||||
tool_entity = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=self.tenant_id,
|
||||
agent_tool=tool,
|
||||
agent_callback=self.agent_callback
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
@ -171,33 +171,11 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
}
|
||||
)
|
||||
|
||||
runtime_parameters = {}
|
||||
|
||||
parameters = tool_entity.parameters or []
|
||||
user_parameters = tool_entity.get_runtime_parameters() or []
|
||||
|
||||
# override parameters
|
||||
for parameter in user_parameters:
|
||||
# check if parameter in tool parameters
|
||||
found = False
|
||||
for tool_parameter in parameters:
|
||||
if tool_parameter.name == parameter.name:
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
# override parameter
|
||||
tool_parameter.type = parameter.type
|
||||
tool_parameter.form = parameter.form
|
||||
tool_parameter.required = parameter.required
|
||||
tool_parameter.default = parameter.default
|
||||
tool_parameter.options = parameter.options
|
||||
tool_parameter.llm_description = parameter.llm_description
|
||||
else:
|
||||
# add new parameter
|
||||
parameters.append(parameter)
|
||||
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = 'string'
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.STRING:
|
||||
@ -213,59 +191,16 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# get tool parameter from form
|
||||
tool_parameter_config = tool.tool_parameters.get(parameter.name)
|
||||
if not tool_parameter_config:
|
||||
# get default value
|
||||
tool_parameter_config = parameter.default
|
||||
if not tool_parameter_config and parameter.required:
|
||||
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
|
||||
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = list(map(lambda x: x.value, parameter.options))
|
||||
if tool_parameter_config not in options:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
|
||||
|
||||
# convert tool parameter config to correct type
|
||||
try:
|
||||
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
# check if tool parameter is integer
|
||||
if isinstance(tool_parameter_config, int):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, float):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, str):
|
||||
if '.' in tool_parameter_config:
|
||||
tool_parameter_config = float(tool_parameter_config)
|
||||
else:
|
||||
tool_parameter_config = int(tool_parameter_config)
|
||||
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
tool_parameter_config = bool(tool_parameter_config)
|
||||
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
elif parameter.type == ToolParameter.ToolParameterType:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
|
||||
|
||||
# save tool parameter to tool entity memory
|
||||
runtime_parameters[parameter.name] = tool_parameter_config
|
||||
|
||||
elif parameter.form == ToolParameter.ToolParameterForm.LLM:
|
||||
message_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
message_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
if parameter.required:
|
||||
message_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
return message_tool, tool_entity
|
||||
|
||||
@ -305,6 +240,9 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
tool_runtime_parameters = tool.get_runtime_parameters() or []
|
||||
|
||||
for parameter in tool_runtime_parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = 'string'
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.STRING:
|
||||
@ -320,18 +258,17 @@ class BaseAssistantApplicationRunner(AppRunner):
|
||||
else:
|
||||
raise ValueError(f"parameter type {parameter.type} is not supported")
|
||||
|
||||
if parameter.form == ToolParameter.ToolParameterForm.LLM:
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
prompt_tool.parameters['properties'][parameter.name] = {
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or '',
|
||||
}
|
||||
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
if len(enum) > 0:
|
||||
prompt_tool.parameters['properties'][parameter.name]['enum'] = enum
|
||||
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
if parameter.required:
|
||||
if parameter.name not in prompt_tool.parameters['required']:
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
|
54
api/core/helper/tool_parameter_cache.py
Normal file
54
api/core/helper/tool_parameter_cache.py
Normal file
@ -0,0 +1,54 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class ToolParameterCacheType(Enum):
|
||||
PARAMETER = "tool_parameter"
|
||||
|
||||
class ToolParameterCache:
|
||||
def __init__(self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
tool_name: str,
|
||||
cache_type: ToolParameterCacheType
|
||||
):
|
||||
self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}"
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""
|
||||
Get cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
cached_tool_parameter = redis_client.get(self.cache_key)
|
||||
if cached_tool_parameter:
|
||||
try:
|
||||
cached_tool_parameter = cached_tool_parameter.decode('utf-8')
|
||||
cached_tool_parameter = json.loads(cached_tool_parameter)
|
||||
except JSONDecodeError:
|
||||
return None
|
||||
|
||||
return cached_tool_parameter
|
||||
else:
|
||||
return None
|
||||
|
||||
def set(self, parameters: dict) -> None:
|
||||
"""
|
||||
Cache model provider credentials.
|
||||
|
||||
:param credentials: provider credentials
|
||||
:return:
|
||||
"""
|
||||
redis_client.setex(self.cache_key, 86400, json.dumps(parameters))
|
||||
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Delete cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
redis_client.delete(self.cache_key)
|
@ -119,7 +119,7 @@ parameters: # Parameter list
|
||||
- The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc.
|
||||
- `parameters` Parameter list
|
||||
- `name` Parameter name, unique, no duplication with other parameters
|
||||
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select` four types, corresponding to string, number, boolean, drop-down box
|
||||
- `type` Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` four types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using `secret-input` type
|
||||
- `required` Required or not
|
||||
- In `llm` mode, if the parameter is required, the Agent is required to infer this parameter
|
||||
- In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts
|
||||
|
@ -119,7 +119,7 @@ parameters: # 参数列表
|
||||
- `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等
|
||||
- `parameters` 参数列表
|
||||
- `name` 参数名称,唯一,不允许和其他参数重名
|
||||
- `type` 参数类型,目前支持`string`、`number`、`boolean`、`select` 四种类型,分别对应字符串、数字、布尔值、下拉框
|
||||
- `type` 参数类型,目前支持`string`、`number`、`boolean`、`select`、`secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型
|
||||
- `required` 是否必填
|
||||
- 在`llm`模式下,如果参数为必填,则会要求Agent必须要推理出这个参数
|
||||
- 在`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数
|
||||
|
@ -100,6 +100,7 @@ class ToolParameter(BaseModel):
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
SELECT = "select"
|
||||
SECRET_INPUT = "secret-input"
|
||||
|
||||
class ToolParameterForm(Enum):
|
||||
SCHEMA = "schema" # should be set while adding tool
|
||||
|
@ -23,6 +23,8 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
_api_base_url = URL('https://co.aippt.cn/api')
|
||||
_api_token_cache = {}
|
||||
_api_token_cache_lock = Lock()
|
||||
_style_cache = {}
|
||||
_style_cache_lock = Lock()
|
||||
|
||||
_task = {}
|
||||
_task_type_map = {
|
||||
@ -390,20 +392,31 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
).digest()
|
||||
).decode('utf-8')
|
||||
|
||||
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
|
||||
@classmethod
|
||||
def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]:
|
||||
"""
|
||||
Get styles
|
||||
|
||||
:param credentials: the credentials
|
||||
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
|
||||
"""
|
||||
|
||||
# check cache
|
||||
with cls._style_cache_lock:
|
||||
# clear expired styles
|
||||
now = time()
|
||||
for key in list(cls._style_cache.keys()):
|
||||
if cls._style_cache[key]['expire'] < now:
|
||||
del cls._style_cache[key]
|
||||
|
||||
key = f'{credentials["aippt_access_key"]}#@#{user_id}'
|
||||
if key in cls._style_cache:
|
||||
return cls._style_cache[key]['colors'], cls._style_cache[key]['styles']
|
||||
|
||||
headers = {
|
||||
'x-channel': '',
|
||||
'x-api-key': self.runtime.credentials['aippt_access_key'],
|
||||
'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id)
|
||||
'x-api-key': credentials['aippt_access_key'],
|
||||
'x-token': cls._get_api_token(credentials=credentials, user_id=user_id)
|
||||
}
|
||||
response = get(
|
||||
str(self._api_base_url / 'template_component' / 'suit' / 'select'),
|
||||
str(cls._api_base_url / 'template_component' / 'suit' / 'select'),
|
||||
headers=headers
|
||||
)
|
||||
|
||||
@ -425,7 +438,26 @@ class AIPPTGenerateTool(BuiltinTool):
|
||||
'name': item.get('title'),
|
||||
} for item in response.get('data', {}).get('suit_style') or []]
|
||||
|
||||
with cls._style_cache_lock:
|
||||
cls._style_cache[key] = {
|
||||
'colors': colors,
|
||||
'styles': styles,
|
||||
'expire': now + 60 * 60
|
||||
}
|
||||
|
||||
return colors, styles
|
||||
|
||||
def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]:
|
||||
"""
|
||||
Get styles
|
||||
|
||||
:param credentials: the credentials
|
||||
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
|
||||
"""
|
||||
if not self.runtime.credentials.get('aippt_access_key') or not self.runtime.credentials.get('aippt_secret_key'):
|
||||
return [], []
|
||||
|
||||
return self._get_styles(credentials=self.runtime.credentials, user_id=user_id)
|
||||
|
||||
def _get_suit(self, style_id: int, colour_id: int) -> int:
|
||||
"""
|
||||
|
@ -14,7 +14,7 @@ description:
|
||||
llm: A tool for sending messages to a chat group on Wecom(企业微信) .
|
||||
parameters:
|
||||
- name: hook_key
|
||||
type: string
|
||||
type: secret-input
|
||||
required: true
|
||||
label:
|
||||
en_US: Wecom Group bot webhook key
|
||||
|
@ -266,6 +266,40 @@ class Tool(BaseModel, ABC):
|
||||
"""
|
||||
return self.parameters
|
||||
|
||||
def get_all_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
get all runtime parameters
|
||||
|
||||
:return: all runtime parameters
|
||||
"""
|
||||
parameters = self.parameters or []
|
||||
parameters = parameters.copy()
|
||||
user_parameters = self.get_runtime_parameters() or []
|
||||
user_parameters = user_parameters.copy()
|
||||
|
||||
# override parameters
|
||||
for parameter in user_parameters:
|
||||
# check if parameter in tool parameters
|
||||
found = False
|
||||
for tool_parameter in parameters:
|
||||
if tool_parameter.name == parameter.name:
|
||||
found = True
|
||||
break
|
||||
|
||||
if found:
|
||||
# override parameter
|
||||
tool_parameter.type = parameter.type
|
||||
tool_parameter.form = parameter.form
|
||||
tool_parameter.required = parameter.required
|
||||
tool_parameter.default = parameter.default
|
||||
tool_parameter.options = parameter.options
|
||||
tool_parameter.llm_description = parameter.llm_description
|
||||
else:
|
||||
# add new parameter
|
||||
parameters.append(parameter)
|
||||
|
||||
return parameters
|
||||
|
||||
def is_tool_available(self) -> bool:
|
||||
"""
|
||||
check if the tool is available
|
||||
|
@ -6,11 +6,17 @@ from os import listdir, path
|
||||
from typing import Any, Union
|
||||
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.entities.application_entities import AgentToolEntity
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.constant import DEFAULT_PROVIDERS
|
||||
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeMessage, ToolProviderCredentials
|
||||
from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
)
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
|
||||
@ -21,7 +27,12 @@ from core.tools.provider.model_tool_provider import ModelToolProviderController
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.utils.configuration import ModelToolConfigurationManager, ToolConfiguration
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.configuration import (
|
||||
ModelToolConfigurationManager,
|
||||
ToolConfigurationManager,
|
||||
ToolParameterConfigurationManager,
|
||||
)
|
||||
from core.tools.utils.encoder import serialize_base_model_dict
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider
|
||||
@ -172,7 +183,7 @@ class ToolManager:
|
||||
# decrypt the credentials
|
||||
credentials = builtin_provider.credentials
|
||||
controller = ToolManager.get_builtin_provider(provider_name)
|
||||
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
|
||||
@ -189,7 +200,7 @@ class ToolManager:
|
||||
api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name)
|
||||
|
||||
# decrypt the credentials
|
||||
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=api_provider)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
|
||||
return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
|
||||
@ -214,6 +225,71 @@ class ToolManager:
|
||||
else:
|
||||
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
|
||||
|
||||
@staticmethod
|
||||
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
"""
|
||||
tool_entity = ToolManager.get_tool_runtime(
|
||||
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name,
|
||||
tenant_id=tenant_id,
|
||||
agent_callback=agent_callback
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# get tool parameter from form
|
||||
tool_parameter_config = agent_tool.tool_parameters.get(parameter.name)
|
||||
if not tool_parameter_config:
|
||||
# get default value
|
||||
tool_parameter_config = parameter.default
|
||||
if not tool_parameter_config and parameter.required:
|
||||
raise ValueError(f"tool parameter {parameter.name} not found in tool config")
|
||||
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
# check if tool_parameter_config in options
|
||||
options = list(map(lambda x: x.value, parameter.options))
|
||||
if tool_parameter_config not in options:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}")
|
||||
|
||||
# convert tool parameter config to correct type
|
||||
try:
|
||||
if parameter.type == ToolParameter.ToolParameterType.NUMBER:
|
||||
# check if tool parameter is integer
|
||||
if isinstance(tool_parameter_config, int):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, float):
|
||||
tool_parameter_config = tool_parameter_config
|
||||
elif isinstance(tool_parameter_config, str):
|
||||
if '.' in tool_parameter_config:
|
||||
tool_parameter_config = float(tool_parameter_config)
|
||||
else:
|
||||
tool_parameter_config = int(tool_parameter_config)
|
||||
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
tool_parameter_config = bool(tool_parameter_config)
|
||||
elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
elif parameter.type == ToolParameter.ToolParameterType:
|
||||
tool_parameter_config = str(tool_parameter_config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type")
|
||||
|
||||
# save tool parameter to tool entity memory
|
||||
runtime_parameters[parameter.name] = tool_parameter_config
|
||||
|
||||
# decrypt runtime parameters
|
||||
encryption_manager = ToolParameterConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
tool_runtime=tool_entity,
|
||||
provider_name=agent_tool.provider_id,
|
||||
provider_type=agent_tool.provider_type,
|
||||
)
|
||||
runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
|
||||
|
||||
tool_entity.runtime.runtime_parameters.update(runtime_parameters)
|
||||
return tool_entity
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_provider_icon(provider: str) -> tuple[str, str]:
|
||||
"""
|
||||
@ -396,7 +472,7 @@ class ToolManager:
|
||||
controller = ToolManager.get_builtin_provider(provider_name)
|
||||
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
|
||||
@ -463,7 +539,7 @@ class ToolManager:
|
||||
)
|
||||
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
||||
@ -523,7 +599,7 @@ class ToolManager:
|
||||
provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
|
||||
)
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
|
||||
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
|
@ -5,16 +5,19 @@ from pydantic import BaseModel
|
||||
from yaml import FullLoader, load
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
||||
from core.tools.entities.tool_entities import (
|
||||
ModelToolConfiguration,
|
||||
ModelToolProviderConfiguration,
|
||||
ToolParameter,
|
||||
ToolProviderCredentials,
|
||||
)
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
|
||||
class ToolConfiguration(BaseModel):
|
||||
class ToolConfigurationManager(BaseModel):
|
||||
tenant_id: str
|
||||
provider_controller: ToolProviderController
|
||||
|
||||
@ -101,6 +104,128 @@ class ToolConfiguration(BaseModel):
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
class ToolParameterConfigurationManager(BaseModel):
|
||||
"""
|
||||
Tool parameter configuration manager
|
||||
"""
|
||||
tenant_id: str
|
||||
tool_runtime: Tool
|
||||
provider_name: str
|
||||
provider_type: str
|
||||
|
||||
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
deep copy parameters
|
||||
"""
|
||||
return {key: value for key, value in parameters.items()}
|
||||
|
||||
def _merge_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
merge parameters
|
||||
"""
|
||||
# get tool parameters
|
||||
tool_parameters = self.tool_runtime.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
|
||||
# override parameters
|
||||
current_parameters = tool_parameters.copy()
|
||||
for runtime_parameter in runtime_parameters:
|
||||
found = False
|
||||
for index, parameter in enumerate(current_parameters):
|
||||
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||
current_parameters[index] = runtime_parameter
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
return current_parameters
|
||||
|
||||
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool parameters
|
||||
|
||||
return a deep copy of parameters with masked values
|
||||
"""
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
for parameter in current_parameters:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
|
||||
if parameter.name in parameters:
|
||||
if len(parameters[parameter.name]) > 6:
|
||||
parameters[parameter.name] = \
|
||||
parameters[parameter.name][:2] + \
|
||||
'*' * (len(parameters[parameter.name]) - 4) +\
|
||||
parameters[parameter.name][-2:]
|
||||
else:
|
||||
parameters[parameter.name] = '*' * len(parameters[parameter.name])
|
||||
|
||||
return parameters
|
||||
|
||||
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
encrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with encrypted values
|
||||
"""
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
for parameter in current_parameters:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
|
||||
if parameter.name in parameters:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
parameters[parameter.name] = encrypted
|
||||
|
||||
return parameters
|
||||
|
||||
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with decrypted values
|
||||
"""
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f'{self.provider_type}.{self.provider_name}',
|
||||
tool_name=self.tool_runtime.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER
|
||||
)
|
||||
cached_parameters = cache.get()
|
||||
if cached_parameters:
|
||||
return cached_parameters
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
has_secret_input = False
|
||||
|
||||
for parameter in current_parameters:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
|
||||
if parameter.name in parameters:
|
||||
try:
|
||||
has_secret_input = True
|
||||
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
except:
|
||||
pass
|
||||
|
||||
if has_secret_input:
|
||||
cache.set(parameters)
|
||||
|
||||
return parameters
|
||||
|
||||
def delete_tool_parameters_cache(self):
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f'{self.provider_type}.{self.provider_name}',
|
||||
tool_name=self.tool_runtime.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
class ModelToolConfigurationManager:
|
||||
"""
|
||||
Model as tool configuration
|
||||
|
@ -17,7 +17,7 @@ from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidatio
|
||||
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ToolConfiguration
|
||||
from core.tools.utils.configuration import ToolConfigurationManager
|
||||
from core.tools.utils.encoder import serialize_base_model_array, serialize_base_model_dict
|
||||
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||
from extensions.ext_database import db
|
||||
@ -77,7 +77,7 @@ class ToolManageService:
|
||||
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
|
||||
tools = provider_controller.get_tools()
|
||||
|
||||
tool_provider_configurations = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
# check if user has added the provider
|
||||
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
@ -279,7 +279,7 @@ class ToolManageService:
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# encrypt credentials
|
||||
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)
|
||||
db_provider.credentials_str = json.dumps(encrypted_credentials)
|
||||
|
||||
@ -366,7 +366,7 @@ class ToolManageService:
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f'provider {provider_name} does not need credentials')
|
||||
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
# get original credentials if exists
|
||||
if provider is not None:
|
||||
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
@ -450,7 +450,7 @@ class ToolManageService:
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# get original credentials if exists
|
||||
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
|
||||
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
||||
@ -490,7 +490,7 @@ class ToolManageService:
|
||||
|
||||
# delete cache
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return { 'result': 'success' }
|
||||
@ -632,7 +632,7 @@ class ToolManageService:
|
||||
|
||||
# decrypt credentials
|
||||
if db_provider.id:
|
||||
tool_configuration = ToolConfiguration(
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
provider_controller=provider_controller
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user