diff --git a/api/core/workflow/nodes/list_operator/exc.py b/api/core/workflow/nodes/list_operator/exc.py new file mode 100644 index 000000000..f88aa0be2 --- /dev/null +++ b/api/core/workflow/nodes/list_operator/exc.py @@ -0,0 +1,16 @@ +class ListOperatorError(ValueError): + """Base class for all ListOperator errors.""" + + pass + + +class InvalidFilterValueError(ListOperatorError): + pass + + +class InvalidKeyError(ListOperatorError): + pass + + +class InvalidConditionError(ListOperatorError): + pass diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index d7e4c6431..6053a15d9 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Sequence -from typing import Literal +from typing import Literal, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment @@ -9,6 +9,7 @@ from core.workflow.nodes.enums import NodeType from models.workflow import WorkflowNodeExecutionStatus from .entities import ListOperatorNodeData +from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError class ListOperatorNode(BaseNode[ListOperatorNodeData]): @@ -26,7 +27,17 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs ) - if variable.value and not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): + if not variable.value: + inputs = {"variable": []} + process_data = {"variable": []} + outputs = {"result": [], "first_record": None, "last_record": None} + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs=outputs, + ) + if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): error_message = ( f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " "or ArrayStringSegment" @@ -36,70 +47,98 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): ) if isinstance(variable, ArrayFileSegment): + inputs = {"variable": [item.to_dict() for item in variable.value]} process_data["variable"] = [item.to_dict() for item in variable.value] else: + inputs = {"variable": variable.value} process_data["variable"] = variable.value - # Filter - if self.node_data.filter_by.enabled: - for condition in self.node_data.filter_by.conditions: - if isinstance(variable, ArrayStringSegment): - if not isinstance(condition.value, str): - raise ValueError(f"Invalid filter value: {condition.value}") - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayNumberSegment): - if not isinstance(condition.value, str): - raise ValueError(f"Invalid filter value: {condition.value}") - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value)) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayFileSegment): - if isinstance(condition.value, str): - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - else: - value = condition.value - filter_func = _get_file_filter_func( - key=condition.key, - condition=condition.comparison_operator, - value=value, - ) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) + try: + # Filter + if self.node_data.filter_by.enabled: + variable = self._apply_filter(variable) - # Order - if self.node_data.order_by.enabled: + # Order + if self.node_data.order_by.enabled: + variable = self._apply_order(variable) + + # Slice + if self.node_data.limit.enabled: + variable = self._apply_slice(variable) + + outputs = { + "result": variable.value, + "first_record": variable.value[0] if variable.value else None, + "last_record": variable.value[-1] if variable.value else None, + } + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs=outputs, + ) + except ListOperatorError as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=inputs, + process_data=process_data, + outputs=outputs, + ) + + def _apply_filter( + self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] + ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + for condition in self.node_data.filter_by.conditions: if isinstance(variable, ArrayStringSegment): - result = _order_string(order=self.node_data.order_by.value, array=variable.value) + if not isinstance(condition.value, str): + raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value) + result = list(filter(filter_func, variable.value)) variable = variable.model_copy(update={"value": result}) elif isinstance(variable, ArrayNumberSegment): - result = _order_number(order=self.node_data.order_by.value, array=variable.value) + if not isinstance(condition.value, str): + raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value)) + result = list(filter(filter_func, variable.value)) variable = variable.model_copy(update={"value": result}) elif isinstance(variable, ArrayFileSegment): - result = _order_file( - order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value + if isinstance(condition.value, str): + value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text + else: + value = condition.value + filter_func = _get_file_filter_func( + key=condition.key, + condition=condition.comparison_operator, + value=value, ) + result = list(filter(filter_func, variable.value)) variable = variable.model_copy(update={"value": result}) + return variable - # Slice - if self.node_data.limit.enabled: - result = variable.value[: self.node_data.limit.size] + def _apply_order( + self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] + ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + if isinstance(variable, ArrayStringSegment): + result = _order_string(order=self.node_data.order_by.value, array=variable.value) variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayNumberSegment): + result = _order_number(order=self.node_data.order_by.value, array=variable.value) + variable = variable.model_copy(update={"value": result}) + elif isinstance(variable, ArrayFileSegment): + result = _order_file( + order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value + ) + variable = variable.model_copy(update={"value": result}) + return variable - outputs = { - "result": variable.value, - "first_record": variable.value[0] if variable.value else None, - "last_record": variable.value[-1] if variable.value else None, - } - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) + def _apply_slice( + self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] + ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: + result = variable.value[: self.node_data.limit.size] + return variable.model_copy(update={"value": result}) def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]: @@ -107,7 +146,7 @@ def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]: case "size": return lambda x: x.size case _: - raise ValueError(f"Invalid key: {key}") + raise InvalidKeyError(f"Invalid key: {key}") def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: @@ -125,7 +164,7 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: case "url": return lambda x: x.remote_url or "" case _: - raise ValueError(f"Invalid key: {key}") + raise InvalidKeyError(f"Invalid key: {key}") def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]: @@ -151,7 +190,7 @@ def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bo case "not empty": return lambda x: x != "" case _: - raise ValueError(f"Invalid condition: {condition}") + raise InvalidConditionError(f"Invalid condition: {condition}") def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]: @@ -161,7 +200,7 @@ def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callab case "not in": return lambda x: not _in(value)(x) case _: - raise ValueError(f"Invalid condition: {condition}") + raise InvalidConditionError(f"Invalid condition: {condition}") def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]: @@ -179,7 +218,7 @@ def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[ case "≥": return _ge(value) case _: - raise ValueError(f"Invalid condition: {condition}") + raise InvalidConditionError(f"Invalid condition: {condition}") def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: @@ -193,7 +232,7 @@ def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str extract_func = _get_file_extract_number_func(key=key) return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_func(x)) else: - raise ValueError(f"Invalid key: {key}") + raise InvalidKeyError(f"Invalid key: {key}") def _contains(value: str):