chore: apply pep8-naming rules for naming convention (#8261)

This commit is contained in:
Bowen Liang 2024-09-11 16:40:52 +08:00 committed by GitHub
parent 53f37a6704
commit 292220c596
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
95 changed files with 287 additions and 258 deletions

View File

@ -20,7 +20,7 @@ from fields.conversation_fields import (
conversation_pagination_fields, conversation_pagination_fields,
conversation_with_summary_pagination_fields, conversation_with_summary_pagination_fields,
) )
from libs.helper import datetime_string from libs.helper import DatetimeString
from libs.login import login_required from libs.login import login_required
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
@ -36,8 +36,8 @@ class CompletionConversationApi(Resource):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args") parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument( parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
) )
@ -143,8 +143,8 @@ class ChatConversationApi(Resource):
raise Forbidden() raise Forbidden()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("keyword", type=str, location="args") parser.add_argument("keyword", type=str, location="args")
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument( parser.add_argument(
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
) )

View File

@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import datetime_string from libs.helper import DatetimeString
from libs.login import login_required from libs.login import login_required
from models.model import AppMode from models.model import AppMode
@ -25,8 +25,8 @@ class DailyMessageStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """
@ -79,8 +79,8 @@ class DailyConversationStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """
@ -133,8 +133,8 @@ class DailyTerminalsStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """
@ -187,8 +187,8 @@ class DailyTokenCostStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """
@ -245,8 +245,8 @@ class AverageSessionInteractionStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
@ -307,8 +307,8 @@ class UserSatisfactionRateStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """
@ -369,8 +369,8 @@ class AverageResponseTimeStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """
@ -425,8 +425,8 @@ class TokensPerSecondStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,

View File

