diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py index b1abb3824..fbe42fbd2 100644 --- a/api/controllers/console/app/error.py +++ b/api/controllers/console/app/error.py @@ -91,3 +91,9 @@ class DraftWorkflowNotExist(BaseHTTPException): error_code = 'draft_workflow_not_exist' description = "Draft workflow need to be initialized." code = 400 + + +class DraftWorkflowNotSync(BaseHTTPException): + error_code = 'draft_workflow_not_sync' + description = "Workflow graph might have been modified, please refresh and resubmit." + code = 400 diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index b88a9b7fc..0345a9e90 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -7,7 +7,7 @@ from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.console import api -from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist +from controllers.console.app.error import ConversationCompletedError, DraftWorkflowNotExist, DraftWorkflowNotSync from controllers.console.app.wraps import get_app_model from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required @@ -20,6 +20,7 @@ from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models.model import App, AppMode from services.app_generate_service import AppGenerateService +from services.errors.app import WorkflowHashNotEqualError from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) @@ -59,6 +60,7 @@ class DraftWorkflowApi(Resource): parser = reqparse.RequestParser() parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') parser.add_argument('features', type=dict, required=True, nullable=False, location='json') + parser.add_argument('hash', type=str, required=False, location='json') args = parser.parse_args() elif 'text/plain' in content_type: try: @@ -71,7 +73,8 @@ class DraftWorkflowApi(Resource): args = { 'graph': data.get('graph'), - 'features': data.get('features') + 'features': data.get('features'), + 'hash': data.get('hash') } except json.JSONDecodeError: return {'message': 'Invalid JSON data'}, 400 @@ -79,15 +82,21 @@ class DraftWorkflowApi(Resource): abort(415) workflow_service = WorkflowService() - workflow = workflow_service.sync_draft_workflow( - app_model=app_model, - graph=args.get('graph'), - features=args.get('features'), - account=current_user - ) + + try: + workflow = workflow_service.sync_draft_workflow( + app_model=app_model, + graph=args.get('graph'), + features=args.get('features'), + unique_hash=args.get('hash'), + account=current_user + ) + except WorkflowHashNotEqualError: + raise DraftWorkflowNotSync() return { "result": "success", + "hash": workflow.unique_hash, "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at) } diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 9919a440e..94d16be8d 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -7,6 +7,7 @@ workflow_fields = { 'id': fields.String, 'graph': fields.Raw(attribute='graph_dict'), 'features': fields.Raw(attribute='features_dict'), + 'hash': fields.String(attribute='unique_hash'), 'created_by': fields.Nested(simple_account_fields, attribute='created_by_account'), 'created_at': TimestampField, 'updated_by': fields.Nested(simple_account_fields, attribute='updated_by_account', allow_null=True), diff --git a/api/models/workflow.py b/api/models/workflow.py index f261c67c7..3f4464103 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -4,6 +4,7 @@ from typing import Optional, Union from core.tools.tool_manager import ToolManager from extensions.ext_database import db +from libs import helper from models import StringUUID from models.account import Account @@ -156,6 +157,21 @@ class Workflow(db.Model): return variables + @property + def unique_hash(self) -> str: + """ + Get hash of workflow. + + :return: hash + """ + entity = { + 'graph': self.graph_dict, + 'features': self.features_dict + } + + return helper.generate_text_hash(json.dumps(entity, sort_keys=True)) + + class WorkflowRunTriggeredFrom(Enum): """ Workflow Run Triggered From Enum diff --git a/api/services/app_service.py b/api/services/app_service.py index 11073af09..585c1c7f8 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -196,6 +196,7 @@ class AppService: app_model=app, graph=workflow.get('graph'), features=workflow.get('features'), + unique_hash=None, account=account ) workflow_service.publish_workflow( diff --git a/api/services/errors/app.py b/api/services/errors/app.py index 7c4ca99c2..87e9e9247 100644 --- a/api/services/errors/app.py +++ b/api/services/errors/app.py @@ -1,2 +1,6 @@ class MoreLikeThisDisabledError(Exception): pass + + +class WorkflowHashNotEqualError(Exception): + pass diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 01fd3aa4a..456ab0dcb 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -21,6 +21,7 @@ from models.workflow import ( WorkflowNodeExecutionTriggeredFrom, WorkflowType, ) +from services.errors.app import WorkflowHashNotEqualError from services.workflow.workflow_converter import WorkflowConverter @@ -63,13 +64,20 @@ class WorkflowService: def sync_draft_workflow(self, app_model: App, graph: dict, features: dict, + unique_hash: Optional[str], account: Account) -> Workflow: """ Sync draft workflow + @throws WorkflowHashNotEqualError """ # fetch draft workflow by app_model workflow = self.get_draft_workflow(app_model=app_model) + if workflow: + # validate unique hash + if workflow.unique_hash != unique_hash: + raise WorkflowHashNotEqualError() + # validate features structure self.validate_features_structure( app_model=app_model,