refactor(api/core/app/app_config/entities.py): Move Type to outside and add EXTERNAL_DATA_TOOL. (#7444)

This commit is contained in:
-LAN- 2024-08-20 17:30:14 +08:00 committed by GitHub
parent e2d214e030
commit a10b207de2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 82 additions and 104 deletions

View File

@ -1,6 +1,6 @@
import re
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity
from core.app.app_config.entities import ExternalDataVariableEntity, VariableEntity, VariableEntityType
from core.external_data_tool.factory import ExternalDataToolFactory
@ -13,7 +13,7 @@ class BasicVariablesConfigManager:
:param config: model config args
"""
external_data_variables = []
variables = []
variable_entities = []
# old external_data_tools
external_data_tools = config.get('external_data_tools', [])
@ -30,50 +30,41 @@ class BasicVariablesConfigManager:
)
# variables and external_data_tools
for variable in config.get('user_input_form', []):
typ = list(variable.keys())[0]
if typ == 'external_data_tool':
val = variable[typ]
if 'config' not in val:
for variables in config.get('user_input_form', []):
variable_type = list(variables.keys())[0]
if variable_type == VariableEntityType.EXTERNAL_DATA_TOOL:
variable = variables[variable_type]
if 'config' not in variable:
continue
external_data_variables.append(
ExternalDataVariableEntity(
variable=val['variable'],
type=val['type'],
config=val['config']
variable=variable['variable'],
type=variable['type'],
config=variable['config']
)
)
elif typ in [
VariableEntity.Type.TEXT_INPUT.value,
VariableEntity.Type.PARAGRAPH.value,
VariableEntity.Type.NUMBER.value,
elif variable_type in [
VariableEntityType.TEXT_INPUT,
VariableEntityType.PARAGRAPH,
VariableEntityType.NUMBER,
VariableEntityType.SELECT,
]:
variables.append(
variable = variables[variable_type]
variable_entities.append(
VariableEntity(
type=VariableEntity.Type.value_of(typ),
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
max_length=variable[typ].get('max_length'),
default=variable[typ].get('default'),
)
)
elif typ == VariableEntity.Type.SELECT.value:
variables.append(
VariableEntity(
type=VariableEntity.Type.SELECT,
variable=variable[typ].get('variable'),
description=variable[typ].get('description'),
label=variable[typ].get('label'),
required=variable[typ].get('required', False),
options=variable[typ].get('options'),
default=variable[typ].get('default'),
type=variable_type,
variable=variable.get('variable'),
description=variable.get('description'),
label=variable.get('label'),
required=variable.get('required', False),
max_length=variable.get('max_length'),
options=variable.get('options'),
default=variable.get('default'),
)
)
return variables, external_data_variables
return variable_entities, external_data_variables
@classmethod
def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]:

View File

@ -82,43 +82,29 @@ class PromptTemplateEntity(BaseModel):
advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None
class VariableEntityType(str, Enum):
TEXT_INPUT = "text-input"
SELECT = "select"
PARAGRAPH = "paragraph"
NUMBER = "number"
EXTERNAL_DATA_TOOL = "external-data-tool"
class VariableEntity(BaseModel):
"""
Variable Entity.
"""
class Type(Enum):
TEXT_INPUT = 'text-input'
SELECT = 'select'
PARAGRAPH = 'paragraph'
NUMBER = 'number'
@classmethod
def value_of(cls, value: str) -> 'VariableEntity.Type':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid variable type value {value}')
variable: str
label: str
description: Optional[str] = None
type: Type
type: VariableEntityType
required: bool = False
max_length: Optional[int] = None
options: Optional[list[str]] = None
default: Optional[str] = None
hint: Optional[str] = None
@property
def name(self) -> str:
return self.variable
class ExternalDataVariableEntity(BaseModel):
"""

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping
from typing import Any, Optional
from core.app.app_config.entities import AppConfig, VariableEntity
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
class BaseAppGenerator:
@ -9,29 +9,29 @@ class BaseAppGenerator:
user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values
variables = app_config.variables
filtered_inputs = {var.name: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
return filtered_inputs
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
user_input_value = inputs.get(var.name)
user_input_value = inputs.get(var.variable)
if var.required and not user_input_value:
raise ValueError(f'{var.name} is required in input form')
raise ValueError(f'{var.variable} is required in input form')
if not var.required and not user_input_value:
# TODO: should we return None here if the default value is None?
return var.default or ''
if (
var.type
in (
VariableEntity.Type.TEXT_INPUT,
VariableEntity.Type.SELECT,
VariableEntity.Type.PARAGRAPH,
VariableEntityType.TEXT_INPUT,
VariableEntityType.SELECT,
VariableEntityType.PARAGRAPH,
)
and user_input_value
and not isinstance(user_input_value, str)
):
raise ValueError(f"(type '{var.type}') {var.name} in input form must be a string")
if var.type == VariableEntity.Type.NUMBER and isinstance(user_input_value, str):
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
# may raise ValueError if user_input_value is not a valid number
try:
if '.' in user_input_value:
@ -39,14 +39,14 @@ class BaseAppGenerator:
else:
return int(user_input_value)
except ValueError:
raise ValueError(f"{var.name} in input form must be a valid number")
if var.type == VariableEntity.Type.SELECT:
raise ValueError(f"{var.variable} in input form must be a valid number")
if var.type == VariableEntityType.SELECT:
options = var.options or []
if user_input_value not in options:
raise ValueError(f'{var.name} in input form must be one of the following: {options}')
elif var.type in (VariableEntity.Type.TEXT_INPUT, VariableEntity.Type.PARAGRAPH):
raise ValueError(f'{var.variable} in input form must be one of the following: {options}')
elif var.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
raise ValueError(f'{var.name} in input form must be less than {var.max_length} characters')
raise ValueError(f'{var.variable} in input form must be less than {var.max_length} characters')
return user_input_value

View File

@ -1,6 +1,6 @@
from typing import Optional
from core.app.app_config.entities import VariableEntity
from core.app.app_config.entities import VariableEntity, VariableEntityType
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
@ -18,6 +18,13 @@ from models.model import App, AppMode
from models.tools import WorkflowToolProvider
from models.workflow import Workflow
VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
VariableEntityType.TEXT_INPUT: ToolParameter.ToolParameterType.STRING,
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
}
class WorkflowToolProviderController(ToolProviderController):
provider_id: str
@ -93,23 +100,11 @@ class WorkflowToolProviderController(ToolProviderController):
if variable:
parameter_type = None
options = None
if variable.type in [
VariableEntity.Type.TEXT_INPUT,
VariableEntity.Type.PARAGRAPH,
]:
parameter_type = ToolParameter.ToolParameterType.STRING
elif variable.type in [
VariableEntity.Type.SELECT
]:
parameter_type = ToolParameter.ToolParameterType.SELECT
elif variable.type in [
VariableEntity.Type.NUMBER
]:
parameter_type = ToolParameter.ToolParameterType.NUMBER
else:
if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
raise ValueError(f'unsupported variable type {variable.type}')
parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
if variable.type == VariableEntity.Type.SELECT and variable.options:
if variable.type == VariableEntityType.SELECT and variable.options:
options = [
ToolParameterOption(
value=option,

View File

@ -1,3 +1,7 @@
from collections.abc import Sequence
from pydantic import Field
from core.app.app_config.entities import VariableEntity
from core.workflow.entities.base_node_data_entities import BaseNodeData
@ -6,4 +10,4 @@ class StartNodeData(BaseNodeData):
"""
Start Node Data
"""
variables: list[VariableEntity] = []
variables: Sequence[VariableEntity] = Field(default_factory=list)

View File

@ -14,6 +14,7 @@ from core.app.app_config.entities import (
ModelConfigEntity,
PromptTemplateEntity,
VariableEntity,
VariableEntityType,
)
from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode
@ -25,23 +26,24 @@ from services.workflow.workflow_converter import WorkflowConverter
@pytest.fixture
def default_variables():
return [
value = [
VariableEntity(
variable="text_input",
label="text-input",
type=VariableEntity.Type.TEXT_INPUT
type=VariableEntityType.TEXT_INPUT,
),
VariableEntity(
variable="paragraph",
label="paragraph",
type=VariableEntity.Type.PARAGRAPH
type=VariableEntityType.PARAGRAPH,
),
VariableEntity(
variable="select",
label="select",
type=VariableEntity.Type.SELECT
)
type=VariableEntityType.SELECT,
),
]
return value
def test__convert_to_start_node(default_variables):