mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-11-30 10:18:13 +08:00
refactor(api/core/app/app_config/entities.py): Move Type to outside and add EXTERNAL_DATA_TOOL. (#7444)
This commit is contained in:
parent
e2d214e030
commit
a10b207de2
@ -1,6 +1,6 @@
|
||||
import re
|
||||
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity
|
||||
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ class BasicVariablesConfigManager:
|
||||
:param config: model config args
|
||||
"""
|
||||
external_data_variables = []
|
||||
variables = []
|
||||
variable_entities = []
|
||||
|
||||
# old external_data_tools
|
||||
external_data_tools = config.get('external_data_tools', [])
|
||||
@ -30,50 +30,41 @@ class BasicVariablesConfigManager:
|
||||
)
|
||||
|
||||
# variables and external_data_tools
|
||||
for variable in config.get('user_input_form', []):
|
||||
typ = list(variable.keys())[0]
|
||||
if typ == 'external_data_tool':
|
||||
val = variable[typ]
|
||||
if 'config' not in val:
|
||||
for variables in config.get('user_input_form', []):
|
||||
variable_type = list(variables.keys())[0]
|
||||
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
|
||||
variable = variables[variable_type]
|
||||
if 'config' not in variable:
|
||||
continue
|
||||
|
||||
external_data_variables.append(
|
||||
ExternalDataVariableEntity(
|
||||
variable=val['variable'],
|
||||
type=val['type'],
|
||||
config=val['config']
|
||||
variable=variable['variable'],
|
||||
type=variable['type'],
|
||||
config=variable['config']
|
||||
)
|
||||
)
|
||||
elif typ in [
|
||||
VariableEntity.Type.TEXT_INPUT.value,
|
||||
VariableEntity.Type.PARAGRAPH.value,
|
||||
VariableEntity.Type.NUMBER.value,
|
||||
elif variable_type in [
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
VariableEntityType.NUMBER,
|
||||
VariableEntityType.SELECT,
|
||||
]:
|
||||
variables.append(
|
||||
variable = variables[variable_type]
|
||||
variable_entities.append(
|
||||
VariableEntity(
|
||||
type=VariableEntity.Type.value_of(typ),
|
||||
variable=variable[typ].get('variable'),
|
||||
description=variable[typ].get('description'),
|
||||
label=variable[typ].get('label'),
|
||||
required=variable[typ].get('required', False),
|
||||
max_length=variable[typ].get('max_length'),
|
||||
default=variable[typ].get('default'),
|
||||
)
|
||||
)
|
||||
elif typ == VariableEntity.Type.SELECT.value:
|
||||
variables.append(
|
||||
VariableEntity(
|
||||
type=VariableEntity.Type.SELECT,
|
||||
variable=variable[typ].get('variable'),
|
||||
description=variable[typ].get('description'),
|
||||
label=variable[typ].get('label'),
|
||||
required=variable[typ].get('required', False),
|
||||
options=variable[typ].get('options'),
|
||||
default=variable[typ].get('default'),
|
||||
type=variable_type,
|
||||
variable=variable.get('variable'),
|
||||
description=variable.get('description'),
|
||||
label=variable.get('label'),
|
||||
required=variable.get('required', False),
|
||||
max_length=variable.get('max_length'),
|
||||
options=variable.get('options'),
|
||||
default=variable.get('default'),
|
||||
)
|
||||
)
|
||||
|
||||
return variables, external_data_variables
|
||||
return variable_entities, external_data_variables
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:
|
||||
@ -183,4 +174,4 @@ class BasicVariablesConfigManager:
|
||||
config=config
|
||||
)
|
||||
|
||||
return config, ["external_data_tools"]
|
||||
return config, ["external_data_tools"]
|
||||
|
@ -82,43 +82,29 @@ class PromptTemplateEntity(BaseModel):
|
||||
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
|
||||
|
||||
|
||||
class VariableEntityType(str, Enum):
|
||||
TEXT_INPUT = "text-input"
|
||||
SELECT = "select"
|
||||
PARAGRAPH = "paragraph"
|
||||
NUMBER = "number"
|
||||
EXTERNAL_DATA_TOOL = "external-data-tool"
|
||||
|
||||
|
||||
class VariableEntity(BaseModel):
|
||||
"""
|
||||
Variable Entity.
|
||||
"""
|
||||
class Type(Enum):
|
||||
TEXT_INPUT = 'text-input'
|
||||
SELECT = 'select'
|
||||
PARAGRAPH = 'paragraph'
|
||||
NUMBER = 'number'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'VariableEntity.Type':
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
:param value: mode value
|
||||
:return: mode
|
||||
"""
|
||||
for mode in cls:
|
||||
if mode.value == value:
|
||||
return mode
|
||||
raise ValueError(f'invalid variable type value {value}')
|
||||
|
||||
variable: str
|
||||
label: str
|
||||
description: Optional[str] = None
|
||||
type: Type
|
||||
type: VariableEntityType
|
||||
required: bool = False
|
||||
max_length: Optional[int] = None
|
||||
options: Optional[list[str]] = None
|
||||
default: Optional[str] = None
|
||||
hint: Optional[str] = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.variable
|
||||
|
||||
|
||||
class ExternalDataVariableEntity(BaseModel):
|
||||
"""
|
||||
@ -252,4 +238,4 @@ class WorkflowUIBasedAppConfig(AppConfig):
|
||||
"""
|
||||
Workflow UI Based App Config Entity.
|
||||
"""
|
||||
workflow_id: str
|
||||
workflow_id: str
|
||||
|
@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.app_config.entities import AppConfig, VariableEntity
|
||||
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
|
||||
|
||||
|
||||
class BaseAppGenerator:
|
||||
@ -9,29 +9,29 @@ class BaseAppGenerator:
|
||||
user_inputs = user_inputs or {}
|
||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||
variables = app_config.variables
|
||||
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
|
||||
return filtered_inputs
|
||||
|
||||
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
|
||||
user_input_value = inputs.get(var.name)
|
||||
user_input_value = inputs.get(var.variable)
|
||||
if var.required and not user_input_value:
|
||||
raise ValueError(f'{var.name} is required in input form')
|
||||
raise ValueError(f'{var.variable} is required in input form')
|
||||
if not var.required and not user_input_value:
|
||||
# TODO: should we return None here if the default value is None?
|
||||
return var.default or ''
|
||||
if (
|
||||
var.type
|
||||
in (
|
||||
VariableEntity.Type.TEXT_INPUT,
|
||||
VariableEntity.Type.SELECT,
|
||||
VariableEntity.Type.PARAGRAPH,
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
)
|
||||
and user_input_value
|
||||
and not isinstance(user_input_value, str)
|
||||
):
|
||||
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
|
||||
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
|
||||
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
|
||||
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if '.' in user_input_value:
|
||||
@ -39,14 +39,14 @@ class BaseAppGenerator:
|
||||
else:
|
||||
return int(user_input_value)
|
||||
except ValueError:
|
||||
raise ValueError(f"{var.name} in input form must be a valid number")
|
||||
if var.type == VariableEntity.Type.SELECT:
|
||||
raise ValueError(f"{var.variable} in input form must be a valid number")
|
||||
if var.type == VariableEntityType.SELECT:
|
||||
options = var.options or []
|
||||
if user_input_value not in options:
|
||||
raise ValueError(f'{var.name} in input form must be one of the following: {options}')
|
||||
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
|
||||
raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
|
||||
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
|
||||
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
||||
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
|
||||
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
|
||||
|
||||
return user_input_value
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import (
|
||||
@ -18,6 +18,13 @@ from models.model import App, AppMode
|
||||
from models.tools import WorkflowToolProvider
|
||||
from models.workflow import Workflow
|
||||
|
||||
VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
|
||||
VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING,
|
||||
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
|
||||
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
|
||||
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
|
||||
}
|
||||
|
||||
|
||||
class WorkflowToolProviderController(ToolProviderController):
|
||||
provider_id: str
|
||||
@ -28,7 +35,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
|
||||
if not app:
|
||||
raise ValueError('app not found')
|
||||
|
||||
|
||||
controller = WorkflowToolProviderController(**{
|
||||
'identity': {
|
||||
'author': db_provider.user.name if db_provider.user_id and db_provider.user else '',
|
||||
@ -46,7 +53,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
'credentials_schema': {},
|
||||
'provider_id': db_provider.id or '',
|
||||
})
|
||||
|
||||
|
||||
# init tools
|
||||
|
||||
controller.tools = [controller._get_db_provider_tool(db_provider, app)]
|
||||
@ -56,7 +63,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.WORKFLOW
|
||||
|
||||
|
||||
def _get_db_provider_tool(self, db_provider: WorkflowToolProvider, app: App) -> WorkflowTool:
|
||||
"""
|
||||
get db provider tool
|
||||
@ -93,23 +100,11 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
if variable:
|
||||
parameter_type = None
|
||||
options = None
|
||||
if variable.type in [
|
||||
VariableEntity.Type.TEXT_INPUT,
|
||||
VariableEntity.Type.PARAGRAPH,
|
||||
]:
|
||||
parameter_type = ToolParameter.ToolParameterType.STRING
|
||||
elif variable.type in [
|
||||
VariableEntity.Type.SELECT
|
||||
]:
|
||||
parameter_type = ToolParameter.ToolParameterType.SELECT
|
||||
elif variable.type in [
|
||||
VariableEntity.Type.NUMBER
|
||||
]:
|
||||
parameter_type = ToolParameter.ToolParameterType.NUMBER
|
||||
else:
|
||||
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
|
||||
raise ValueError(f'unsupported variable type {variable.type}')
|
||||
|
||||
if variable.type == VariableEntity.Type.SELECT and variable.options:
|
||||
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
|
||||
|
||||
if variable.type == VariableEntityType.SELECT and variable.options:
|
||||
options = [
|
||||
ToolParameterOption(
|
||||
value=option,
|
||||
@ -200,7 +195,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
"""
|
||||
if self.tools is not None:
|
||||
return self.tools
|
||||
|
||||
|
||||
db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == self.provider_id,
|
||||
@ -208,11 +203,11 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
|
||||
if not db_providers:
|
||||
return []
|
||||
|
||||
|
||||
self.tools = [self._get_db_provider_tool(db_providers, db_providers.app)]
|
||||
|
||||
return self.tools
|
||||
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[WorkflowTool]:
|
||||
"""
|
||||
get tool by name
|
||||
@ -226,5 +221,5 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||
for tool in self.tools:
|
||||
if tool.identity.name == tool_name:
|
||||
return tool
|
||||
|
||||
|
||||
return None
|
||||
|
@ -1,3 +1,7 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
@ -6,4 +10,4 @@ class StartNodeData(BaseNodeData):
|
||||
"""
|
||||
Start Node Data
|
||||
"""
|
||||
variables: list[VariableEntity] = []
|
||||
variables: Sequence[VariableEntity] = Field(default_factory=list)
|
||||
|
@ -14,6 +14,7 @@ from core.app.app_config.entities import (
|
||||
ModelConfigEntity,
|
||||
PromptTemplateEntity,
|
||||
VariableEntity,
|
||||
VariableEntityType,
|
||||
)
|
||||
from core.helper import encrypter
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
@ -25,23 +26,24 @@ from services.workflow.workflow_converter import WorkflowConverter
|
||||
|
||||
@pytest.fixture
|
||||
def default_variables():
|
||||
return [
|
||||
value = [
|
||||
VariableEntity(
|
||||
variable="text_input",
|
||||
label="text-input",
|
||||
type=VariableEntity.Type.TEXT_INPUT
|
||||
type=VariableEntityType.TEXT_INPUT,
|
||||
),
|
||||
VariableEntity(
|
||||
variable="paragraph",
|
||||
label="paragraph",
|
||||
type=VariableEntity.Type.PARAGRAPH
|
||||
type=VariableEntityType.PARAGRAPH,
|
||||
),
|
||||
VariableEntity(
|
||||
variable="select",
|
||||
label="select",
|
||||
type=VariableEntity.Type.SELECT
|
||||
)
|
||||
type=VariableEntityType.SELECT,
|
||||
),
|
||||
]
|
||||
return value
|
||||
|
||||
|
||||
def test__convert_to_start_node(default_variables):
|
||||
|
Loading…
Reference in New Issue
Block a user