mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-02 11:18:19 +08:00
fix: generate not stop when pressing stop link (#1961)
This commit is contained in:
parent
a8cedea15a
commit
0c746f5c5a
@ -1,7 +1,7 @@
|
|||||||
import time
|
import time
|
||||||
from typing import cast, Optional, List, Tuple, Generator, Union
|
from typing import cast, Optional, List, Tuple, Generator, Union
|
||||||
|
|
||||||
from core.application_queue_manager import ApplicationQueueManager
|
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||||
from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity
|
from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity
|
||||||
from core.file.file_obj import FileObj
|
from core.file.file_obj import FileObj
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
@ -183,7 +183,7 @@ class AppRunner:
|
|||||||
index=index,
|
index=index,
|
||||||
message=AssistantPromptMessage(content=token)
|
message=AssistantPromptMessage(content=token)
|
||||||
)
|
)
|
||||||
))
|
), PublishFrom.APPLICATION_MANAGER)
|
||||||
index += 1
|
index += 1
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
|
|
||||||
@ -193,7 +193,8 @@ class AppRunner:
|
|||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
message=AssistantPromptMessage(content=text),
|
message=AssistantPromptMessage(content=text),
|
||||||
usage=usage if usage else LLMUsage.empty_usage()
|
usage=usage if usage else LLMUsage.empty_usage()
|
||||||
)
|
),
|
||||||
|
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
|
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
|
||||||
@ -226,7 +227,8 @@ class AppRunner:
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
queue_manager.publish_message_end(
|
queue_manager.publish_message_end(
|
||||||
llm_result=invoke_result
|
llm_result=invoke_result,
|
||||||
|
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_invoke_result_stream(self, invoke_result: Generator,
|
def _handle_invoke_result_stream(self, invoke_result: Generator,
|
||||||
@ -242,7 +244,7 @@ class AppRunner:
|
|||||||
text = ''
|
text = ''
|
||||||
usage = None
|
usage = None
|
||||||
for result in invoke_result:
|
for result in invoke_result:
|
||||||
queue_manager.publish_chunk_message(result)
|
queue_manager.publish_chunk_message(result, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
|
||||||
text += result.delta.message.content
|
text += result.delta.message.content
|
||||||
|
|
||||||
@ -263,5 +265,6 @@ class AppRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
queue_manager.publish_message_end(
|
queue_manager.publish_message_end(
|
||||||
llm_result=llm_result
|
llm_result=llm_result,
|
||||||
|
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
|
@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
|
|||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
|
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
|
||||||
AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity
|
AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity
|
||||||
from core.application_queue_manager import ApplicationQueueManager
|
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||||
from core.features.annotation_reply import AnnotationReplyFeature
|
from core.features.annotation_reply import AnnotationReplyFeature
|
||||||
from core.features.dataset_retrieval import DatasetRetrievalFeature
|
from core.features.dataset_retrieval import DatasetRetrievalFeature
|
||||||
from core.features.external_data_fetch import ExternalDataFetchFeature
|
from core.features.external_data_fetch import ExternalDataFetchFeature
|
||||||
@ -121,7 +121,8 @@ class BasicApplicationRunner(AppRunner):
|
|||||||
|
|
||||||
if annotation_reply:
|
if annotation_reply:
|
||||||
queue_manager.publish_annotation_reply(
|
queue_manager.publish_annotation_reply(
|
||||||
message_annotation_id=annotation_reply.id
|
message_annotation_id=annotation_reply.id,
|
||||||
|
pub_from=PublishFrom.APPLICATION_MANAGER
|
||||||
)
|
)
|
||||||
self.direct_output(
|
self.direct_output(
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
|
@ -7,7 +7,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule
|
from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule
|
||||||
from core.entities.application_entities import ApplicationGenerateEntity
|
from core.entities.application_entities import ApplicationGenerateEntity
|
||||||
from core.application_queue_manager import ApplicationQueueManager
|
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||||
from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \
|
from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \
|
||||||
QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \
|
QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \
|
||||||
AnnotationReplyEvent
|
AnnotationReplyEvent
|
||||||
@ -312,8 +312,11 @@ class GenerateTaskPipeline:
|
|||||||
index=0,
|
index=0,
|
||||||
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
|
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
|
||||||
)
|
)
|
||||||
))
|
), PublishFrom.TASK_PIPELINE)
|
||||||
self._queue_manager.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION))
|
self._queue_manager.publish(
|
||||||
|
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
|
||||||
|
PublishFrom.TASK_PIPELINE
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
self._output_moderation_handler.append_new_token(delta_text)
|
self._output_moderation_handler.append_new_token(delta_text)
|
||||||
|
@ -6,6 +6,7 @@ from typing import Any, Optional, Dict
|
|||||||
from flask import current_app, Flask
|
from flask import current_app, Flask
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.application_queue_manager import PublishFrom
|
||||||
from core.moderation.base import ModerationAction, ModerationOutputsResult
|
from core.moderation.base import ModerationAction, ModerationOutputsResult
|
||||||
from core.moderation.factory import ModerationFactory
|
from core.moderation.factory import ModerationFactory
|
||||||
|
|
||||||
@ -66,7 +67,7 @@ class OutputModerationHandler(BaseModel):
|
|||||||
final_output = result.text
|
final_output = result.text
|
||||||
|
|
||||||
if public_event:
|
if public_event:
|
||||||
self.on_message_replace_func(final_output)
|
self.on_message_replace_func(final_output, PublishFrom.TASK_PIPELINE)
|
||||||
|
|
||||||
return final_output
|
return final_output
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeErr
|
|||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.prompt.prompt_template import PromptTemplateParser
|
from core.prompt.prompt_template import PromptTemplateParser
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException
|
from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.model import EndUser, Conversation, Message, MessageFile, App
|
from models.model import EndUser, Conversation, Message, MessageFile, App
|
||||||
@ -169,15 +169,18 @@ class ApplicationManager:
|
|||||||
except ConversationTaskStoppedException:
|
except ConversationTaskStoppedException:
|
||||||
pass
|
pass
|
||||||
except InvokeAuthorizationError:
|
except InvokeAuthorizationError:
|
||||||
queue_manager.publish_error(InvokeAuthorizationError('Incorrect API key provided'))
|
queue_manager.publish_error(
|
||||||
|
InvokeAuthorizationError('Incorrect API key provided'),
|
||||||
|
PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
logger.exception("Validation Error when generating")
|
logger.exception("Validation Error when generating")
|
||||||
queue_manager.publish_error(e)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
except (ValueError, InvokeError) as e:
|
except (ValueError, InvokeError) as e:
|
||||||
queue_manager.publish_error(e)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Unknown Error when generating")
|
logger.exception("Unknown Error when generating")
|
||||||
queue_manager.publish_error(e)
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
finally:
|
finally:
|
||||||
db.session.remove()
|
db.session.remove()
|
||||||
|
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import queue
|
import queue
|
||||||
import time
|
import time
|
||||||
|
from enum import Enum
|
||||||
from typing import Generator, Any
|
from typing import Generator, Any
|
||||||
|
|
||||||
from sqlalchemy.orm import DeclarativeMeta
|
from sqlalchemy.orm import DeclarativeMeta
|
||||||
@ -13,6 +14,11 @@ from extensions.ext_redis import redis_client
|
|||||||
from models.model import MessageAgentThought
|
from models.model import MessageAgentThought
|
||||||
|
|
||||||
|
|
||||||
|
class PublishFrom(Enum):
|
||||||
|
APPLICATION_MANAGER = 1
|
||||||
|
TASK_PIPELINE = 2
|
||||||
|
|
||||||
|
|
||||||
class ApplicationQueueManager:
|
class ApplicationQueueManager:
|
||||||
def __init__(self, task_id: str,
|
def __init__(self, task_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
@ -61,11 +67,14 @@ class ApplicationQueueManager:
|
|||||||
if elapsed_time >= listen_timeout or self._is_stopped():
|
if elapsed_time >= listen_timeout or self._is_stopped():
|
||||||
# publish two messages to make sure the client can receive the stop signal
|
# publish two messages to make sure the client can receive the stop signal
|
||||||
# and stop listening after the stop signal processed
|
# and stop listening after the stop signal processed
|
||||||
self.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))
|
self.publish(
|
||||||
|
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL),
|
||||||
|
PublishFrom.TASK_PIPELINE
|
||||||
|
)
|
||||||
self.stop_listen()
|
self.stop_listen()
|
||||||
|
|
||||||
if elapsed_time // 10 > last_ping_time:
|
if elapsed_time // 10 > last_ping_time:
|
||||||
self.publish(QueuePingEvent())
|
self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE)
|
||||||
last_ping_time = elapsed_time // 10
|
last_ping_time = elapsed_time // 10
|
||||||
|
|
||||||
def stop_listen(self) -> None:
|
def stop_listen(self) -> None:
|
||||||
@ -75,76 +84,83 @@ class ApplicationQueueManager:
|
|||||||
"""
|
"""
|
||||||
self._q.put(None)
|
self._q.put(None)
|
||||||
|
|
||||||
def publish_chunk_message(self, chunk: LLMResultChunk) -> None:
|
def publish_chunk_message(self, chunk: LLMResultChunk, pub_from: PublishFrom) -> None:
|
||||||
"""
|
"""
|
||||||
Publish chunk message to channel
|
Publish chunk message to channel
|
||||||
|
|
||||||
:param chunk: chunk
|
:param chunk: chunk
|
||||||
|
:param pub_from: publish from
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
self.publish(QueueMessageEvent(
|
self.publish(QueueMessageEvent(
|
||||||
chunk=chunk
|
chunk=chunk
|
||||||
))
|
), pub_from)
|
||||||
|
|
||||||
def publish_message_replace(self, text: str) -> None:
|
def publish_message_replace(self, text: str, pub_from: PublishFrom) -> None:
|
||||||
"""
|
"""
|
||||||
Publish message replace
|
Publish message replace
|
||||||
:param text: text
|
:param text: text
|
||||||
|
:param pub_from: publish from
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
self.publish(QueueMessageReplaceEvent(
|
self.publish(QueueMessageReplaceEvent(
|
||||||
text=text
|
text=text
|
||||||
))
|
), pub_from)
|
||||||
|
|
||||||
def publish_retriever_resources(self, retriever_resources: list[dict]) -> None:
|
def publish_retriever_resources(self, retriever_resources: list[dict], pub_from: PublishFrom) -> None:
|
||||||
"""
|
"""
|
||||||
Publish retriever resources
|
Publish retriever resources
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources))
|
self.publish(QueueRetrieverResourcesEvent(retriever_resources=retriever_resources), pub_from)
|
||||||
|
|
||||||
def publish_annotation_reply(self, message_annotation_id: str) -> None:
|
def publish_annotation_reply(self, message_annotation_id: str, pub_from: PublishFrom) -> None:
|
||||||
"""
|
"""
|
||||||
Publish annotation reply
|
Publish annotation reply
|
||||||
:param message_annotation_id: message annotation id
|
:param message_annotation_id: message annotation id
|
||||||
|
:param pub_from: publish from
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id))
|
self.publish(AnnotationReplyEvent(message_annotation_id=message_annotation_id), pub_from)
|
||||||
|
|
||||||
def publish_message_end(self, llm_result: LLMResult) -> None:
|
def publish_message_end(self, llm_result: LLMResult, pub_from: PublishFrom) -> None:
|
||||||
"""
|
"""
|
||||||
Publish message end
|
Publish message end
|
||||||
:param llm_result: llm result
|
:param llm_result: llm result
|
||||||
|
:param pub_from: publish from
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
self.publish(QueueMessageEndEvent(llm_result=llm_result))
|
self.publish(QueueMessageEndEvent(llm_result=llm_result), pub_from)
|
||||||
self.stop_listen()
|
self.stop_listen()
|
||||||
|
|
||||||
def publish_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
|
def publish_agent_thought(self, message_agent_thought: MessageAgentThought, pub_from: PublishFrom) -> None:
|
||||||
"""
|
"""
|
||||||
Publish agent thought
|
Publish agent thought
|
||||||
:param message_agent_thought: message agent thought
|
:param message_agent_thought: message agent thought
|
||||||
|
:param pub_from: publish from
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
self.publish(QueueAgentThoughtEvent(
|
self.publish(QueueAgentThoughtEvent(
|
||||||
agent_thought_id=message_agent_thought.id
|
agent_thought_id=message_agent_thought.id
|
||||||
))
|
), pub_from)
|
||||||
|
|
||||||
def publish_error(self, e) -> None:
|
def publish_error(self, e, pub_from: PublishFrom) -> None:
|
||||||
"""
|
"""
|
||||||
Publish error
|
Publish error
|
||||||
:param e: error
|
:param e: error
|
||||||
|
:param pub_from: publish from
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
self.publish(QueueErrorEvent(
|
self.publish(QueueErrorEvent(
|
||||||
error=e
|
error=e
|
||||||
))
|
), pub_from)
|
||||||
self.stop_listen()
|
self.stop_listen()
|
||||||
|
|
||||||
def publish(self, event: AppQueueEvent) -> None:
|
def publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||||
"""
|
"""
|
||||||
Publish event to queue
|
Publish event to queue
|
||||||
:param event:
|
:param event:
|
||||||
|
:param pub_from:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
self._check_for_sqlalchemy_models(event.dict())
|
self._check_for_sqlalchemy_models(event.dict())
|
||||||
@ -162,6 +178,9 @@ class ApplicationQueueManager:
|
|||||||
if isinstance(event, QueueStopEvent):
|
if isinstance(event, QueueStopEvent):
|
||||||
self.stop_listen()
|
self.stop_listen()
|
||||||
|
|
||||||
|
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||||
|
raise ConversationTaskStoppedException()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
|
def set_stop_flag(cls, task_id: str, invoke_from: InvokeFrom, user_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
@ -187,7 +206,6 @@ class ApplicationQueueManager:
|
|||||||
stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
|
stopped_cache_key = ApplicationQueueManager._generate_stopped_cache_key(self._task_id)
|
||||||
result = redis_client.get(stopped_cache_key)
|
result = redis_client.get(stopped_cache_key)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
redis_client.delete(stopped_cache_key)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
@ -8,7 +8,7 @@ from langchain.agents import openai_functions_agent, openai_functions_multi_agen
|
|||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
|
from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration, BaseMessage
|
||||||
|
|
||||||
from core.application_queue_manager import ApplicationQueueManager
|
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||||
from core.callback_handler.entity.agent_loop import AgentLoop
|
from core.callback_handler.entity.agent_loop import AgentLoop
|
||||||
from core.entities.application_entities import ModelConfigEntity
|
from core.entities.application_entities import ModelConfigEntity
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
|
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
|
||||||
@ -232,7 +232,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
|
|||||||
db.session.add(message_agent_thought)
|
db.session.add(message_agent_thought)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
self.queue_manager.publish_agent_thought(message_agent_thought)
|
self.queue_manager.publish_agent_thought(message_agent_thought, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
|
||||||
return message_agent_thought
|
return message_agent_thought
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ from typing import List, Union
|
|||||||
|
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
|
|
||||||
from core.application_queue_manager import ApplicationQueueManager
|
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||||
from core.entities.application_entities import InvokeFrom
|
from core.entities.application_entities import InvokeFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.dataset import DocumentSegment, DatasetQuery
|
from models.dataset import DocumentSegment, DatasetQuery
|
||||||
@ -80,4 +80,4 @@ class DatasetIndexToolCallbackHandler:
|
|||||||
db.session.add(dataset_retriever_resource)
|
db.session.add(dataset_retriever_resource)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
self._queue_manager.publish_retriever_resources(resource)
|
self._queue_manager.publish_retriever_resources(resource, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
Loading…
Reference in New Issue
Block a user