@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import datetime_string from libs.helper import DatetimeString
from libs.login import login_required from libs.login import login_required
from models.model import AppMode from models.model import AppMode
from models.workflow import WorkflowRunTriggeredFrom from models.workflow import WorkflowRunTriggeredFrom
@ -26,8 +26,8 @@ class WorkflowDailyRunsStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """
@ -86,8 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """
@ -146,8 +146,8 @@ class WorkflowDailyTokenCostStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """
@ -213,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
account = current_user account = current_user
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args() args = parser.parse_args()
sql_query = """ sql_query = """

View File

@ -8,7 +8,7 @@ from constants.languages import supported_language
from controllers.console import api from controllers.console import api
from controllers.console.error import AlreadyActivateError from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db from extensions.ext_database import db
from libs.helper import email, str_len, timezone from libs.helper import StrLen, email, timezone
from libs.password import hash_password, valid_password from libs.password import hash_password, valid_password
from models.account import AccountStatus from models.account import AccountStatus
from services.account_service import RegisterService from services.account_service import RegisterService
@ -37,7 +37,7 @@ class ActivateApi(Resource):
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
parser.add_argument("email", type=email, required=False, nullable=True, location="json") parser.add_argument("email", type=email, required=False, nullable=True, location="json")
parser.add_argument("token", type=str, required=True, nullable=False, location="json") parser.add_argument("token", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json") parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json") parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
parser.add_argument( parser.add_argument(
"interface_language", type=supported_language, required=True, nullable=False, location="json" "interface_language", type=supported_language, required=True, nullable=False, location="json"

View File

@ -4,7 +4,7 @@ from flask import session
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from configs import dify_config from configs import dify_config
from libs.helper import str_len from libs.helper import StrLen
from models.model import DifySetup from models.model import DifySetup
from services.account_service import TenantService from services.account_service import TenantService
@ -28,7 +28,7 @@ class InitValidateAPI(Resource):
raise AlreadySetupError() raise AlreadySetupError()
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("password", type=str_len(30), required=True, location="json") parser.add_argument("password", type=StrLen(30), required=True, location="json")
input_password = parser.parse_args()["password"] input_password = parser.parse_args()["password"]
if input_password != os.environ.get("INIT_PASSWORD"): if input_password != os.environ.get("INIT_PASSWORD"):

View File

@ -4,7 +4,7 @@ from flask import request
from flask_restful import Resource, reqparse from flask_restful import Resource, reqparse
from configs import dify_config from configs import dify_config
from libs.helper import email, get_remote_ip, str_len from libs.helper import StrLen, email, get_remote_ip
from libs.password import valid_password from libs.password import valid_password
from models.model import DifySetup from models.model import DifySetup
from services.account_service import RegisterService, TenantService from services.account_service import RegisterService, TenantService
@ -40,7 +40,7 @@ class SetupApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("email", type=email, required=True, location="json") parser.add_argument("email", type=email, required=True, location="json")
parser.add_argument("name", type=str_len(30), required=True, location="json") parser.add_argument("name", type=StrLen(30), required=True, location="json")
parser.add_argument("password", type=valid_password, required=True, location="json") parser.add_argument("password", type=valid_password, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()

View File

@ -15,7 +15,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
@ -293,7 +293,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
) )
runner.run() runner.run()
except GenerateTaskStoppedException: except GenerateTaskStoppedError:
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(
@ -349,7 +349,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
return generate_task_pipeline.process() return generate_task_pipeline.process()
except ValueError as e: except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedException() raise GenerateTaskStoppedError()
else: else:
logger.exception(e) logger.exception(e)
raise e raise e

View File

@ -21,7 +21,7 @@ class AudioTrunk:
self.status = status self.status = status
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str): def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
if not text_content or text_content.isspace(): if not text_content or text_content.isspace():
return return
return model_instance.invoke_tts( return model_instance.invoke_tts(
@ -81,7 +81,7 @@ class AppGeneratorTTSPublisher:
if message is None: if message is None:
if self.msg_text and len(self.msg_text.strip()) > 0: if self.msg_text and len(self.msg_text.strip()) > 0:
futures_result = self.executor.submit( futures_result = self.executor.submit(
_invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice _invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
) )
future_queue.put(futures_result) future_queue.put(futures_result)
break break
@ -97,7 +97,7 @@ class AppGeneratorTTSPublisher:
self.MAX_SENTENCE += 1 self.MAX_SENTENCE += 1
text_content = "".join(sentence_arr) text_content = "".join(sentence_arr)
futures_result = self.executor.submit( futures_result = self.executor.submit(
_invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
) )
future_queue.put(futures_result) future_queue.put(futures_result)
if text_tmp: if text_tmp:
@ -110,7 +110,7 @@ class AppGeneratorTTSPublisher:
break break
future_queue.put(None) future_queue.put(None)
def checkAndGetAudio(self) -> AudioTrunk | None: def check_and_get_audio(self) -> AudioTrunk | None:
try: try:
if self._last_audio_event and self._last_audio_event.status == "finish": if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor: if self.executor:

View File

@ -19,7 +19,7 @@ from core.app.entities.queue_entities import (
QueueStopEvent, QueueStopEvent,
QueueTextChunkEvent, QueueTextChunkEvent,
) )
from core.moderation.base import ModerationException from core.moderation.base import ModerationError
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
from core.workflow.entities.node_entities import UserFrom from core.workflow.entities.node_entities import UserFrom
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
@ -217,7 +217,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
query=query, query=query,
message_id=message_id, message_id=message_id,
) )
except ModerationException as e: except ModerationError as e:
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION) self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
return True return True

View File

@ -179,10 +179,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream_response=stream_response, stream_response=stream_response,
) )
def _listenAudioMsg(self, publisher, task_id: str): def _listen_audio_msg(self, publisher, task_id: str):
if not publisher: if not publisher:
return None return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio() audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish": if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None return None
@ -204,7 +204,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True: while True:
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
if audio_response: if audio_response:
yield audio_response yield audio_response
else: else:
@ -217,7 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
try: try:
if not tts_publisher: if not tts_publisher:
break break
audio_trunk = tts_publisher.checkAndGetAudio() audio_trunk = tts_publisher.check_and_get_audio()
if audio_trunk is None: if audio_trunk is None:
# release cpu # release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)

View File

@ -13,7 +13,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
@ -205,7 +205,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation, conversation=conversation,
message=message, message=message,
) )
except GenerateTaskStoppedException: except GenerateTaskStoppedError:
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(

View File

@ -15,7 +15,7 @@ from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationException from core.moderation.base import ModerationError
from core.tools.entities.tool_entities import ToolRuntimeVariablePool from core.tools.entities.tool_entities import ToolRuntimeVariablePool
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, Conversation, Message, MessageAgentThought from models.model import App, Conversation, Message, MessageAgentThought
@ -103,7 +103,7 @@ class AgentChatAppRunner(AppRunner):
query=query, query=query,
message_id=message.id, message_id=message.id,
) )
except ModerationException as e: except ModerationError as e:
self.direct_output( self.direct_output(
queue_manager=queue_manager, queue_manager=queue_manager,
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,

View File

@ -171,5 +171,5 @@ class AppQueueManager:
) )
class GenerateTaskStoppedException(Exception): class GenerateTaskStoppedError(Exception):
pass pass

View File

@ -10,7 +10,7 @@ from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
@ -205,7 +205,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation, conversation=conversation,
message=message, message=message,
) )
except GenerateTaskStoppedException: except GenerateTaskStoppedError:
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(

View File

@ -11,7 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.moderation.base import ModerationException from core.moderation.base import ModerationError
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, Conversation, Message from models.model import App, Conversation, Message
@ -98,7 +98,7 @@ class ChatAppRunner(AppRunner):
query=query, query=query,
message_id=message.id, message_id=message.id,
) )
except ModerationException as e: except ModerationError as e:
self.direct_output( self.direct_output(
queue_manager=queue_manager, queue_manager=queue_manager,
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,

View File

@ -10,7 +10,7 @@ from pydantic import ValidationError
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.completion.app_runner import CompletionAppRunner
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
@ -185,7 +185,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
queue_manager=queue_manager, queue_manager=queue_manager,
message=message, message=message,
) )
except GenerateTaskStoppedException: except GenerateTaskStoppedError:
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(

View File

@ -9,7 +9,7 @@ from core.app.entities.app_invoke_entities import (
) )
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.moderation.base import ModerationException from core.moderation.base import ModerationError
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App, Message from models.model import App, Message
@ -79,7 +79,7 @@ class CompletionAppRunner(AppRunner):
query=query, query=query,
message_id=message.id, message_id=message.id,
) )
except ModerationException as e: except ModerationError as e:
self.direct_output( self.direct_output(
queue_manager=queue_manager, queue_manager=queue_manager,
app_generate_entity=application_generate_entity, app_generate_entity=application_generate_entity,

View File

@ -8,7 +8,7 @@ from sqlalchemy import and_
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity, AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity, AgentChatAppGenerateEntity,
@ -77,7 +77,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return generate_task_pipeline.process() return generate_task_pipeline.process()
except ValueError as e: except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedException() raise GenerateTaskStoppedError()
else: else:
logger.exception(e) logger.exception(e)
raise e raise e

View File

@ -1,4 +1,4 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,
@ -53,4 +53,4 @@ class MessageBasedAppQueueManager(AppQueueManager):
self.stop_listen() self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedException() raise GenerateTaskStoppedError()

View File

@ -12,7 +12,7 @@ from pydantic import ValidationError
import contexts import contexts
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow.app_runner import WorkflowAppRunner
@ -253,7 +253,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
) )
runner.run() runner.run()
except GenerateTaskStoppedException: except GenerateTaskStoppedError:
pass pass
except InvokeAuthorizationError: except InvokeAuthorizationError:
queue_manager.publish_error( queue_manager.publish_error(
@ -302,7 +302,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
return generate_task_pipeline.process() return generate_task_pipeline.process()
except ValueError as e: except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error if e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedException() raise GenerateTaskStoppedError()
else: else:
logger.exception(e) logger.exception(e)
raise e raise e

View File

@ -1,4 +1,4 @@
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
AppQueueEvent, AppQueueEvent,
@ -39,4 +39,4 @@ class WorkflowAppQueueManager(AppQueueManager):
self.stop_listen() self.stop_listen()
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped(): if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
raise GenerateTaskStoppedException() raise GenerateTaskStoppedError()

View File

@ -162,10 +162,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
def _listenAudioMsg(self, publisher, task_id: str): def _listen_audio_msg(self, publisher, task_id: str):
if not publisher: if not publisher:
return None return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio() audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish": if audio_msg and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None return None
@ -187,7 +187,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True: while True:
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id) audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
if audio_response: if audio_response:
yield audio_response yield audio_response
else: else:
@ -199,7 +199,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
try: try:
if not tts_publisher: if not tts_publisher:
break break
audio_trunk = tts_publisher.checkAndGetAudio() audio_trunk = tts_publisher.check_and_get_audio()
if audio_trunk is None: if audio_trunk is None:
# release cpu # release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)

View File

@ -15,6 +15,7 @@ class Segment(BaseModel):
value: Any value: Any
@field_validator("value_type") @field_validator("value_type")
@classmethod
def validate_value_type(cls, value): def validate_value_type(cls, value):
""" """
This validator checks if the provided value is equal to the default value of the 'value_type' field. This validator checks if the provided value is equal to the default value of the 'value_type' field.

View File

@ -201,10 +201,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
stream_response=stream_response, stream_response=stream_response,
) )
def _listenAudioMsg(self, publisher, task_id: str): def _listen_audio_msg(self, publisher, task_id: str):
if publisher is None: if publisher is None:
return None return None
audio_msg: AudioTrunk = publisher.checkAndGetAudio() audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish": if audio_msg and audio_msg.status != "finish":
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore') # audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
@ -225,7 +225,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None)) publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
while True: while True:
audio_response = self._listenAudioMsg(publisher, task_id) audio_response = self._listen_audio_msg(publisher, task_id)
if audio_response: if audio_response:
yield audio_response yield audio_response
else: else:
@ -237,7 +237,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
if publisher is None: if publisher is None:
break break
audio = publisher.checkAndGetAudio() audio = publisher.check_and_get_audio()
if audio is None: if audio is None:
# release cpu # release cpu
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)

View File

@ -16,7 +16,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CodeExecutionException(Exception): class CodeExecutionError(Exception):
pass pass
@ -86,15 +86,15 @@ class CodeExecutor:
), ),
) )
if response.status_code == 503: if response.status_code == 503:
raise CodeExecutionException("Code execution service is unavailable") raise CodeExecutionError("Code execution service is unavailable")
elif response.status_code != 200: elif response.status_code != 200:
raise Exception( raise Exception(
f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running" f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running"
) )
except CodeExecutionException as e: except CodeExecutionError as e:
raise e raise e
except Exception as e: except Exception as e:
raise CodeExecutionException( raise CodeExecutionError(
"Failed to execute code, which is likely a network issue," "Failed to execute code, which is likely a network issue,"
" please check if the sandbox service is running." " please check if the sandbox service is running."
f" ( Error: {str(e)} )" f" ( Error: {str(e)} )"
@ -103,15 +103,15 @@ class CodeExecutor:
try: try:
response = response.json() response = response.json()
except: except:
raise CodeExecutionException("Failed to parse response") raise CodeExecutionError("Failed to parse response")
if (code := response.get("code")) != 0: if (code := response.get("code")) != 0:
raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}") raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}")
response = CodeExecutionResponse(**response) response = CodeExecutionResponse(**response)
if response.data.error: if response.data.error:
raise CodeExecutionException(response.data.error) raise CodeExecutionError(response.data.error)
return response.data.stdout or "" return response.data.stdout or ""
@ -126,13 +126,13 @@ class CodeExecutor:
""" """
template_transformer = cls.code_template_transformers.get(language) template_transformer = cls.code_template_transformers.get(language)
if not template_transformer: if not template_transformer:
raise CodeExecutionException(f"Unsupported language {language}") raise CodeExecutionError(f"Unsupported language {language}")
runner, preload = template_transformer.transform_caller(code, inputs) runner, preload = template_transformer.transform_caller(code, inputs)
try: try:
response = cls.execute_code(language, preload, runner) response = cls.execute_code(language, preload, runner)
except CodeExecutionException as e: except CodeExecutionError as e:
raise e raise e
return template_transformer.transform_response(response) return template_transformer.transform_response(response)

View File

@ -78,8 +78,8 @@ class IndexingRunner:
dataset_document=dataset_document, dataset_document=dataset_document,
documents=documents, documents=documents,
) )
except DocumentIsPausedException: except DocumentIsPausedError:
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
except ProviderTokenNotInitError as e: except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e.description) dataset_document.error = str(e.description)
@ -134,8 +134,8 @@ class IndexingRunner:
self._load( self._load(
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
) )
except DocumentIsPausedException: except DocumentIsPausedError:
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
except ProviderTokenNotInitError as e: except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e.description) dataset_document.error = str(e.description)
@ -192,8 +192,8 @@ class IndexingRunner:
self._load( self._load(
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
) )
except DocumentIsPausedException: except DocumentIsPausedError:
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id)) raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
except ProviderTokenNotInitError as e: except ProviderTokenNotInitError as e:
dataset_document.indexing_status = "error" dataset_document.indexing_status = "error"
dataset_document.error = str(e.description) dataset_document.error = str(e.description)
@ -756,7 +756,7 @@ class IndexingRunner:
indexing_cache_key = "document_{}_is_paused".format(document_id) indexing_cache_key = "document_{}_is_paused".format(document_id)
result = redis_client.get(indexing_cache_key) result = redis_client.get(indexing_cache_key)
if result: if result:
raise DocumentIsPausedException() raise DocumentIsPausedError()
@staticmethod @staticmethod
def _update_document_index_status( def _update_document_index_status(
@ -767,10 +767,10 @@ class IndexingRunner:
""" """
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count() count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
if count > 0: if count > 0:
raise DocumentIsPausedException() raise DocumentIsPausedError()
document = DatasetDocument.query.filter_by(id=document_id).first() document = DatasetDocument.query.filter_by(id=document_id).first()
if not document: if not document:
raise DocumentIsDeletedPausedException() raise DocumentIsDeletedPausedError()
update_params = {DatasetDocument.indexing_status: after_indexing_status} update_params = {DatasetDocument.indexing_status: after_indexing_status}
@ -875,9 +875,9 @@ class IndexingRunner:
pass pass
class DocumentIsPausedException(Exception): class DocumentIsPausedError(Exception):
pass pass
class DocumentIsDeletedPausedException(Exception): class DocumentIsDeletedPausedError(Exception):
pass pass

View File

@ -1,2 +1,2 @@
class OutputParserException(Exception): class OutputParserError(Exception):
pass pass

View File

@ -1,6 +1,6 @@
from typing import Any from typing import Any
from core.llm_generator.output_parser.errors import OutputParserException from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.prompts import ( from core.llm_generator.prompts import (
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
@ -29,4 +29,4 @@ class RuleConfigGeneratorOutputParser:
raise ValueError("Expected 'opening_statement' to be a str.") raise ValueError("Expected 'opening_statement' to be a str.")
return parsed return parsed
except Exception as e: except Exception as e:
raise OutputParserException(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}") raise OutputParserError(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}")

View File

@ -7,7 +7,7 @@ from requests import post
from core.model_runtime.entities.message_entities import PromptMessageTool from core.model_runtime.entities.message_entities import PromptMessageTool
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
BadRequestError, BadRequestError,
InsufficientAccountBalance, InsufficientAccountBalanceError,
InternalServerError, InternalServerError,
InvalidAPIKeyError, InvalidAPIKeyError,
InvalidAuthenticationError, InvalidAuthenticationError,
@ -124,7 +124,7 @@ class BaichuanModel:
if err == "invalid_api_key": if err == "invalid_api_key":
raise InvalidAPIKeyError(msg) raise InvalidAPIKeyError(msg)
elif err == "insufficient_quota": elif err == "insufficient_quota":
raise InsufficientAccountBalance(msg) raise InsufficientAccountBalanceError(msg)
elif err == "invalid_authentication": elif err == "invalid_authentication":
raise InvalidAuthenticationError(msg) raise InvalidAuthenticationError(msg)
elif err == "invalid_request_error": elif err == "invalid_request_error":

View File

@ -10,7 +10,7 @@ class RateLimitReachedError(Exception):
pass pass
class InsufficientAccountBalance(Exception): class InsufficientAccountBalanceError(Exception):
pass pass

View File

@ -29,7 +29,7 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import B
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
BadRequestError, BadRequestError,
InsufficientAccountBalance, InsufficientAccountBalanceError,
InternalServerError, InternalServerError,
InvalidAPIKeyError, InvalidAPIKeyError,
InvalidAuthenticationError, InvalidAuthenticationError,
@ -289,7 +289,7 @@ class BaichuanLanguageModel(LargeLanguageModel):
InvokeRateLimitError: [RateLimitReachedError], InvokeRateLimitError: [RateLimitReachedError],
InvokeAuthorizationError: [ InvokeAuthorizationError: [
InvalidAuthenticationError, InvalidAuthenticationError,
InsufficientAccountBalance, InsufficientAccountBalanceError,
InvalidAPIKeyError, InvalidAPIKeyError,
], ],
InvokeBadRequestError: [BadRequestError, KeyError], InvokeBadRequestError: [BadRequestError, KeyError],

View File

@ -19,7 +19,7 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
BadRequestError, BadRequestError,
InsufficientAccountBalance, InsufficientAccountBalanceError,
InternalServerError, InternalServerError,
InvalidAPIKeyError, InvalidAPIKeyError,
InvalidAuthenticationError, InvalidAuthenticationError,
@ -109,7 +109,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
if err == "invalid_api_key": if err == "invalid_api_key":
raise InvalidAPIKeyError(msg) raise InvalidAPIKeyError(msg)
elif err == "insufficient_quota": elif err == "insufficient_quota":
raise InsufficientAccountBalance(msg) raise InsufficientAccountBalanceError(msg)
elif err == "invalid_authentication": elif err == "invalid_authentication":
raise InvalidAuthenticationError(msg) raise InvalidAuthenticationError(msg)
elif err and "rate" in err: elif err and "rate" in err:
@ -166,7 +166,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
InvokeRateLimitError: [RateLimitReachedError], InvokeRateLimitError: [RateLimitReachedError],
InvokeAuthorizationError: [ InvokeAuthorizationError: [
InvalidAuthenticationError, InvalidAuthenticationError,
InsufficientAccountBalance, InsufficientAccountBalanceError,
InvalidAPIKeyError, InvalidAPIKeyError,
], ],
InvokeBadRequestError: [BadRequestError, KeyError], InvokeBadRequestError: [BadRequestError, KeyError],

View File

@ -10,7 +10,7 @@ from core.model_runtime.errors.invoke import (
) )
class _CommonOAI_API_Compat: class _CommonOaiApiCompat:
@property @property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
""" """

View File

@ -35,13 +35,13 @@ from core.model_runtime.entities.model_entities import (
from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.invoke import InvokeError
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
from core.model_runtime.utils import helper from core.model_runtime.utils import helper
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
""" """
Model class for OpenAI large language model. Model class for OpenAI large language model.
""" """

View File

@ -6,10 +6,10 @@ import requests
from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel): class OAICompatSpeech2TextModel(_CommonOaiApiCompat, Speech2TextModel):
""" """
Model class for OpenAI Compatible Speech to text model. Model class for OpenAI Compatible Speech to text model.
""" """

View File

@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import (
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
""" """
Model class for an OpenAI API-compatible text embedding model. Model class for an OpenAI API-compatible text embedding model.
""" """

View File

@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import (
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
""" """
Model class for an OpenAI API-compatible text embedding model. Model class for an OpenAI API-compatible text embedding model.
""" """

View File

@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasError, MaasService
class MaaSClient(MaasService): class MaaSClient(MaasService):
@ -106,7 +106,7 @@ class MaaSClient(MaasService):
def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator: def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
try: try:
resp = fn() resp = fn()
except MaasException as e: except MaasError as e:
raise wrap_error(e) raise wrap_error(e)
return resp return resp

View File

@ -1,144 +1,144 @@
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasError
class ClientSDKRequestError(MaasException): class ClientSDKRequestError(MaasError):
pass pass
class SignatureDoesNotMatch(MaasException): class SignatureDoesNotMatchError(MaasError):
pass pass
class RequestTimeout(MaasException): class RequestTimeoutError(MaasError):
pass pass
class ServiceConnectionTimeout(MaasException): class ServiceConnectionTimeoutError(MaasError):
pass pass
class MissingAuthenticationHeader(MaasException): class MissingAuthenticationHeaderError(MaasError):
pass pass
class AuthenticationHeaderIsInvalid(MaasException): class AuthenticationHeaderIsInvalidError(MaasError):
pass pass
class InternalServiceError(MaasException): class InternalServiceError(MaasError):
pass pass
class MissingParameter(MaasException): class MissingParameterError(MaasError):
pass pass
class InvalidParameter(MaasException): class InvalidParameterError(MaasError):
pass pass
class AuthenticationExpire(MaasException): class AuthenticationExpireError(MaasError):
pass pass
class EndpointIsInvalid(MaasException): class EndpointIsInvalidError(MaasError):
pass pass
class EndpointIsNotEnable(MaasException): class EndpointIsNotEnableError(MaasError):
pass pass
class ModelNotSupportStreamMode(MaasException): class ModelNotSupportStreamModeError(MaasError):
pass pass
class ReqTextExistRisk(MaasException): class ReqTextExistRiskError(MaasError):
pass pass
class RespTextExistRisk(MaasException): class RespTextExistRiskError(MaasError):
pass pass
class EndpointRateLimitExceeded(MaasException): class EndpointRateLimitExceededError(MaasError):
pass pass
class ServiceConnectionRefused(MaasException): class ServiceConnectionRefusedError(MaasError):
pass pass
class ServiceConnectionClosed(MaasException): class ServiceConnectionClosedError(MaasError):
pass pass
class UnauthorizedUserForEndpoint(MaasException): class UnauthorizedUserForEndpointError(MaasError):
pass pass
class InvalidEndpointWithNoURL(MaasException): class InvalidEndpointWithNoURLError(MaasError):
pass pass
class EndpointAccountRpmRateLimitExceeded(MaasException): class EndpointAccountRpmRateLimitExceededError(MaasError):
pass pass
class EndpointAccountTpmRateLimitExceeded(MaasException): class EndpointAccountTpmRateLimitExceededError(MaasError):
pass pass
class ServiceResourceWaitQueueFull(MaasException): class ServiceResourceWaitQueueFullError(MaasError):
pass pass
class EndpointIsPending(MaasException): class EndpointIsPendingError(MaasError):
pass pass
class ServiceNotOpen(MaasException): class ServiceNotOpenError(MaasError):
pass pass
AuthErrors = { AuthErrors = {
"SignatureDoesNotMatch": SignatureDoesNotMatch, "SignatureDoesNotMatch": SignatureDoesNotMatchError,
"MissingAuthenticationHeader": MissingAuthenticationHeader, "MissingAuthenticationHeader": MissingAuthenticationHeaderError,
"AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalid, "AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalidError,
"AuthenticationExpire": AuthenticationExpire, "AuthenticationExpire": AuthenticationExpireError,
"UnauthorizedUserForEndpoint": UnauthorizedUserForEndpoint, "UnauthorizedUserForEndpoint": UnauthorizedUserForEndpointError,
} }
BadRequestErrors = { BadRequestErrors = {
"MissingParameter": MissingParameter, "MissingParameter": MissingParameterError,
"InvalidParameter": InvalidParameter, "InvalidParameter": InvalidParameterError,
"EndpointIsInvalid": EndpointIsInvalid, "EndpointIsInvalid": EndpointIsInvalidError,
"EndpointIsNotEnable": EndpointIsNotEnable, "EndpointIsNotEnable": EndpointIsNotEnableError,
"ModelNotSupportStreamMode": ModelNotSupportStreamMode, "ModelNotSupportStreamMode": ModelNotSupportStreamModeError,
"ReqTextExistRisk": ReqTextExistRisk, "ReqTextExistRisk": ReqTextExistRiskError,
"RespTextExistRisk": RespTextExistRisk, "RespTextExistRisk": RespTextExistRiskError,
"InvalidEndpointWithNoURL": InvalidEndpointWithNoURL, "InvalidEndpointWithNoURL": InvalidEndpointWithNoURLError,
"ServiceNotOpen": ServiceNotOpen, "ServiceNotOpen": ServiceNotOpenError,
} }
RateLimitErrors = { RateLimitErrors = {
"EndpointRateLimitExceeded": EndpointRateLimitExceeded, "EndpointRateLimitExceeded": EndpointRateLimitExceededError,
"EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceeded, "EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceededError,
"EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceeded, "EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceededError,
} }
ServerUnavailableErrors = { ServerUnavailableErrors = {
"InternalServiceError": InternalServiceError, "InternalServiceError": InternalServiceError,
"EndpointIsPending": EndpointIsPending, "EndpointIsPending": EndpointIsPendingError,
"ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFull, "ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFullError,
} }
ConnectionErrors = { ConnectionErrors = {
"ClientSDKRequestError": ClientSDKRequestError, "ClientSDKRequestError": ClientSDKRequestError,
"RequestTimeout": RequestTimeout, "RequestTimeout": RequestTimeoutError,
"ServiceConnectionTimeout": ServiceConnectionTimeout, "ServiceConnectionTimeout": ServiceConnectionTimeoutError,
"ServiceConnectionRefused": ServiceConnectionRefused, "ServiceConnectionRefused": ServiceConnectionRefusedError,
"ServiceConnectionClosed": ServiceConnectionClosed, "ServiceConnectionClosed": ServiceConnectionClosedError,
} }
ErrorCodeMap = { ErrorCodeMap = {
@ -150,7 +150,7 @@ ErrorCodeMap = {
} }
def wrap_error(e: MaasException) -> Exception: def wrap_error(e: MaasError) -> Exception:
if ErrorCodeMap.get(e.code): if ErrorCodeMap.get(e.code):
return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id) return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id)
return e return e

View File

@ -1,4 +1,4 @@
from .common import ChatRole from .common import ChatRole
from .maas import MaasException, MaasService from .maas import MaasError, MaasService
__all__ = ["MaasService", "ChatRole", "MaasException"] __all__ = ["MaasService", "ChatRole", "MaasError"]

View File

@ -63,7 +63,7 @@ class MaasService(Service):
raise raise
if res.error is not None and res.error.code_n != 0: if res.error is not None and res.error.code_n != 0:
raise MaasException( raise MaasError(
res.error.code_n, res.error.code_n,
res.error.code, res.error.code,
res.error.message, res.error.message,
@ -72,7 +72,7 @@ class MaasService(Service):
yield res yield res
return iter_fn() return iter_fn()
except MaasException: except MaasError:
raise raise
except Exception as e: except Exception as e:
raise new_client_sdk_request_error(str(e)) raise new_client_sdk_request_error(str(e))
@ -94,7 +94,7 @@ class MaasService(Service):
resp["req_id"] = req_id resp["req_id"] = req_id
return resp return resp
except MaasException as e: except MaasError as e:
raise e raise e
except Exception as e: except Exception as e:
raise new_client_sdk_request_error(str(e), req_id) raise new_client_sdk_request_error(str(e), req_id)
@ -147,14 +147,14 @@ class MaasService(Service):
raise new_client_sdk_request_error(raw, req_id) raise new_client_sdk_request_error(raw, req_id)
if resp.error: if resp.error:
raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, req_id) raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, req_id)
else: else:
raise new_client_sdk_request_error(resp, req_id) raise new_client_sdk_request_error(resp, req_id)
return res return res
class MaasException(Exception): class MaasError(Exception):
def __init__(self, code_n, code, message, req_id): def __init__(self, code_n, code, message, req_id):
self.code_n = code_n self.code_n = code_n
self.code = code self.code = code
@ -172,7 +172,7 @@ class MaasException(Exception):
def new_client_sdk_request_error(raw, req_id=""): def new_client_sdk_request_error(raw, req_id=""):
return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id) return MaasError(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id)
class BinaryResponseContent: class BinaryResponseContent:
@ -192,7 +192,7 @@ class BinaryResponseContent:
if len(error_bytes) > 0: if len(error_bytes) > 0:
resp = json_to_object(str(error_bytes, encoding="utf-8"), req_id=self.request_id) resp = json_to_object(str(error_bytes, encoding="utf-8"), req_id=self.request_id)
raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, self.request_id) raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, self.request_id)
def iter_bytes(self) -> Iterator[bytes]: def iter_bytes(self) -> Iterator[bytes]:
yield from self.response yield from self.response

View File

@ -35,7 +35,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
AuthErrors, AuthErrors,
BadRequestErrors, BadRequestErrors,
ConnectionErrors, ConnectionErrors,
MaasException, MaasError,
RateLimitErrors, RateLimitErrors,
ServerUnavailableErrors, ServerUnavailableErrors,
) )
@ -85,7 +85,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
}, },
[UserPromptMessage(content="ping\nAnswer: ")], [UserPromptMessage(content="ping\nAnswer: ")],
) )
except MaasException as e: except MaasError as e:
raise CredentialsValidateFailedError(e.message) raise CredentialsValidateFailedError(e.message)
@staticmethod @staticmethod

View File

@ -28,7 +28,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
AuthErrors, AuthErrors,
BadRequestErrors, BadRequestErrors,
ConnectionErrors, ConnectionErrors,
MaasException, MaasError,
RateLimitErrors, RateLimitErrors,
ServerUnavailableErrors, ServerUnavailableErrors,
) )
@ -111,7 +111,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
def _validate_credentials_v2(self, model: str, credentials: dict) -> None: def _validate_credentials_v2(self, model: str, credentials: dict) -> None:
try: try:
self._invoke(model=model, credentials=credentials, texts=["ping"]) self._invoke(model=model, credentials=credentials, texts=["ping"])
except MaasException as e: except MaasError as e:
raise CredentialsValidateFailedError(e.message) raise CredentialsValidateFailedError(e.message)
def _validate_credentials_v3(self, model: str, credentials: dict) -> None: def _validate_credentials_v3(self, model: str, credentials: dict) -> None:

View File

@ -23,7 +23,7 @@ def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]:
InvokeRateLimitError: [RateLimitReachedError], InvokeRateLimitError: [RateLimitReachedError],
InvokeAuthorizationError: [ InvokeAuthorizationError: [
InvalidAuthenticationError, InvalidAuthenticationError,
InsufficientAccountBalance, InsufficientAccountBalanceError,
InvalidAPIKeyError, InvalidAPIKeyError,
], ],
InvokeBadRequestError: [BadRequestError, KeyError], InvokeBadRequestError: [BadRequestError, KeyError],
@ -42,7 +42,7 @@ class RateLimitReachedError(Exception):
pass pass
class InsufficientAccountBalance(Exception): class InsufficientAccountBalanceError(Exception):
pass pass

View File

@ -76,7 +76,7 @@ class Moderation(Extensible, ABC):
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None: def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None:
# inputs_config # inputs_config
inputs_config = config.get("inputs_config") inputs_config = config.get("inputs_config")
if not isinstance(inputs_config, dict): if not isinstance(inputs_config, dict):
@ -111,5 +111,5 @@ class Moderation(Extensible, ABC):
raise ValueError("outputs_config.preset_response must be less than 100 characters") raise ValueError("outputs_config.preset_response must be less than 100 characters")
class ModerationException(Exception): class ModerationError(Exception):
pass pass

View File

@ -2,7 +2,7 @@ import logging
from typing import Optional from typing import Optional
from core.app.app_config.entities import AppConfig from core.app.app_config.entities import AppConfig
from core.moderation.base import ModerationAction, ModerationException from core.moderation.base import ModerationAction, ModerationError
from core.moderation.factory import ModerationFactory from core.moderation.factory import ModerationFactory
from core.ops.entities.trace_entity import TraceTaskName from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
@ -61,7 +61,7 @@ class InputModeration:
return False, inputs, query return False, inputs, query
if moderation_result.action == ModerationAction.DIRECT_OUTPUT: if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
raise ModerationException(moderation_result.preset_response) raise ModerationError(moderation_result.preset_response)
elif moderation_result.action == ModerationAction.OVERRIDDEN: elif moderation_result.action == ModerationAction.OVERRIDDEN:
inputs = moderation_result.inputs inputs = moderation_result.inputs
query = moderation_result.query query = moderation_result.query

View File

@ -26,6 +26,7 @@ class LangfuseConfig(BaseTracingConfig):
host: str = "https://api.langfuse.com" host: str = "https://api.langfuse.com"
@field_validator("host") @field_validator("host")
@classmethod
def set_value(cls, v, info: ValidationInfo): def set_value(cls, v, info: ValidationInfo):
if v is None or v == "": if v is None or v == "":
v = "https://api.langfuse.com" v = "https://api.langfuse.com"
@ -45,6 +46,7 @@ class LangSmithConfig(BaseTracingConfig):
endpoint: str = "https://api.smith.langchain.com" endpoint: str = "https://api.smith.langchain.com"
@field_validator("endpoint") @field_validator("endpoint")
@classmethod
def set_value(cls, v, info: ValidationInfo): def set_value(cls, v, info: ValidationInfo):
if v is None or v == "": if v is None or v == "":
v = "https://api.smith.langchain.com" v = "https://api.smith.langchain.com"

View File

@ -15,6 +15,7 @@ class BaseTraceInfo(BaseModel):
metadata: dict[str, Any] metadata: dict[str, Any]
@field_validator("inputs", "outputs") @field_validator("inputs", "outputs")
@classmethod
def ensure_type(cls, v): def ensure_type(cls, v):
if v is None: if v is None:
return None return None

View File

@ -101,6 +101,7 @@ class LangfuseTrace(BaseModel):
) )
@field_validator("input", "output") @field_validator("input", "output")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo): def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name field_name = info.field_name
return validate_input_output(v, field_name) return validate_input_output(v, field_name)
@ -171,6 +172,7 @@ class LangfuseSpan(BaseModel):
) )
@field_validator("input", "output") @field_validator("input", "output")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo): def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name field_name = info.field_name
return validate_input_output(v, field_name) return validate_input_output(v, field_name)
@ -196,6 +198,7 @@ class GenerationUsage(BaseModel):
totalCost: Optional[float] = None totalCost: Optional[float] = None
@field_validator("input", "output") @field_validator("input", "output")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo): def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name field_name = info.field_name
return validate_input_output(v, field_name) return validate_input_output(v, field_name)
@ -273,6 +276,7 @@ class LangfuseGeneration(BaseModel):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
@field_validator("input", "output") @field_validator("input", "output")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo): def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name field_name = info.field_name
return validate_input_output(v, field_name) return validate_input_output(v, field_name)

View File

@ -51,6 +51,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run") output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
@field_validator("inputs", "outputs") @field_validator("inputs", "outputs")
@classmethod
def ensure_dict(cls, v, info: ValidationInfo): def ensure_dict(cls, v, info: ValidationInfo):
field_name = info.field_name field_name = info.field_name
values = info.data values = info.data
@ -115,6 +116,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
return v return v
return v return v
@classmethod
@field_validator("start_time", "end_time") @field_validator("start_time", "end_time")
def format_time(cls, v, info: ValidationInfo): def format_time(cls, v, info: ValidationInfo):
if not isinstance(v, datetime): if not isinstance(v, datetime):

View File

@ -27,6 +27,7 @@ class ElasticSearchConfig(BaseModel):
password: str password: str
@model_validator(mode="before") @model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict: def validate_config(cls, values: dict) -> dict:
if not values["host"]: if not values["host"]:
raise ValueError("config HOST is required") raise ValueError("config HOST is required")

View File

@ -28,6 +28,7 @@ class MilvusConfig(BaseModel):
database: str = "default" database: str = "default"
@model_validator(mode="before") @model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict: def validate_config(cls, values: dict) -> dict:
if not values.get("uri"): if not values.get("uri"):
raise ValueError("config MILVUS_URI is required") raise ValueError("config MILVUS_URI is required")

View File

@ -29,6 +29,7 @@ class OpenSearchConfig(BaseModel):
secure: bool = False secure: bool = False
@model_validator(mode="before") @model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict: def validate_config(cls, values: dict) -> dict:
if not values.get("host"): if not values.get("host"):
raise ValueError("config OPENSEARCH_HOST is required") raise ValueError("config OPENSEARCH_HOST is required")

View File

@ -32,6 +32,7 @@ class OracleVectorConfig(BaseModel):
database: str database: str
@model_validator(mode="before") @model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict: def validate_config(cls, values: dict) -> dict:
if not values["host"]: if not values["host"]:
raise ValueError("config ORACLE_HOST is required") raise ValueError("config ORACLE_HOST is required")

View File

@ -32,6 +32,7 @@ class PgvectoRSConfig(BaseModel):
database: str database: str
@model_validator(mode="before") @model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict: def validate_config(cls, values: dict) -> dict:
if not values["host"]: if not values["host"]:
raise ValueError("config PGVECTO_RS_HOST is required") raise ValueError("config PGVECTO_RS_HOST is required")

View File

@ -25,6 +25,7 @@ class PGVectorConfig(BaseModel):
database: str database: str
@model_validator(mode="before") @model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict: def validate_config(cls, values: dict) -> dict:
if not values["host"]: if not values["host"]:
raise ValueError("config PGVECTOR_HOST is required") raise ValueError("config PGVECTOR_HOST is required")

View File

@ -34,6 +34,7 @@ class RelytConfig(BaseModel):
database: str database: str
@model_validator(mode="before") @model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict: def validate_config(cls, values: dict) -> dict:
if not values["host"]: if not values["host"]:
raise ValueError("config RELYT_HOST is required") raise ValueError("config RELYT_HOST is required")

View File

@ -29,6 +29,7 @@ class TiDBVectorConfig(BaseModel):
program_name: str program_name: str
@model_validator(mode="before") @model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict: def validate_config(cls, values: dict) -> dict:
if not values["host"]: if not values["host"]:
raise ValueError("config TIDB_VECTOR_HOST is required") raise ValueError("config TIDB_VECTOR_HOST is required")

View File

@ -23,6 +23,7 @@ class WeaviateConfig(BaseModel):
batch_size: int = 100 batch_size: int = 100
@model_validator(mode="before") @model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict: def validate_config(cls, values: dict) -> dict:
if not values["endpoint"]: if not values["endpoint"]:
raise ValueError("config WEAVIATE_ENDPOINT is required") raise ValueError("config WEAVIATE_ENDPOINT is required")

View File

@ -69,7 +69,7 @@ class DallE3Tool(BuiltinTool):
self.create_blob_message( self.create_blob_message(
blob=b64decode(image.b64_json), blob=b64decode(image.b64_json),
meta={"mime_type": "image/png"}, meta={"mime_type": "image/png"},
save_as=self.VARIABLE_KEY.IMAGE.value, save_as=self.VariableKey.IMAGE.value,
) )
) )
result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}")) result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}"))

View File

@ -59,7 +59,7 @@ class DallE2Tool(BuiltinTool):
self.create_blob_message( self.create_blob_message(
blob=b64decode(image.b64_json), blob=b64decode(image.b64_json),
meta={"mime_type": "image/png"}, meta={"mime_type": "image/png"},
save_as=self.VARIABLE_KEY.IMAGE.value, save_as=self.VariableKey.IMAGE.value,
) )
) )

View File

@ -66,7 +66,7 @@ class DallE3Tool(BuiltinTool):
for image in response.data: for image in response.data:
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json) mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
blob_message = self.create_blob_message( blob_message = self.create_blob_message(
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VARIABLE_KEY.IMAGE.value blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value
) )
result.append(blob_message) result.append(blob_message)
return result return result

View File

@ -34,7 +34,7 @@ class NovitaAiCreateTileTool(BuiltinTool):
self.create_blob_message( self.create_blob_message(
blob=b64decode(client_result.image_file), blob=b64decode(client_result.image_file),
meta={"mime_type": f"image/{client_result.image_type}"}, meta={"mime_type": f"image/{client_result.image_type}"},
save_as=self.VARIABLE_KEY.IMAGE.value, save_as=self.VariableKey.IMAGE.value,
) )
) )

View File

@ -40,7 +40,7 @@ class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
self.create_blob_message( self.create_blob_message(
blob=b64decode(image_encoded), blob=b64decode(image_encoded),
meta={"mime_type": f"image/{image.image_type}"}, meta={"mime_type": f"image/{image.image_type}"},
save_as=self.VARIABLE_KEY.IMAGE.value, save_as=self.VariableKey.IMAGE.value,
) )
) )

View File

@ -46,7 +46,7 @@ class QRCodeGeneratorTool(BuiltinTool):
image = self._generate_qrcode(content, border, error_correction) image = self._generate_qrcode(content, border, error_correction)
image_bytes = self._image_to_byte_array(image) image_bytes = self._image_to_byte_array(image)
return self.create_blob_message( return self.create_blob_message(
blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
) )
except Exception: except Exception:
logging.exception(f"Failed to generate QR code for content: {content}") logging.exception(f"Failed to generate QR code for content: {content}")

View File

@ -32,5 +32,5 @@ class FluxTool(BuiltinTool):
res = response.json() res = response.json()
result = [self.create_json_message(res)] result = [self.create_json_message(res)]
for image in res.get("images", []): for image in res.get("images", []):
result.append(self.create_image_message(image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value)) result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value))
return result return result

View File

@ -41,5 +41,5 @@ class StableDiffusionTool(BuiltinTool):
res = response.json() res = response.json()
result = [self.create_json_message(res)] result = [self.create_json_message(res)]
for image in res.get("images", []): for image in res.get("images", []):
result.append(self.create_image_message(image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value)) result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value))
return result return result

View File

@ -15,16 +15,16 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
class AssembleHeaderException(Exception): class AssembleHeaderError(Exception):
def __init__(self, msg): def __init__(self, msg):
self.message = msg self.message = msg
class Url: class Url:
def __init__(this, host, path, schema): def __init__(self, host, path, schema):
this.host = host self.host = host
this.path = path self.path = path
this.schema = schema self.schema = schema
# calculate sha256 and encode to base64 # calculate sha256 and encode to base64
@ -41,7 +41,7 @@ def parse_url(request_url):
schema = request_url[: stidx + 3] schema = request_url[: stidx + 3]
edidx = host.index("/") edidx = host.index("/")
if edidx <= 0: if edidx <= 0:
raise AssembleHeaderException("invalid request url:" + request_url) raise AssembleHeaderError("invalid request url:" + request_url)
path = host[edidx:] path = host[edidx:]
host = host[:edidx] host = host[:edidx]
u = Url(host, path, schema) u = Url(host, path, schema)
@ -115,7 +115,7 @@ class SparkImgGeneratorTool(BuiltinTool):
self.create_blob_message( self.create_blob_message(
blob=b64decode(image["base64_image"]), blob=b64decode(image["base64_image"]),
meta={"mime_type": "image/png"}, meta={"mime_type": "image/png"},
save_as=self.VARIABLE_KEY.IMAGE.value, save_as=self.VariableKey.IMAGE.value,
) )
) )
return result return result

View File

@ -52,5 +52,5 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
raise Exception(response.text) raise Exception(response.text)
return self.create_blob_message( return self.create_blob_message(
blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
) )

View File

@ -260,7 +260,7 @@ class StableDiffusionTool(BuiltinTool):
image = response.json()["images"][0] image = response.json()["images"][0]
return self.create_blob_message( return self.create_blob_message(
blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
) )
except Exception as e: except Exception as e:
@ -294,7 +294,7 @@ class StableDiffusionTool(BuiltinTool):
image = response.json()["images"][0] image = response.json()["images"][0]
return self.create_blob_message( return self.create_blob_message(
blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
) )
except Exception as e: except Exception as e:

View File

@ -45,5 +45,5 @@ class PoiSearchTool(BuiltinTool):
).content ).content
return self.create_blob_message( return self.create_blob_message(
blob=result, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value blob=result, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
) )

View File

@ -32,7 +32,7 @@ class VectorizerTool(BuiltinTool):
if image_id.startswith("__test_"): if image_id.startswith("__test_"):
image_binary = b64decode(VECTORIZER_ICON_PNG) image_binary = b64decode(VECTORIZER_ICON_PNG)
else: else:
image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE) image_binary = self.get_variable_file(self.VariableKey.IMAGE)
if not image_binary: if not image_binary:
return self.create_text_message("Image not found, please request user to generate image firstly.") return self.create_text_message("Image not found, please request user to generate image firstly.")

View File

@ -63,7 +63,7 @@ class Tool(BaseModel, ABC):
def __init__(self, **data: Any): def __init__(self, **data: Any):
super().__init__(**data) super().__init__(**data)
class VARIABLE_KEY(Enum): class VariableKey(Enum):
IMAGE = "image" IMAGE = "image"
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
@ -142,7 +142,7 @@ class Tool(BaseModel, ABC):
if not self.variables: if not self.variables:
return None return None
return self.get_variable(self.VARIABLE_KEY.IMAGE) return self.get_variable(self.VariableKey.IMAGE)
def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]: def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]:
""" """
@ -189,7 +189,7 @@ class Tool(BaseModel, ABC):
result = [] result = []
for variable in self.variables.pool: for variable in self.variables.pool:
if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value): if variable.name.startswith(self.VariableKey.IMAGE.value):
result.append(variable) result.append(variable)
return result return result

View File

@ -8,7 +8,7 @@ from typing import Any, Optional
from flask import Flask, current_app from flask import Flask, current_app
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import ( from core.workflow.entities.node_entities import (
NodeRunMetadataKey, NodeRunMetadataKey,
@ -669,7 +669,7 @@ class GraphEngine:
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
) )
except GenerateTaskStoppedException: except GenerateTaskStoppedError:
# trigger node run failed event # trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped." route_node_state.failed_reason = "Workflow stopped."

View File

@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
from configs import dify_config from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
@ -61,7 +61,7 @@ class CodeNode(BaseNode):
# Transform result # Transform result
result = self._transform_result(result, node_data.outputs) result = self._transform_result(result, node_data.outputs)
except (CodeExecutionException, ValueError) as e: except (CodeExecutionError, ValueError) as e:
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)

View File

@ -2,7 +2,7 @@ import os
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
@ -45,7 +45,7 @@ class TemplateTransformNode(BaseNode):
result = CodeExecutor.execute_workflow_code_template( result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables
) )
except CodeExecutionException as e: except CodeExecutionError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:

View File

@ -6,7 +6,7 @@ from typing import Any, Optional, cast
from configs import dify_config from configs import dify_config
from core.app.app_config.entities import FileExtraConfig from core.app.app_config.entities import FileExtraConfig
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.file_obj import FileTransferMethod, FileType, FileVar from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
@ -103,7 +103,7 @@ class WorkflowEntry:
for callback in callbacks: for callback in callbacks:
callback.on_event(event=event) callback.on_event(event=event)
yield event yield event
except GenerateTaskStoppedException: except GenerateTaskStoppedError:
pass pass
except Exception as e: except Exception as e:
logger.exception("Unknown Error when workflow entry running") logger.exception("Unknown Error when workflow entry running")

View File

@ -5,7 +5,7 @@ import time
import click import click
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedException, IndexingRunner from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from events.event_handlers.document_index_event import document_index_created from events.event_handlers.document_index_event import document_index_created
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Document from models.dataset import Document
@ -43,7 +43,7 @@ def handle(sender, **kwargs):
indexing_runner.run(documents) indexing_runner.run(documents)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
except DocumentIsPausedException as ex: except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow")) logging.info(click.style(str(ex), fg="yellow"))
except Exception: except Exception:
pass pass

View File

@ -5,7 +5,7 @@ from collections.abc import Generator
from contextlib import closing from contextlib import closing
from flask import Flask from flask import Flask
from google.cloud import storage as GoogleCloudStorage from google.cloud import storage as google_cloud_storage
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage
@ -23,9 +23,9 @@ class GoogleStorage(BaseStorage):
service_account_json = base64.b64decode(service_account_json_str).decode("utf-8") service_account_json = base64.b64decode(service_account_json_str).decode("utf-8")
# convert str to object # convert str to object
service_account_obj = json.loads(service_account_json) service_account_obj = json.loads(service_account_json)
self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj) self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj)
else: else:
self.client = GoogleCloudStorage.Client() self.client = google_cloud_storage.Client()
def save(self, filename, data): def save(self, filename, data):
bucket = self.client.get_bucket(self.bucket_name) bucket = self.client.get_bucket(self.bucket_name)

View File

@ -31,7 +31,7 @@ from Crypto.Util.py3compat import _copy_bytes, bord
from Crypto.Util.strxor import strxor from Crypto.Util.strxor import strxor
class PKCS1OAEP_Cipher: class PKCS1OAepCipher:
"""Cipher object for PKCS#1 v1.5 OAEP. """Cipher object for PKCS#1 v1.5 OAEP.
Do not create directly: use :func:`new` instead.""" Do not create directly: use :func:`new` instead."""
@ -237,4 +237,4 @@ def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None):
if randfunc is None: if randfunc is None:
randfunc = Random.get_random_bytes randfunc = Random.get_random_bytes
return PKCS1OAEP_Cipher(key, hashAlgo, mgfunc, label, randfunc) return PKCS1OAepCipher(key, hashAlgo, mgfunc, label, randfunc)

