fix: generate not stop when pressing stop link (#1961)

This commit is contained in:
takatost 2024-01-06 03:03:56 +08:00 committed by GitHub
parent a8cedea15a
commit 0c746f5c5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 68 additions and 39 deletions

View File

@ -1,7 +1,7 @@
import time import time
from typing import cast, Optional, List, Tuple, Generator, Union from typing import cast, Optional, List, Tuple, Generator, Union
from core.application_queue_manager import ApplicationQueueManager from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity
from core.file.file_obj import FileObj from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
@ -183,7 +183,7 @@ class AppRunner:
index=index, index=index,
message=AssistantPromptMessage(content=token) message=AssistantPromptMessage(content=token)
) )
)) ), PublishFrom.APPLICATION_MANAGER)
index += 1 index += 1
time.sleep(0.01) time.sleep(0.01)
@ -193,7 +193,8 @@ class AppRunner:
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=text), message=AssistantPromptMessage(content=text),
usage=usage if usage else LLMUsage.empty_usage() usage=usage if usage else LLMUsage.empty_usage()
) ),
pub_from=PublishFrom.APPLICATION_MANAGER
) )
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator], def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
@ -226,7 +227,8 @@ class AppRunner:
:return: :return:
""" """
queue_manager.publish_message_end( queue_manager.publish_message_end(
llm_result=invoke_result llm_result=invoke_result,
pub_from=PublishFrom.APPLICATION_MANAGER
) )
def _handle_invoke_result_stream(self, invoke_result: Generator, def _handle_invoke_result_stream(self, invoke_result: Generator,
@ -242,7 +244,7 @@ class AppRunner:
text = '' text = ''
usage = None usage = None
for result in invoke_result: for result in invoke_result:
queue_manager.publish_chunk_message(result) queue_manager.publish_chunk_message(result, PublishFrom.APPLICATION_MANAGER)
text += result.delta.message.content text += result.delta.message.content
@ -263,5 +265,6 @@ class AppRunner:
) )
queue_manager.publish_message_end( queue_manager.publish_message_end(
llm_result=llm_result llm_result=llm_result,
pub_from=PublishFrom.APPLICATION_MANAGER
) )

View File

@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \ from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity
from core.application_queue_manager import ApplicationQueueManager from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.features.annotation_reply import AnnotationReplyFeature from core.features.annotation_reply import AnnotationReplyFeature
from core.features.dataset_retrieval import DatasetRetrievalFeature from core.features.dataset_retrieval import DatasetRetrievalFeature
from core.features.external_data_fetch import ExternalDataFetchFeature from core.features.external_data_fetch import ExternalDataFetchFeature
@ -121,7 +121,8 @@ class BasicApplicationRunner(AppRunner):
if annotation_reply: if annotation_reply:
queue_manager.publish_annotation_reply( queue_manager.publish_annotation_reply(
message_annotation_id=annotation_reply.id message_annotation_id=annotation_reply.id,
pub_from=PublishFrom.APPLICATION_MANAGER
) )
self.direct_output( self.direct_output(
queue_manager=queue_manager, queue_manager=queue_manager,

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel
from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule
from core.entities.application_entities import ApplicationGenerateEntity from core.entities.application_entities import ApplicationGenerateEntity
from core.application_queue_manager import ApplicationQueueManager from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \ from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \
QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \ QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \
AnnotationReplyEvent AnnotationReplyEvent
@ -312,8 +312,11 @@ class GenerateTaskPipeline:
index=0, index=0,
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content) message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
) )
)) ), PublishFrom.TASK_PIPELINE)
self._queue_manager.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION)) self._queue_manager.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
PublishFrom.TASK_PIPELINE
)
continue continue
else: else:
self._output_moderation_handler.append_new_token(delta_text) self._output_moderation_handler.append_new_token(delta_text)

View File

