From 72ea3d6b98157bc6e360ba5b7c1a8c4137d78243 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Fri, 25 Oct 2024 22:33:34 +0800 Subject: [PATCH] fix(workflow): Take back LLM streaming output after IF-ELSE (#9875) --- api/core/workflow/graph_engine/graph_engine.py | 13 ++++++------- .../nodes/answer/answer_stream_generate_router.py | 8 ++++---- .../nodes/answer/answer_stream_processor.py | 2 +- .../workflow/nodes/answer/base_stream_processor.py | 1 - api/core/workflow/nodes/answer/entities.py | 3 ++- 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index ada0b14ce..8f58af00e 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -130,15 +130,14 @@ class GraphEngine: yield GraphRunStartedEvent() try: - stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor] if self.init_params.workflow_type == WorkflowType.CHAT: - stream_processor_cls = AnswerStreamProcessor + stream_processor = AnswerStreamProcessor( + graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool + ) else: - stream_processor_cls = EndStreamProcessor - - stream_processor = stream_processor_cls( - graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool - ) + stream_processor = EndStreamProcessor( + graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool + ) # run graph generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id)) diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index bce28c5fc..bc4b05614 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -149,10 +149,10 @@ class AnswerStreamGeneratorRouter: source_node_id = edge.source_node_id source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type") if source_node_type in { - NodeType.ANSWER.value, - NodeType.IF_ELSE.value, - NodeType.QUESTION_CLASSIFIER.value, - NodeType.ITERATION.value, + NodeType.ANSWER, + NodeType.IF_ELSE, + NodeType.QUESTION_CLASSIFIER, + NodeType.ITERATION, }: answer_dependencies[answer_node_id].append(source_node_id) else: diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index e3889941c..8a768088d 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -22,7 +22,7 @@ class AnswerStreamProcessor(StreamProcessor): super().__init__(graph, variable_pool) self.generate_routes = graph.answer_stream_generate_routes self.route_position = {} - for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items(): + for answer_node_id in self.generate_routes.answer_generate_route: self.route_position[answer_node_id] = 0 self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {} diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index 52d0358c7..36c3fe180 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -41,7 +41,6 @@ class StreamProcessor(ABC): continue else: unreachable_first_node_ids.append(edge.target_node_id) - unreachable_first_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) for node_id in unreachable_first_node_ids: self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index e543d02dd..a05cc44c9 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from enum import Enum from pydantic import BaseModel, Field @@ -32,7 +33,7 @@ class VarGenerateRouteChunk(GenerateRouteChunk): type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR """generate route chunk type""" - value_selector: list[str] = Field(..., description="value selector") + value_selector: Sequence[str] = Field(..., description="value selector") class TextGenerateRouteChunk(GenerateRouteChunk):