View File

@ -84,7 +84,7 @@ def timestamp_value(timestamp):
raise ValueError(error) raise ValueError(error)
class str_len: class StrLen:
"""Restrict input to an integer in a range (inclusive)""" """Restrict input to an integer in a range (inclusive)"""
def __init__(self, max_length, argument="argument"): def __init__(self, max_length, argument="argument"):
@ -102,7 +102,7 @@ class str_len:
return value return value
class float_range: class FloatRange:
"""Restrict input to an float in a range (inclusive)""" """Restrict input to an float in a range (inclusive)"""
def __init__(self, low, high, argument="argument"): def __init__(self, low, high, argument="argument"):
@ -121,7 +121,7 @@ class float_range:
return value return value
class datetime_string: class DatetimeString:
def __init__(self, format, argument="argument"): def __init__(self, format, argument="argument"):
self.format = format self.format = format
self.argument = argument self.argument = argument

View File

@ -1,6 +1,6 @@
import json import json
from core.llm_generator.output_parser.errors import OutputParserException from core.llm_generator.output_parser.errors import OutputParserError
def parse_json_markdown(json_string: str) -> dict: def parse_json_markdown(json_string: str) -> dict:
@ -33,10 +33,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict:
try: try:
json_obj = parse_json_markdown(text) json_obj = parse_json_markdown(text)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise OutputParserException(f"Got invalid JSON object. Error: {e}") raise OutputParserError(f"Got invalid JSON object. Error: {e}")
for key in expected_keys: for key in expected_keys:
if key not in json_obj: if key not in json_obj:
raise OutputParserException( raise OutputParserError(
f"Got invalid return object. Expected key `{key}` " f"to be present, but got {json_obj}" f"Got invalid return object. Expected key `{key}` " f"to be present, but got {json_obj}"
) )
return json_obj return json_obj