@ -6,6 +6,7 @@ from typing import Any, Optional, Dict
from flask import current_app, Flask from flask import current_app, Flask
from pydantic import BaseModel from pydantic import BaseModel
from core.application_queue_manager import PublishFrom
from core.moderation.base import ModerationAction, ModerationOutputsResult from core.moderation.base import ModerationAction, ModerationOutputsResult
from core.moderation.factory import ModerationFactory from core.moderation.factory import ModerationFactory
@ -66,7 +67,7 @@ class OutputModerationHandler(BaseModel):
final_output = result.text final_output = result.text
if public_event: if public_event:
self.on_message_replace_func(final_output) self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE)
return final_output return final_output

View File

@ -23,7 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.prompt_template import PromptTemplateParser from core.prompt.prompt_template import PromptTemplateParser
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.model import EndUser, Conversation, Message, MessageFile, App from models.model import EndUser, Conversation, Message, MessageFile, App
@ -169,15 +169,18 @@ class ApplicationManager:
except ConversationTaskStoppedException: except ConversationTaskStoppedException:
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error(InvokeAuthorizationError('Incorrect API key provided')) queue_manager.publish_error(
InvokeAuthorizationError('Incorrect API key provided'),
PublishFrom.APPLICATION_MANAGER
)
except ValidationError as e: except ValidationError as e:
logger.exception("Validation Error when generating") logger.exception("Validation Error when generating")
queue_manager.publish_error(e) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except (ValueError, InvokeError) as e: except (ValueError, InvokeError) as e:
queue_manager.publish_error(e) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
except Exception as e: except Exception as e:
logger.exception("Unknown Error when generating") logger.exception("Unknown Error when generating")
queue_manager.publish_error(e) queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally: finally:
db.session.remove() db.session.remove()

View File

