mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-11-30 02:08:37 +08:00
chore: apply pep8-naming rules for naming convention (#8261)
This commit is contained in:
parent
53f37a6704
commit
292220c596
@ -20,7 +20,7 @@ from fields.conversation_fields import (
|
||||
conversation_pagination_fields,
|
||||
conversation_with_summary_pagination_fields,
|
||||
)
|
||||
from libs.helper import datetime_string
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
|
||||
|
||||
@ -36,8 +36,8 @@ class CompletionConversationApi(Resource):
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
@ -143,8 +143,8 @@ class ChatConversationApi(Resource):
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("keyword", type=str, location="args")
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument(
|
||||
"annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args"
|
||||
)
|
||||
|
@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import datetime_string
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
|
||||
@ -25,8 +25,8 @@ class DailyMessageStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -79,8 +79,8 @@ class DailyConversationStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -133,8 +133,8 @@ class DailyTerminalsStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -187,8 +187,8 @@ class DailyTokenCostStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -245,8 +245,8 @@ class AverageSessionInteractionStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
@ -307,8 +307,8 @@ class UserSatisfactionRateStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -369,8 +369,8 @@ class AverageResponseTimeStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -425,8 +425,8 @@ class TokensPerSecondStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
|
||||
|
@ -11,7 +11,7 @@ from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import datetime_string
|
||||
from libs.helper import DatetimeString
|
||||
from libs.login import login_required
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRunTriggeredFrom
|
||||
@ -26,8 +26,8 @@ class WorkflowDailyRunsStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -86,8 +86,8 @@ class WorkflowDailyTerminalsStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -146,8 +146,8 @@ class WorkflowDailyTokenCostStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
@ -213,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource):
|
||||
account = current_user
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("start", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
sql_query = """
|
||||
|
@ -8,7 +8,7 @@ from constants.languages import supported_language
|
||||
from controllers.console import api
|
||||
from controllers.console.error import AlreadyActivateError
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email, str_len, timezone
|
||||
from libs.helper import StrLen, email, timezone
|
||||
from libs.password import hash_password, valid_password
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import RegisterService
|
||||
@ -37,7 +37,7 @@ class ActivateApi(Resource):
|
||||
parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json")
|
||||
parser.add_argument("email", type=email, required=False, nullable=True, location="json")
|
||||
parser.add_argument("token", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, nullable=False, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json")
|
||||
parser.add_argument(
|
||||
"interface_language", type=supported_language, required=True, nullable=False, location="json"
|
||||
|
@ -4,7 +4,7 @@ from flask import session
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import str_len
|
||||
from libs.helper import StrLen
|
||||
from models.model import DifySetup
|
||||
from services.account_service import TenantService
|
||||
|
||||
@ -28,7 +28,7 @@ class InitValidateAPI(Resource):
|
||||
raise AlreadySetupError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("password", type=str_len(30), required=True, location="json")
|
||||
parser.add_argument("password", type=StrLen(30), required=True, location="json")
|
||||
input_password = parser.parse_args()["password"]
|
||||
|
||||
if input_password != os.environ.get("INIT_PASSWORD"):
|
||||
|
@ -4,7 +4,7 @@ from flask import request
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import email, get_remote_ip, str_len
|
||||
from libs.helper import StrLen, email, get_remote_ip
|
||||
from libs.password import valid_password
|
||||
from models.model import DifySetup
|
||||
from services.account_service import RegisterService, TenantService
|
||||
@ -40,7 +40,7 @@ class SetupApi(Resource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("email", type=email, required=True, location="json")
|
||||
parser.add_argument("name", type=str_len(30), required=True, location="json")
|
||||
parser.add_argument("name", type=StrLen(30), required=True, location="json")
|
||||
parser.add_argument("password", type=valid_password, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -15,7 +15,7 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
|
||||
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
@ -293,7 +293,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
@ -349,7 +349,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
@ -21,7 +21,7 @@ class AudioTrunk:
|
||||
self.status = status
|
||||
|
||||
|
||||
def _invoiceTTS(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
@ -81,7 +81,7 @@ class AppGeneratorTTSPublisher:
|
||||
if message is None:
|
||||
if self.msg_text and len(self.msg_text.strip()) > 0:
|
||||
futures_result = self.executor.submit(
|
||||
_invoiceTTS, self.msg_text, self.model_instance, self.tenant_id, self.voice
|
||||
_invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
@ -97,7 +97,7 @@ class AppGeneratorTTSPublisher:
|
||||
self.MAX_SENTENCE += 1
|
||||
text_content = "".join(sentence_arr)
|
||||
futures_result = self.executor.submit(
|
||||
_invoiceTTS, text_content, self.model_instance, self.tenant_id, self.voice
|
||||
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
if text_tmp:
|
||||
@ -110,7 +110,7 @@ class AppGeneratorTTSPublisher:
|
||||
break
|
||||
future_queue.put(None)
|
||||
|
||||
def checkAndGetAudio(self) -> AudioTrunk | None:
|
||||
def check_and_get_audio(self) -> AudioTrunk | None:
|
||||
try:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
|
@ -19,7 +19,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueStopEvent,
|
||||
QueueTextChunkEvent,
|
||||
)
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.base import ModerationError
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
@ -217,7 +217,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
query=query,
|
||||
message_id=message_id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
except ModerationError as e:
|
||||
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
|
||||
return True
|
||||
|
||||
|
@ -179,10 +179,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
@ -204,7 +204,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
@ -217,7 +217,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||
try:
|
||||
if not tts_publisher:
|
||||
break
|
||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
||||
audio_trunk = tts_publisher.check_and_get_audio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
|
@ -13,7 +13,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
|
||||
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||
@ -205,7 +205,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
|
@ -15,7 +15,7 @@ from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.base import ModerationError
|
||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message, MessageAgentThought
|
||||
@ -103,7 +103,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
except ModerationError as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
|
@ -171,5 +171,5 @@ class AppQueueManager:
|
||||
)
|
||||
|
||||
|
||||
class GenerateTaskStoppedException(Exception):
|
||||
class GenerateTaskStoppedError(Exception):
|
||||
pass
|
||||
|
@ -10,7 +10,7 @@ from pydantic import ValidationError
|
||||
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.chat.app_runner import ChatAppRunner
|
||||
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
|
||||
@ -205,7 +205,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
|
@ -11,7 +11,7 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message
|
||||
@ -98,7 +98,7 @@ class ChatAppRunner(AppRunner):
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
except ModerationError as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
|
@ -10,7 +10,7 @@ from pydantic import ValidationError
|
||||
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from core.app.apps.completion.app_runner import CompletionAppRunner
|
||||
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
|
||||
@ -185,7 +185,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
|
@ -9,7 +9,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationException
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Message
|
||||
@ -79,7 +79,7 @@ class CompletionAppRunner(AppRunner):
|
||||
query=query,
|
||||
message_id=message.id,
|
||||
)
|
||||
except ModerationException as e:
|
||||
except ModerationError as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_generate_entity=application_generate_entity,
|
||||
|
@ -8,7 +8,7 @@ from sqlalchemy import and_
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
@ -77,7 +77,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
@ -1,4 +1,4 @@
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
@ -53,4 +53,4 @@ class MessageBasedAppQueueManager(AppQueueManager):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
|
@ -12,7 +12,7 @@ from pydantic import ValidationError
|
||||
import contexts
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
|
||||
from core.app.apps.workflow.app_runner import WorkflowAppRunner
|
||||
@ -253,7 +253,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
)
|
||||
|
||||
runner.run()
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except InvokeAuthorizationError:
|
||||
queue_manager.publish_error(
|
||||
@ -302,7 +302,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
return generate_task_pipeline.process()
|
||||
except ValueError as e:
|
||||
if e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
else:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
@ -1,4 +1,4 @@
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedException, PublishFrom
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
@ -39,4 +39,4 @@ class WorkflowAppQueueManager(AppQueueManager):
|
||||
self.stop_listen()
|
||||
|
||||
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||
raise GenerateTaskStoppedException()
|
||||
raise GenerateTaskStoppedError()
|
||||
|
@ -162,10 +162,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
@ -187,7 +187,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
@ -199,7 +199,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||
try:
|
||||
if not tts_publisher:
|
||||
break
|
||||
audio_trunk = tts_publisher.checkAndGetAudio()
|
||||
audio_trunk = tts_publisher.check_and_get_audio()
|
||||
if audio_trunk is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
|
@ -15,6 +15,7 @@ class Segment(BaseModel):
|
||||
value: Any
|
||||
|
||||
@field_validator("value_type")
|
||||
@classmethod
|
||||
def validate_value_type(cls, value):
|
||||
"""
|
||||
This validator checks if the provided value is equal to the default value of the 'value_type' field.
|
||||
|
@ -201,10 +201,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listenAudioMsg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
if publisher is None:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.checkAndGetAudio()
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
@ -225,7 +225,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
publisher = AppGeneratorTTSPublisher(tenant_id, text_to_speech_dict.get("voice", None))
|
||||
for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listenAudioMsg(publisher, task_id)
|
||||
audio_response = self._listen_audio_msg(publisher, task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
@ -237,7 +237,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
||||
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
|
||||
if publisher is None:
|
||||
break
|
||||
audio = publisher.checkAndGetAudio()
|
||||
audio = publisher.check_and_get_audio()
|
||||
if audio is None:
|
||||
# release cpu
|
||||
# sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
|
||||
|
@ -16,7 +16,7 @@ from core.helper.code_executor.template_transformer import TemplateTransformer
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CodeExecutionException(Exception):
|
||||
class CodeExecutionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@ -86,15 +86,15 @@ class CodeExecutor:
|
||||
),
|
||||
)
|
||||
if response.status_code == 503:
|
||||
raise CodeExecutionException("Code execution service is unavailable")
|
||||
raise CodeExecutionError("Code execution service is unavailable")
|
||||
elif response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to execute code, got status code {response.status_code}, please check if the sandbox service is running"
|
||||
)
|
||||
except CodeExecutionException as e:
|
||||
except CodeExecutionError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise CodeExecutionException(
|
||||
raise CodeExecutionError(
|
||||
"Failed to execute code, which is likely a network issue,"
|
||||
" please check if the sandbox service is running."
|
||||
f" ( Error: {str(e)} )"
|
||||
@ -103,15 +103,15 @@ class CodeExecutor:
|
||||
try:
|
||||
response = response.json()
|
||||
except:
|
||||
raise CodeExecutionException("Failed to parse response")
|
||||
raise CodeExecutionError("Failed to parse response")
|
||||
|
||||
if (code := response.get("code")) != 0:
|
||||
raise CodeExecutionException(f"Got error code: {code}. Got error msg: {response.get('message')}")
|
||||
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response.get('message')}")
|
||||
|
||||
response = CodeExecutionResponse(**response)
|
||||
|
||||
if response.data.error:
|
||||
raise CodeExecutionException(response.data.error)
|
||||
raise CodeExecutionError(response.data.error)
|
||||
|
||||
return response.data.stdout or ""
|
||||
|
||||
@ -126,13 +126,13 @@ class CodeExecutor:
|
||||
"""
|
||||
template_transformer = cls.code_template_transformers.get(language)
|
||||
if not template_transformer:
|
||||
raise CodeExecutionException(f"Unsupported language {language}")
|
||||
raise CodeExecutionError(f"Unsupported language {language}")
|
||||
|
||||
runner, preload = template_transformer.transform_caller(code, inputs)
|
||||
|
||||
try:
|
||||
response = cls.execute_code(language, preload, runner)
|
||||
except CodeExecutionException as e:
|
||||
except CodeExecutionError as e:
|
||||
raise e
|
||||
|
||||
return template_transformer.transform_response(response)
|
||||
|
@ -78,8 +78,8 @@ class IndexingRunner:
|
||||
dataset_document=dataset_document,
|
||||
documents=documents,
|
||||
)
|
||||
except DocumentIsPausedException:
|
||||
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
@ -134,8 +134,8 @@ class IndexingRunner:
|
||||
self._load(
|
||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
||||
)
|
||||
except DocumentIsPausedException:
|
||||
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
@ -192,8 +192,8 @@ class IndexingRunner:
|
||||
self._load(
|
||||
index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
|
||||
)
|
||||
except DocumentIsPausedException:
|
||||
raise DocumentIsPausedException("Document paused, document id: {}".format(dataset_document.id))
|
||||
except DocumentIsPausedError:
|
||||
raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
|
||||
except ProviderTokenNotInitError as e:
|
||||
dataset_document.indexing_status = "error"
|
||||
dataset_document.error = str(e.description)
|
||||
@ -756,7 +756,7 @@ class IndexingRunner:
|
||||
indexing_cache_key = "document_{}_is_paused".format(document_id)
|
||||
result = redis_client.get(indexing_cache_key)
|
||||
if result:
|
||||
raise DocumentIsPausedException()
|
||||
raise DocumentIsPausedError()
|
||||
|
||||
@staticmethod
|
||||
def _update_document_index_status(
|
||||
@ -767,10 +767,10 @@ class IndexingRunner:
|
||||
"""
|
||||
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
|
||||
if count > 0:
|
||||
raise DocumentIsPausedException()
|
||||
raise DocumentIsPausedError()
|
||||
document = DatasetDocument.query.filter_by(id=document_id).first()
|
||||
if not document:
|
||||
raise DocumentIsDeletedPausedException()
|
||||
raise DocumentIsDeletedPausedError()
|
||||
|
||||
update_params = {DatasetDocument.indexing_status: after_indexing_status}
|
||||
|
||||
@ -875,9 +875,9 @@ class IndexingRunner:
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIsPausedException(Exception):
|
||||
class DocumentIsPausedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIsDeletedPausedException(Exception):
|
||||
class DocumentIsDeletedPausedError(Exception):
|
||||
pass
|
||||
|
@ -1,2 +1,2 @@
|
||||
class OutputParserException(Exception):
|
||||
class OutputParserError(Exception):
|
||||
pass
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserException
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.prompts import (
|
||||
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
|
||||
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
|
||||
@ -29,4 +29,4 @@ class RuleConfigGeneratorOutputParser:
|
||||
raise ValueError("Expected 'opening_statement' to be a str.")
|
||||
return parsed
|
||||
except Exception as e:
|
||||
raise OutputParserException(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}")
|
||||
raise OutputParserError(f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}")
|
||||
|
@ -7,7 +7,7 @@ from requests import post
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
InsufficientAccountBalanceError,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
@ -124,7 +124,7 @@ class BaichuanModel:
|
||||
if err == "invalid_api_key":
|
||||
raise InvalidAPIKeyError(msg)
|
||||
elif err == "insufficient_quota":
|
||||
raise InsufficientAccountBalance(msg)
|
||||
raise InsufficientAccountBalanceError(msg)
|
||||
elif err == "invalid_authentication":
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif err == "invalid_request_error":
|
||||
|
@ -10,7 +10,7 @@ class RateLimitReachedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InsufficientAccountBalance(Exception):
|
||||
class InsufficientAccountBalanceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -29,7 +29,7 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import B
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanModel
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
InsufficientAccountBalanceError,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
@ -289,7 +289,7 @@ class BaichuanLanguageModel(LargeLanguageModel):
|
||||
InvokeRateLimitError: [RateLimitReachedError],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InsufficientAccountBalanceError,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||
|
@ -19,7 +19,7 @@ from core.model_runtime.model_providers.__base.text_embedding_model import TextE
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer
|
||||
from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (
|
||||
BadRequestError,
|
||||
InsufficientAccountBalance,
|
||||
InsufficientAccountBalanceError,
|
||||
InternalServerError,
|
||||
InvalidAPIKeyError,
|
||||
InvalidAuthenticationError,
|
||||
@ -109,7 +109,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
if err == "invalid_api_key":
|
||||
raise InvalidAPIKeyError(msg)
|
||||
elif err == "insufficient_quota":
|
||||
raise InsufficientAccountBalance(msg)
|
||||
raise InsufficientAccountBalanceError(msg)
|
||||
elif err == "invalid_authentication":
|
||||
raise InvalidAuthenticationError(msg)
|
||||
elif err and "rate" in err:
|
||||
@ -166,7 +166,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel):
|
||||
InvokeRateLimitError: [RateLimitReachedError],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InsufficientAccountBalanceError,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||
|
@ -10,7 +10,7 @@ from core.model_runtime.errors.invoke import (
|
||||
)
|
||||
|
||||
|
||||
class _CommonOAI_API_Compat:
|
||||
class _CommonOaiApiCompat:
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
|
@ -35,13 +35,13 @@ from core.model_runtime.entities.model_entities import (
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel):
|
||||
"""
|
||||
Model class for OpenAI large language model.
|
||||
"""
|
||||
|
@ -6,10 +6,10 @@ import requests
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||
|
||||
|
||||
class OAICompatSpeech2TextModel(_CommonOAI_API_Compat, Speech2TextModel):
|
||||
class OAICompatSpeech2TextModel(_CommonOaiApiCompat, Speech2TextModel):
|
||||
"""
|
||||
Model class for OpenAI Compatible Speech to text model.
|
||||
"""
|
||||
|
@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import (
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||
|
||||
|
||||
class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
||||
class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for an OpenAI API-compatible text embedding model.
|
||||
"""
|
||||
|
@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import (
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOAI_API_Compat
|
||||
from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat
|
||||
|
||||
|
||||
class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
||||
class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel):
|
||||
"""
|
||||
Model class for an OpenAI API-compatible text embedding model.
|
||||
"""
|
||||
|
@ -13,7 +13,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.errors import wrap_error
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasException, MaasService
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import ChatRole, MaasError, MaasService
|
||||
|
||||
|
||||
class MaaSClient(MaasService):
|
||||
@ -106,7 +106,7 @@ class MaaSClient(MaasService):
|
||||
def wrap_exception(fn: Callable[[], dict | Generator]) -> dict | Generator:
|
||||
try:
|
||||
resp = fn()
|
||||
except MaasException as e:
|
||||
except MaasError as e:
|
||||
raise wrap_error(e)
|
||||
|
||||
return resp
|
||||
|
@ -1,144 +1,144 @@
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasException
|
||||
from core.model_runtime.model_providers.volcengine_maas.legacy.volc_sdk import MaasError
|
||||
|
||||
|
||||
class ClientSDKRequestError(MaasException):
|
||||
class ClientSDKRequestError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class SignatureDoesNotMatch(MaasException):
|
||||
class SignatureDoesNotMatchError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class RequestTimeout(MaasException):
|
||||
class RequestTimeoutError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class ServiceConnectionTimeout(MaasException):
|
||||
class ServiceConnectionTimeoutError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class MissingAuthenticationHeader(MaasException):
|
||||
class MissingAuthenticationHeaderError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationHeaderIsInvalid(MaasException):
|
||||
class AuthenticationHeaderIsInvalidError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class InternalServiceError(MaasException):
|
||||
class InternalServiceError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class MissingParameter(MaasException):
|
||||
class MissingParameterError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidParameter(MaasException):
|
||||
class InvalidParameterError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationExpire(MaasException):
|
||||
class AuthenticationExpireError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class EndpointIsInvalid(MaasException):
|
||||
class EndpointIsInvalidError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class EndpointIsNotEnable(MaasException):
|
||||
class EndpointIsNotEnableError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class ModelNotSupportStreamMode(MaasException):
|
||||
class ModelNotSupportStreamModeError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class ReqTextExistRisk(MaasException):
|
||||
class ReqTextExistRiskError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class RespTextExistRisk(MaasException):
|
||||
class RespTextExistRiskError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class EndpointRateLimitExceeded(MaasException):
|
||||
class EndpointRateLimitExceededError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class ServiceConnectionRefused(MaasException):
|
||||
class ServiceConnectionRefusedError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class ServiceConnectionClosed(MaasException):
|
||||
class ServiceConnectionClosedError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class UnauthorizedUserForEndpoint(MaasException):
|
||||
class UnauthorizedUserForEndpointError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidEndpointWithNoURL(MaasException):
|
||||
class InvalidEndpointWithNoURLError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class EndpointAccountRpmRateLimitExceeded(MaasException):
|
||||
class EndpointAccountRpmRateLimitExceededError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class EndpointAccountTpmRateLimitExceeded(MaasException):
|
||||
class EndpointAccountTpmRateLimitExceededError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class ServiceResourceWaitQueueFull(MaasException):
|
||||
class ServiceResourceWaitQueueFullError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class EndpointIsPending(MaasException):
|
||||
class EndpointIsPendingError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
class ServiceNotOpen(MaasException):
|
||||
class ServiceNotOpenError(MaasError):
|
||||
pass
|
||||
|
||||
|
||||
AuthErrors = {
|
||||
"SignatureDoesNotMatch": SignatureDoesNotMatch,
|
||||
"MissingAuthenticationHeader": MissingAuthenticationHeader,
|
||||
"AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalid,
|
||||
"AuthenticationExpire": AuthenticationExpire,
|
||||
"UnauthorizedUserForEndpoint": UnauthorizedUserForEndpoint,
|
||||
"SignatureDoesNotMatch": SignatureDoesNotMatchError,
|
||||
"MissingAuthenticationHeader": MissingAuthenticationHeaderError,
|
||||
"AuthenticationHeaderIsInvalid": AuthenticationHeaderIsInvalidError,
|
||||
"AuthenticationExpire": AuthenticationExpireError,
|
||||
"UnauthorizedUserForEndpoint": UnauthorizedUserForEndpointError,
|
||||
}
|
||||
|
||||
BadRequestErrors = {
|
||||
"MissingParameter": MissingParameter,
|
||||
"InvalidParameter": InvalidParameter,
|
||||
"EndpointIsInvalid": EndpointIsInvalid,
|
||||
"EndpointIsNotEnable": EndpointIsNotEnable,
|
||||
"ModelNotSupportStreamMode": ModelNotSupportStreamMode,
|
||||
"ReqTextExistRisk": ReqTextExistRisk,
|
||||
"RespTextExistRisk": RespTextExistRisk,
|
||||
"InvalidEndpointWithNoURL": InvalidEndpointWithNoURL,
|
||||
"ServiceNotOpen": ServiceNotOpen,
|
||||
"MissingParameter": MissingParameterError,
|
||||
"InvalidParameter": InvalidParameterError,
|
||||
"EndpointIsInvalid": EndpointIsInvalidError,
|
||||
"EndpointIsNotEnable": EndpointIsNotEnableError,
|
||||
"ModelNotSupportStreamMode": ModelNotSupportStreamModeError,
|
||||
"ReqTextExistRisk": ReqTextExistRiskError,
|
||||
"RespTextExistRisk": RespTextExistRiskError,
|
||||
"InvalidEndpointWithNoURL": InvalidEndpointWithNoURLError,
|
||||
"ServiceNotOpen": ServiceNotOpenError,
|
||||
}
|
||||
|
||||
RateLimitErrors = {
|
||||
"EndpointRateLimitExceeded": EndpointRateLimitExceeded,
|
||||
"EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceeded,
|
||||
"EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceeded,
|
||||
"EndpointRateLimitExceeded": EndpointRateLimitExceededError,
|
||||
"EndpointAccountRpmRateLimitExceeded": EndpointAccountRpmRateLimitExceededError,
|
||||
"EndpointAccountTpmRateLimitExceeded": EndpointAccountTpmRateLimitExceededError,
|
||||
}
|
||||
|
||||
ServerUnavailableErrors = {
|
||||
"InternalServiceError": InternalServiceError,
|
||||
"EndpointIsPending": EndpointIsPending,
|
||||
"ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFull,
|
||||
"EndpointIsPending": EndpointIsPendingError,
|
||||
"ServiceResourceWaitQueueFull": ServiceResourceWaitQueueFullError,
|
||||
}
|
||||
|
||||
ConnectionErrors = {
|
||||
"ClientSDKRequestError": ClientSDKRequestError,
|
||||
"RequestTimeout": RequestTimeout,
|
||||
"ServiceConnectionTimeout": ServiceConnectionTimeout,
|
||||
"ServiceConnectionRefused": ServiceConnectionRefused,
|
||||
"ServiceConnectionClosed": ServiceConnectionClosed,
|
||||
"RequestTimeout": RequestTimeoutError,
|
||||
"ServiceConnectionTimeout": ServiceConnectionTimeoutError,
|
||||
"ServiceConnectionRefused": ServiceConnectionRefusedError,
|
||||
"ServiceConnectionClosed": ServiceConnectionClosedError,
|
||||
}
|
||||
|
||||
ErrorCodeMap = {
|
||||
@ -150,7 +150,7 @@ ErrorCodeMap = {
|
||||
}
|
||||
|
||||
|
||||
def wrap_error(e: MaasException) -> Exception:
|
||||
def wrap_error(e: MaasError) -> Exception:
|
||||
if ErrorCodeMap.get(e.code):
|
||||
return ErrorCodeMap.get(e.code)(e.code_n, e.code, e.message, e.req_id)
|
||||
return e
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .common import ChatRole
|
||||
from .maas import MaasException, MaasService
|
||||
from .maas import MaasError, MaasService
|
||||
|
||||
__all__ = ["MaasService", "ChatRole", "MaasException"]
|
||||
__all__ = ["MaasService", "ChatRole", "MaasError"]
|
||||
|
@ -63,7 +63,7 @@ class MaasService(Service):
|
||||
raise
|
||||
|
||||
if res.error is not None and res.error.code_n != 0:
|
||||
raise MaasException(
|
||||
raise MaasError(
|
||||
res.error.code_n,
|
||||
res.error.code,
|
||||
res.error.message,
|
||||
@ -72,7 +72,7 @@ class MaasService(Service):
|
||||
yield res
|
||||
|
||||
return iter_fn()
|
||||
except MaasException:
|
||||
except MaasError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise new_client_sdk_request_error(str(e))
|
||||
@ -94,7 +94,7 @@ class MaasService(Service):
|
||||
resp["req_id"] = req_id
|
||||
return resp
|
||||
|
||||
except MaasException as e:
|
||||
except MaasError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise new_client_sdk_request_error(str(e), req_id)
|
||||
@ -147,14 +147,14 @@ class MaasService(Service):
|
||||
raise new_client_sdk_request_error(raw, req_id)
|
||||
|
||||
if resp.error:
|
||||
raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, req_id)
|
||||
raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, req_id)
|
||||
else:
|
||||
raise new_client_sdk_request_error(resp, req_id)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class MaasException(Exception):
|
||||
class MaasError(Exception):
|
||||
def __init__(self, code_n, code, message, req_id):
|
||||
self.code_n = code_n
|
||||
self.code = code
|
||||
@ -172,7 +172,7 @@ class MaasException(Exception):
|
||||
|
||||
|
||||
def new_client_sdk_request_error(raw, req_id=""):
|
||||
return MaasException(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id)
|
||||
return MaasError(1709701, "ClientSDKRequestError", "MaaS SDK request error: {}".format(raw), req_id)
|
||||
|
||||
|
||||
class BinaryResponseContent:
|
||||
@ -192,7 +192,7 @@ class BinaryResponseContent:
|
||||
|
||||
if len(error_bytes) > 0:
|
||||
resp = json_to_object(str(error_bytes, encoding="utf-8"), req_id=self.request_id)
|
||||
raise MaasException(resp.error.code_n, resp.error.code, resp.error.message, self.request_id)
|
||||
raise MaasError(resp.error.code_n, resp.error.code, resp.error.message, self.request_id)
|
||||
|
||||
def iter_bytes(self) -> Iterator[bytes]:
|
||||
yield from self.response
|
||||
|
@ -35,7 +35,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
|
||||
AuthErrors,
|
||||
BadRequestErrors,
|
||||
ConnectionErrors,
|
||||
MaasException,
|
||||
MaasError,
|
||||
RateLimitErrors,
|
||||
ServerUnavailableErrors,
|
||||
)
|
||||
@ -85,7 +85,7 @@ class VolcengineMaaSLargeLanguageModel(LargeLanguageModel):
|
||||
},
|
||||
[UserPromptMessage(content="ping\nAnswer: ")],
|
||||
)
|
||||
except MaasException as e:
|
||||
except MaasError as e:
|
||||
raise CredentialsValidateFailedError(e.message)
|
||||
|
||||
@staticmethod
|
||||
|
@ -28,7 +28,7 @@ from core.model_runtime.model_providers.volcengine_maas.legacy.errors import (
|
||||
AuthErrors,
|
||||
BadRequestErrors,
|
||||
ConnectionErrors,
|
||||
MaasException,
|
||||
MaasError,
|
||||
RateLimitErrors,
|
||||
ServerUnavailableErrors,
|
||||
)
|
||||
@ -111,7 +111,7 @@ class VolcengineMaaSTextEmbeddingModel(TextEmbeddingModel):
|
||||
def _validate_credentials_v2(self, model: str, credentials: dict) -> None:
|
||||
try:
|
||||
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||
except MaasException as e:
|
||||
except MaasError as e:
|
||||
raise CredentialsValidateFailedError(e.message)
|
||||
|
||||
def _validate_credentials_v3(self, model: str, credentials: dict) -> None:
|
||||
|
@ -23,7 +23,7 @@ def invoke_error_mapping() -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
InvokeRateLimitError: [RateLimitReachedError],
|
||||
InvokeAuthorizationError: [
|
||||
InvalidAuthenticationError,
|
||||
InsufficientAccountBalance,
|
||||
InsufficientAccountBalanceError,
|
||||
InvalidAPIKeyError,
|
||||
],
|
||||
InvokeBadRequestError: [BadRequestError, KeyError],
|
||||
@ -42,7 +42,7 @@ class RateLimitReachedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InsufficientAccountBalance(Exception):
|
||||
class InsufficientAccountBalanceError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -76,7 +76,7 @@ class Moderation(Extensible, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None:
|
||||
def _validate_inputs_and_outputs_config(cls, config: dict, is_preset_response_required: bool) -> None:
|
||||
# inputs_config
|
||||
inputs_config = config.get("inputs_config")
|
||||
if not isinstance(inputs_config, dict):
|
||||
@ -111,5 +111,5 @@ class Moderation(Extensible, ABC):
|
||||
raise ValueError("outputs_config.preset_response must be less than 100 characters")
|
||||
|
||||
|
||||
class ModerationException(Exception):
|
||||
class ModerationError(Exception):
|
||||
pass
|
||||
|
@ -2,7 +2,7 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
from core.app.app_config.entities import AppConfig
|
||||
from core.moderation.base import ModerationAction, ModerationException
|
||||
from core.moderation.base import ModerationAction, ModerationError
|
||||
from core.moderation.factory import ModerationFactory
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
@ -61,7 +61,7 @@ class InputModeration:
|
||||
return False, inputs, query
|
||||
|
||||
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
raise ModerationException(moderation_result.preset_response)
|
||||
raise ModerationError(moderation_result.preset_response)
|
||||
elif moderation_result.action == ModerationAction.OVERRIDDEN:
|
||||
inputs = moderation_result.inputs
|
||||
query = moderation_result.query
|
||||
|
@ -26,6 +26,7 @@ class LangfuseConfig(BaseTracingConfig):
|
||||
host: str = "https://api.langfuse.com"
|
||||
|
||||
@field_validator("host")
|
||||
@classmethod
|
||||
def set_value(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "https://api.langfuse.com"
|
||||
@ -45,6 +46,7 @@ class LangSmithConfig(BaseTracingConfig):
|
||||
endpoint: str = "https://api.smith.langchain.com"
|
||||
|
||||
@field_validator("endpoint")
|
||||
@classmethod
|
||||
def set_value(cls, v, info: ValidationInfo):
|
||||
if v is None or v == "":
|
||||
v = "https://api.smith.langchain.com"
|
||||
|
@ -15,6 +15,7 @@ class BaseTraceInfo(BaseModel):
|
||||
metadata: dict[str, Any]
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
def ensure_type(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
|
@ -101,6 +101,7 @@ class LangfuseTrace(BaseModel):
|
||||
)
|
||||
|
||||
@field_validator("input", "output")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
return validate_input_output(v, field_name)
|
||||
@ -171,6 +172,7 @@ class LangfuseSpan(BaseModel):
|
||||
)
|
||||
|
||||
@field_validator("input", "output")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
return validate_input_output(v, field_name)
|
||||
@ -196,6 +198,7 @@ class GenerationUsage(BaseModel):
|
||||
totalCost: Optional[float] = None
|
||||
|
||||
@field_validator("input", "output")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
return validate_input_output(v, field_name)
|
||||
@ -273,6 +276,7 @@ class LangfuseGeneration(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@field_validator("input", "output")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
return validate_input_output(v, field_name)
|
||||
|
@ -51,6 +51,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
output_attachments: Optional[dict[str, Any]] = Field(None, description="Output attachments of the run")
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
def ensure_dict(cls, v, info: ValidationInfo):
|
||||
field_name = info.field_name
|
||||
values = info.data
|
||||
@ -115,6 +116,7 @@ class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel):
|
||||
return v
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
@field_validator("start_time", "end_time")
|
||||
def format_time(cls, v, info: ValidationInfo):
|
||||
if not isinstance(v, datetime):
|
||||
|
@ -27,6 +27,7 @@ class ElasticSearchConfig(BaseModel):
|
||||
password: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config HOST is required")
|
||||
|
@ -28,6 +28,7 @@ class MilvusConfig(BaseModel):
|
||||
database: str = "default"
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values.get("uri"):
|
||||
raise ValueError("config MILVUS_URI is required")
|
||||
|
@ -29,6 +29,7 @@ class OpenSearchConfig(BaseModel):
|
||||
secure: bool = False
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values.get("host"):
|
||||
raise ValueError("config OPENSEARCH_HOST is required")
|
||||
|
@ -32,6 +32,7 @@ class OracleVectorConfig(BaseModel):
|
||||
database: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config ORACLE_HOST is required")
|
||||
|
@ -32,6 +32,7 @@ class PgvectoRSConfig(BaseModel):
|
||||
database: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config PGVECTO_RS_HOST is required")
|
||||
|
@ -25,6 +25,7 @@ class PGVectorConfig(BaseModel):
|
||||
database: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config PGVECTOR_HOST is required")
|
||||
|
@ -34,6 +34,7 @@ class RelytConfig(BaseModel):
|
||||
database: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config RELYT_HOST is required")
|
||||
|
@ -29,6 +29,7 @@ class TiDBVectorConfig(BaseModel):
|
||||
program_name: str
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["host"]:
|
||||
raise ValueError("config TIDB_VECTOR_HOST is required")
|
||||
|
@ -23,6 +23,7 @@ class WeaviateConfig(BaseModel):
|
||||
batch_size: int = 100
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_config(cls, values: dict) -> dict:
|
||||
if not values["endpoint"]:
|
||||
raise ValueError("config WEAVIATE_ENDPOINT is required")
|
||||
|
@ -69,7 +69,7 @@ class DallE3Tool(BuiltinTool):
|
||||
self.create_blob_message(
|
||||
blob=b64decode(image.b64_json),
|
||||
meta={"mime_type": "image/png"},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value,
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
)
|
||||
)
|
||||
result.append(self.create_text_message(f"\nGenerate image source to Seed ID: {seed_id}"))
|
||||
|
@ -59,7 +59,7 @@ class DallE2Tool(BuiltinTool):
|
||||
self.create_blob_message(
|
||||
blob=b64decode(image.b64_json),
|
||||
meta={"mime_type": "image/png"},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value,
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -66,7 +66,7 @@ class DallE3Tool(BuiltinTool):
|
||||
for image in response.data:
|
||||
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
|
||||
blob_message = self.create_blob_message(
|
||||
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VARIABLE_KEY.IMAGE.value
|
||||
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value
|
||||
)
|
||||
result.append(blob_message)
|
||||
return result
|
||||
|
@ -34,7 +34,7 @@ class NovitaAiCreateTileTool(BuiltinTool):
|
||||
self.create_blob_message(
|
||||
blob=b64decode(client_result.image_file),
|
||||
meta={"mime_type": f"image/{client_result.image_type}"},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value,
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -40,7 +40,7 @@ class NovitaAiTxt2ImgTool(BuiltinTool, NovitaAiToolBase):
|
||||
self.create_blob_message(
|
||||
blob=b64decode(image_encoded),
|
||||
meta={"mime_type": f"image/{image.image_type}"},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value,
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -46,7 +46,7 @@ class QRCodeGeneratorTool(BuiltinTool):
|
||||
image = self._generate_qrcode(content, border, error_correction)
|
||||
image_bytes = self._image_to_byte_array(image)
|
||||
return self.create_blob_message(
|
||||
blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
|
||||
blob=image_bytes, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(f"Failed to generate QR code for content: {content}")
|
||||
|
@ -32,5 +32,5 @@ class FluxTool(BuiltinTool):
|
||||
res = response.json()
|
||||
result = [self.create_json_message(res)]
|
||||
for image in res.get("images", []):
|
||||
result.append(self.create_image_message(image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value))
|
||||
return result
|
||||
|
@ -41,5 +41,5 @@ class StableDiffusionTool(BuiltinTool):
|
||||
res = response.json()
|
||||
result = [self.create_json_message(res)]
|
||||
for image in res.get("images", []):
|
||||
result.append(self.create_image_message(image=image.get("url"), save_as=self.VARIABLE_KEY.IMAGE.value))
|
||||
result.append(self.create_image_message(image=image.get("url"), save_as=self.VariableKey.IMAGE.value))
|
||||
return result
|
||||
|
@ -15,16 +15,16 @@ from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AssembleHeaderException(Exception):
|
||||
class AssembleHeaderError(Exception):
|
||||
def __init__(self, msg):
|
||||
self.message = msg
|
||||
|
||||
|
||||
class Url:
|
||||
def __init__(this, host, path, schema):
|
||||
this.host = host
|
||||
this.path = path
|
||||
this.schema = schema
|
||||
def __init__(self, host, path, schema):
|
||||
self.host = host
|
||||
self.path = path
|
||||
self.schema = schema
|
||||
|
||||
|
||||
# calculate sha256 and encode to base64
|
||||
@ -41,7 +41,7 @@ def parse_url(request_url):
|
||||
schema = request_url[: stidx + 3]
|
||||
edidx = host.index("/")
|
||||
if edidx <= 0:
|
||||
raise AssembleHeaderException("invalid request url:" + request_url)
|
||||
raise AssembleHeaderError("invalid request url:" + request_url)
|
||||
path = host[edidx:]
|
||||
host = host[:edidx]
|
||||
u = Url(host, path, schema)
|
||||
@ -115,7 +115,7 @@ class SparkImgGeneratorTool(BuiltinTool):
|
||||
self.create_blob_message(
|
||||
blob=b64decode(image["base64_image"]),
|
||||
meta={"mime_type": "image/png"},
|
||||
save_as=self.VARIABLE_KEY.IMAGE.value,
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
@ -52,5 +52,5 @@ class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
|
||||
raise Exception(response.text)
|
||||
|
||||
return self.create_blob_message(
|
||||
blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
|
||||
blob=response.content, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
|
||||
)
|
||||
|
@ -260,7 +260,7 @@ class StableDiffusionTool(BuiltinTool):
|
||||
image = response.json()["images"][0]
|
||||
|
||||
return self.create_blob_message(
|
||||
blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
|
||||
blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@ -294,7 +294,7 @@ class StableDiffusionTool(BuiltinTool):
|
||||
image = response.json()["images"][0]
|
||||
|
||||
return self.create_blob_message(
|
||||
blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
|
||||
blob=b64decode(image), meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
@ -45,5 +45,5 @@ class PoiSearchTool(BuiltinTool):
|
||||
).content
|
||||
|
||||
return self.create_blob_message(
|
||||
blob=result, meta={"mime_type": "image/png"}, save_as=self.VARIABLE_KEY.IMAGE.value
|
||||
blob=result, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value
|
||||
)
|
||||
|
@ -32,7 +32,7 @@ class VectorizerTool(BuiltinTool):
|
||||
if image_id.startswith("__test_"):
|
||||
image_binary = b64decode(VECTORIZER_ICON_PNG)
|
||||
else:
|
||||
image_binary = self.get_variable_file(self.VARIABLE_KEY.IMAGE)
|
||||
image_binary = self.get_variable_file(self.VariableKey.IMAGE)
|
||||
if not image_binary:
|
||||
return self.create_text_message("Image not found, please request user to generate image firstly.")
|
||||
|
||||
|
@ -63,7 +63,7 @@ class Tool(BaseModel, ABC):
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
|
||||
class VARIABLE_KEY(Enum):
|
||||
class VariableKey(Enum):
|
||||
IMAGE = "image"
|
||||
|
||||
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
|
||||
@ -142,7 +142,7 @@ class Tool(BaseModel, ABC):
|
||||
if not self.variables:
|
||||
return None
|
||||
|
||||
return self.get_variable(self.VARIABLE_KEY.IMAGE)
|
||||
return self.get_variable(self.VariableKey.IMAGE)
|
||||
|
||||
def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]:
|
||||
"""
|
||||
@ -189,7 +189,7 @@ class Tool(BaseModel, ABC):
|
||||
result = []
|
||||
|
||||
for variable in self.variables.pool:
|
||||
if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value):
|
||||
if variable.name.startswith(self.VariableKey.IMAGE.value):
|
||||
result.append(variable)
|
||||
|
||||
return result
|
||||
|
@ -8,7 +8,7 @@ from typing import Any, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import (
|
||||
NodeRunMetadataKey,
|
||||
@ -669,7 +669,7 @@ class GraphEngine:
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
# trigger node run failed event
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
|
@ -2,7 +2,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
|
||||
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
|
||||
@ -61,7 +61,7 @@ class CodeNode(BaseNode):
|
||||
|
||||
# Transform result
|
||||
result = self._transform_result(result, node_data.outputs)
|
||||
except (CodeExecutionException, ValueError) as e:
|
||||
except (CodeExecutionError, ValueError) as e:
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
|
||||
@ -45,7 +45,7 @@ class TemplateTransformNode(BaseNode):
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2, code=node_data.template, inputs=variables
|
||||
)
|
||||
except CodeExecutionException as e:
|
||||
except CodeExecutionError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
||||
if len(result["result"]) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
|
@ -6,7 +6,7 @@ from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
@ -103,7 +103,7 @@ class WorkflowEntry:
|
||||
for callback in callbacks:
|
||||
callback.on_event(event=event)
|
||||
yield event
|
||||
except GenerateTaskStoppedException:
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when workflow entry running")
|
||||
|
@ -5,7 +5,7 @@ import time
|
||||
import click
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedException, IndexingRunner
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from events.event_handlers.document_index_event import document_index_created
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document
|
||||
@ -43,7 +43,7 @@ def handle(sender, **kwargs):
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
|
||||
except DocumentIsPausedException as ex:
|
||||
except DocumentIsPausedError as ex:
|
||||
logging.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
pass
|
||||
|
@ -5,7 +5,7 @@ from collections.abc import Generator
|
||||
from contextlib import closing
|
||||
|
||||
from flask import Flask
|
||||
from google.cloud import storage as GoogleCloudStorage
|
||||
from google.cloud import storage as google_cloud_storage
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
@ -23,9 +23,9 @@ class GoogleStorage(BaseStorage):
|
||||
service_account_json = base64.b64decode(service_account_json_str).decode("utf-8")
|
||||
# convert str to object
|
||||
service_account_obj = json.loads(service_account_json)
|
||||
self.client = GoogleCloudStorage.Client.from_service_account_info(service_account_obj)
|
||||
self.client = google_cloud_storage.Client.from_service_account_info(service_account_obj)
|
||||
else:
|
||||
self.client = GoogleCloudStorage.Client()
|
||||
self.client = google_cloud_storage.Client()
|
||||
|
||||
def save(self, filename, data):
|
||||
bucket = self.client.get_bucket(self.bucket_name)
|
||||
|
@ -31,7 +31,7 @@ from Crypto.Util.py3compat import _copy_bytes, bord
|
||||
from Crypto.Util.strxor import strxor
|
||||
|
||||
|
||||
class PKCS1OAEP_Cipher:
|
||||
class PKCS1OAepCipher:
|
||||
"""Cipher object for PKCS#1 v1.5 OAEP.
|
||||
Do not create directly: use :func:`new` instead."""
|
||||
|
||||
@ -237,4 +237,4 @@ def new(key, hashAlgo=None, mgfunc=None, label=b"", randfunc=None):
|
||||
|
||||
if randfunc is None:
|
||||
randfunc = Random.get_random_bytes
|
||||
return PKCS1OAEP_Cipher(key, hashAlgo, mgfunc, label, randfunc)
|
||||
return PKCS1OAepCipher(key, hashAlgo, mgfunc, label, randfunc)
|
||||
|
@ -84,7 +84,7 @@ def timestamp_value(timestamp):
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
class str_len:
|
||||
class StrLen:
|
||||
"""Restrict input to an integer in a range (inclusive)"""
|
||||
|
||||
def __init__(self, max_length, argument="argument"):
|
||||
@ -102,7 +102,7 @@ class str_len:
|
||||
return value
|
||||
|
||||
|
||||
class float_range:
|
||||
class FloatRange:
|
||||
"""Restrict input to an float in a range (inclusive)"""
|
||||
|
||||
def __init__(self, low, high, argument="argument"):
|
||||
@ -121,7 +121,7 @@ class float_range:
|
||||
return value
|
||||
|
||||
|
||||
class datetime_string:
|
||||
class DatetimeString:
|
||||
def __init__(self, format, argument="argument"):
|
||||
self.format = format
|
||||
self.argument = argument
|
||||
|
@ -1,6 +1,6 @@
|
||||
import json
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserException
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
|
||||
|
||||
def parse_json_markdown(json_string: str) -> dict:
|
||||
@ -33,10 +33,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict:
|
||||
try:
|
||||
json_obj = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
|
||||
raise OutputParserError(f"Got invalid JSON object. Error: {e}")
|
||||
for key in expected_keys:
|
||||
if key not in json_obj:
|
||||
raise OutputParserException(
|
||||
raise OutputParserError(
|
||||
f"Got invalid return object. Expected key `{key}` " f"to be present, but got {json_obj}"
|
||||
)
|
||||
return json_obj
|
||||
|
@ -15,8 +15,8 @@ select = [
|
||||
"C4", # flake8-comprehensions
|
||||
"F", # pyflakes rules
|
||||
"I", # isort rules
|
||||
"N", # pep8-naming
|
||||
"UP", # pyupgrade rules
|
||||
"B035", # static-key-dict-comprehension
|
||||
"E101", # mixed-spaces-and-tabs
|
||||
"E111", # indentation-with-invalid-multiple
|
||||
"E112", # no-indented-block
|
||||
@ -47,9 +47,10 @@ ignore = [
|
||||
"B006", # mutable-argument-default
|
||||
"B007", # unused-loop-control-variable
|
||||
"B026", # star-arg-unpacking-after-keyword-arg
|
||||
# "B901", # return-in-generator
|
||||
"B904", # raise-without-from-inside-except
|
||||
"B905", # zip-without-explicit-strict
|
||||
"N806", # non-lowercase-variable-in-function
|
||||
"N815", # mixed-case-variable-in-class-scope
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
@ -65,6 +66,12 @@ ignore = [
|
||||
"F401", # unused-import
|
||||
"F811", # redefined-while-unused
|
||||
]
|
||||
"configs/*" = [
|
||||
"N802", # invalid-function-name
|
||||
]
|
||||
"libs/gmpy2_pkcs10aep_cipher.py" = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
|
||||
[tool.ruff.format]
|
||||
exclude = [
|
||||
|
@ -32,7 +32,7 @@ from services.errors.account import (
|
||||
NoPermissionError,
|
||||
RateLimitExceededError,
|
||||
RoleAlreadyAssignedError,
|
||||
TenantNotFound,
|
||||
TenantNotFoundError,
|
||||
)
|
||||
from tasks.mail_invite_member_task import send_invite_member_mail_task
|
||||
from tasks.mail_reset_password_task import send_reset_password_mail_task
|
||||
@ -311,13 +311,13 @@ class TenantService:
|
||||
"""Get tenant by account and add the role"""
|
||||
tenant = account.current_tenant
|
||||
if not tenant:
|
||||
raise TenantNotFound("Tenant not found.")
|
||||
raise TenantNotFoundError("Tenant not found.")
|
||||
|
||||
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
if ta:
|
||||
tenant.role = ta.role
|
||||
else:
|
||||
raise TenantNotFound("Tenant not found for the account.")
|
||||
raise TenantNotFoundError("Tenant not found for the account.")
|
||||
return tenant
|
||||
|
||||
@staticmethod
|
||||
@ -614,8 +614,8 @@ class RegisterService:
|
||||
"email": account.email,
|
||||
"workspace_id": tenant.id,
|
||||
}
|
||||
expiryHours = dify_config.INVITE_EXPIRY_HOURS
|
||||
redis_client.setex(cls._get_invitation_token_key(token), expiryHours * 60 * 60, json.dumps(invitation_data))
|
||||
expiry_hours = dify_config.INVITE_EXPIRY_HOURS
|
||||
redis_client.setex(cls._get_invitation_token_key(token), expiry_hours * 60 * 60, json.dumps(invitation_data))
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
|
@ -1,7 +1,7 @@
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class AccountNotFound(BaseServiceError):
|
||||
class AccountNotFoundError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ class LinkAccountIntegrateError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
class TenantNotFound(BaseServiceError):
|
||||
class TenantNotFoundError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
||||
|
@ -6,7 +6,7 @@ import click
|
||||
from celery import shared_task
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedException, IndexingRunner
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
@ -106,7 +106,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
logging.info(
|
||||
click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green")
|
||||
)
|
||||
except DocumentIsPausedException as ex:
|
||||
except DocumentIsPausedError as ex:
|
||||
logging.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
pass
|
||||
|
@ -6,7 +6,7 @@ import click
|
||||
from celery import shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from core.indexing_runner import DocumentIsPausedException, IndexingRunner
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document
|
||||
from services.feature_service import FeatureService
|
||||
@ -72,7 +72,7 @@ def document_indexing_task(dataset_id: str, document_ids: list):
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
|
||||
except DocumentIsPausedException as ex:
|
||||
except DocumentIsPausedError as ex:
|
||||
logging.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
pass
|
||||
|
@ -6,7 +6,7 @@ import click
|
||||
from celery import shared_task
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedException, IndexingRunner
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -69,7 +69,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
||||
indexing_runner.run([document])
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style("update document: {} latency: {}".format(document.id, end_at - start_at), fg="green"))
|
||||
except DocumentIsPausedException as ex:
|
||||
except DocumentIsPausedError as ex:
|
||||
logging.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
pass
|
||||
|
@ -6,7 +6,7 @@ import click
|
||||
from celery import shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from core.indexing_runner import DocumentIsPausedException, IndexingRunner
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
@ -88,7 +88,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style("Processed dataset: {} latency: {}".format(dataset_id, end_at - start_at), fg="green"))
|
||||
except DocumentIsPausedException as ex:
|
||||
except DocumentIsPausedError as ex:
|
||||
logging.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
pass
|
||||
|
@ -5,7 +5,7 @@ import click
|
||||
from celery import shared_task
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedException, IndexingRunner
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document
|
||||
|
||||
@ -39,7 +39,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
|
||||
logging.info(
|
||||
click.style("Processed document: {} latency: {}".format(document.id, end_at - start_at), fg="green")
|
||||
)
|
||||
except DocumentIsPausedException as ex:
|
||||
except DocumentIsPausedError as ex:
|
||||
logging.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
pass
|
||||
|
@ -70,6 +70,7 @@ class MockTEIClass:
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def invoke_rerank(server_url: str, query: str, texts: list[str]) -> list[dict]:
|
||||
# Example response:
|
||||
# [
|
||||
|
@ -7,6 +7,7 @@ from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
|
||||
class MockedHttp:
|
||||
@staticmethod
|
||||
def httpx_request(
|
||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
|
||||
) -> httpx.Response:
|
||||
|
@ -13,7 +13,7 @@ from xinference_client.types import Embedding
|
||||
|
||||
|
||||
class MockTcvectordbClass:
|
||||
def VectorDBClient(
|
||||
def mock_vector_db_client(
|
||||
self,
|
||||
url=None,
|
||||
username="",
|
||||
@ -110,7 +110,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
|
||||
@pytest.fixture
|
||||
def setup_tcvectordb_mock(request, monkeypatch: MonkeyPatch):
|
||||
if MOCK:
|
||||
monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.VectorDBClient)
|
||||
monkeypatch.setattr(VectorDBClient, "__init__", MockTcvectordbClass.mock_vector_db_client)
|
||||
monkeypatch.setattr(VectorDBClient, "list_databases", MockTcvectordbClass.list_databases)
|
||||
monkeypatch.setattr(Database, "collection", MockTcvectordbClass.describe_collection)
|
||||
monkeypatch.setattr(Database, "list_collections", MockTcvectordbClass.list_collections)
|
||||
|
@ -10,6 +10,7 @@ MOCK = os.getenv("MOCK_SWITCH", "false") == "true"
|
||||
|
||||
|
||||
class MockedHttp:
|
||||
@staticmethod
|
||||
def httpx_request(
|
||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
|
||||
) -> httpx.Response:
|
||||
|
@ -1,11 +1,11 @@
|
||||
import pytest
|
||||
|
||||
from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
|
||||
|
||||
CODE_LANGUAGE = "unsupported_language"
|
||||
|
||||
|
||||
def test_unsupported_with_code_template():
|
||||
with pytest.raises(CodeExecutionException) as e:
|
||||
with pytest.raises(CodeExecutionError) as e:
|
||||
CodeExecutor.execute_workflow_code_template(language=CODE_LANGUAGE, code="", inputs={})
|
||||
assert str(e.value) == f"Unsupported language {CODE_LANGUAGE}"
|
||||
|
Loading…
Reference in New Issue
Block a user