View File

@ -15,8 +15,8 @@ select = [
"C4", # flake8-comprehensions "C4", # flake8-comprehensions
"F", # pyflakes rules "F", # pyflakes rules
"I", # isort rules "I", # isort rules
"N", # pep8-naming
"UP", # pyupgrade rules "UP", # pyupgrade rules
"B035", # static-key-dict-comprehension
"E101", # mixed-spaces-and-tabs "E101", # mixed-spaces-and-tabs
"E111", # indentation-with-invalid-multiple "E111", # indentation-with-invalid-multiple
"E112", # no-indented-block "E112", # no-indented-block
@ -47,9 +47,10 @@ ignore = [
"B006", # mutable-argument-default "B006", # mutable-argument-default
"B007", # unused-loop-control-variable "B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg "B026", # star-arg-unpacking-after-keyword-arg
# "B901", # return-in-generator
"B904", # raise-without-from-inside-except "B904", # raise-without-from-inside-except
"B905", # zip-without-explicit-strict "B905", # zip-without-explicit-strict
"N806", # non-lowercase-variable-in-function
"N815", # mixed-case-variable-in-class-scope
] ]
[tool.ruff.lint.per-file-ignores] [tool.ruff.lint.per-file-ignores]
@ -65,6 +66,12 @@ ignore = [
"F401", # unused-import "F401", # unused-import
"F811", # redefined-while-unused "F811", # redefined-while-unused
] ]
"configs/*" = [
"N802", # invalid-function-name
]
"libs/gmpy2_pkcs10aep_cipher.py" = [
"N803", # invalid-argument-name
]
[tool.ruff.format] [tool.ruff.format]
exclude = [ exclude = [

View File

@ -32,7 +32,7 @@ from services.errors.account import (
NoPermissionError, NoPermissionError,
RateLimitExceededError, RateLimitExceededError,
RoleAlreadyAssignedError, RoleAlreadyAssignedError,
TenantNotFound, TenantNotFoundError,
) )
from tasks.mail_invite_member_task import send_invite_member_mail_task from tasks.mail_invite_member_task import send_invite_member_mail_task
from tasks.mail_reset_password_task import send_reset_password_mail_task from tasks.mail_reset_password_task import send_reset_password_mail_task
@ -311,13 +311,13 @@ class TenantService:
"""Get tenant by account and add the role""" """Get tenant by account and add the role"""
tenant = account.current_tenant tenant = account.current_tenant
if not tenant: if not tenant:
raise TenantNotFound("Tenant not found.") raise TenantNotFoundError("Tenant not found.")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
if ta: if ta:
tenant.role = ta.role tenant.role = ta.role
else: else:
raise TenantNotFound("Tenant not found for the account.") raise TenantNotFoundError("Tenant not found for the account.")
return tenant return tenant
@staticmethod @staticmethod
@ -614,8 +614,8 @@ class RegisterService:
"email": account.email, "email": account.email,
"workspace_id": tenant.id, "workspace_id": tenant.id,
} }
expiryHours = dify_config.INVITE_EXPIRY_HOURS expiry_hours = dify_config.INVITE_EXPIRY_HOURS
redis_client.setex(cls._get_invitation_token_key(token), expiryHours * 60 * 60, json.dumps(invitation_data)) redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data))
return token return token
@classmethod @classmethod