@ -1,5 +1,6 @@
import queue import queue
import time import time
from enum import Enum
from typing import Generator, Any from typing import Generator, Any
from sqlalchemy.orm import DeclarativeMeta from sqlalchemy.orm import DeclarativeMeta
@ -13,6 +14,11 @@ from extensions.ext_redis import redis_client
from models.model import MessageAgentThought from models.model import MessageAgentThought
class PublishFrom(Enum):
APPLICATION_MANAGER = 1
TASK_PIPELINE = 2
class ApplicationQueueManager: class ApplicationQueueManager:
def __init__(self, task_id: str, def __init__(self, task_id: str,
user_id: str, user_id: str,
@ -61,11 +67,14 @@ class ApplicationQueueManager:
if elapsed_time >= listen_timeout or self._is_stopped(): if elapsed_time >= listen_timeout or self._is_stopped():
# publish two messages to make sure the client can receive the stop signal # publish two messages to make sure the client can receive the stop signal
# and stop listening after the stop signal processed # and stop listening after the stop signal processed
self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL)) self.publish(
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
PublishFrom.TASK_PIPELINE
)
self.stop_listen() self.stop_listen()
if elapsed_time // 10 > last_ping_time: if elapsed_time // 10 > last_ping_time:
self.publish(QueuePingEvent()) self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
last_ping_time = elapsed_time // 10 last_ping_time = elapsed_time // 10
def stop_listen(self) -> None: def stop_listen(self) -> None:
@ -75,76 +84,83 @@ class ApplicationQueueManager:
""" """
self._q.put(None) self._q.put(None)
def publish_chunk_message(self, chunk: LLMResultChunk) -> None: def publish_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
""" """
Publish chunk message to channel Publish chunk message to channel
:param chunk: chunk :param chunk: chunk
:param pub_from: publish from
:return: :return:
""" """
self.publish(QueueMessageEvent( self.publish(QueueMessageEvent(
chunk=chunk chunk=chunk
)) ), pub_from)
def publish_message_replace(self, text: str) -> None: def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None:
""" """
Publish message replace Publish message replace
:param text: text :param text: text
:param pub_from: publish from
:return: :return:
""" """
self.publish(QueueMessageReplaceEvent( self.publish(QueueMessageReplaceEvent(
text=text text=text
)) ), pub_from)
def publish_retriever_resources(self, retriever_resources: list[dict]) -> None: def publish_retriever_resources(self, retriever_resources: list[dict], pub_from: PublishFrom) -> None:
""" """
Publish retriever resources Publish retriever resources
:return: :return:
""" """
self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources)) self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources), pub_from)
def publish_annotation_reply(self, message_annotation_id: str) -> None: def publish_annotation_reply(self, message_annotation_id: str, pub_from: PublishFrom) -> None:
""" """
Publish annotation reply Publish annotation reply
:param message_annotation_id: message annotation id :param message_annotation_id: message annotation id
:param pub_from: publish from
:return: :return:
""" """
self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id)) self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from)
def publish_message_end(self, llm_result: LLMResult) -> None: def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None:
""" """
Publish message end Publish message end
:param llm_result: llm result :param llm_result: llm result
:param pub_from: publish from
:return: :return:
""" """
self.publish(QueueMessageEndEvent(llm_result=llm_result)) self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from)
self.stop_listen() self.stop_listen()
def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None: def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None:
""" """
Publish agent thought Publish agent thought
:param message_agent_thought: message agent thought :param message_agent_thought: message agent thought
:param pub_from: publish from
:return: :return:
""" """
self.publish(QueueAgentThoughtEvent( self.publish(QueueAgentThoughtEvent(
agent_thought_id=message_agent_thought.id agent_thought_id=message_agent_thought.id
)) ), pub_from)
def publish_error(self, e) -> None: def publish_error(self, e, pub_from: PublishFrom) -> None:
""" """
Publish error Publish error
:param e: error :param e: error
:param pub_from: publish from
:return: :return:
""" """
self.publish(QueueErrorEvent( self.publish(QueueErrorEvent(
error=e error=e
)) ), pub_from)
self.stop_listen() self.stop_listen()
def publish(self, event: AppQueueEvent) -> None: def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
""" """
Publish event to queue Publish event to queue
:param event: :param event:
:param pub_from:
:return: :return:
""" """
self._check_for_sqlalchemy_models(event.dict()) self._check_for_sqlalchemy_models(event.dict())
@ -162,6 +178,9 @@ class ApplicationQueueManager:
if isinstance(event, QueueStopEvent): if isinstance(event, QueueStopEvent):
self.stop_listen() self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise ConversationTaskStoppedException()
@classmethod @classmethod
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None: def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
""" """
@ -187,7 +206,6 @@ class ApplicationQueueManager:
stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id) stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
result = redis_client.get(stopped_cache_key) result = redis_client.get(stopped_cache_key)
if result is not None: if result is not None:
redis_client.delete(stopped_cache_key)
return True return True
return False return False

View File

@ -8,7 +8,7 @@ from langchain.agents import openai_functions_agent, openai_functions_multi_agen
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
from core.application_queue_manager import ApplicationQueueManager from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.agent_loop import AgentLoop
from core.entities.application_entities import ModelConfigEntity from core.entities.application_entities import ModelConfigEntity
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
@ -232,7 +232,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
db.session.add(message_agent_thought) db.session.add(message_agent_thought)
db.session.commit() db.session.commit()
self.queue_manager.publish_agent_thought(message_agent_thought) self.queue_manager.publish_agent_thought(message_agent_thought, PublishFrom.APPLICATION_MANAGER)
return message_agent_thought return message_agent_thought

View File

@ -2,7 +2,7 @@ from typing import List, Union
from langchain.schema import Document from langchain.schema import Document
from core.application_queue_manager import ApplicationQueueManager from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import InvokeFrom from core.entities.application_entities import InvokeFrom
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import DocumentSegment, DatasetQuery from models.dataset import DocumentSegment, DatasetQuery
@ -80,4 +80,4 @@ class DatasetIndexToolCallbackHandler:
db.session.add(dataset_retriever_resource) db.session.add(dataset_retriever_resource)
db.session.commit() db.session.commit()
self._queue_manager.publish_retriever_resources(resource) self._queue_manager.publish_retriever_resources(resource, PublishFrom.APPLICATION_MANAGER)