diff --git a/api/controllers/service_api/app/__init__.py b/api/controllers/service_api/app/__init__.py index d8018ee38..e69de29bb 100644 --- a/api/controllers/service_api/app/__init__.py +++ b/api/controllers/service_api/app/__init__.py @@ -1,27 +0,0 @@ -from extensions.ext_database import db -from models.model import EndUser - - -def create_or_update_end_user_for_user_id(app_model, user_id): - """ - Create or update session terminal based on user ID. - """ - end_user = db.session.query(EndUser) \ - .filter( - EndUser.tenant_id == app_model.tenant_id, - EndUser.session_id == user_id, - EndUser.type == 'service_api' - ).first() - - if end_user is None: - end_user = EndUser( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - type='service_api', - is_anonymous=True, - session_id=user_id - ) - db.session.add(end_user) - db.session.commit() - - return end_user diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 9cd9770c0..a3151fc4a 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,16 +1,16 @@ import json from flask import current_app -from flask_restful import fields, marshal_with +from flask_restful import fields, marshal_with, Resource from controllers.service_api import api -from controllers.service_api.wraps import AppApiResource +from controllers.service_api.wraps import validate_app_token from extensions.ext_database import db from models.model import App, AppModelConfig from models.tools import ApiToolProvider -class AppParameterApi(AppApiResource): +class AppParameterApi(Resource): """Resource for app variables.""" variable_fields = { @@ -42,8 +42,9 @@ class AppParameterApi(AppApiResource): 'system_parameters': fields.Nested(system_parameters_fields) } + @validate_app_token @marshal_with(parameters_fields) - def get(self, app_model: App, end_user): + def get(self, app_model: App): """Retrieve app parameters.""" app_model_config = app_model.app_model_config @@ -64,8 +65,9 @@ class AppParameterApi(AppApiResource): } } -class AppMetaApi(AppApiResource): - def get(self, app_model: App, end_user): +class AppMetaApi(Resource): + @validate_app_token + def get(self, app_model: App): """Get app meta""" app_model_config: AppModelConfig = app_model.app_model_config diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index d2906b1d6..58ab56a29 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -1,7 +1,7 @@ import logging from flask import request -from flask_restful import reqparse +from flask_restful import Resource, reqparse from werkzeug.exceptions import InternalServerError import services @@ -17,10 +17,10 @@ from controllers.service_api.app.error import ( ProviderQuotaExceededError, UnsupportedAudioTypeError, ) -from controllers.service_api.wraps import AppApiResource +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from models.model import App, AppModelConfig +from models.model import App, AppModelConfig, EndUser from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, @@ -30,8 +30,9 @@ from services.errors.audio import ( ) -class AudioApi(AppApiResource): - def post(self, app_model: App, end_user): +class AudioApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) + def post(self, app_model: App, end_user: EndUser): app_model_config: AppModelConfig = app_model.app_model_config if not app_model_config.speech_to_text_dict['enabled']: @@ -73,11 +74,11 @@ class AudioApi(AppApiResource): raise InternalServerError() -class TextApi(AppApiResource): - def post(self, app_model: App, end_user): +class TextApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) + def post(self, app_model: App, end_user: EndUser): parser = reqparse.RequestParser() parser.add_argument('text', type=str, required=True, nullable=False, location='json') - parser.add_argument('user', type=str, required=True, nullable=False, location='json') parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json') args = parser.parse_args() @@ -85,7 +86,7 @@ class TextApi(AppApiResource): response = AudioService.transcript_tts( tenant_id=app_model.tenant_id, text=args['text'], - end_user=args['user'], + end_user=end_user, voice=app_model.app_model_config.text_to_speech_dict.get('voice'), streaming=args['streaming'] ) diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 5331f796e..c6cfb2437 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -4,12 +4,11 @@ from collections.abc import Generator from typing import Union from flask import Response, stream_with_context -from flask_restful import reqparse +from flask_restful import Resource, reqparse from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.service_api import api -from controllers.service_api.app import create_or_update_end_user_for_user_id from controllers.service_api.app.error import ( AppUnavailableError, CompletionRequestError, @@ -19,17 +18,19 @@ from controllers.service_api.app.error import ( ProviderNotInitializeError, ProviderQuotaExceededError, ) -from controllers.service_api.wraps import AppApiResource +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.application_queue_manager import ApplicationQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value +from models.model import App, EndUser from services.completion_service import CompletionService -class CompletionApi(AppApiResource): - def post(self, app_model, end_user): +class CompletionApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) + def post(self, app_model: App, end_user: EndUser): if app_model.mode != 'completion': raise AppUnavailableError() @@ -38,16 +39,12 @@ class CompletionApi(AppApiResource): parser.add_argument('query', type=str, location='json', default='') parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('user', required=True, nullable=False, type=str, location='json') parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') args = parser.parse_args() streaming = args['response_mode'] == 'streaming' - if end_user is None and args['user'] is not None: - end_user = create_or_update_end_user_for_user_id(app_model, args['user']) - args['auto_generate_name'] = False try: @@ -82,29 +79,20 @@ class CompletionApi(AppApiResource): raise InternalServerError() -class CompletionStopApi(AppApiResource): - def post(self, app_model, end_user, task_id): +class CompletionStopApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) + def post(self, app_model: App, end_user: EndUser, task_id): if app_model.mode != 'completion': raise AppUnavailableError() - if end_user is None: - parser = reqparse.RequestParser() - parser.add_argument('user', required=True, nullable=False, type=str, location='json') - args = parser.parse_args() - - user = args.get('user') - if user is not None: - end_user = create_or_update_end_user_for_user_id(app_model, user) - else: - raise ValueError("arg user muse be input.") - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) return {'result': 'success'}, 200 -class ChatApi(AppApiResource): - def post(self, app_model, end_user): +class ChatApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) + def post(self, app_model: App, end_user: EndUser): if app_model.mode != 'chat': raise NotChatAppError() @@ -114,7 +102,6 @@ class ChatApi(AppApiResource): parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('user', type=str, required=True, nullable=False, location='json') parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json') @@ -122,9 +109,6 @@ class ChatApi(AppApiResource): streaming = args['response_mode'] == 'streaming' - if end_user is None and args['user'] is not None: - end_user = create_or_update_end_user_for_user_id(app_model, args['user']) - try: response = CompletionService.completion( app_model=app_model, @@ -157,22 +141,12 @@ class ChatApi(AppApiResource): raise InternalServerError() -class ChatStopApi(AppApiResource): - def post(self, app_model, end_user, task_id): +class ChatStopApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) + def post(self, app_model: App, end_user: EndUser, task_id): if app_model.mode != 'chat': raise NotChatAppError() - if end_user is None: - parser = reqparse.RequestParser() - parser.add_argument('user', required=True, nullable=False, type=str, location='json') - args = parser.parse_args() - - user = args.get('user') - if user is not None: - end_user = create_or_update_end_user_for_user_id(app_model, user) - else: - raise ValueError("arg user muse be input.") - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) return {'result': 'success'}, 200 diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 3c157bed9..4a5fe2f19 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,52 +1,44 @@ -from flask import request -from flask_restful import marshal_with, reqparse +from flask_restful import Resource, marshal_with, reqparse from flask_restful.inputs import int_range from werkzeug.exceptions import NotFound import services from controllers.service_api import api -from controllers.service_api.app import create_or_update_end_user_for_user_id from controllers.service_api.app.error import NotChatAppError -from controllers.service_api.wraps import AppApiResource +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value +from models.model import App, EndUser from services.conversation_service import ConversationService -class ConversationApi(AppApiResource): +class ConversationApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @marshal_with(conversation_infinite_scroll_pagination_fields) - def get(self, app_model, end_user): + def get(self, app_model: App, end_user: EndUser): if app_model.mode != 'chat': raise NotChatAppError() parser = reqparse.RequestParser() parser.add_argument('last_id', type=uuid_value, location='args') parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('user', type=str, location='args') args = parser.parse_args() - if end_user is None and args['user'] is not None: - end_user = create_or_update_end_user_for_user_id(app_model, args['user']) - try: return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit']) except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") -class ConversationDetailApi(AppApiResource): +class ConversationDetailApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @marshal_with(simple_conversation_fields) - def delete(self, app_model, end_user, c_id): + def delete(self, app_model: App, end_user: EndUser, c_id): if app_model.mode != 'chat': raise NotChatAppError() conversation_id = str(c_id) - user = request.get_json().get('user') - - if end_user is None and user is not None: - end_user = create_or_update_end_user_for_user_id(app_model, user) - try: ConversationService.delete(app_model, conversation_id, end_user) except services.errors.conversation.ConversationNotExistsError: @@ -54,10 +46,11 @@ class ConversationDetailApi(AppApiResource): return {"result": "success"}, 204 -class ConversationRenameApi(AppApiResource): +class ConversationRenameApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @marshal_with(simple_conversation_fields) - def post(self, app_model, end_user, c_id): + def post(self, app_model: App, end_user: EndUser, c_id): if app_model.mode != 'chat': raise NotChatAppError() @@ -65,13 +58,9 @@ class ConversationRenameApi(AppApiResource): parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('user', type=str, location='json') parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') args = parser.parse_args() - if end_user is None and args['user'] is not None: - end_user = create_or_update_end_user_for_user_id(app_model, args['user']) - try: return ConversationService.rename( app_model, diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index a901375ec..5dbc1b1d1 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -1,30 +1,27 @@ from flask import request -from flask_restful import marshal_with +from flask_restful import Resource, marshal_with import services from controllers.service_api import api -from controllers.service_api.app import create_or_update_end_user_for_user_id from controllers.service_api.app.error import ( FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError, ) -from controllers.service_api.wraps import AppApiResource +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from fields.file_fields import file_fields +from models.model import App, EndUser from services.file_service import FileService -class FileApi(AppApiResource): +class FileApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) @marshal_with(file_fields) - def post(self, app_model, end_user): + def post(self, app_model: App, end_user: EndUser): file = request.files['file'] - user_args = request.form.get('user') - - if end_user is None and user_args is not None: - end_user = create_or_update_end_user_for_user_id(app_model, user_args) # check file if 'file' not in request.files: diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index d90f536a4..0050ab1ae 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,20 +1,18 @@ -from flask_restful import fields, marshal_with, reqparse +from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful.inputs import int_range from werkzeug.exceptions import NotFound import services from controllers.service_api import api -from controllers.service_api.app import create_or_update_end_user_for_user_id from controllers.service_api.app.error import NotChatAppError -from controllers.service_api.wraps import AppApiResource -from extensions.ext_database import db +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from fields.conversation_fields import message_file_fields from libs.helper import TimestampField, uuid_value -from models.model import EndUser, Message +from models.model import App, EndUser from services.message_service import MessageService -class MessageListApi(AppApiResource): +class MessageListApi(Resource): feedback_fields = { 'rating': fields.String } @@ -70,8 +68,9 @@ class MessageListApi(AppApiResource): 'data': fields.List(fields.Nested(message_fields)) } + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @marshal_with(message_infinite_scroll_pagination_fields) - def get(self, app_model, end_user): + def get(self, app_model: App, end_user: EndUser): if app_model.mode != 'chat': raise NotChatAppError() @@ -79,12 +78,8 @@ class MessageListApi(AppApiResource): parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') parser.add_argument('first_id', type=uuid_value, location='args') parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('user', type=str, location='args') args = parser.parse_args() - if end_user is None and args['user'] is not None: - end_user = create_or_update_end_user_for_user_id(app_model, args['user']) - try: return MessageService.pagination_by_first_id(app_model, end_user, args['conversation_id'], args['first_id'], args['limit']) @@ -94,18 +89,15 @@ class MessageListApi(AppApiResource): raise NotFound("First Message Not Exists.") -class MessageFeedbackApi(AppApiResource): - def post(self, app_model, end_user, message_id): +class MessageFeedbackApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) + def post(self, app_model: App, end_user: EndUser, message_id): message_id = str(message_id) parser = reqparse.RequestParser() parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') - parser.add_argument('user', type=str, location='json') args = parser.parse_args() - if end_user is None and args['user'] is not None: - end_user = create_or_update_end_user_for_user_id(app_model, args['user']) - try: MessageService.create_feedback(app_model, message_id, end_user, args['rating']) except services.errors.message.MessageNotExistsError: @@ -114,29 +106,17 @@ class MessageFeedbackApi(AppApiResource): return {'result': 'success'} -class MessageSuggestedApi(AppApiResource): - def get(self, app_model, end_user, message_id): +class MessageSuggestedApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) + def get(self, app_model: App, end_user: EndUser, message_id): message_id = str(message_id) if app_model.mode != 'chat': raise NotChatAppError() - try: - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - ).first() - if end_user is None and message.from_end_user_id is not None: - user = db.session.query(EndUser) \ - .filter( - EndUser.tenant_id == app_model.tenant_id, - EndUser.id == message.from_end_user_id, - EndUser.type == 'service_api' - ).first() - else: - user = end_user + try: questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, - user=user, + user=end_user, message_id=message_id, check_enabled=False ) diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index a0d89fe62..169c475af 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,22 +1,40 @@ +from collections.abc import Callable from datetime import datetime +from enum import Enum from functools import wraps +from typing import Optional from flask import current_app, request from flask_login import user_logged_in from flask_restful import Resource +from pydantic import BaseModel from werkzeug.exceptions import NotFound, Unauthorized from extensions.ext_database import db from libs.login import _get_user from models.account import Account, Tenant, TenantAccountJoin -from models.model import ApiToken, App +from models.model import ApiToken, App, EndUser from services.feature_service import FeatureService -def validate_app_token(view=None): - def decorator(view): - @wraps(view) - def decorated(*args, **kwargs): +class WhereisUserArg(Enum): + """ + Enum for whereis_user_arg. + """ + QUERY = 'query' + JSON = 'json' + FORM = 'form' + + +class FetchUserArg(BaseModel): + fetch_from: WhereisUserArg + required: bool = False + + +def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None): + def decorator(view_func): + @wraps(view_func) + def decorated_view(*args, **kwargs): api_token = validate_and_get_api_token('app') app_model = db.session.query(App).filter(App.id == api_token.app_id).first() @@ -29,16 +47,35 @@ def validate_app_token(view=None): if not app_model.enable_api: raise NotFound() - return view(app_model, None, *args, **kwargs) - return decorated + kwargs['app_model'] = app_model - if view: + if not fetch_user_arg: + # use default-user + user_id = None + else: + if fetch_user_arg.fetch_from == WhereisUserArg.QUERY: + user_id = request.args.get('user') + elif fetch_user_arg.fetch_from == WhereisUserArg.JSON: + user_id = request.get_json().get('user') + elif fetch_user_arg.fetch_from == WhereisUserArg.FORM: + user_id = request.form.get('user') + else: + # use default-user + user_id = None + + if not user_id and fetch_user_arg.required: + raise ValueError("Arg user must be provided.") + + kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id) + + return view_func(*args, **kwargs) + return decorated_view + + if view is None: + return decorator + else: return decorator(view) - # if view is None, it means that the decorator is used without parentheses - # use the decorator as a function for method_decorators - return decorator - def cloud_edition_billing_resource_check(resource: str, api_token_type: str, @@ -128,8 +165,33 @@ def validate_and_get_api_token(scope=None): return api_token -class AppApiResource(Resource): - method_decorators = [validate_app_token] +def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser: + """ + Create or update session terminal based on user ID. + """ + if not user_id: + user_id = 'DEFAULT-USER' + + end_user = db.session.query(EndUser) \ + .filter( + EndUser.tenant_id == app_model.tenant_id, + EndUser.app_id == app_model.id, + EndUser.session_id == user_id, + EndUser.type == 'service_api' + ).first() + + if end_user is None: + end_user = EndUser( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type='service_api', + is_anonymous=True if user_id == 'DEFAULT-USER' else False, + session_id=user_id + ) + db.session.add(end_user) + db.session.commit() + + return end_user class DatasetApiResource(Resource):