View File

@ -1,7 +1,7 @@
from services.errors.base import BaseServiceError from services.errors.base import BaseServiceError
class AccountNotFound(BaseServiceError): class AccountNotFoundError(BaseServiceError):
pass pass
@ -25,7 +25,7 @@ class LinkAccountIntegrateError(BaseServiceError):
pass pass
class TenantNotFound(BaseServiceError): class TenantNotFoundError(BaseServiceError):
pass pass

View File

@ -6,7 +6,7 @@ import click
from celery import shared_task from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedException, IndexingRunner from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
@ -106,7 +106,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logging.info( logging.info(
click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green") click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green")
) )
except DocumentIsPausedException as ex: except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow")) logging.info(click.style(str(ex), fg="yellow"))
except Exception: except Exception:
pass pass

View File

@ -6,7 +6,7 @@ import click
from celery import shared_task from celery import shared_task
from configs import dify_config from configs import dify_config
from core.indexing_runner import DocumentIsPausedException, IndexingRunner from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Document from models.dataset import Dataset, Document
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -72,7 +72,7 @@ def document_indexing_task(dataset_id: str, document_ids: list):
indexing_runner.run(documents) indexing_runner.run(documents)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
except DocumentIsPausedException as ex: except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow")) logging.info(click.style(str(ex), fg="yellow"))
except Exception: except Exception:
pass pass

View File

@ -6,7 +6,7 @@ import click
from celery import shared_task from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedException, IndexingRunner from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
@ -69,7 +69,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
indexing_runner.run([document]) indexing_runner.run([document])
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green")) logging.info(click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green"))
except DocumentIsPausedException as ex: except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow")) logging.info(click.style(str(ex), fg="yellow"))
except Exception: except Exception:
pass pass

View File

@ -6,7 +6,7 @@ import click
from celery import shared_task from celery import shared_task
from configs import dify_config from configs import dify_config
from core.indexing_runner import DocumentIsPausedException, IndexingRunner from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
@ -88,7 +88,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
indexing_runner.run(documents) indexing_runner.run(documents)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")) logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
except DocumentIsPausedException as ex: except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow")) logging.info(click.style(str(ex), fg="yellow"))
except Exception: except Exception:
pass pass

View File

@ -5,7 +5,7 @@ import click
from celery import shared_task from celery import shared_task
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedException, IndexingRunner from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Document from models.dataset import Document
@ -39,7 +39,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
logging.info( logging.info(
click.style("Processed document: {} latency: {}".format(document.id, end_at - start_at), fg="green") click.style("Processed document: {} latency: {}".format(document.id, end_at - start_at), fg="green")
) )
except DocumentIsPausedException as ex: except DocumentIsPausedError as ex:
logging.info(click.style(str(ex), fg="yellow")) logging.info(click.style(str(ex), fg="yellow"))
except Exception: except Exception:
pass pass

View File

@ -70,6 +70,7 @@ class MockTEIClass:
}, },
} }
@staticmethod
def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]: def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]:
# Example response: # Example response:
# [ # [

View File

@ -7,6 +7,7 @@ from _pytest.monkeypatch import MonkeyPatch
class MockedHttp: class MockedHttp:
@staticmethod
def httpx_request( def httpx_request(
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
) -> httpx.Response: ) -> httpx.Response:

View File

@ -13,7 +13,7 @@ from xinference_client.types import Embedding
class MockTcvectordbClass: class MockTcvectordbClass:
def VectorDBClient( def mock_vector_db_client(
self, self,
url=None, url=None,
username="", username="",
@ -110,7 +110,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture @pytest.fixture
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch): def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
if MOCK: if MOCK:
monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.VectorDBClient) monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client)
monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases) monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases)
monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection) monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection)
monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections) monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections)

View File

@ -10,6 +10,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
class MockedHttp: class MockedHttp:
@staticmethod
def httpx_request( def httpx_request(
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
) -> httpx.Response: ) -> httpx.Response:

View File

@ -1,11 +1,11 @@
import pytest import pytest
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
CODE_LANGUAGE = "unsupported_language" CODE_LANGUAGE = "unsupported_language"
def test_unsupported_with_code_template(): def test_unsupported_with_code_template():
with pytest.raises(CodeExecutionException) as e: with pytest.raises(CodeExecutionError) as e:
CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={}) CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={})
assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}" assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}"