diff --git a/api/pyproject.toml b/api/pyproject.toml index 9b90b837a..eddeeb0cd 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -74,7 +74,6 @@ exclude = [ "controllers/**/*.py", "models/**/*.py", "migrations/**/*", - "services/**/*.py", ] [tool.pytest_env] diff --git a/api/services/__init__.py b/api/services/__init__.py index 689143631..5163862cc 100644 --- a/api/services/__init__.py +++ b/api/services/__init__.py @@ -1,3 +1,3 @@ from . import errors -__all__ = ['errors'] +__all__ = ["errors"] diff --git a/api/services/account_service.py b/api/services/account_service.py index d73cec269..cd501c979 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -39,12 +39,7 @@ from tasks.mail_reset_password_task import send_reset_password_mail_task class AccountService: - - reset_password_rate_limiter = RateLimiter( - prefix="reset_password_rate_limit", - max_attempts=5, - time_window=60 * 60 - ) + reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60) @staticmethod def load_user(user_id: str) -> None | Account: @@ -55,12 +50,15 @@ class AccountService: if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]: raise Unauthorized("Account is banned or closed.") - current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() + current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by( + account_id=account.id, current=True + ).first() if current_tenant: account.current_tenant_id = current_tenant.tenant_id else: - available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \ - .order_by(TenantAccountJoin.id.asc()).first() + available_ta = ( + TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() + ) if not available_ta: return None @@ -74,14 +72,13 @@ class AccountService: return account - @staticmethod def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)): payload = { "user_id": account.id, "exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp, "iss": dify_config.EDITION, - "sub": 'Console API Passport', + "sub": "Console API Passport", } token = PassportService().issue(payload) @@ -93,10 +90,10 @@ class AccountService: account = Account.query.filter_by(email=email).first() if not account: - raise AccountLoginError('Invalid email or password.') + raise AccountLoginError("Invalid email or password.") if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: - raise AccountLoginError('Account is banned or closed.') + raise AccountLoginError("Account is banned or closed.") if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value @@ -104,7 +101,7 @@ class AccountService: db.session.commit() if account.password is None or not compare_password(password, account.password, account.password_salt): - raise AccountLoginError('Invalid email or password.') + raise AccountLoginError("Invalid email or password.") return account @staticmethod @@ -129,11 +126,9 @@ class AccountService: return account @staticmethod - def create_account(email: str, - name: str, - interface_language: str, - password: Optional[str] = None, - interface_theme: str = 'light') -> Account: + def create_account( + email: str, name: str, interface_language: str, password: Optional[str] = None, interface_theme: str = "light" + ) -> Account: """create account""" account = Account() account.email = email @@ -155,7 +150,7 @@ class AccountService: account.interface_theme = interface_theme # Set timezone based on language - account.timezone = language_timezone_mapping.get(interface_language, 'UTC') + account.timezone = language_timezone_mapping.get(interface_language, "UTC") db.session.add(account) db.session.commit() @@ -166,8 +161,9 @@ class AccountService: """Link account integrate""" try: # Query whether there is an existing binding record for the same provider - account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(account_id=account.id, - provider=provider).first() + account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by( + account_id=account.id, provider=provider + ).first() if account_integrate: # If it exists, update the record @@ -176,15 +172,16 @@ class AccountService: account_integrate.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) else: # If it does not exist, create a new record - account_integrate = AccountIntegrate(account_id=account.id, provider=provider, open_id=open_id, - encrypted_token="") + account_integrate = AccountIntegrate( + account_id=account.id, provider=provider, open_id=open_id, encrypted_token="" + ) db.session.add(account_integrate) db.session.commit() - logging.info(f'Account {account.id} linked {provider} account {open_id}.') + logging.info(f"Account {account.id} linked {provider} account {open_id}.") except Exception as e: - logging.exception(f'Failed to link {provider} account {open_id} to Account {account.id}') - raise LinkAccountIntegrateError('Failed to link account.') from e + logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}") + raise LinkAccountIntegrateError("Failed to link account.") from e @staticmethod def close_account(account: Account) -> None: @@ -218,7 +215,7 @@ class AccountService: AccountService.update_last_login(account, ip_address=ip_address) exp = timedelta(days=30) token = AccountService.get_account_jwt_token(account, exp=exp) - redis_client.set(_get_login_cache_key(account_id=account.id, token=token), '1', ex=int(exp.total_seconds())) + redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds())) return token @staticmethod @@ -236,22 +233,18 @@ class AccountService: if cls.reset_password_rate_limiter.is_rate_limited(account.email): raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.") - token = TokenManager.generate_token(account, 'reset_password') - send_reset_password_mail_task.delay( - language=account.interface_language, - to=account.email, - token=token - ) + token = TokenManager.generate_token(account, "reset_password") + send_reset_password_mail_task.delay(language=account.interface_language, to=account.email, token=token) cls.reset_password_rate_limiter.increment_rate_limit(account.email) return token @classmethod def revoke_reset_password_token(cls, token: str): - TokenManager.revoke_token(token, 'reset_password') + TokenManager.revoke_token(token, "reset_password") @classmethod def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: - return TokenManager.get_token_data(token, 'reset_password') + return TokenManager.get_token_data(token, "reset_password") def _get_login_cache_key(*, account_id: str, token: str): @@ -259,7 +252,6 @@ def _get_login_cache_key(*, account_id: str, token: str): class TenantService: - @staticmethod def create_tenant(name: str) -> Tenant: """Create tenant""" @@ -275,31 +267,28 @@ class TenantService: @staticmethod def create_owner_tenant_if_not_exist(account: Account): """Create owner tenant if not exist""" - available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \ - .order_by(TenantAccountJoin.id.asc()).first() + available_ta = ( + TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() + ) if available_ta: return tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role='owner') + TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant db.session.commit() tenant_was_created.send(tenant) @staticmethod - def create_tenant_member(tenant: Tenant, account: Account, role: str = 'normal') -> TenantAccountJoin: + def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin: """Create tenant member""" if role == TenantAccountJoinRole.OWNER.value: if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]): - logging.error(f'Tenant {tenant.id} has already an owner.') - raise Exception('Tenant already has an owner.') + logging.error(f"Tenant {tenant.id} has already an owner.") + raise Exception("Tenant already has an owner.") - ta = TenantAccountJoin( - tenant_id=tenant.id, - account_id=account.id, - role=role - ) + ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role) db.session.add(ta) db.session.commit() return ta @@ -307,9 +296,12 @@ class TenantService: @staticmethod def get_join_tenants(account: Account) -> list[Tenant]: """Get account join tenants""" - return db.session.query(Tenant).join( - TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id - ).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all() + return ( + db.session.query(Tenant) + .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) + .filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL) + .all() + ) @staticmethod def get_current_tenant_by_account(account: Account): @@ -333,16 +325,23 @@ class TenantService: if tenant_id is None: raise ValueError("Tenant ID must be provided.") - tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter( - TenantAccountJoin.account_id == account.id, - TenantAccountJoin.tenant_id == tenant_id, - Tenant.status == TenantStatus.NORMAL, - ).first() + tenant_account_join = ( + db.session.query(TenantAccountJoin) + .join(Tenant, TenantAccountJoin.tenant_id == Tenant.id) + .filter( + TenantAccountJoin.account_id == account.id, + TenantAccountJoin.tenant_id == tenant_id, + Tenant.status == TenantStatus.NORMAL, + ) + .first() + ) if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") else: - TenantAccountJoin.query.filter(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id).update({'current': False}) + TenantAccountJoin.query.filter( + TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id + ).update({"current": False}) tenant_account_join.current = True # Set the current tenant for the account account.current_tenant_id = tenant_account_join.tenant_id @@ -354,9 +353,7 @@ class TenantService: query = ( db.session.query(Account, TenantAccountJoin.role) .select_from(Account) - .join( - TenantAccountJoin, Account.id == TenantAccountJoin.account_id - ) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .filter(TenantAccountJoin.tenant_id == tenant.id) ) @@ -375,11 +372,9 @@ class TenantService: query = ( db.session.query(Account, TenantAccountJoin.role) .select_from(Account) - .join( - TenantAccountJoin, Account.id == TenantAccountJoin.account_id - ) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) .filter(TenantAccountJoin.tenant_id == tenant.id) - .filter(TenantAccountJoin.role == 'dataset_operator') + .filter(TenantAccountJoin.role == "dataset_operator") ) # Initialize an empty list to store the updated accounts @@ -395,20 +390,25 @@ class TenantService: def has_roles(tenant: Tenant, roles: list[TenantAccountJoinRole]) -> bool: """Check if user has any of the given roles for a tenant""" if not all(isinstance(role, TenantAccountJoinRole) for role in roles): - raise ValueError('all roles must be TenantAccountJoinRole') + raise ValueError("all roles must be TenantAccountJoinRole") - return db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant.id, - TenantAccountJoin.role.in_([role.value for role in roles]) - ).first() is not None + return ( + db.session.query(TenantAccountJoin) + .filter( + TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles]) + ) + .first() + is not None + ) @staticmethod def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]: """Get the role of the current account for a given tenant""" - join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant.id, - TenantAccountJoin.account_id == account.id - ).first() + join = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id) + .first() + ) return join.role if join else None @staticmethod @@ -420,29 +420,26 @@ class TenantService: def check_member_permission(tenant: Tenant, operator: Account, member: Account, action: str) -> None: """Check member permission""" perms = { - 'add': [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], - 'remove': [TenantAccountRole.OWNER], - 'update': [TenantAccountRole.OWNER] + "add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], + "remove": [TenantAccountRole.OWNER], + "update": [TenantAccountRole.OWNER], } - if action not in ['add', 'remove', 'update']: + if action not in ["add", "remove", "update"]: raise InvalidActionError("Invalid action.") if member: if operator.id == member.id: raise CannotOperateSelfError("Cannot operate self.") - ta_operator = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - account_id=operator.id - ).first() + ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first() if not ta_operator or ta_operator.role not in perms[action]: - raise NoPermissionError(f'No permission to {action} member.') + raise NoPermissionError(f"No permission to {action} member.") @staticmethod def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None: """Remove member from tenant""" - if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'): + if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, "remove"): raise CannotOperateSelfError("Cannot operate self.") ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() @@ -455,23 +452,17 @@ class TenantService: @staticmethod def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None: """Update member role""" - TenantService.check_member_permission(tenant, operator, member, 'update') + TenantService.check_member_permission(tenant, operator, member, "update") - target_member_join = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - account_id=member.id - ).first() + target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first() if target_member_join.role == new_role: raise RoleAlreadyAssignedError("The provided role is already assigned to the member.") - if new_role == 'owner': + if new_role == "owner": # Find the current owner and change their role to 'admin' - current_owner_join = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - role='owner' - ).first() - current_owner_join.role = 'admin' + current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() + current_owner_join.role = "admin" # Update the role of the target member target_member_join.role = new_role @@ -480,8 +471,8 @@ class TenantService: @staticmethod def dissolve_tenant(tenant: Tenant, operator: Account) -> None: """Dissolve tenant""" - if not TenantService.check_member_permission(tenant, operator, operator, 'remove'): - raise NoPermissionError('No permission to dissolve tenant.') + if not TenantService.check_member_permission(tenant, operator, operator, "remove"): + raise NoPermissionError("No permission to dissolve tenant.") db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() db.session.delete(tenant) db.session.commit() @@ -494,10 +485,9 @@ class TenantService: class RegisterService: - @classmethod def _get_invitation_token_key(cls, token: str) -> str: - return f'member_invite:token:{token}' + return f"member_invite:token:{token}" @classmethod def setup(cls, email: str, name: str, password: str, ip_address: str) -> None: @@ -523,9 +513,7 @@ class RegisterService: TenantService.create_owner_tenant_if_not_exist(account) - dify_setup = DifySetup( - version=dify_config.CURRENT_VERSION - ) + dify_setup = DifySetup(version=dify_config.CURRENT_VERSION) db.session.add(dify_setup) db.session.commit() except Exception as e: @@ -535,34 +523,35 @@ class RegisterService: db.session.query(Tenant).delete() db.session.commit() - logging.exception(f'Setup failed: {e}') - raise ValueError(f'Setup failed: {e}') + logging.exception(f"Setup failed: {e}") + raise ValueError(f"Setup failed: {e}") @classmethod - def register(cls, email, name, - password: Optional[str] = None, - open_id: Optional[str] = None, - provider: Optional[str] = None, - language: Optional[str] = None, - status: Optional[AccountStatus] = None) -> Account: + def register( + cls, + email, + name, + password: Optional[str] = None, + open_id: Optional[str] = None, + provider: Optional[str] = None, + language: Optional[str] = None, + status: Optional[AccountStatus] = None, + ) -> Account: db.session.begin_nested() """Register account""" try: account = AccountService.create_account( - email=email, - name=name, - interface_language=language if language else languages[0], - password=password + email=email, name=name, interface_language=language if language else languages[0], password=password ) account.status = AccountStatus.ACTIVE.value if not status else status.value account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) if open_id is not None or provider is not None: AccountService.link_account_integrate(provider, open_id, account) - if dify_config.EDITION != 'SELF_HOSTED': + if dify_config.EDITION != "SELF_HOSTED": tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role='owner') + TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant tenant_was_created.send(tenant) @@ -570,30 +559,29 @@ class RegisterService: db.session.commit() except Exception as e: db.session.rollback() - logging.error(f'Register failed: {e}') - raise AccountRegisterError(f'Registration failed: {e}') from e + logging.error(f"Register failed: {e}") + raise AccountRegisterError(f"Registration failed: {e}") from e return account @classmethod - def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal', inviter: Account = None) -> str: + def invite_new_member( + cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account = None + ) -> str: """Invite new member""" account = Account.query.filter_by(email=email).first() if not account: - TenantService.check_member_permission(tenant, inviter, None, 'add') - name = email.split('@')[0] + TenantService.check_member_permission(tenant, inviter, None, "add") + name = email.split("@")[0] account = cls.register(email=email, name=name, language=language, status=AccountStatus.PENDING) # Create new tenant member for invited tenant TenantService.create_tenant_member(tenant, account, role) TenantService.switch_tenant(account, tenant.id) else: - TenantService.check_member_permission(tenant, inviter, account, 'add') - ta = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, - account_id=account.id - ).first() + TenantService.check_member_permission(tenant, inviter, account, "add") + ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() if not ta: TenantService.create_tenant_member(tenant, account, role) @@ -609,7 +597,7 @@ class RegisterService: language=account.interface_language, to=email, token=token, - inviter_name=inviter.name if inviter else 'Dify', + inviter_name=inviter.name if inviter else "Dify", workspace_name=tenant.name, ) @@ -619,23 +607,19 @@ class RegisterService: def generate_invite_token(cls, tenant: Tenant, account: Account) -> str: token = str(uuid.uuid4()) invitation_data = { - 'account_id': account.id, - 'email': account.email, - 'workspace_id': tenant.id, + "account_id": account.id, + "email": account.email, + "workspace_id": tenant.id, } expiryHours = dify_config.INVITE_EXPIRY_HOURS - redis_client.setex( - cls._get_invitation_token_key(token), - expiryHours * 60 * 60, - json.dumps(invitation_data) - ) + redis_client.setex(cls._get_invitation_token_key(token), expiryHours * 60 * 60, json.dumps(invitation_data)) return token @classmethod def revoke_token(cls, workspace_id: str, email: str, token: str): if workspace_id and email: email_hash = sha256(email.encode()).hexdigest() - cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token) + cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token) redis_client.delete(cache_key) else: redis_client.delete(cls._get_invitation_token_key(token)) @@ -646,17 +630,21 @@ class RegisterService: if not invitation_data: return None - tenant = db.session.query(Tenant).filter( - Tenant.id == invitation_data['workspace_id'], - Tenant.status == 'normal' - ).first() + tenant = ( + db.session.query(Tenant) + .filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal") + .first() + ) if not tenant: return None - tenant_account = db.session.query(Account, TenantAccountJoin.role).join( - TenantAccountJoin, Account.id == TenantAccountJoin.account_id - ).filter(Account.email == invitation_data['email'], TenantAccountJoin.tenant_id == tenant.id).first() + tenant_account = ( + db.session.query(Account, TenantAccountJoin.role) + .join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id) + .filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id) + .first() + ) if not tenant_account: return None @@ -665,29 +653,29 @@ class RegisterService: if not account: return None - if invitation_data['account_id'] != str(account.id): + if invitation_data["account_id"] != str(account.id): return None return { - 'account': account, - 'data': invitation_data, - 'tenant': tenant, + "account": account, + "data": invitation_data, + "tenant": tenant, } @classmethod def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> Optional[dict[str, str]]: if workspace_id is not None and email is not None: email_hash = sha256(email.encode()).hexdigest() - cache_key = f'member_invite_token:{workspace_id}, {email_hash}:{token}' + cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}" account_id = redis_client.get(cache_key) if not account_id: return None return { - 'account_id': account_id.decode('utf-8'), - 'email': email, - 'workspace_id': workspace_id, + "account_id": account_id.decode("utf-8"), + "email": email, + "workspace_id": workspace_id, } else: data = redis_client.get(cls._get_invitation_token_key(token)) diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index 213df2622..d2cd7bea6 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -1,4 +1,3 @@ - import copy from core.prompt.prompt_templates.advanced_prompt_templates import ( @@ -17,59 +16,78 @@ from models.model import AppMode class AdvancedPromptTemplateService: - @classmethod def get_prompt(cls, args: dict) -> dict: - app_mode = args['app_mode'] - model_mode = args['model_mode'] - model_name = args['model_name'] - has_context = args['has_context'] + app_mode = args["app_mode"] + model_mode = args["model_mode"] + model_name = args["model_name"] + has_context = args["has_context"] - if 'baichuan' in model_name.lower(): + if "baichuan" in model_name.lower(): return cls.get_baichuan_prompt(app_mode, model_mode, has_context) else: return cls.get_common_prompt(app_mode, model_mode, has_context) @classmethod - def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: + def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: context_prompt = copy.deepcopy(CONTEXT) if app_mode == AppMode.CHAT.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt + ) elif model_mode == "chat": return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) elif app_mode == AppMode.COMPLETION.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt + ) elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt) - + return cls.get_chat_prompt( + copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt + ) + @classmethod def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: - if has_context == 'true': - prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text'] - + if has_context == "true": + prompt_template["completion_prompt_config"]["prompt"]["text"] = ( + context + prompt_template["completion_prompt_config"]["prompt"]["text"] + ) + return prompt_template @classmethod def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict: - if has_context == 'true': - prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text'] - + if has_context == "true": + prompt_template["chat_prompt_config"]["prompt"][0]["text"] = ( + context + prompt_template["chat_prompt_config"]["prompt"][0]["text"] + ) + return prompt_template @classmethod - def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: + def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) if app_mode == AppMode.CHAT.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) + return cls.get_chat_prompt( + copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) elif app_mode == AppMode.COMPLETION.value: if model_mode == "completion": - return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) + return cls.get_completion_prompt( + copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), + has_context, + baichuan_context_prompt, + ) elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) \ No newline at end of file + return cls.get_chat_prompt( + copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt + ) diff --git a/api/services/agent_service.py b/api/services/agent_service.py index ba5fd9332..887fb878b 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -10,59 +10,65 @@ from models.model import App, Conversation, EndUser, Message, MessageAgentThough class AgentService: @classmethod - def get_agent_logs(cls, app_model: App, - conversation_id: str, - message_id: str) -> dict: + def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict: """ Service to get agent logs """ - conversation: Conversation = db.session.query(Conversation).filter( - Conversation.id == conversation_id, - Conversation.app_id == app_model.id, - ).first() + conversation: Conversation = ( + db.session.query(Conversation) + .filter( + Conversation.id == conversation_id, + Conversation.app_id == app_model.id, + ) + .first() + ) if not conversation: raise ValueError(f"Conversation not found: {conversation_id}") - message: Message = db.session.query(Message).filter( - Message.id == message_id, - Message.conversation_id == conversation_id, - ).first() + message: Message = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.conversation_id == conversation_id, + ) + .first() + ) if not message: raise ValueError(f"Message not found: {message_id}") - + agent_thoughts: list[MessageAgentThought] = message.agent_thoughts if conversation.from_end_user_id: # only select name field - executor = db.session.query(EndUser, EndUser.name).filter( - EndUser.id == conversation.from_end_user_id - ).first() + executor = ( + db.session.query(EndUser, EndUser.name).filter(EndUser.id == conversation.from_end_user_id).first() + ) else: - executor = db.session.query(Account, Account.name).filter( - Account.id == conversation.from_account_id - ).first() - + executor = ( + db.session.query(Account, Account.name).filter(Account.id == conversation.from_account_id).first() + ) + if executor: executor = executor.name else: - executor = 'Unknown' + executor = "Unknown" timezone = pytz.timezone(current_user.timezone) result = { - 'meta': { - 'status': 'success', - 'executor': executor, - 'start_time': message.created_at.astimezone(timezone).isoformat(), - 'elapsed_time': message.provider_response_latency, - 'total_tokens': message.answer_tokens + message.message_tokens, - 'agent_mode': app_model.app_model_config.agent_mode_dict.get('strategy', 'react'), - 'iterations': len(agent_thoughts), + "meta": { + "status": "success", + "executor": executor, + "start_time": message.created_at.astimezone(timezone).isoformat(), + "elapsed_time": message.provider_response_latency, + "total_tokens": message.answer_tokens + message.message_tokens, + "agent_mode": app_model.app_model_config.agent_mode_dict.get("strategy", "react"), + "iterations": len(agent_thoughts), }, - 'iterations': [], - 'files': message.files, + "iterations": [], + "files": message.files, } agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict()) @@ -86,12 +92,12 @@ class AgentService: tool_input = tool_inputs.get(tool_name, {}) tool_output = tool_outputs.get(tool_name, {}) tool_meta_data = tool_meta.get(tool_name, {}) - tool_config = tool_meta_data.get('tool_config', {}) - if tool_config.get('tool_provider_type', '') != 'dataset-retrieval': + tool_config = tool_meta_data.get("tool_config", {}) + if tool_config.get("tool_provider_type", "") != "dataset-retrieval": tool_icon = ToolManager.get_tool_icon( tenant_id=app_model.tenant_id, - provider_type=tool_config.get('tool_provider_type', ''), - provider_id=tool_config.get('tool_provider', ''), + provider_type=tool_config.get("tool_provider_type", ""), + provider_id=tool_config.get("tool_provider", ""), ) if not tool_icon: tool_entity = find_agent_tool(tool_name) @@ -102,30 +108,34 @@ class AgentService: provider_id=tool_entity.provider_id, ) else: - tool_icon = '' + tool_icon = "" - tool_calls.append({ - 'status': 'success' if not tool_meta_data.get('error') else 'error', - 'error': tool_meta_data.get('error'), - 'time_cost': tool_meta_data.get('time_cost', 0), - 'tool_name': tool_name, - 'tool_label': tool_label, - 'tool_input': tool_input, - 'tool_output': tool_output, - 'tool_parameters': tool_meta_data.get('tool_parameters', {}), - 'tool_icon': tool_icon, - }) + tool_calls.append( + { + "status": "success" if not tool_meta_data.get("error") else "error", + "error": tool_meta_data.get("error"), + "time_cost": tool_meta_data.get("time_cost", 0), + "tool_name": tool_name, + "tool_label": tool_label, + "tool_input": tool_input, + "tool_output": tool_output, + "tool_parameters": tool_meta_data.get("tool_parameters", {}), + "tool_icon": tool_icon, + } + ) - result['iterations'].append({ - 'tokens': agent_thought.tokens, - 'tool_calls': tool_calls, - 'tool_raw': { - 'inputs': agent_thought.tool_input, - 'outputs': agent_thought.observation, - }, - 'thought': agent_thought.thought, - 'created_at': agent_thought.created_at.isoformat(), - 'files': agent_thought.files, - }) + result["iterations"].append( + { + "tokens": agent_thought.tokens, + "tool_calls": tool_calls, + "tool_raw": { + "inputs": agent_thought.tool_input, + "outputs": agent_thought.observation, + }, + "thought": agent_thought.thought, + "created_at": agent_thought.created_at.isoformat(), + "files": agent_thought.files, + } + ) - return result \ No newline at end of file + return result diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index addcde44e..3cc6c51c2 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -23,21 +23,18 @@ class AppAnnotationService: @classmethod def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - if args.get('message_id'): - message_id = str(args['message_id']) + if args.get("message_id"): + message_id = str(args["message_id"]) # get message info - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app.id - ).first() + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app.id).first() if not message: raise NotFound("Message Not Exists.") @@ -45,159 +42,166 @@ class AppAnnotationService: annotation = message.annotation # save the message annotation if annotation: - annotation.content = args['answer'] - annotation.question = args['question'] + annotation.content = args["answer"] + annotation.question = args["question"] else: annotation = MessageAnnotation( app_id=app.id, conversation_id=message.conversation_id, message_id=message.id, - content=args['answer'], - question=args['question'], - account_id=current_user.id + content=args["answer"], + question=args["question"], + account_id=current_user.id, ) else: annotation = MessageAnnotation( - app_id=app.id, - content=args['answer'], - question=args['question'], - account_id=current_user.id + app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id ) db.session.add(annotation) db.session.commit() # if annotation reply is enabled , add annotation to index - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if annotation_setting: - add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id, - app_id, annotation_setting.collection_binding_id) + add_annotation_to_index_task.delay( + annotation.id, + args["question"], + current_user.current_tenant_id, + app_id, + annotation_setting.collection_binding_id, + ) return annotation @classmethod def enable_app_annotation(cls, args: dict, app_id: str) -> dict: - enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id)) + enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id)) cache_result = redis_client.get(enable_app_annotation_key) if cache_result is not None: - return { - 'job_id': cache_result, - 'job_status': 'processing' - } + return {"job_id": cache_result, "job_status": "processing"} # async job job_id = str(uuid.uuid4()) - enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id)) + enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(enable_app_annotation_job_key, 'waiting') - enable_annotation_reply_task.delay(str(job_id), app_id, current_user.id, current_user.current_tenant_id, - args['score_threshold'], - args['embedding_provider_name'], args['embedding_model_name']) - return { - 'job_id': job_id, - 'job_status': 'waiting' - } + redis_client.setnx(enable_app_annotation_job_key, "waiting") + enable_annotation_reply_task.delay( + str(job_id), + app_id, + current_user.id, + current_user.current_tenant_id, + args["score_threshold"], + args["embedding_provider_name"], + args["embedding_model_name"], + ) + return {"job_id": job_id, "job_status": "waiting"} @classmethod def disable_app_annotation(cls, app_id: str) -> dict: - disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id)) + disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id)) cache_result = redis_client.get(disable_app_annotation_key) if cache_result is not None: - return { - 'job_id': cache_result, - 'job_status': 'processing' - } + return {"job_id": cache_result, "job_status": "processing"} # async job job_id = str(uuid.uuid4()) - disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id)) + disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(disable_app_annotation_job_key, 'waiting') + redis_client.setnx(disable_app_annotation_job_key, "waiting") disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id) - return { - 'job_id': job_id, - 'job_status': 'waiting' - } + return {"job_id": job_id, "job_status": "waiting"} @classmethod def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") if keyword: - annotations = (db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) - .filter( - or_( - MessageAnnotation.question.ilike('%{}%'.format(keyword)), - MessageAnnotation.content.ilike('%{}%'.format(keyword)) + annotations = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .filter( + or_( + MessageAnnotation.question.ilike("%{}%".format(keyword)), + MessageAnnotation.content.ilike("%{}%".format(keyword)), + ) ) + .order_by(MessageAnnotation.created_at.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) ) - .order_by(MessageAnnotation.created_at.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) else: - annotations = (db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) - .order_by(MessageAnnotation.created_at.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) + annotations = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .order_by(MessageAnnotation.created_at.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + ) return annotations.items, annotations.total @classmethod def export_annotation_list_by_app_id(cls, app_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - annotations = (db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app_id) - .order_by(MessageAnnotation.created_at.desc()).all()) + annotations = ( + db.session.query(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) + .order_by(MessageAnnotation.created_at.desc()) + .all() + ) return annotations @classmethod def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation: # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") annotation = MessageAnnotation( - app_id=app.id, - content=args['answer'], - question=args['question'], - account_id=current_user.id + app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id ) db.session.add(annotation) db.session.commit() # if annotation reply is enabled , add annotation to index - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if annotation_setting: - add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id, - app_id, annotation_setting.collection_binding_id) + add_annotation_to_index_task.delay( + annotation.id, + args["question"], + current_user.current_tenant_id, + app_id, + annotation_setting.collection_binding_id, + ) return annotation @classmethod def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -207,30 +211,34 @@ class AppAnnotationService: if not annotation: raise NotFound("Annotation not found") - annotation.content = args['answer'] - annotation.question = args['question'] + annotation.content = args["answer"] + annotation.question = args["question"] db.session.commit() # if annotation reply is enabled , add annotation to index - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if app_annotation_setting: - update_annotation_to_index_task.delay(annotation.id, annotation.question, - current_user.current_tenant_id, - app_id, app_annotation_setting.collection_binding_id) + update_annotation_to_index_task.delay( + annotation.id, + annotation.question, + current_user.current_tenant_id, + app_id, + app_annotation_setting.collection_binding_id, + ) return annotation @classmethod def delete_app_annotation(cls, app_id: str, annotation_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -242,33 +250,34 @@ class AppAnnotationService: db.session.delete(annotation) - annotation_hit_histories = (db.session.query(AppAnnotationHitHistory) - .filter(AppAnnotationHitHistory.annotation_id == annotation_id) - .all() - ) + annotation_hit_histories = ( + db.session.query(AppAnnotationHitHistory) + .filter(AppAnnotationHitHistory.annotation_id == annotation_id) + .all() + ) if annotation_hit_histories: for annotation_hit_history in annotation_hit_histories: db.session.delete(annotation_hit_history) db.session.commit() # if annotation reply is enabled , delete annotation index - app_annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id - ).first() + app_annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if app_annotation_setting: - delete_annotation_index_task.delay(annotation.id, app_id, - current_user.current_tenant_id, - app_annotation_setting.collection_binding_id) + delete_annotation_index_task.delay( + annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id + ) @classmethod def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict: # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -278,10 +287,7 @@ class AppAnnotationService: df = pd.read_csv(file) result = [] for index, row in df.iterrows(): - content = { - 'question': row[0], - 'answer': row[1] - } + content = {"question": row[0], "answer": row[1]} result.append(content) if len(result) == 0: raise ValueError("The CSV file is empty.") @@ -293,28 +299,24 @@ class AppAnnotationService: raise ValueError("The number of annotations exceeds the limit of your subscription.") # async job job_id = str(uuid.uuid4()) - indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(indexing_cache_key, 'waiting') - batch_import_annotations_task.delay(str(job_id), result, app_id, - current_user.current_tenant_id, current_user.id) + redis_client.setnx(indexing_cache_key, "waiting") + batch_import_annotations_task.delay( + str(job_id), result, app_id, current_user.current_tenant_id, current_user.id + ) except Exception as e: - return { - 'error_msg': str(e) - } - return { - 'job_id': job_id, - 'job_status': 'waiting' - } + return {"error_msg": str(e)} + return {"job_id": job_id, "job_status": "waiting"} @classmethod def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") @@ -324,12 +326,15 @@ class AppAnnotationService: if not annotation: raise NotFound("Annotation not found") - annotation_hit_histories = (db.session.query(AppAnnotationHitHistory) - .filter(AppAnnotationHitHistory.app_id == app_id, - AppAnnotationHitHistory.annotation_id == annotation_id, - ) - .order_by(AppAnnotationHitHistory.created_at.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False)) + annotation_hit_histories = ( + db.session.query(AppAnnotationHitHistory) + .filter( + AppAnnotationHitHistory.app_id == app_id, + AppAnnotationHitHistory.annotation_id == annotation_id, + ) + .order_by(AppAnnotationHitHistory.created_at.desc()) + .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + ) return annotation_hit_histories.items, annotation_hit_histories.total @classmethod @@ -341,15 +346,21 @@ class AppAnnotationService: return annotation @classmethod - def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_question: str, - annotation_content: str, query: str, user_id: str, - message_id: str, from_source: str, score: float): + def add_annotation_history( + cls, + annotation_id: str, + app_id: str, + annotation_question: str, + annotation_content: str, + query: str, + user_id: str, + message_id: str, + from_source: str, + score: float, + ): # add hit count to annotation - db.session.query(MessageAnnotation).filter( - MessageAnnotation.id == annotation_id - ).update( - {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, - synchronize_session=False + db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).update( + {MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False ) annotation_hit_history = AppAnnotationHitHistory( @@ -361,7 +372,7 @@ class AppAnnotationService: score=score, message_id=message_id, annotation_question=annotation_question, - annotation_content=annotation_content + annotation_content=annotation_content, ) db.session.add(annotation_hit_history) db.session.commit() @@ -369,17 +380,18 @@ class AppAnnotationService: @classmethod def get_app_annotation_setting_by_app_id(cls, app_id: str): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first() + ) if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail return { @@ -388,32 +400,34 @@ class AppAnnotationService: "score_threshold": annotation_setting.score_threshold, "embedding_model": { "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name - } + "embedding_model_name": collection_binding_detail.model_name, + }, } - return { - "enabled": False - } + return {"enabled": False} @classmethod def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict): # get app info - app = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app: raise NotFound("App not found") - annotation_setting = db.session.query(AppAnnotationSetting).filter( - AppAnnotationSetting.app_id == app_id, - AppAnnotationSetting.id == annotation_setting_id, - ).first() + annotation_setting = ( + db.session.query(AppAnnotationSetting) + .filter( + AppAnnotationSetting.app_id == app_id, + AppAnnotationSetting.id == annotation_setting_id, + ) + .first() + ) if not annotation_setting: raise NotFound("App annotation not found") - annotation_setting.score_threshold = args['score_threshold'] + annotation_setting.score_threshold = args["score_threshold"] annotation_setting.updated_user_id = current_user.id annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(annotation_setting) @@ -427,6 +441,6 @@ class AppAnnotationService: "score_threshold": annotation_setting.score_threshold, "embedding_model": { "embedding_provider_name": collection_binding_detail.provider_name, - "embedding_model_name": collection_binding_detail.model_name - } + "embedding_model_name": collection_binding_detail.model_name, + }, } diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py index 8441bbedb..601d67d2f 100644 --- a/api/services/api_based_extension_service.py +++ b/api/services/api_based_extension_service.py @@ -5,13 +5,14 @@ from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint class APIBasedExtensionService: - @staticmethod def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]: - extension_list = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=tenant_id) \ - .order_by(APIBasedExtension.created_at.desc()) \ - .all() + extension_list = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=tenant_id) + .order_by(APIBasedExtension.created_at.desc()) + .all() + ) for extension in extension_list: extension.api_key = decrypt_token(extension.tenant_id, extension.api_key) @@ -35,10 +36,12 @@ class APIBasedExtensionService: @staticmethod def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: - extension = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=tenant_id) \ - .filter_by(id=api_based_extension_id) \ + extension = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=tenant_id) + .filter_by(id=api_based_extension_id) .first() + ) if not extension: raise ValueError("API based extension is not found") @@ -55,20 +58,24 @@ class APIBasedExtensionService: if not extension_data.id: # case one: check new data, name must be unique - is_name_existed = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=extension_data.tenant_id) \ - .filter_by(name=extension_data.name) \ + is_name_existed = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=extension_data.tenant_id) + .filter_by(name=extension_data.name) .first() + ) if is_name_existed: raise ValueError("name must be unique, it is already existed") else: # case two: check existing data, name must be unique - is_name_existed = db.session.query(APIBasedExtension) \ - .filter_by(tenant_id=extension_data.tenant_id) \ - .filter_by(name=extension_data.name) \ - .filter(APIBasedExtension.id != extension_data.id) \ + is_name_existed = ( + db.session.query(APIBasedExtension) + .filter_by(tenant_id=extension_data.tenant_id) + .filter_by(name=extension_data.name) + .filter(APIBasedExtension.id != extension_data.id) .first() + ) if is_name_existed: raise ValueError("name must be unique, it is already existed") @@ -92,7 +99,7 @@ class APIBasedExtensionService: try: client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key) resp = client.request(point=APIBasedExtensionPoint.PING, params={}) - if resp.get('result') != 'pong': + if resp.get("result") != "pong": raise ValueError(resp) except Exception as e: raise ValueError("connection error: {}".format(e)) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 737def336..a938b4f93 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -75,43 +75,44 @@ class AppDslService: # check or repair dsl version import_data = cls._check_or_fix_dsl(import_data) - app_data = import_data.get('app') + app_data = import_data.get("app") if not app_data: raise ValueError("Missing app in data argument") # get app basic info - name = args.get("name") if args.get("name") else app_data.get('name') - description = args.get("description") if args.get("description") else app_data.get('description', '') - icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get('icon_type') - icon = args.get("icon") if args.get("icon") else app_data.get('icon') - icon_background = args.get("icon_background") if args.get("icon_background") \ - else app_data.get('icon_background') + name = args.get("name") if args.get("name") else app_data.get("name") + description = args.get("description") if args.get("description") else app_data.get("description", "") + icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get("icon_type") + icon = args.get("icon") if args.get("icon") else app_data.get("icon") + icon_background = ( + args.get("icon_background") if args.get("icon_background") else app_data.get("icon_background") + ) # import dsl and create app - app_mode = AppMode.value_of(app_data.get('mode')) + app_mode = AppMode.value_of(app_data.get("mode")) if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: app = cls._import_and_create_new_workflow_based_app( tenant_id=tenant_id, app_mode=app_mode, - workflow_data=import_data.get('workflow'), + workflow_data=import_data.get("workflow"), account=account, name=name, description=description, icon_type=icon_type, icon=icon, - icon_background=icon_background + icon_background=icon_background, ) elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]: app = cls._import_and_create_new_model_config_based_app( tenant_id=tenant_id, app_mode=app_mode, - model_config_data=import_data.get('model_config'), + model_config_data=import_data.get("model_config"), account=account, name=name, description=description, icon_type=icon_type, icon=icon, - icon_background=icon_background + icon_background=icon_background, ) else: raise ValueError("Invalid app mode") @@ -134,27 +135,26 @@ class AppDslService: # check or repair dsl version import_data = cls._check_or_fix_dsl(import_data) - app_data = import_data.get('app') + app_data = import_data.get("app") if not app_data: raise ValueError("Missing app in data argument") # import dsl and overwrite app - app_mode = AppMode.value_of(app_data.get('mode')) + app_mode = AppMode.value_of(app_data.get("mode")) if app_mode not in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: raise ValueError("Only support import workflow in advanced-chat or workflow app.") - if app_data.get('mode') != app_model.mode: - raise ValueError( - f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}") + if app_data.get("mode") != app_model.mode: + raise ValueError(f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}") return cls._import_and_overwrite_workflow_based_app( app_model=app_model, - workflow_data=import_data.get('workflow'), + workflow_data=import_data.get("workflow"), account=account, ) @classmethod - def export_dsl(cls, app_model: App, include_secret:bool = False) -> str: + def export_dsl(cls, app_model: App, include_secret: bool = False) -> str: """ Export app :param app_model: App instance @@ -168,14 +168,16 @@ class AppDslService: "app": { "name": app_model.name, "mode": app_model.mode, - "icon": '🤖' if app_model.icon_type == 'image' else app_model.icon, - "icon_background": '#FFEAD5' if app_model.icon_type == 'image' else app_model.icon_background, - "description": app_model.description - } + "icon": "🤖" if app_model.icon_type == "image" else app_model.icon, + "icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background, + "description": app_model.description, + }, } if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: - cls._append_workflow_export_data(export_data=export_data, app_model=app_model, include_secret=include_secret) + cls._append_workflow_export_data( + export_data=export_data, app_model=app_model, include_secret=include_secret + ) else: cls._append_model_config_export_data(export_data, app_model) @@ -188,31 +190,35 @@ class AppDslService: :param import_data: import data """ - if not import_data.get('version'): - import_data['version'] = "0.1.0" + if not import_data.get("version"): + import_data["version"] = "0.1.0" - if not import_data.get('kind') or import_data.get('kind') != "app": - import_data['kind'] = "app" + if not import_data.get("kind") or import_data.get("kind") != "app": + import_data["kind"] = "app" - if import_data.get('version') != current_dsl_version: + if import_data.get("version") != current_dsl_version: # Currently only one DSL version, so no difference checks or compatibility fixes will be performed. - logger.warning(f"DSL version {import_data.get('version')} is not compatible " - f"with current version {current_dsl_version}, related to " - f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}.") + logger.warning( + f"DSL version {import_data.get('version')} is not compatible " + f"with current version {current_dsl_version}, related to " + f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}." + ) return import_data @classmethod - def _import_and_create_new_workflow_based_app(cls, - tenant_id: str, - app_mode: AppMode, - workflow_data: dict, - account: Account, - name: str, - description: str, - icon_type: str, - icon: str, - icon_background: str) -> App: + def _import_and_create_new_workflow_based_app( + cls, + tenant_id: str, + app_mode: AppMode, + workflow_data: dict, + account: Account, + name: str, + description: str, + icon_type: str, + icon: str, + icon_background: str, + ) -> App: """ Import app dsl and create new workflow based app @@ -227,8 +233,7 @@ class AppDslService: :param icon_background: app icon background """ if not workflow_data: - raise ValueError("Missing workflow in data argument " - "when app mode is advanced-chat or workflow") + raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow") app = cls._create_app( tenant_id=tenant_id, @@ -238,37 +243,32 @@ class AppDslService: description=description, icon_type=icon_type, icon=icon, - icon_background=icon_background + icon_background=icon_background, ) # init draft workflow - environment_variables_list = workflow_data.get('environment_variables') or [] + environment_variables_list = workflow_data.get("environment_variables") or [] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] - conversation_variables_list = workflow_data.get('conversation_variables') or [] + conversation_variables_list = workflow_data.get("conversation_variables") or [] conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] workflow_service = WorkflowService() draft_workflow = workflow_service.sync_draft_workflow( app_model=app, - graph=workflow_data.get('graph', {}), - features=workflow_data.get('../core/app/features', {}), + graph=workflow_data.get("graph", {}), + features=workflow_data.get("../core/app/features", {}), unique_hash=None, account=account, environment_variables=environment_variables, conversation_variables=conversation_variables, ) - workflow_service.publish_workflow( - app_model=app, - account=account, - draft_workflow=draft_workflow - ) + workflow_service.publish_workflow(app_model=app, account=account, draft_workflow=draft_workflow) return app @classmethod - def _import_and_overwrite_workflow_based_app(cls, - app_model: App, - workflow_data: dict, - account: Account) -> Workflow: + def _import_and_overwrite_workflow_based_app( + cls, app_model: App, workflow_data: dict, account: Account + ) -> Workflow: """ Import app dsl and overwrite workflow based app @@ -277,8 +277,7 @@ class AppDslService: :param account: Account instance """ if not workflow_data: - raise ValueError("Missing workflow in data argument " - "when app mode is advanced-chat or workflow") + raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow") # fetch draft workflow by app_model workflow_service = WorkflowService() @@ -289,14 +288,14 @@ class AppDslService: unique_hash = None # sync draft workflow - environment_variables_list = workflow_data.get('environment_variables') or [] + environment_variables_list = workflow_data.get("environment_variables") or [] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] - conversation_variables_list = workflow_data.get('conversation_variables') or [] + conversation_variables_list = workflow_data.get("conversation_variables") or [] conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] draft_workflow = workflow_service.sync_draft_workflow( app_model=app_model, - graph=workflow_data.get('graph', {}), - features=workflow_data.get('features', {}), + graph=workflow_data.get("graph", {}), + features=workflow_data.get("features", {}), unique_hash=unique_hash, account=account, environment_variables=environment_variables, @@ -306,16 +305,18 @@ class AppDslService: return draft_workflow @classmethod - def _import_and_create_new_model_config_based_app(cls, - tenant_id: str, - app_mode: AppMode, - model_config_data: dict, - account: Account, - name: str, - description: str, - icon_type: str, - icon: str, - icon_background: str) -> App: + def _import_and_create_new_model_config_based_app( + cls, + tenant_id: str, + app_mode: AppMode, + model_config_data: dict, + account: Account, + name: str, + description: str, + icon_type: str, + icon: str, + icon_background: str, + ) -> App: """ Import app dsl and create new model config based app @@ -329,8 +330,7 @@ class AppDslService: :param icon_background: app icon background """ if not model_config_data: - raise ValueError("Missing model_config in data argument " - "when app mode is chat, agent-chat or completion") + raise ValueError("Missing model_config in data argument " "when app mode is chat, agent-chat or completion") app = cls._create_app( tenant_id=tenant_id, @@ -340,7 +340,7 @@ class AppDslService: description=description, icon_type=icon_type, icon=icon, - icon_background=icon_background + icon_background=icon_background, ) app_model_config = AppModelConfig() @@ -352,23 +352,22 @@ class AppDslService: app.app_model_config_id = app_model_config.id - app_model_config_was_updated.send( - app, - app_model_config=app_model_config - ) + app_model_config_was_updated.send(app, app_model_config=app_model_config) return app @classmethod - def _create_app(cls, - tenant_id: str, - app_mode: AppMode, - account: Account, - name: str, - description: str, - icon_type: str, - icon: str, - icon_background: str) -> App: + def _create_app( + cls, + tenant_id: str, + app_mode: AppMode, + account: Account, + name: str, + description: str, + icon_type: str, + icon: str, + icon_background: str, + ) -> App: """ Create new app @@ -390,7 +389,7 @@ class AppDslService: icon=icon, icon_background=icon_background, enable_site=True, - enable_api=True + enable_api=True, ) db.session.add(app) @@ -412,7 +411,7 @@ class AppDslService: if not workflow: raise ValueError("Missing draft workflow configuration, please check.") - export_data['workflow'] = workflow.to_dict(include_secret=include_secret) + export_data["workflow"] = workflow.to_dict(include_secret=include_secret) @classmethod def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None: @@ -425,4 +424,4 @@ class AppDslService: if not app_model_config: raise ValueError("Missing app configuration, please check.") - export_data['model_config'] = app_model_config.to_dict() + export_data["model_config"] = app_model_config.to_dict() diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index cff4ba8af..34fce4630 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -14,14 +14,15 @@ from services.workflow_service import WorkflowService class AppGenerateService: - @classmethod - def generate(cls, app_model: App, - user: Union[Account, EndUser], - args: Any, - invoke_from: InvokeFrom, - streaming: bool = True, - ): + def generate( + cls, + app_model: App, + user: Union[Account, EndUser], + args: Any, + invoke_from: InvokeFrom, + streaming: bool = True, + ): """ App Content Generate :param app_model: app model @@ -37,51 +38,54 @@ class AppGenerateService: try: request_id = rate_limit.enter(request_id) if app_model.mode == AppMode.COMPLETION.value: - return rate_limit.generate(CompletionAppGenerator().generate( - app_model=app_model, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + CompletionAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming + ), + request_id, + ) elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: - return rate_limit.generate(AgentChatAppGenerator().generate( - app_model=app_model, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + AgentChatAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming + ), + request_id, + ) elif app_model.mode == AppMode.CHAT.value: - return rate_limit.generate(ChatAppGenerator().generate( - app_model=app_model, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + ChatAppGenerator().generate( + app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming + ), + request_id, + ) elif app_model.mode == AppMode.ADVANCED_CHAT.value: workflow = cls._get_workflow(app_model, invoke_from) - return rate_limit.generate(AdvancedChatAppGenerator().generate( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + AdvancedChatAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming, + ), + request_id, + ) elif app_model.mode == AppMode.WORKFLOW.value: workflow = cls._get_workflow(app_model, invoke_from) - return rate_limit.generate(WorkflowAppGenerator().generate( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming - ), request_id) + return rate_limit.generate( + WorkflowAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + stream=streaming, + ), + request_id, + ) else: - raise ValueError(f'Invalid app mode {app_model.mode}') + raise ValueError(f"Invalid app mode {app_model.mode}") finally: if not streaming: rate_limit.exit(request_id) @@ -94,38 +98,31 @@ class AppGenerateService: return max_active_requests @classmethod - def generate_single_iteration(cls, app_model: App, - user: Union[Account, EndUser], - node_id: str, - args: Any, - streaming: bool = True): + def generate_single_iteration( + cls, app_model: App, user: Union[Account, EndUser], node_id: str, args: Any, streaming: bool = True + ): if app_model.mode == AppMode.ADVANCED_CHAT.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator().single_iteration_generate( - app_model=app_model, - workflow=workflow, - node_id=node_id, - user=user, - args=args, - stream=streaming + app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming ) elif app_model.mode == AppMode.WORKFLOW.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return WorkflowAppGenerator().single_iteration_generate( - app_model=app_model, - workflow=workflow, - node_id=node_id, - user=user, - args=args, - stream=streaming + app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming ) else: - raise ValueError(f'Invalid app mode {app_model.mode}') + raise ValueError(f"Invalid app mode {app_model.mode}") @classmethod - def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser], - message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \ - -> Union[dict, Generator]: + def generate_more_like_this( + cls, + app_model: App, + user: Union[Account, EndUser], + message_id: str, + invoke_from: InvokeFrom, + streaming: bool = True, + ) -> Union[dict, Generator]: """ Generate more like this :param app_model: app model @@ -136,11 +133,7 @@ class AppGenerateService: :return: """ return CompletionAppGenerator().generate_more_like_this( - app_model=app_model, - message_id=message_id, - user=user, - invoke_from=invoke_from, - stream=streaming + app_model=app_model, message_id=message_id, user=user, invoke_from=invoke_from, stream=streaming ) @classmethod @@ -157,12 +150,12 @@ class AppGenerateService: workflow = workflow_service.get_draft_workflow(app_model=app_model) if not workflow: - raise ValueError('Workflow not initialized') + raise ValueError("Workflow not initialized") else: # fetch published workflow by app_model workflow = workflow_service.get_published_workflow(app_model=app_model) if not workflow: - raise ValueError('Workflow not published') + raise ValueError("Workflow not published") return workflow diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index c84f6fbf4..a1ad27105 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -5,7 +5,6 @@ from models.model import AppMode class AppModelConfigService: - @classmethod def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict: if app_mode == AppMode.CHAT: diff --git a/api/services/app_service.py b/api/services/app_service.py index 93f7169c1..8a2f8c053 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -33,27 +33,22 @@ class AppService: :param args: request args :return: """ - filters = [ - App.tenant_id == tenant_id, - App.is_universal == False - ] + filters = [App.tenant_id == tenant_id, App.is_universal == False] - if args['mode'] == 'workflow': + if args["mode"] == "workflow": filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value])) - elif args['mode'] == 'chat': + elif args["mode"] == "chat": filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value])) - elif args['mode'] == 'agent-chat': + elif args["mode"] == "agent-chat": filters.append(App.mode == AppMode.AGENT_CHAT.value) - elif args['mode'] == 'channel': + elif args["mode"] == "channel": filters.append(App.mode == AppMode.CHANNEL.value) - if args.get('name'): - name = args['name'][:30] - filters.append(App.name.ilike(f'%{name}%')) - if args.get('tag_ids'): - target_ids = TagService.get_target_ids_by_tag_ids('app', - tenant_id, - args['tag_ids']) + if args.get("name"): + name = args["name"][:30] + filters.append(App.name.ilike(f"%{name}%")) + if args.get("tag_ids"): + target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"]) if target_ids: filters.append(App.id.in_(target_ids)) else: @@ -61,9 +56,9 @@ class AppService: app_models = db.paginate( db.select(App).where(*filters).order_by(App.created_at.desc()), - page=args['page'], - per_page=args['limit'], - error_out=False + page=args["page"], + per_page=args["limit"], + error_out=False, ) return app_models @@ -75,21 +70,20 @@ class AppService: :param args: request args :param account: Account instance """ - app_mode = AppMode.value_of(args['mode']) + app_mode = AppMode.value_of(args["mode"]) app_template = default_app_templates[app_mode] # get model config - default_model_config = app_template.get('model_config') + default_model_config = app_template.get("model_config") default_model_config = default_model_config.copy() if default_model_config else None - if default_model_config and 'model' in default_model_config: + if default_model_config and "model" in default_model_config: # get model provider model_manager = ModelManager() # get default model instance try: model_instance = model_manager.get_default_model_instance( - tenant_id=account.current_tenant_id, - model_type=ModelType.LLM + tenant_id=account.current_tenant_id, model_type=ModelType.LLM ) except (ProviderTokenNotInitError, LLMBadRequestError): model_instance = None @@ -98,39 +92,41 @@ class AppService: model_instance = None if model_instance: - if model_instance.model == default_model_config['model']['name'] and model_instance.provider == default_model_config['model']['provider']: - default_model_dict = default_model_config['model'] + if ( + model_instance.model == default_model_config["model"]["name"] + and model_instance.provider == default_model_config["model"]["provider"] + ): + default_model_dict = default_model_config["model"] else: llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) default_model_dict = { - 'provider': model_instance.provider, - 'name': model_instance.model, - 'mode': model_schema.model_properties.get(ModelPropertyKey.MODE), - 'completion_params': {} + "provider": model_instance.provider, + "name": model_instance.model, + "mode": model_schema.model_properties.get(ModelPropertyKey.MODE), + "completion_params": {}, } else: provider, model = model_manager.get_default_provider_model_name( - tenant_id=account.current_tenant_id, - model_type=ModelType.LLM + tenant_id=account.current_tenant_id, model_type=ModelType.LLM ) - default_model_config['model']['provider'] = provider - default_model_config['model']['name'] = model - default_model_dict = default_model_config['model'] + default_model_config["model"]["provider"] = provider + default_model_config["model"]["name"] = model + default_model_dict = default_model_config["model"] - default_model_config['model'] = json.dumps(default_model_dict) + default_model_config["model"] = json.dumps(default_model_dict) - app = App(**app_template['app']) - app.name = args['name'] - app.description = args.get('description', '') - app.mode = args['mode'] - app.icon_type = args.get('icon_type', 'emoji') - app.icon = args['icon'] - app.icon_background = args['icon_background'] + app = App(**app_template["app"]) + app.name = args["name"] + app.description = args.get("description", "") + app.mode = args["mode"] + app.icon_type = args.get("icon_type", "emoji") + app.icon = args["icon"] + app.icon_background = args["icon_background"] app.tenant_id = tenant_id - app.api_rph = args.get('api_rph', 0) - app.api_rpm = args.get('api_rpm', 0) + app.api_rph = args.get("api_rph", 0) + app.api_rpm = args.get("api_rpm", 0) db.session.add(app) db.session.flush() @@ -158,7 +154,7 @@ class AppService: model_config: AppModelConfig = app.app_model_config agent_mode = model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input - for tool in agent_mode.get('tools') or []: + for tool in agent_mode.get("tools") or []: if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue agent_tool_entity = AgentToolEntity(**tool) @@ -174,7 +170,7 @@ class AppService: tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, - identity_id=f'AGENT.{app.id}' + identity_id=f"AGENT.{app.id}", ) # get decrypted parameters @@ -185,7 +181,7 @@ class AppService: masked_parameter = {} # override tool parameters - tool['tool_parameters'] = masked_parameter + tool["tool_parameters"] = masked_parameter except Exception as e: pass @@ -215,12 +211,12 @@ class AppService: :param args: request args :return: App instance """ - app.name = args.get('name') - app.description = args.get('description', '') - app.max_active_requests = args.get('max_active_requests') - app.icon_type = args.get('icon_type', 'emoji') - app.icon = args.get('icon') - app.icon_background = args.get('icon_background') + app.name = args.get("name") + app.description = args.get("description", "") + app.max_active_requests = args.get("max_active_requests") + app.icon_type = args.get("icon_type", "emoji") + app.icon = args.get("icon") + app.icon_background = args.get("icon_background") app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() @@ -298,10 +294,7 @@ class AppService: db.session.commit() # Trigger asynchronous deletion of app and related data - remove_app_and_related_data_task.delay( - tenant_id=app.tenant_id, - app_id=app.id - ) + remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id) def get_app_meta(self, app_model: App) -> dict: """ @@ -311,9 +304,7 @@ class AppService: """ app_mode = AppMode.value_of(app_model.mode) - meta = { - 'tool_icons': {} - } + meta = {"tool_icons": {}} if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]: workflow = app_model.workflow @@ -321,17 +312,19 @@ class AppService: return meta graph = workflow.graph_dict - nodes = graph.get('nodes', []) + nodes = graph.get("nodes", []) tools = [] for node in nodes: - if node.get('data', {}).get('type') == 'tool': - node_data = node.get('data', {}) - tools.append({ - 'provider_type': node_data.get('provider_type'), - 'provider_id': node_data.get('provider_id'), - 'tool_name': node_data.get('tool_name'), - 'tool_parameters': {} - }) + if node.get("data", {}).get("type") == "tool": + node_data = node.get("data", {}) + tools.append( + { + "provider_type": node_data.get("provider_type"), + "provider_id": node_data.get("provider_id"), + "tool_name": node_data.get("tool_name"), + "tool_parameters": {}, + } + ) else: app_model_config: AppModelConfig = app_model.app_model_config @@ -341,30 +334,26 @@ class AppService: agent_config = app_model_config.agent_mode_dict or {} # get all tools - tools = agent_config.get('tools', []) + tools = agent_config.get("tools", []) - url_prefix = (dify_config.CONSOLE_API_URL - + "/console/api/workspaces/current/tool-provider/builtin/") + url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/" for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: # current tool standard - provider_type = tool.get('provider_type') - provider_id = tool.get('provider_id') - tool_name = tool.get('tool_name') - if provider_type == 'builtin': - meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon' - elif provider_type == 'api': + provider_type = tool.get("provider_type") + provider_id = tool.get("provider_id") + tool_name = tool.get("tool_name") + if provider_type == "builtin": + meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" + elif provider_type == "api": try: - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.id == provider_id - ).first() - meta['tool_icons'][tool_name] = json.loads(provider.icon) + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first() + ) + meta["tool_icons"][tool_name] = json.loads(provider.icon) except: - meta['tool_icons'][tool_name] = { - "background": "#252525", - "content": "\ud83d\ude01" - } + meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"} return meta diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 58c950816..05cd1c96a 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -17,7 +17,7 @@ from services.errors.audio import ( FILE_SIZE = 30 FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024 -ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm', 'amr'] +ALLOWED_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm", "amr"] logger = logging.getLogger(__name__) @@ -31,19 +31,19 @@ class AudioService: raise ValueError("Speech to text is not enabled") features_dict = workflow.features_dict - if 'speech_to_text' not in features_dict or not features_dict['speech_to_text'].get('enabled'): + if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"): raise ValueError("Speech to text is not enabled") else: app_model_config: AppModelConfig = app_model.app_model_config - if not app_model_config.speech_to_text_dict['enabled']: + if not app_model_config.speech_to_text_dict["enabled"]: raise ValueError("Speech to text is not enabled") if file is None: raise NoAudioUploadedServiceError() extension = file.mimetype - if extension not in [f'audio/{ext}' for ext in ALLOWED_EXTENSIONS]: + if extension not in [f"audio/{ext}" for ext in ALLOWED_EXTENSIONS]: raise UnsupportedAudioTypeServiceError() file_content = file.read() @@ -55,20 +55,25 @@ class AudioService: model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, - model_type=ModelType.SPEECH2TEXT + tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT ) if model_instance is None: raise ProviderNotSupportSpeechToTextServiceError() buffer = io.BytesIO(file_content) - buffer.name = 'temp.mp3' + buffer.name = "temp.mp3" return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} @classmethod - def transcript_tts(cls, app_model: App, text: Optional[str] = None, - voice: Optional[str] = None, end_user: Optional[str] = None, message_id: Optional[str] = None): + def transcript_tts( + cls, + app_model: App, + text: Optional[str] = None, + voice: Optional[str] = None, + end_user: Optional[str] = None, + message_id: Optional[str] = None, + ): from collections.abc import Generator from flask import Response, stream_with_context @@ -84,65 +89,56 @@ class AudioService: raise ValueError("TTS is not enabled") features_dict = workflow.features_dict - if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'): + if "text_to_speech" not in features_dict or not features_dict["text_to_speech"].get("enabled"): raise ValueError("TTS is not enabled") - voice = features_dict['text_to_speech'].get('voice') if voice is None else voice + voice = features_dict["text_to_speech"].get("voice") if voice is None else voice else: text_to_speech_dict = app_model.app_model_config.text_to_speech_dict - if not text_to_speech_dict.get('enabled'): + if not text_to_speech_dict.get("enabled"): raise ValueError("TTS is not enabled") - voice = text_to_speech_dict.get('voice') if voice is None else voice + voice = text_to_speech_dict.get("voice") if voice is None else voice model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, - model_type=ModelType.TTS + tenant_id=app_model.tenant_id, model_type=ModelType.TTS ) try: if not voice: voices = model_instance.get_tts_voices() if voices: - voice = voices[0].get('value') + voice = voices[0].get("value") else: raise ValueError("Sorry, no voice available.") return model_instance.invoke_tts( - content_text=text_content.strip(), - user=end_user, - tenant_id=app_model.tenant_id, - voice=voice + content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice ) except Exception as e: raise e if message_id: - message = db.session.query(Message).filter( - Message.id == message_id - ).first() - if message.answer == '' and message.status == 'normal': + message = db.session.query(Message).filter(Message.id == message_id).first() + if message.answer == "" and message.status == "normal": return None else: response = invoke_tts(message.answer, app_model=app_model, voice=voice) if isinstance(response, Generator): - return Response(stream_with_context(response), content_type='audio/mpeg') + return Response(stream_with_context(response), content_type="audio/mpeg") return response else: response = invoke_tts(text, app_model, voice) if isinstance(response, Generator): - return Response(stream_with_context(response), content_type='audio/mpeg') + return Response(stream_with_context(response), content_type="audio/mpeg") return response @classmethod def transcript_tts_voices(cls, tenant_id: str, language: str): model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.TTS - ) + model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS) if model_instance is None: raise ProviderNotSupportTextToSpeechServiceError() diff --git a/api/services/auth/api_key_auth_factory.py b/api/services/auth/api_key_auth_factory.py index ccd0023c4..ae5b953b4 100644 --- a/api/services/auth/api_key_auth_factory.py +++ b/api/services/auth/api_key_auth_factory.py @@ -1,14 +1,12 @@ - from services.auth.firecrawl import FirecrawlAuth class ApiKeyAuthFactory: - def __init__(self, provider: str, credentials: dict): - if provider == 'firecrawl': + if provider == "firecrawl": self.auth = FirecrawlAuth(credentials) else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") def validate_credentials(self): return self.auth.validate_credentials() diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index 43d0fbf98..e5f4a3ef6 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -7,39 +7,43 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory class ApiKeyAuthService: - @staticmethod def get_provider_auth_list(tenant_id: str) -> list: - data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter( - DataSourceApiKeyAuthBinding.tenant_id == tenant_id, - DataSourceApiKeyAuthBinding.disabled.is_(False) - ).all() + data_source_api_key_bindings = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)) + .all() + ) return data_source_api_key_bindings @staticmethod def create_provider_auth(tenant_id: str, args: dict): - auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials() + auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials() if auth_result: # Encrypt the api key - api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key']) - args['credentials']['config']['api_key'] = api_key + api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"]) + args["credentials"]["config"]["api_key"] = api_key data_source_api_key_binding = DataSourceApiKeyAuthBinding() data_source_api_key_binding.tenant_id = tenant_id - data_source_api_key_binding.category = args['category'] - data_source_api_key_binding.provider = args['provider'] - data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False) + data_source_api_key_binding.category = args["category"] + data_source_api_key_binding.provider = args["provider"] + data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False) db.session.add(data_source_api_key_binding) db.session.commit() @staticmethod def get_auth_credentials(tenant_id: str, category: str, provider: str): - data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter( - DataSourceApiKeyAuthBinding.tenant_id == tenant_id, - DataSourceApiKeyAuthBinding.category == category, - DataSourceApiKeyAuthBinding.provider == provider, - DataSourceApiKeyAuthBinding.disabled.is_(False) - ).first() + data_source_api_key_bindings = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter( + DataSourceApiKeyAuthBinding.tenant_id == tenant_id, + DataSourceApiKeyAuthBinding.category == category, + DataSourceApiKeyAuthBinding.provider == provider, + DataSourceApiKeyAuthBinding.disabled.is_(False), + ) + .first() + ) if not data_source_api_key_bindings: return None credentials = json.loads(data_source_api_key_bindings.credentials) @@ -47,24 +51,24 @@ class ApiKeyAuthService: @staticmethod def delete_provider_auth(tenant_id: str, binding_id: str): - data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter( - DataSourceApiKeyAuthBinding.tenant_id == tenant_id, - DataSourceApiKeyAuthBinding.id == binding_id - ).first() + data_source_api_key_binding = ( + db.session.query(DataSourceApiKeyAuthBinding) + .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id) + .first() + ) if data_source_api_key_binding: db.session.delete(data_source_api_key_binding) db.session.commit() @classmethod def validate_api_key_auth_args(cls, args): - if 'category' not in args or not args['category']: - raise ValueError('category is required') - if 'provider' not in args or not args['provider']: - raise ValueError('provider is required') - if 'credentials' not in args or not args['credentials']: - raise ValueError('credentials is required') - if not isinstance(args['credentials'], dict): - raise ValueError('credentials must be a dictionary') - if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']: - raise ValueError('auth_type is required') - + if "category" not in args or not args["category"]: + raise ValueError("category is required") + if "provider" not in args or not args["provider"]: + raise ValueError("provider is required") + if "credentials" not in args or not args["credentials"]: + raise ValueError("credentials is required") + if not isinstance(args["credentials"], dict): + raise ValueError("credentials must be a dictionary") + if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]: + raise ValueError("auth_type is required") diff --git a/api/services/auth/firecrawl.py b/api/services/auth/firecrawl.py index 69e3fb43c..30e4ee57c 100644 --- a/api/services/auth/firecrawl.py +++ b/api/services/auth/firecrawl.py @@ -8,49 +8,40 @@ from services.auth.api_key_auth_base import ApiKeyAuthBase class FirecrawlAuth(ApiKeyAuthBase): def __init__(self, credentials: dict): super().__init__(credentials) - auth_type = credentials.get('auth_type') - if auth_type != 'bearer': - raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer') - self.api_key = credentials.get('config').get('api_key', None) - self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev') + auth_type = credentials.get("auth_type") + if auth_type != "bearer": + raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer") + self.api_key = credentials.get("config").get("api_key", None) + self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev") if not self.api_key: - raise ValueError('No API key provided') + raise ValueError("No API key provided") def validate_credentials(self): headers = self._prepare_headers() options = { - 'url': 'https://example.com', - 'crawlerOptions': { - 'excludes': [], - 'includes': [], - 'limit': 1 - }, - 'pageOptions': { - 'onlyMainContent': True - } + "url": "https://example.com", + "crawlerOptions": {"excludes": [], "includes": [], "limit": 1}, + "pageOptions": {"onlyMainContent": True}, } - response = self._post_request(f'{self.base_url}/v0/crawl', options, headers) + response = self._post_request(f"{self.base_url}/v0/crawl", options, headers) if response.status_code == 200: return True else: self._handle_error(response) def _prepare_headers(self): - return { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}' - } + return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _post_request(self, url, data, headers): return requests.post(url, headers=headers, json=data) def _handle_error(self, response): if response.status_code in [402, 409, 500]: - error_message = response.json().get('error', 'Unknown error occurred') - raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}') + error_message = response.json().get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") else: if response.text: - error_message = json.loads(response.text).get('error', 'Unknown error occurred') - raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}') - raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}') + error_message = json.loads(response.text).get("error", "Unknown error occurred") + raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}") + raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}") diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 539f2712b..911d23464 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -7,58 +7,40 @@ from models.account import TenantAccountJoin, TenantAccountRole class BillingService: - base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL') - secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY') + base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") + secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") @classmethod def get_info(cls, tenant_id: str): - params = {'tenant_id': tenant_id} + params = {"tenant_id": tenant_id} - billing_info = cls._send_request('GET', '/subscription/info', params=params) + billing_info = cls._send_request("GET", "/subscription/info", params=params) return billing_info @classmethod - def get_subscription(cls, plan: str, - interval: str, - prefilled_email: str = '', - tenant_id: str = ''): - params = { - 'plan': plan, - 'interval': interval, - 'prefilled_email': prefilled_email, - 'tenant_id': tenant_id - } - return cls._send_request('GET', '/subscription/payment-link', params=params) + def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""): + params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id} + return cls._send_request("GET", "/subscription/payment-link", params=params) @classmethod - def get_model_provider_payment_link(cls, - provider_name: str, - tenant_id: str, - account_id: str, - prefilled_email: str): + def get_model_provider_payment_link(cls, provider_name: str, tenant_id: str, account_id: str, prefilled_email: str): params = { - 'provider_name': provider_name, - 'tenant_id': tenant_id, - 'account_id': account_id, - 'prefilled_email': prefilled_email + "provider_name": provider_name, + "tenant_id": tenant_id, + "account_id": account_id, + "prefilled_email": prefilled_email, } - return cls._send_request('GET', '/model-provider/payment-link', params=params) + return cls._send_request("GET", "/model-provider/payment-link", params=params) @classmethod - def get_invoices(cls, prefilled_email: str = '', tenant_id: str = ''): - params = { - 'prefilled_email': prefilled_email, - 'tenant_id': tenant_id - } - return cls._send_request('GET', '/invoices', params=params) + def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""): + params = {"prefilled_email": prefilled_email, "tenant_id": tenant_id} + return cls._send_request("GET", "/invoices", params=params) @classmethod def _send_request(cls, method, endpoint, json=None, params=None): - headers = { - "Content-Type": "application/json", - "Billing-Api-Secret-Key": cls.secret_key - } + headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers) @@ -69,10 +51,11 @@ class BillingService: def is_tenant_owner_or_admin(current_user): tenant_id = current_user.current_tenant_id - join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant_id, - TenantAccountJoin.account_id == current_user.id - ).first() + join = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id) + .first() + ) if not TenantAccountRole.is_privileged_role(join.role): - raise ValueError('Only team owner or team admin can perform this action') + raise ValueError("Only team owner or team admin can perform this action") diff --git a/api/services/code_based_extension_service.py b/api/services/code_based_extension_service.py index 7b0d50a83..f7597b7f1 100644 --- a/api/services/code_based_extension_service.py +++ b/api/services/code_based_extension_service.py @@ -2,12 +2,15 @@ from extensions.ext_code_based_extension import code_based_extension class CodeBasedExtensionService: - @staticmethod def get_code_based_extension(module: str) -> list[dict]: module_extensions = code_based_extension.module_extensions(module) - return [{ - 'name': module_extension.name, - 'label': module_extension.label, - 'form_schema': module_extension.form_schema - } for module_extension in module_extensions if not module_extension.builtin] + return [ + { + "name": module_extension.name, + "label": module_extension.label, + "form_schema": module_extension.form_schema, + } + for module_extension in module_extensions + if not module_extension.builtin + ] diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 053d704e1..7bfe59afa 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -15,22 +15,27 @@ from services.errors.message import MessageNotExistsError class ConversationService: @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int, - invoke_from: InvokeFrom, - include_ids: Optional[list] = None, - exclude_ids: Optional[list] = None, - sort_by: str = '-updated_at') -> InfiniteScrollPagination: + def pagination_by_last_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + invoke_from: InvokeFrom, + include_ids: Optional[list] = None, + exclude_ids: Optional[list] = None, + sort_by: str = "-updated_at", + ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) base_query = db.session.query(Conversation).filter( Conversation.is_deleted == False, Conversation.app_id == app_model.id, - Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'), + Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), Conversation.from_account_id == (user.id if isinstance(user, Account) else None), - or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value) + or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), ) if include_ids is not None: @@ -58,28 +63,26 @@ class ConversationService: has_more = False if len(conversations) == limit: current_page_last_conversation = conversations[-1] - rest_filter_condition = cls._build_filter_condition(sort_field, sort_direction, - current_page_last_conversation, is_next_page=True) + rest_filter_condition = cls._build_filter_condition( + sort_field, sort_direction, current_page_last_conversation, is_next_page=True + ) rest_count = base_query.filter(rest_filter_condition).count() if rest_count > 0: has_more = True - return InfiniteScrollPagination( - data=conversations, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=conversations, limit=limit, has_more=has_more) @classmethod def _get_sort_params(cls, sort_by: str) -> tuple[str, callable]: - if sort_by.startswith('-'): + if sort_by.startswith("-"): return sort_by[1:], desc return sort_by, asc @classmethod - def _build_filter_condition(cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation, - is_next_page: bool = False): + def _build_filter_condition( + cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation, is_next_page: bool = False + ): field_value = getattr(reference_conversation, sort_field) if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page): return getattr(Conversation, sort_field) < field_value @@ -87,8 +90,14 @@ class ConversationService: return getattr(Conversation, sort_field) > field_value @classmethod - def rename(cls, app_model: App, conversation_id: str, - user: Optional[Union[Account, EndUser]], name: str, auto_generate: bool): + def rename( + cls, + app_model: App, + conversation_id: str, + user: Optional[Union[Account, EndUser]], + name: str, + auto_generate: bool, + ): conversation = cls.get_conversation(app_model, conversation_id, user) if auto_generate: @@ -103,11 +112,12 @@ class ConversationService: @classmethod def auto_generate_name(cls, app_model: App, conversation: Conversation): # get conversation first message - message = db.session.query(Message) \ - .filter( - Message.app_id == app_model.id, - Message.conversation_id == conversation.id - ).order_by(Message.created_at.asc()).first() + message = ( + db.session.query(Message) + .filter(Message.app_id == app_model.id, Message.conversation_id == conversation.id) + .order_by(Message.created_at.asc()) + .first() + ) if not message: raise MessageNotExistsError() @@ -127,15 +137,18 @@ class ConversationService: @classmethod def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): - conversation = db.session.query(Conversation) \ + conversation = ( + db.session.query(Conversation) .filter( - Conversation.id == conversation_id, - Conversation.app_id == app_model.id, - Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Conversation.from_account_id == (user.id if isinstance(user, Account) else None), - Conversation.is_deleted == False - ).first() + Conversation.id == conversation_id, + Conversation.app_id == app_model.id, + Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), + Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Conversation.from_account_id == (user.id if isinstance(user, Account) else None), + Conversation.is_deleted == False, + ) + .first() + ) if not conversation: raise ConversationNotExistsError() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index d54701486..8649d0fea 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -55,7 +55,6 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde class DatasetService: - @staticmethod def get_datasets(page, per_page, provider="vendor", tenant_id=None, user=None, search=None, tag_ids=None): query = Dataset.query.filter(Dataset.provider == provider, Dataset.tenant_id == tenant_id).order_by( @@ -64,10 +63,7 @@ class DatasetService: if user: # get permitted dataset ids - dataset_permission = DatasetPermission.query.filter_by( - account_id=user.id, - tenant_id=tenant_id - ).all() + dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all() permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None if user.current_role == TenantAccountRole.DATASET_OPERATOR: @@ -83,14 +79,17 @@ class DatasetService: db.or_( Dataset.permission == DatasetPermissionEnum.ALL_TEAM, db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id), - db.and_(Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM, Dataset.id.in_(permitted_dataset_ids)) + db.and_( + Dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM, + Dataset.id.in_(permitted_dataset_ids), + ), ) ) else: query = query.filter( db.or_( Dataset.permission == DatasetPermissionEnum.ALL_TEAM, - db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id) + db.and_(Dataset.permission == DatasetPermissionEnum.ONLY_ME, Dataset.created_by == user.id), ) ) else: @@ -98,49 +97,40 @@ class DatasetService: query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM) if search: - query = query.filter(Dataset.name.ilike(f'%{search}%')) + query = query.filter(Dataset.name.ilike(f"%{search}%")) if tag_ids: - target_ids = TagService.get_target_ids_by_tag_ids('knowledge', tenant_id, tag_ids) + target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids) if target_ids: query = query.filter(Dataset.id.in_(target_ids)) else: return [], 0 - datasets = query.paginate( - page=page, - per_page=per_page, - max_per_page=100, - error_out=False - ) + datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) return datasets.items, datasets.total @staticmethod def get_process_rules(dataset_id): # get the latest process rule - dataset_process_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.dataset_id == dataset_id). \ - order_by(DatasetProcessRule.created_at.desc()). \ - limit(1). \ - one_or_none() + dataset_process_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.dataset_id == dataset_id) + .order_by(DatasetProcessRule.created_at.desc()) + .limit(1) + .one_or_none() + ) if dataset_process_rule: mode = dataset_process_rule.mode rules = dataset_process_rule.rules_dict else: - mode = DocumentService.DEFAULT_RULES['mode'] - rules = DocumentService.DEFAULT_RULES['rules'] - return { - 'mode': mode, - 'rules': rules - } + mode = DocumentService.DEFAULT_RULES["mode"] + rules = DocumentService.DEFAULT_RULES["rules"] + return {"mode": mode, "rules": rules} @staticmethod def get_datasets_by_ids(ids, tenant_id): - datasets = Dataset.query.filter( - Dataset.id.in_(ids), - Dataset.tenant_id == tenant_id - ).paginate( + datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate( page=1, per_page=len(ids), max_per_page=len(ids), error_out=False ) return datasets.items, datasets.total @@ -149,15 +139,12 @@ class DatasetService: def create_empty_dataset(tenant_id: str, name: str, indexing_technique: Optional[str], account: Account): # check if dataset name already exists if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first(): - raise DatasetNameDuplicateError( - f'Dataset with name {name} already exists.' - ) + raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None - if indexing_technique == 'high_quality': + if indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_default_model_instance( - tenant_id=tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING ) dataset = Dataset(name=name, indexing_technique=indexing_technique) # dataset = Dataset(name=name, provider=provider, config=config) @@ -172,20 +159,18 @@ class DatasetService: @staticmethod def get_dataset(dataset_id): - return Dataset.query.filter_by( - id=dataset_id - ).first() + return Dataset.query.filter_by(id=dataset_id).first() @staticmethod def check_dataset_model_setting(dataset): - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ValueError( @@ -193,65 +178,56 @@ class DatasetService: "in the Settings -> Model Provider." ) except ProviderTokenNotInitError as ex: - raise ValueError( - f"The dataset in unavailable, due to: " - f"{ex.description}" - ) + raise ValueError(f"The dataset in unavailable, due to: " f"{ex.description}") @staticmethod - def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model:str): + def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str): try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=tenant_id, provider=embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=embedding_model + model=embedding_model, ) except LLMBadRequestError: raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." + "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." ) except ProviderTokenNotInitError as ex: - raise ValueError( - f"The dataset in unavailable, due to: " - f"{ex.description}" - ) - + raise ValueError(f"The dataset in unavailable, due to: " f"{ex.description}") @staticmethod def update_dataset(dataset_id, data, user): - data.pop('partial_member_list', None) - filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'} + data.pop("partial_member_list", None) + filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} dataset = DatasetService.get_dataset(dataset_id) DatasetService.check_dataset_permission(dataset, user) action = None - if dataset.indexing_technique != data['indexing_technique']: + if dataset.indexing_technique != data["indexing_technique"]: # if update indexing_technique - if data['indexing_technique'] == 'economy': - action = 'remove' - filtered_data['embedding_model'] = None - filtered_data['embedding_model_provider'] = None - filtered_data['collection_binding_id'] = None - elif data['indexing_technique'] == 'high_quality': - action = 'add' + if data["indexing_technique"] == "economy": + action = "remove" + filtered_data["embedding_model"] = None + filtered_data["embedding_model_provider"] = None + filtered_data["collection_binding_id"] = None + elif data["indexing_technique"] == "high_quality": + action = "add" # get embedding model setting try: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - provider=data['embedding_model_provider'], + provider=data["embedding_model_provider"], model_type=ModelType.TEXT_EMBEDDING, - model=data['embedding_model'] + model=data["embedding_model"], ) - filtered_data['embedding_model'] = embedding_model.model - filtered_data['embedding_model_provider'] = embedding_model.provider + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + embedding_model.provider, embedding_model.model ) - filtered_data['collection_binding_id'] = dataset_collection_binding.id + filtered_data["collection_binding_id"] = dataset_collection_binding.id except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -260,24 +236,25 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) else: - if data['embedding_model_provider'] != dataset.embedding_model_provider or \ - data['embedding_model'] != dataset.embedding_model: - action = 'update' + if ( + data["embedding_model_provider"] != dataset.embedding_model_provider + or data["embedding_model"] != dataset.embedding_model + ): + action = "update" try: model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, - provider=data['embedding_model_provider'], + provider=data["embedding_model_provider"], model_type=ModelType.TEXT_EMBEDDING, - model=data['embedding_model'] + model=data["embedding_model"], ) - filtered_data['embedding_model'] = embedding_model.model - filtered_data['embedding_model_provider'] = embedding_model.provider + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + embedding_model.provider, embedding_model.model ) - filtered_data['collection_binding_id'] = dataset_collection_binding.id + filtered_data["collection_binding_id"] = dataset_collection_binding.id except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -286,11 +263,11 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) - filtered_data['updated_by'] = user.id - filtered_data['updated_at'] = datetime.datetime.now() + filtered_data["updated_by"] = user.id + filtered_data["updated_at"] = datetime.datetime.now() # update Retrieval model - filtered_data['retrieval_model'] = data['retrieval_model'] + filtered_data["retrieval_model"] = data["retrieval_model"] dataset.query.filter_by(id=dataset_id).update(filtered_data) @@ -301,7 +278,6 @@ class DatasetService: @staticmethod def delete_dataset(dataset_id, user): - dataset = DatasetService.get_dataset(dataset_id) if dataset is None: @@ -325,72 +301,57 @@ class DatasetService: @staticmethod def check_dataset_permission(dataset, user): if dataset.tenant_id != user.current_tenant_id: - logging.debug( - f'User {user.id} does not have permission to access dataset {dataset.id}' - ) - raise NoPermissionError( - 'You do not have permission to access this dataset.' - ) + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: - logging.debug( - f'User {user.id} does not have permission to access dataset {dataset.id}' - ) - raise NoPermissionError( - 'You do not have permission to access this dataset.' - ) - if dataset.permission == 'partial_members': - user_permission = DatasetPermission.query.filter_by( - dataset_id=dataset.id, account_id=user.id - ).first() + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") + if dataset.permission == "partial_members": + user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first() if not user_permission and dataset.tenant_id != user.current_tenant_id and dataset.created_by != user.id: - logging.debug( - f'User {user.id} does not have permission to access dataset {dataset.id}' - ) - raise NoPermissionError( - 'You do not have permission to access this dataset.' - ) + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod def check_dataset_operator_permission(user: Account = None, dataset: Dataset = None): if dataset.permission == DatasetPermissionEnum.ONLY_ME: if dataset.created_by != user.id: - raise NoPermissionError('You do not have permission to access this dataset.') + raise NoPermissionError("You do not have permission to access this dataset.") elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: if not any( dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all() ): - raise NoPermissionError('You do not have permission to access this dataset.') + raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod def get_dataset_queries(dataset_id: str, page: int, per_page: int): - dataset_queries = DatasetQuery.query.filter_by(dataset_id=dataset_id) \ - .order_by(db.desc(DatasetQuery.created_at)) \ - .paginate( - page=page, per_page=per_page, max_per_page=100, error_out=False + dataset_queries = ( + DatasetQuery.query.filter_by(dataset_id=dataset_id) + .order_by(db.desc(DatasetQuery.created_at)) + .paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) ) return dataset_queries.items, dataset_queries.total @staticmethod def get_related_apps(dataset_id: str): - return AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) \ - .order_by(db.desc(AppDatasetJoin.created_at)).all() + return ( + AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) + .order_by(db.desc(AppDatasetJoin.created_at)) + .all() + ) class DocumentService: DEFAULT_RULES = { - 'mode': 'custom', - 'rules': { - 'pre_processing_rules': [ - {'id': 'remove_extra_spaces', 'enabled': True}, - {'id': 'remove_urls_emails', 'enabled': False} + "mode": "custom", + "rules": { + "pre_processing_rules": [ + {"id": "remove_extra_spaces", "enabled": True}, + {"id": "remove_urls_emails", "enabled": False}, ], - 'segmentation': { - 'delimiter': '\n', - 'max_tokens': 500, - 'chunk_overlap': 50 - } - } + "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, + }, } DOCUMENT_METADATA_SCHEMA = { @@ -483,58 +444,55 @@ class DocumentService: "commit_date": str, "commit_author": str, }, - "others": dict + "others": dict, } @staticmethod def get_document(dataset_id: str, document_id: str) -> Optional[Document]: - document = db.session.query(Document).filter( - Document.id == document_id, - Document.dataset_id == dataset_id - ).first() + document = ( + db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first() + ) return document @staticmethod def get_document_by_id(document_id: str) -> Optional[Document]: - document = db.session.query(Document).filter( - Document.id == document_id - ).first() + document = db.session.query(Document).filter(Document.id == document_id).first() return document @staticmethod def get_document_by_dataset_id(dataset_id: str) -> list[Document]: - documents = db.session.query(Document).filter( - Document.dataset_id == dataset_id, - Document.enabled == True - ).all() + documents = db.session.query(Document).filter(Document.dataset_id == dataset_id, Document.enabled == True).all() return documents @staticmethod def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]: - documents = db.session.query(Document).filter( - Document.dataset_id == dataset_id, - Document.indexing_status.in_(['error', 'paused']) - ).all() + documents = ( + db.session.query(Document) + .filter(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) + .all() + ) return documents @staticmethod def get_batch_documents(dataset_id: str, batch: str) -> list[Document]: - documents = db.session.query(Document).filter( - Document.batch == batch, - Document.dataset_id == dataset_id, - Document.tenant_id == current_user.current_tenant_id - ).all() + documents = ( + db.session.query(Document) + .filter( + Document.batch == batch, + Document.dataset_id == dataset_id, + Document.tenant_id == current_user.current_tenant_id, + ) + .all() + ) return documents @staticmethod def get_document_file_detail(file_id: str): - file_detail = db.session.query(UploadFile). \ - filter(UploadFile.id == file_id). \ - one_or_none() + file_detail = db.session.query(UploadFile).filter(UploadFile.id == file_id).one_or_none() return file_detail @staticmethod @@ -548,13 +506,14 @@ class DocumentService: def delete_document(document): # trigger document_was_deleted signal file_id = None - if document.data_source_type == 'upload_file': + if document.data_source_type == "upload_file": if document.data_source_info: data_source_info = document.data_source_info_dict - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] - document_was_deleted.send(document.id, dataset_id=document.dataset_id, - doc_form=document.doc_form, file_id=file_id) + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] + document_was_deleted.send( + document.id, dataset_id=document.dataset_id, doc_form=document.doc_form, file_id=file_id + ) db.session.delete(document) db.session.commit() @@ -563,15 +522,15 @@ class DocumentService: def rename_document(dataset_id: str, document_id: str, name: str) -> Document: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise ValueError('Dataset not found.') + raise ValueError("Dataset not found.") document = DocumentService.get_document(dataset_id, document_id) if not document: - raise ValueError('Document not found.') + raise ValueError("Document not found.") if document.tenant_id != current_user.current_tenant_id: - raise ValueError('No permission.') + raise ValueError("No permission.") document.name = name @@ -592,7 +551,7 @@ class DocumentService: db.session.add(document) db.session.commit() # set document paused flag - indexing_cache_key = 'document_{}_is_paused'.format(document.id) + indexing_cache_key = "document_{}_is_paused".format(document.id) redis_client.setnx(indexing_cache_key, "True") @staticmethod @@ -607,7 +566,7 @@ class DocumentService: db.session.add(document) db.session.commit() # delete paused flag - indexing_cache_key = 'document_{}_is_paused'.format(document.id) + indexing_cache_key = "document_{}_is_paused".format(document.id) redis_client.delete(indexing_cache_key) # trigger async task recover_document_indexing_task.delay(document.dataset_id, document.id) @@ -616,12 +575,12 @@ class DocumentService: def retry_document(dataset_id: str, documents: list[Document]): for document in documents: # add retry flag - retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id) + retry_indexing_cache_key = "document_{}_is_retried".format(document.id) cache_result = redis_client.get(retry_indexing_cache_key) if cache_result is not None: raise ValueError("Document is being retried, please try again later") # retry document indexing - document.indexing_status = 'waiting' + document.indexing_status = "waiting" db.session.add(document) db.session.commit() @@ -633,14 +592,14 @@ class DocumentService: @staticmethod def sync_website_document(dataset_id: str, document: Document): # add sync flag - sync_indexing_cache_key = 'document_{}_is_sync'.format(document.id) + sync_indexing_cache_key = "document_{}_is_sync".format(document.id) cache_result = redis_client.get(sync_indexing_cache_key) if cache_result is not None: raise ValueError("Document is being synced, please try again later") # sync document indexing - document.indexing_status = 'waiting' + document.indexing_status = "waiting" data_source_info = document.data_source_info_dict - data_source_info['mode'] = 'scrape' + data_source_info["mode"] = "scrape" document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) db.session.add(document) db.session.commit() @@ -659,27 +618,28 @@ class DocumentService: @staticmethod def save_document_with_dataset_id( - dataset: Dataset, document_data: dict, - account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, - created_from: str = 'web' + dataset: Dataset, + document_data: dict, + account: Account, + dataset_process_rule: Optional[DatasetProcessRule] = None, + created_from: str = "web", ): - # check document limit features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - if 'original_document_id' not in document_data or not document_data['original_document_id']: + if "original_document_id" not in document_data or not document_data["original_document_id"]: count = 0 if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] count = len(upload_file_list) elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] for notion_info in notion_info_list: - count = count + len(notion_info['pages']) + count = count + len(notion_info["pages"]) elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - count = len(website_info['urls']) + website_info = document_data["data_source"]["info_list"]["website_info_list"] + count = len(website_info["urls"]) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -691,42 +651,41 @@ class DocumentService: dataset.data_source_type = document_data["data_source"]["type"] if not dataset.indexing_technique: - if 'indexing_technique' not in document_data \ - or document_data['indexing_technique'] not in Dataset.INDEXING_TECHNIQUE_LIST: + if ( + "indexing_technique" not in document_data + or document_data["indexing_technique"] not in Dataset.INDEXING_TECHNIQUE_LIST + ): raise ValueError("Indexing technique is required") dataset.indexing_technique = document_data["indexing_technique"] - if document_data["indexing_technique"] == 'high_quality': + if document_data["indexing_technique"] == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING ) dataset.embedding_model = embedding_model.model dataset.embedding_model_provider = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + embedding_model.provider, embedding_model.model ) dataset.collection_binding_id = dataset_collection_binding.id if not dataset.retrieval_model: default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } - dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get( - 'retrieval_model' - ) else default_retrieval_model + dataset.retrieval_model = ( + document_data.get("retrieval_model") + if document_data.get("retrieval_model") + else default_retrieval_model + ) documents = [] - batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999)) + batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) if document_data.get("original_document_id"): document = DocumentService.update_document_with_dataset_id(dataset, document_data, account) documents.append(document) @@ -739,14 +698,14 @@ class DocumentService: dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(process_rule["rules"]), - created_by=account.id + created_by=account.id, ) elif process_rule["mode"] == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), - created_by=account.id + created_by=account.id, ) db.session.add(dataset_process_rule) db.session.commit() @@ -754,12 +713,13 @@ class DocumentService: document_ids = [] duplicate_document_ids = [] if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] for file_id in upload_file_list: - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == dataset.tenant_id, - UploadFile.id == file_id - ).first() + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) # raise error if file not found if not file: @@ -770,34 +730,39 @@ class DocumentService: "upload_file_id": file_id, } # check duplicate - if document_data.get('duplicate', False): + if document_data.get("duplicate", False): document = Document.query.filter_by( dataset_id=dataset.id, tenant_id=current_user.current_tenant_id, - data_source_type='upload_file', + data_source_type="upload_file", enabled=True, - name=file_name + name=file_name, ).first() if document: document.dataset_process_rule_id = dataset_process_rule.id document.updated_at = datetime.datetime.utcnow() document.created_from = created_from - document.doc_form = document_data['doc_form'] - document.doc_language = document_data['doc_language'] + document.doc_form = document_data["doc_form"] + document.doc_language = document_data["doc_language"] document.data_source_info = json.dumps(data_source_info) document.batch = batch - document.indexing_status = 'waiting' + document.indexing_status = "waiting" db.session.add(document) documents.append(document) duplicate_document_ids.append(document.id) continue document = DocumentService.build_document( - dataset, dataset_process_rule.id, + dataset, + dataset_process_rule.id, document_data["data_source"]["type"], document_data["doc_form"], document_data["doc_language"], - data_source_info, created_from, position, - account, file_name, batch + data_source_info, + created_from, + position, + account, + file_name, + batch, ) db.session.add(document) db.session.flush() @@ -805,47 +770,52 @@ class DocumentService: documents.append(document) position += 1 elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] exist_page_ids = [] exist_document = {} documents = Document.query.filter_by( dataset_id=dataset.id, tenant_id=current_user.current_tenant_id, - data_source_type='notion_import', - enabled=True + data_source_type="notion_import", + enabled=True, ).all() if documents: for document in documents: data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info['notion_page_id']) - exist_document[data_source_info['notion_page_id']] = document.id + exist_page_ids.append(data_source_info["notion_page_id"]) + exist_document[data_source_info["notion_page_id"]] = document.id for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] + workspace_id = notion_info["workspace_id"] data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise ValueError('Data source binding not found.') - for page in notion_info['pages']: - if page['page_id'] not in exist_page_ids: + raise ValueError("Data source binding not found.") + for page in notion_info["pages"]: + if page["page_id"] not in exist_page_ids: data_source_info = { "notion_workspace_id": workspace_id, - "notion_page_id": page['page_id'], - "notion_page_icon": page['page_icon'], - "type": page['type'] + "notion_page_id": page["page_id"], + "notion_page_icon": page["page_icon"], + "type": page["type"], } document = DocumentService.build_document( - dataset, dataset_process_rule.id, + dataset, + dataset_process_rule.id, document_data["data_source"]["type"], document_data["doc_form"], document_data["doc_language"], - data_source_info, created_from, position, - account, page['page_name'], batch + data_source_info, + created_from, + position, + account, + page["page_name"], + batch, ) db.session.add(document) db.session.flush() @@ -853,32 +823,37 @@ class DocumentService: documents.append(document) position += 1 else: - exist_document.pop(page['page_id']) + exist_document.pop(page["page_id"]) # delete not selected documents if len(exist_document) > 0: clean_notion_document_task.delay(list(exist_document.values()), dataset.id) elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - urls = website_info['urls'] + website_info = document_data["data_source"]["info_list"]["website_info_list"] + urls = website_info["urls"] for url in urls: data_source_info = { - 'url': url, - 'provider': website_info['provider'], - 'job_id': website_info['job_id'], - 'only_main_content': website_info.get('only_main_content', False), - 'mode': 'crawl', + "url": url, + "provider": website_info["provider"], + "job_id": website_info["job_id"], + "only_main_content": website_info.get("only_main_content", False), + "mode": "crawl", } if len(url) > 255: - document_name = url[:200] + '...' + document_name = url[:200] + "..." else: document_name = url document = DocumentService.build_document( - dataset, dataset_process_rule.id, + dataset, + dataset_process_rule.id, document_data["data_source"]["type"], document_data["doc_form"], document_data["doc_language"], - data_source_info, created_from, position, - account, document_name, batch + data_source_info, + created_from, + position, + account, + document_name, + batch, ) db.session.add(document) db.session.flush() @@ -900,15 +875,22 @@ class DocumentService: can_upload_size = features.documents_upload_quota.limit - features.documents_upload_quota.size if count > can_upload_size: raise ValueError( - f'You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded.' + f"You have reached the limit of your subscription. Only {can_upload_size} documents can be uploaded." ) @staticmethod def build_document( - dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str, - document_language: str, data_source_info: dict, created_from: str, position: int, + dataset: Dataset, + process_rule_id: str, + data_source_type: str, + document_form: str, + document_language: str, + data_source_info: dict, + created_from: str, + position: int, account: Account, - name: str, batch: str + name: str, + batch: str, ): document = Document( tenant_id=dataset.tenant_id, @@ -922,7 +904,7 @@ class DocumentService: created_from=created_from, created_by=account.id, doc_form=document_form, - doc_language=document_language + doc_language=document_language, ) return document @@ -932,54 +914,57 @@ class DocumentService: Document.completed_at.isnot(None), Document.enabled == True, Document.archived == False, - Document.tenant_id == current_user.current_tenant_id + Document.tenant_id == current_user.current_tenant_id, ).count() return documents_count @staticmethod def update_document_with_dataset_id( - dataset: Dataset, document_data: dict, - account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None, - created_from: str = 'web' + dataset: Dataset, + document_data: dict, + account: Account, + dataset_process_rule: Optional[DatasetProcessRule] = None, + created_from: str = "web", ): DatasetService.check_dataset_model_setting(dataset) document = DocumentService.get_document(dataset.id, document_data["original_document_id"]) - if document.display_status != 'available': + if document.display_status != "available": raise ValueError("Document is not available") # update document name - if document_data.get('name'): - document.name = document_data['name'] + if document_data.get("name"): + document.name = document_data["name"] # save process rule - if document_data.get('process_rule'): + if document_data.get("process_rule"): process_rule = document_data["process_rule"] if process_rule["mode"] == "custom": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(process_rule["rules"]), - created_by=account.id + created_by=account.id, ) elif process_rule["mode"] == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule["mode"], rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), - created_by=account.id + created_by=account.id, ) db.session.add(dataset_process_rule) db.session.commit() document.dataset_process_rule_id = dataset_process_rule.id # update document data source - if document_data.get('data_source'): - file_name = '' + if document_data.get("data_source"): + file_name = "" data_source_info = {} if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] for file_id in upload_file_list: - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == dataset.tenant_id, - UploadFile.id == file_id - ).first() + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) + .first() + ) # raise error if file not found if not file: @@ -990,42 +975,42 @@ class DocumentService: "upload_file_id": file_id, } elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] + workspace_id = notion_info["workspace_id"] data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise ValueError('Data source binding not found.') - for page in notion_info['pages']: + raise ValueError("Data source binding not found.") + for page in notion_info["pages"]: data_source_info = { "notion_workspace_id": workspace_id, - "notion_page_id": page['page_id'], - "notion_page_icon": page['page_icon'], - "type": page['type'] + "notion_page_id": page["page_id"], + "notion_page_icon": page["page_icon"], + "type": page["type"], } elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - urls = website_info['urls'] + website_info = document_data["data_source"]["info_list"]["website_info_list"] + urls = website_info["urls"] for url in urls: data_source_info = { - 'url': url, - 'provider': website_info['provider'], - 'job_id': website_info['job_id'], - 'only_main_content': website_info.get('only_main_content', False), - 'mode': 'crawl', + "url": url, + "provider": website_info["provider"], + "job_id": website_info["job_id"], + "only_main_content": website_info.get("only_main_content", False), + "mode": "crawl", } document.data_source_type = document_data["data_source"]["type"] document.data_source_info = json.dumps(data_source_info) document.name = file_name # update document to be waiting - document.indexing_status = 'waiting' + document.indexing_status = "waiting" document.completed_at = None document.processing_started_at = None document.parsing_completed_at = None @@ -1033,13 +1018,11 @@ class DocumentService: document.splitting_completed_at = None document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) document.created_from = created_from - document.doc_form = document_data['doc_form'] + document.doc_form = document_data["doc_form"] db.session.add(document) db.session.commit() # update document segment - update_params = { - DocumentSegment.status: 're_segment' - } + update_params = {DocumentSegment.status: "re_segment"} DocumentSegment.query.filter_by(document_id=document.id).update(update_params) db.session.commit() # trigger async task @@ -1053,15 +1036,15 @@ class DocumentService: if features.billing.enabled: count = 0 if document_data["data_source"]["type"] == "upload_file": - upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids'] + upload_file_list = document_data["data_source"]["info_list"]["file_info_list"]["file_ids"] count = len(upload_file_list) elif document_data["data_source"]["type"] == "notion_import": - notion_info_list = document_data["data_source"]['info_list']['notion_info_list'] + notion_info_list = document_data["data_source"]["info_list"]["notion_info_list"] for notion_info in notion_info_list: - count = count + len(notion_info['pages']) + count = count + len(notion_info["pages"]) elif document_data["data_source"]["type"] == "website_crawl": - website_info = document_data["data_source"]['info_list']['website_info_list'] - count = len(website_info['urls']) + website_info = document_data["data_source"]["info_list"]["website_info_list"] + count = len(website_info["urls"]) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) if count > batch_upload_limit: raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.") @@ -1071,42 +1054,37 @@ class DocumentService: embedding_model = None dataset_collection_binding_id = None retrieval_model = None - if document_data['indexing_technique'] == 'high_quality': + if document_data["indexing_technique"] == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING ) dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, - embedding_model.model + embedding_model.provider, embedding_model.model ) dataset_collection_binding_id = dataset_collection_binding.id - if document_data.get('retrieval_model'): - retrieval_model = document_data['retrieval_model'] + if document_data.get("retrieval_model"): + retrieval_model = document_data["retrieval_model"] else: default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } retrieval_model = default_retrieval_model # save dataset dataset = Dataset( tenant_id=tenant_id, - name='', + name="", data_source_type=document_data["data_source"]["type"], indexing_technique=document_data["indexing_technique"], created_by=account.id, embedding_model=embedding_model.model if embedding_model else None, embedding_model_provider=embedding_model.provider if embedding_model else None, collection_binding_id=dataset_collection_binding_id, - retrieval_model=retrieval_model + retrieval_model=retrieval_model, ) db.session.add(dataset) @@ -1116,236 +1094,259 @@ class DocumentService: cut_length = 18 cut_name = documents[0].name[:cut_length] - dataset.name = cut_name + '...' - dataset.description = 'useful for when you want to answer queries about the ' + documents[0].name + dataset.name = cut_name + "..." + dataset.description = "useful for when you want to answer queries about the " + documents[0].name db.session.commit() return dataset, documents, batch @classmethod def document_create_args_validate(cls, args: dict): - if 'original_document_id' not in args or not args['original_document_id']: + if "original_document_id" not in args or not args["original_document_id"]: DocumentService.data_source_args_validate(args) DocumentService.process_rule_args_validate(args) else: - if ('data_source' not in args and not args['data_source']) \ - and ('process_rule' not in args and not args['process_rule']): + if ("data_source" not in args and not args["data_source"]) and ( + "process_rule" not in args and not args["process_rule"] + ): raise ValueError("Data source or Process rule is required") else: - if args.get('data_source'): + if args.get("data_source"): DocumentService.data_source_args_validate(args) - if args.get('process_rule'): + if args.get("process_rule"): DocumentService.process_rule_args_validate(args) @classmethod def data_source_args_validate(cls, args: dict): - if 'data_source' not in args or not args['data_source']: + if "data_source" not in args or not args["data_source"]: raise ValueError("Data source is required") - if not isinstance(args['data_source'], dict): + if not isinstance(args["data_source"], dict): raise ValueError("Data source is invalid") - if 'type' not in args['data_source'] or not args['data_source']['type']: + if "type" not in args["data_source"] or not args["data_source"]["type"]: raise ValueError("Data source type is required") - if args['data_source']['type'] not in Document.DATA_SOURCES: + if args["data_source"]["type"] not in Document.DATA_SOURCES: raise ValueError("Data source type is invalid") - if 'info_list' not in args['data_source'] or not args['data_source']['info_list']: + if "info_list" not in args["data_source"] or not args["data_source"]["info_list"]: raise ValueError("Data source info is required") - if args['data_source']['type'] == 'upload_file': - if 'file_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ - 'file_info_list']: + if args["data_source"]["type"] == "upload_file": + if ( + "file_info_list" not in args["data_source"]["info_list"] + or not args["data_source"]["info_list"]["file_info_list"] + ): raise ValueError("File source info is required") - if args['data_source']['type'] == 'notion_import': - if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ - 'notion_info_list']: + if args["data_source"]["type"] == "notion_import": + if ( + "notion_info_list" not in args["data_source"]["info_list"] + or not args["data_source"]["info_list"]["notion_info_list"] + ): raise ValueError("Notion source info is required") - if args['data_source']['type'] == 'website_crawl': - if 'website_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][ - 'website_info_list']: + if args["data_source"]["type"] == "website_crawl": + if ( + "website_info_list" not in args["data_source"]["info_list"] + or not args["data_source"]["info_list"]["website_info_list"] + ): raise ValueError("Website source info is required") @classmethod def process_rule_args_validate(cls, args: dict): - if 'process_rule' not in args or not args['process_rule']: + if "process_rule" not in args or not args["process_rule"]: raise ValueError("Process rule is required") - if not isinstance(args['process_rule'], dict): + if not isinstance(args["process_rule"], dict): raise ValueError("Process rule is invalid") - if 'mode' not in args['process_rule'] or not args['process_rule']['mode']: + if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: raise ValueError("Process rule mode is required") - if args['process_rule']['mode'] not in DatasetProcessRule.MODES: + if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args['process_rule']['mode'] == 'automatic': - args['process_rule']['rules'] = {} + if args["process_rule"]["mode"] == "automatic": + args["process_rule"]["rules"] = {} else: - if 'rules' not in args['process_rule'] or not args['process_rule']['rules']: + if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: raise ValueError("Process rule rules is required") - if not isinstance(args['process_rule']['rules'], dict): + if not isinstance(args["process_rule"]["rules"], dict): raise ValueError("Process rule rules is invalid") - if 'pre_processing_rules' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['pre_processing_rules'] is None: + if ( + "pre_processing_rules" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["pre_processing_rules"] is None + ): raise ValueError("Process rule pre_processing_rules is required") - if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list): + if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): raise ValueError("Process rule pre_processing_rules is invalid") unique_pre_processing_rule_dicts = {} - for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']: - if 'id' not in pre_processing_rule or not pre_processing_rule['id']: + for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: + if "id" not in pre_processing_rule or not pre_processing_rule["id"]: raise ValueError("Process rule pre_processing_rules id is required") - if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES: + if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: raise ValueError("Process rule pre_processing_rules id is invalid") - if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None: + if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: raise ValueError("Process rule pre_processing_rules enabled is required") - if not isinstance(pre_processing_rule['enabled'], bool): + if not isinstance(pre_processing_rule["enabled"], bool): raise ValueError("Process rule pre_processing_rules enabled is invalid") - unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule + unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule - args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values()) + args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) - if 'segmentation' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['segmentation'] is None: + if ( + "segmentation" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["segmentation"] is None + ): raise ValueError("Process rule segmentation is required") - if not isinstance(args['process_rule']['rules']['segmentation'], dict): + if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): raise ValueError("Process rule segmentation is invalid") - if 'separator' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['separator']: + if ( + "separator" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["separator"] + ): raise ValueError("Process rule segmentation separator is required") - if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): raise ValueError("Process rule segmentation separator is invalid") - if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['max_tokens']: + if ( + "max_tokens" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] + ): raise ValueError("Process rule segmentation max_tokens is required") - if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") @classmethod def estimate_args_validate(cls, args: dict): - if 'info_list' not in args or not args['info_list']: + if "info_list" not in args or not args["info_list"]: raise ValueError("Data source info is required") - if not isinstance(args['info_list'], dict): + if not isinstance(args["info_list"], dict): raise ValueError("Data info is invalid") - if 'process_rule' not in args or not args['process_rule']: + if "process_rule" not in args or not args["process_rule"]: raise ValueError("Process rule is required") - if not isinstance(args['process_rule'], dict): + if not isinstance(args["process_rule"], dict): raise ValueError("Process rule is invalid") - if 'mode' not in args['process_rule'] or not args['process_rule']['mode']: + if "mode" not in args["process_rule"] or not args["process_rule"]["mode"]: raise ValueError("Process rule mode is required") - if args['process_rule']['mode'] not in DatasetProcessRule.MODES: + if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args['process_rule']['mode'] == 'automatic': - args['process_rule']['rules'] = {} + if args["process_rule"]["mode"] == "automatic": + args["process_rule"]["rules"] = {} else: - if 'rules' not in args['process_rule'] or not args['process_rule']['rules']: + if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: raise ValueError("Process rule rules is required") - if not isinstance(args['process_rule']['rules'], dict): + if not isinstance(args["process_rule"]["rules"], dict): raise ValueError("Process rule rules is invalid") - if 'pre_processing_rules' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['pre_processing_rules'] is None: + if ( + "pre_processing_rules" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["pre_processing_rules"] is None + ): raise ValueError("Process rule pre_processing_rules is required") - if not isinstance(args['process_rule']['rules']['pre_processing_rules'], list): + if not isinstance(args["process_rule"]["rules"]["pre_processing_rules"], list): raise ValueError("Process rule pre_processing_rules is invalid") unique_pre_processing_rule_dicts = {} - for pre_processing_rule in args['process_rule']['rules']['pre_processing_rules']: - if 'id' not in pre_processing_rule or not pre_processing_rule['id']: + for pre_processing_rule in args["process_rule"]["rules"]["pre_processing_rules"]: + if "id" not in pre_processing_rule or not pre_processing_rule["id"]: raise ValueError("Process rule pre_processing_rules id is required") - if pre_processing_rule['id'] not in DatasetProcessRule.PRE_PROCESSING_RULES: + if pre_processing_rule["id"] not in DatasetProcessRule.PRE_PROCESSING_RULES: raise ValueError("Process rule pre_processing_rules id is invalid") - if 'enabled' not in pre_processing_rule or pre_processing_rule['enabled'] is None: + if "enabled" not in pre_processing_rule or pre_processing_rule["enabled"] is None: raise ValueError("Process rule pre_processing_rules enabled is required") - if not isinstance(pre_processing_rule['enabled'], bool): + if not isinstance(pre_processing_rule["enabled"], bool): raise ValueError("Process rule pre_processing_rules enabled is invalid") - unique_pre_processing_rule_dicts[pre_processing_rule['id']] = pre_processing_rule + unique_pre_processing_rule_dicts[pre_processing_rule["id"]] = pre_processing_rule - args['process_rule']['rules']['pre_processing_rules'] = list(unique_pre_processing_rule_dicts.values()) + args["process_rule"]["rules"]["pre_processing_rules"] = list(unique_pre_processing_rule_dicts.values()) - if 'segmentation' not in args['process_rule']['rules'] \ - or args['process_rule']['rules']['segmentation'] is None: + if ( + "segmentation" not in args["process_rule"]["rules"] + or args["process_rule"]["rules"]["segmentation"] is None + ): raise ValueError("Process rule segmentation is required") - if not isinstance(args['process_rule']['rules']['segmentation'], dict): + if not isinstance(args["process_rule"]["rules"]["segmentation"], dict): raise ValueError("Process rule segmentation is invalid") - if 'separator' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['separator']: + if ( + "separator" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["separator"] + ): raise ValueError("Process rule segmentation separator is required") - if not isinstance(args['process_rule']['rules']['segmentation']['separator'], str): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["separator"], str): raise ValueError("Process rule segmentation separator is invalid") - if 'max_tokens' not in args['process_rule']['rules']['segmentation'] \ - or not args['process_rule']['rules']['segmentation']['max_tokens']: + if ( + "max_tokens" not in args["process_rule"]["rules"]["segmentation"] + or not args["process_rule"]["rules"]["segmentation"]["max_tokens"] + ): raise ValueError("Process rule segmentation max_tokens is required") - if not isinstance(args['process_rule']['rules']['segmentation']['max_tokens'], int): + if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") class SegmentService: @classmethod def segment_create_args_validate(cls, args: dict, document: Document): - if document.doc_form == 'qa_model': - if 'answer' not in args or not args['answer']: + if document.doc_form == "qa_model": + if "answer" not in args or not args["answer"]: raise ValueError("Answer is required") - if not args['answer'].strip(): + if not args["answer"].strip(): raise ValueError("Answer is empty") - if 'content' not in args or not args['content'] or not args['content'].strip(): + if "content" not in args or not args["content"] or not args["content"].strip(): raise ValueError("Content is empty") @classmethod def create_segment(cls, args: dict, document: Document, dataset: Dataset): - content = args['content'] + content = args["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) - lock_name = 'add_segment_lock_document_id_{}'.format(document.id) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) + lock_name = "add_segment_lock_document_id_{}".format(document.id) with redis_client.lock(lock_name, timeout=600): - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == document.id - ).scalar() + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == document.id) + .scalar() + ) segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, dataset_id=document.dataset_id, @@ -1356,25 +1357,25 @@ class SegmentService: content=content, word_count=len(content), tokens=tokens, - status='completed', + status="completed", indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - created_by=current_user.id + created_by=current_user.id, ) - if document.doc_form == 'qa_model': - segment_document.answer = args['answer'] + if document.doc_form == "qa_model": + segment_document.answer = args["answer"] db.session.add(segment_document) db.session.commit() # save vector index try: - VectorService.create_segments_vector([args['keywords']], [segment_document], dataset) + VectorService.create_segments_vector([args["keywords"]], [segment_document], dataset) except Exception as e: logging.exception("create segment index failed") segment_document.enabled = False segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment_document.status = 'error' + segment_document.status = "error" segment_document.error = str(e) db.session.commit() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first() @@ -1382,33 +1383,33 @@ class SegmentService: @classmethod def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): - lock_name = 'multi_add_segment_lock_document_id_{}'.format(document.id) + lock_name = "multi_add_segment_lock_document_id_{}".format(document.id) with redis_client.lock(lock_name, timeout=600): embedding_model = None - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) - max_position = db.session.query(func.max(DocumentSegment.position)).filter( - DocumentSegment.document_id == document.id - ).scalar() + max_position = ( + db.session.query(func.max(DocumentSegment.position)) + .filter(DocumentSegment.document_id == document.id) + .scalar() + ) pre_segment_data_list = [] segment_data_list = [] keywords_list = [] for segment_item in segments: - content = segment_item['content'] + content = segment_item["content"] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == 'high_quality' and embedding_model: + if dataset.indexing_technique == "high_quality" and embedding_model: # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) segment_document = DocumentSegment( tenant_id=current_user.current_tenant_id, dataset_id=document.dataset_id, @@ -1419,19 +1420,19 @@ class SegmentService: content=content, word_count=len(content), tokens=tokens, - status='completed', + status="completed", indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - created_by=current_user.id + created_by=current_user.id, ) - if document.doc_form == 'qa_model': - segment_document.answer = segment_item['answer'] + if document.doc_form == "qa_model": + segment_document.answer = segment_item["answer"] db.session.add(segment_document) segment_data_list.append(segment_document) pre_segment_data_list.append(segment_document) - if 'keywords' in segment_item: - keywords_list.append(segment_item['keywords']) + if "keywords" in segment_item: + keywords_list.append(segment_item["keywords"]) else: keywords_list.append(None) @@ -1443,19 +1444,19 @@ class SegmentService: for segment_document in segment_data_list: segment_document.enabled = False segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment_document.status = 'error' + segment_document.status = "error" segment_document.error = str(e) db.session.commit() return segment_data_list @classmethod def update_segment(cls, args: dict, segment: DocumentSegment, document: Document, dataset: Dataset): - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is indexing, please try again later") - if 'enabled' in args and args['enabled'] is not None: - action = args['enabled'] + if "enabled" in args and args["enabled"] is not None: + action = args["enabled"] if segment.enabled != action: if not action: segment.enabled = action @@ -1468,25 +1469,25 @@ class SegmentService: disable_segment_from_index_task.delay(segment.id) return segment if not segment.enabled: - if 'enabled' in args and args['enabled'] is not None: - if not args['enabled']: + if "enabled" in args and args["enabled"] is not None: + if not args["enabled"]: raise ValueError("Can't update disabled segment") else: raise ValueError("Can't update disabled segment") try: - content = args['content'] + content = args["content"] if segment.content == content: - if document.doc_form == 'qa_model': - segment.answer = args['answer'] - if args.get('keywords'): - segment.keywords = args['keywords'] + if document.doc_form == "qa_model": + segment.answer = args["answer"] + if args.get("keywords"): + segment.keywords = args["keywords"] segment.enabled = True segment.disabled_at = None segment.disabled_by = None db.session.add(segment) db.session.commit() # update segment index task - if 'keywords' in args: + if "keywords" in args: keyword = Keyword(dataset) keyword.delete_by_ids([segment.index_node_id]) document = RAGDocument( @@ -1496,30 +1497,28 @@ class SegmentService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) - keyword.add_texts([document], keywords_list=[args['keywords']]) + keyword.add_texts([document], keywords_list=[args["keywords"]]) else: segment_hash = helper.generate_text_hash(content) tokens = 0 - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) # calc embedding use tokens - tokens = embedding_model.get_text_embedding_num_tokens( - texts=[content] - ) + tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) segment.content = content segment.index_node_hash = segment_hash segment.word_count = len(content) segment.tokens = tokens - segment.status = 'completed' + segment.status = "completed" segment.indexing_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) segment.completed_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) segment.updated_by = current_user.id @@ -1527,18 +1526,18 @@ class SegmentService: segment.enabled = True segment.disabled_at = None segment.disabled_by = None - if document.doc_form == 'qa_model': - segment.answer = args['answer'] + if document.doc_form == "qa_model": + segment.answer = args["answer"] db.session.add(segment) db.session.commit() # update segment vector index - VectorService.update_segment_vector(args['keywords'], segment, dataset) + VectorService.update_segment_vector(args["keywords"], segment, dataset) except Exception as e: logging.exception("update segment index failed") segment.enabled = False segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment.status = 'error' + segment.status = "error" segment.error = str(e) db.session.commit() segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first() @@ -1546,7 +1545,7 @@ class SegmentService: @classmethod def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset): - indexing_cache_key = 'segment_{}_delete_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_delete_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise ValueError("Segment is deleting.") @@ -1563,24 +1562,25 @@ class SegmentService: class DatasetCollectionBindingService: @classmethod def get_dataset_collection_binding( - cls, provider_name: str, model_name: str, - collection_type: str = 'dataset' + cls, provider_name: str, model_name: str, collection_type: str = "dataset" ) -> DatasetCollectionBinding: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter( - DatasetCollectionBinding.provider_name == provider_name, - DatasetCollectionBinding.model_name == model_name, - DatasetCollectionBinding.type == collection_type - ). \ - order_by(DatasetCollectionBinding.created_at). \ - first() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.provider_name == provider_name, + DatasetCollectionBinding.model_name == model_name, + DatasetCollectionBinding.type == collection_type, + ) + .order_by(DatasetCollectionBinding.created_at) + .first() + ) if not dataset_collection_binding: dataset_collection_binding = DatasetCollectionBinding( provider_name=provider_name, model_name=model_name, collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), - type=collection_type + type=collection_type, ) db.session.add(dataset_collection_binding) db.session.commit() @@ -1588,16 +1588,16 @@ class DatasetCollectionBindingService: @classmethod def get_dataset_collection_binding_by_id_and_type( - cls, collection_binding_id: str, - collection_type: str = 'dataset' + cls, collection_binding_id: str, collection_type: str = "dataset" ) -> DatasetCollectionBinding: - dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ - filter( - DatasetCollectionBinding.id == collection_binding_id, - DatasetCollectionBinding.type == collection_type - ). \ - order_by(DatasetCollectionBinding.created_at). \ - first() + dataset_collection_binding = ( + db.session.query(DatasetCollectionBinding) + .filter( + DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type + ) + .order_by(DatasetCollectionBinding.created_at) + .first() + ) return dataset_collection_binding @@ -1605,11 +1605,13 @@ class DatasetCollectionBindingService: class DatasetPermissionService: @classmethod def get_dataset_partial_member_list(cls, dataset_id): - user_list_query = db.session.query( - DatasetPermission.account_id, - ).filter( - DatasetPermission.dataset_id == dataset_id - ).all() + user_list_query = ( + db.session.query( + DatasetPermission.account_id, + ) + .filter(DatasetPermission.dataset_id == dataset_id) + .all() + ) user_list = [] for user in user_list_query: @@ -1626,7 +1628,7 @@ class DatasetPermissionService: permission = DatasetPermission( tenant_id=tenant_id, dataset_id=dataset_id, - account_id=user['user_id'], + account_id=user["user_id"], ) permissions.append(permission) @@ -1639,19 +1641,19 @@ class DatasetPermissionService: @classmethod def check_permission(cls, user, dataset, requested_permission, requested_partial_member_list): if not user.is_dataset_editor: - raise NoPermissionError('User does not have permission to edit this dataset.') + raise NoPermissionError("User does not have permission to edit this dataset.") if user.is_dataset_operator and dataset.permission != requested_permission: - raise NoPermissionError('Dataset operators cannot change the dataset permissions.') + raise NoPermissionError("Dataset operators cannot change the dataset permissions.") - if user.is_dataset_operator and requested_permission == 'partial_members': + if user.is_dataset_operator and requested_permission == "partial_members": if not requested_partial_member_list: - raise ValueError('Partial member list is required when setting to partial members.') + raise ValueError("Partial member list is required when setting to partial members.") local_member_list = cls.get_dataset_partial_member_list(dataset.id) - request_member_list = [user['user_id'] for user in requested_partial_member_list] + request_member_list = [user["user_id"] for user in requested_partial_member_list] if set(local_member_list) != set(request_member_list): - raise ValueError('Dataset operators cannot change the dataset permissions.') + raise ValueError("Dataset operators cannot change the dataset permissions.") @classmethod def clear_partial_member_list(cls, dataset_id): diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index c483d2815..ddee52164 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -4,15 +4,12 @@ import requests class EnterpriseRequest: - base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL') - secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY') + base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") + secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") @classmethod def send_request(cls, method, endpoint, json=None, params=None): - headers = { - "Content-Type": "application/json", - "Enterprise-Api-Secret-Key": cls.secret_key - } + headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers) diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 6fd72c232..abc01ddf8 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -2,11 +2,10 @@ from services.enterprise.base import EnterpriseRequest class EnterpriseService: - @classmethod def get_info(cls): - return EnterpriseRequest.send_request('GET', '/info') + return EnterpriseRequest.send_request("GET", "/info") @classmethod def get_app_web_sso_enabled(cls, app_code): - return EnterpriseRequest.send_request('GET', f'/app-sso-setting?appCode={app_code}') + return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}") diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index e5e4d7e23..c519f0b0e 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -22,14 +22,16 @@ class CustomConfigurationStatus(Enum): """ Enum class for custom configuration status. """ - ACTIVE = 'active' - NO_CONFIGURE = 'no-configure' + + ACTIVE = "active" + NO_CONFIGURE = "no-configure" class CustomConfigurationResponse(BaseModel): """ Model class for provider custom configuration response. """ + status: CustomConfigurationStatus @@ -37,6 +39,7 @@ class SystemConfigurationResponse(BaseModel): """ Model class for provider system configuration response. """ + enabled: bool current_quota_type: Optional[ProviderQuotaType] = None quota_configurations: list[QuotaConfiguration] = [] @@ -46,6 +49,7 @@ class ProviderResponse(BaseModel): """ Model class for provider response. """ + provider: str label: I18nObject description: Optional[I18nObject] = None @@ -67,18 +71,15 @@ class ProviderResponse(BaseModel): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (dify_config.CONSOLE_API_URL - + f"/console/api/workspaces/current/model-providers/{self.provider}") + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" if self.icon_small is not None: self.icon_small = I18nObject( - en_US=f"{url_prefix}/icon_small/en_US", - zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) if self.icon_large is not None: self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", - zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) @@ -86,6 +87,7 @@ class ProviderWithModelsResponse(BaseModel): """ Model class for provider with models response. """ + provider: str label: I18nObject icon_small: Optional[I18nObject] = None @@ -96,18 +98,15 @@ class ProviderWithModelsResponse(BaseModel): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (dify_config.CONSOLE_API_URL - + f"/console/api/workspaces/current/model-providers/{self.provider}") + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" if self.icon_small is not None: self.icon_small = I18nObject( - en_US=f"{url_prefix}/icon_small/en_US", - zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) if self.icon_large is not None: self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", - zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) @@ -119,18 +118,15 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): def __init__(self, **data) -> None: super().__init__(**data) - url_prefix = (dify_config.CONSOLE_API_URL - + f"/console/api/workspaces/current/model-providers/{self.provider}") + url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}" if self.icon_small is not None: self.icon_small = I18nObject( - en_US=f"{url_prefix}/icon_small/en_US", - zh_Hans=f"{url_prefix}/icon_small/zh_Hans" + en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans" ) if self.icon_large is not None: self.icon_large = I18nObject( - en_US=f"{url_prefix}/icon_large/en_US", - zh_Hans=f"{url_prefix}/icon_large/zh_Hans" + en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) @@ -138,6 +134,7 @@ class DefaultModelResponse(BaseModel): """ Default model entity. """ + model: str model_type: ModelType provider: SimpleProviderEntityResponse @@ -150,6 +147,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity): """ Model with provider entity. """ + provider: SimpleProviderEntityResponse def __init__(self, model: ModelWithProviderEntity) -> None: diff --git a/api/services/errors/account.py b/api/services/errors/account.py index ddc2dbdea..cae31c506 100644 --- a/api/services/errors/account.py +++ b/api/services/errors/account.py @@ -55,4 +55,3 @@ class RoleAlreadyAssignedError(BaseServiceError): class RateLimitExceededError(BaseServiceError): pass - diff --git a/api/services/errors/base.py b/api/services/errors/base.py index f5d41e17f..1fed71cf9 100644 --- a/api/services/errors/base.py +++ b/api/services/errors/base.py @@ -1,3 +1,3 @@ class BaseServiceError(Exception): def __init__(self, description: str = None): - self.description = description \ No newline at end of file + self.description = description diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 57400e595..4d5812c6c 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -6,8 +6,8 @@ from services.enterprise.enterprise_service import EnterpriseService class SubscriptionModel(BaseModel): - plan: str = 'sandbox' - interval: str = '' + plan: str = "sandbox" + interval: str = "" class BillingModel(BaseModel): @@ -27,7 +27,7 @@ class FeatureModel(BaseModel): vector_space: LimitationModel = LimitationModel(size=0, limit=5) annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10) documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50) - docs_processing: str = 'standard' + docs_processing: str = "standard" can_replace_logo: bool = False model_load_balancing_enabled: bool = False dataset_operator_enabled: bool = False @@ -38,13 +38,13 @@ class FeatureModel(BaseModel): class SystemFeatureModel(BaseModel): sso_enforced_for_signin: bool = False - sso_enforced_for_signin_protocol: str = '' + sso_enforced_for_signin_protocol: str = "" sso_enforced_for_web: bool = False - sso_enforced_for_web_protocol: str = '' + sso_enforced_for_web_protocol: str = "" enable_web_sso_switch_component: bool = False -class FeatureService: +class FeatureService: @classmethod def get_features(cls, tenant_id: str) -> FeatureModel: features = FeatureModel() @@ -76,44 +76,44 @@ class FeatureService: def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): billing_info = BillingService.get_info(tenant_id) - features.billing.enabled = billing_info['enabled'] - features.billing.subscription.plan = billing_info['subscription']['plan'] - features.billing.subscription.interval = billing_info['subscription']['interval'] + features.billing.enabled = billing_info["enabled"] + features.billing.subscription.plan = billing_info["subscription"]["plan"] + features.billing.subscription.interval = billing_info["subscription"]["interval"] - if 'members' in billing_info: - features.members.size = billing_info['members']['size'] - features.members.limit = billing_info['members']['limit'] + if "members" in billing_info: + features.members.size = billing_info["members"]["size"] + features.members.limit = billing_info["members"]["limit"] - if 'apps' in billing_info: - features.apps.size = billing_info['apps']['size'] - features.apps.limit = billing_info['apps']['limit'] + if "apps" in billing_info: + features.apps.size = billing_info["apps"]["size"] + features.apps.limit = billing_info["apps"]["limit"] - if 'vector_space' in billing_info: - features.vector_space.size = billing_info['vector_space']['size'] - features.vector_space.limit = billing_info['vector_space']['limit'] + if "vector_space" in billing_info: + features.vector_space.size = billing_info["vector_space"]["size"] + features.vector_space.limit = billing_info["vector_space"]["limit"] - if 'documents_upload_quota' in billing_info: - features.documents_upload_quota.size = billing_info['documents_upload_quota']['size'] - features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit'] + if "documents_upload_quota" in billing_info: + features.documents_upload_quota.size = billing_info["documents_upload_quota"]["size"] + features.documents_upload_quota.limit = billing_info["documents_upload_quota"]["limit"] - if 'annotation_quota_limit' in billing_info: - features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size'] - features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit'] + if "annotation_quota_limit" in billing_info: + features.annotation_quota_limit.size = billing_info["annotation_quota_limit"]["size"] + features.annotation_quota_limit.limit = billing_info["annotation_quota_limit"]["limit"] - if 'docs_processing' in billing_info: - features.docs_processing = billing_info['docs_processing'] + if "docs_processing" in billing_info: + features.docs_processing = billing_info["docs_processing"] - if 'can_replace_logo' in billing_info: - features.can_replace_logo = billing_info['can_replace_logo'] + if "can_replace_logo" in billing_info: + features.can_replace_logo = billing_info["can_replace_logo"] - if 'model_load_balancing_enabled' in billing_info: - features.model_load_balancing_enabled = billing_info['model_load_balancing_enabled'] + if "model_load_balancing_enabled" in billing_info: + features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"] @classmethod def _fulfill_params_from_enterprise(cls, features): enterprise_info = EnterpriseService.get_info() - features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin'] - features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol'] - features.sso_enforced_for_web = enterprise_info['sso_enforced_for_web'] - features.sso_enforced_for_web_protocol = enterprise_info['sso_enforced_for_web_protocol'] + features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"] + features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"] + features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"] + features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"] diff --git a/api/services/file_service.py b/api/services/file_service.py index 913996224..5780abb2b 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -17,27 +17,45 @@ from models.account import Account from models.model import EndUser, UploadFile from services.errors.file import FileTooLargeError, UnsupportedFileTypeError -IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] +IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) -ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', 'docx', 'csv'] -UNSTRUCTURED_ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', - 'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml', 'epub'] +ALLOWED_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"] +UNSTRUCTURED_ALLOWED_EXTENSIONS = [ + "txt", + "markdown", + "md", + "pdf", + "html", + "htm", + "xlsx", + "xls", + "docx", + "csv", + "eml", + "msg", + "pptx", + "ppt", + "xml", + "epub", +] PREVIEW_WORDS_LIMIT = 3000 class FileService: - @staticmethod def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile: filename = file.filename - extension = file.filename.split('.')[-1] + extension = file.filename.split(".")[-1] if len(filename) > 200: - filename = filename.split('.')[0][:200] + '.' + extension + filename = filename.split(".")[0][:200] + "." + extension etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \ + allowed_extensions = ( + UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS + if etl_type == "Unstructured" else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS + ) if extension.lower() not in allowed_extensions: raise UnsupportedFileTypeError() elif only_image and extension.lower() not in IMAGE_EXTENSIONS: @@ -55,7 +73,7 @@ class FileService: file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024 if file_size > file_size_limit: - message = f'File size exceeded. {file_size} > {file_size_limit}' + message = f"File size exceeded. {file_size} > {file_size_limit}" raise FileTooLargeError(message) # user uuid as file name @@ -67,7 +85,7 @@ class FileService: # end_user current_tenant_id = user.tenant_id - file_key = 'upload_files/' + current_tenant_id + '/' + file_uuid + '.' + extension + file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension # save file to storage storage.save(file_key, file_content) @@ -81,11 +99,11 @@ class FileService: size=file_size, extension=extension, mime_type=file.mimetype, - created_by_role=('account' if isinstance(user, Account) else 'end_user'), + created_by_role=("account" if isinstance(user, Account) else "end_user"), created_by=user.id, created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=False, - hash=hashlib.sha3_256(file_content).hexdigest() + hash=hashlib.sha3_256(file_content).hexdigest(), ) db.session.add(upload_file) @@ -99,10 +117,10 @@ class FileService: text_name = text_name[:200] # user uuid as file name file_uuid = str(uuid.uuid4()) - file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt' + file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt" # save file to storage - storage.save(file_key, text.encode('utf-8')) + storage.save(file_key, text.encode("utf-8")) # save file to db upload_file = UploadFile( @@ -111,13 +129,13 @@ class FileService: key=file_key, name=text_name, size=len(text), - extension='txt', - mime_type='text/plain', + extension="txt", + mime_type="text/plain", created_by=current_user.id, created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), used=True, used_by=current_user.id, - used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), ) db.session.add(upload_file) @@ -127,9 +145,7 @@ class FileService: @staticmethod def get_file_preview(file_id: str) -> str: - upload_file = db.session.query(UploadFile) \ - .filter(UploadFile.id == file_id) \ - .first() + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found") @@ -137,12 +153,12 @@ class FileService: # extract text from file extension = upload_file.extension etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS + allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS if extension.lower() not in allowed_extensions: raise UnsupportedFileTypeError() text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True) - text = text[0:PREVIEW_WORDS_LIMIT] if text else '' + text = text[0:PREVIEW_WORDS_LIMIT] if text else "" return text @@ -152,9 +168,7 @@ class FileService: if not result: raise NotFound("File not found or signature is invalid") - upload_file = db.session.query(UploadFile) \ - .filter(UploadFile.id == file_id) \ - .first() + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") @@ -170,9 +184,7 @@ class FileService: @staticmethod def get_public_image_preview(file_id: str) -> tuple[Generator, str]: - upload_file = db.session.query(UploadFile) \ - .filter(UploadFile.id == file_id) \ - .first() + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index de5f6994b..db9906481 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -9,14 +9,11 @@ from models.account import Account from models.dataset import Dataset, DatasetQuery, DocumentSegment default_retrieval_model = { - 'search_method': RetrievalMethod.SEMANTIC_SEARCH.value, - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False + "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 2, + "score_threshold_enabled": False, } @@ -27,9 +24,9 @@ class HitTestingService: return { "query": { "content": query, - "tsne_position": {'x': 0, 'y': 0}, + "tsne_position": {"x": 0, "y": 0}, }, - "records": [] + "records": [], } start = time.perf_counter() @@ -38,28 +35,28 @@ class HitTestingService: if not retrieval_model: retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model - all_documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'), - dataset_id=dataset.id, - query=cls.escape_query_for_search(query), - top_k=retrieval_model.get('top_k', 2), - score_threshold=retrieval_model.get('score_threshold', .0) - if retrieval_model['score_threshold_enabled'] else None, - reranking_model=retrieval_model.get('reranking_model', None) - if retrieval_model['reranking_enable'] else None, - reranking_mode=retrieval_model.get('reranking_mode') - if retrieval_model.get('reranking_mode') else 'reranking_model', - weights=retrieval_model.get('weights', None), - ) + all_documents = RetrievalService.retrieve( + retrival_method=retrieval_model.get("search_method", "semantic_search"), + dataset_id=dataset.id, + query=cls.escape_query_for_search(query), + top_k=retrieval_model.get("top_k", 2), + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else None, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") + if retrieval_model.get("reranking_mode") + else "reranking_model", + weights=retrieval_model.get("weights", None), + ) end = time.perf_counter() logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds") dataset_query = DatasetQuery( - dataset_id=dataset.id, - content=query, - source='hit_testing', - created_by_role='account', - created_by=account.id + dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id ) db.session.add(dataset_query) @@ -72,14 +69,18 @@ class HitTestingService: i = 0 records = [] for document in documents: - index_node_id = document.metadata['doc_id'] + index_node_id = document.metadata["doc_id"] - segment = db.session.query(DocumentSegment).filter( - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.enabled == True, - DocumentSegment.status == 'completed', - DocumentSegment.index_node_id == index_node_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.enabled == True, + DocumentSegment.status == "completed", + DocumentSegment.index_node_id == index_node_id, + ) + .first() + ) if not segment: i += 1 @@ -87,7 +88,7 @@ class HitTestingService: record = { "segment": segment, - "score": document.metadata.get('score', None), + "score": document.metadata.get("score", None), } records.append(record) @@ -98,15 +99,15 @@ class HitTestingService: "query": { "content": query, }, - "records": records + "records": records, } @classmethod def hit_testing_args_check(cls, args): - query = args['query'] + query = args["query"] if not query or len(query) > 250: - raise ValueError('Query is required and cannot exceed 250 characters') + raise ValueError("Query is required and cannot exceed 250 characters") @staticmethod def escape_query_for_search(query: str) -> str: diff --git a/api/services/message_service.py b/api/services/message_service.py index 491a914c7..ecb121c36 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -27,8 +27,14 @@ from services.workflow_service import WorkflowService class MessageService: @classmethod - def pagination_by_first_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - conversation_id: str, first_id: Optional[str], limit: int) -> InfiniteScrollPagination: + def pagination_by_first_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + conversation_id: str, + first_id: Optional[str], + limit: int, + ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -36,52 +42,69 @@ class MessageService: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) conversation = ConversationService.get_conversation( - app_model=app_model, - user=user, - conversation_id=conversation_id + app_model=app_model, user=user, conversation_id=conversation_id ) if first_id: - first_message = db.session.query(Message) \ - .filter(Message.conversation_id == conversation.id, Message.id == first_id).first() + first_message = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id, Message.id == first_id) + .first() + ) if not first_message: raise FirstMessageNotExistsError() - history_messages = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < first_message.created_at, - Message.id != first_message.id - ) \ - .order_by(Message.created_at.desc()).limit(limit).all() + history_messages = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < first_message.created_at, + Message.id != first_message.id, + ) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) else: - history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \ - .order_by(Message.created_at.desc()).limit(limit).all() + history_messages = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) has_more = False if len(history_messages) == limit: current_page_first_message = history_messages[-1] - rest_count = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < current_page_first_message.created_at, - Message.id != current_page_first_message.id - ).count() + rest_count = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < current_page_first_message.created_at, + Message.id != current_page_first_message.id, + ) + .count() + ) if rest_count > 0: has_more = True history_messages = list(reversed(history_messages)) - return InfiniteScrollPagination( - data=history_messages, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int, conversation_id: Optional[str] = None, - include_ids: Optional[list] = None) -> InfiniteScrollPagination: + def pagination_by_last_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + conversation_id: Optional[str] = None, + include_ids: Optional[list] = None, + ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) @@ -89,9 +112,7 @@ class MessageService: if conversation_id is not None: conversation = ConversationService.get_conversation( - app_model=app_model, - user=user, - conversation_id=conversation_id + app_model=app_model, user=user, conversation_id=conversation_id ) base_query = base_query.filter(Message.conversation_id == conversation.id) @@ -105,10 +126,12 @@ class MessageService: if not last_message: raise LastMessageNotExistsError() - history_messages = base_query.filter( - Message.created_at < last_message.created_at, - Message.id != last_message.id - ).order_by(Message.created_at.desc()).limit(limit).all() + history_messages = ( + base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id) + .order_by(Message.created_at.desc()) + .limit(limit) + .all() + ) else: history_messages = base_query.order_by(Message.created_at.desc()).limit(limit).all() @@ -116,30 +139,22 @@ class MessageService: if len(history_messages) == limit: current_page_first_message = history_messages[-1] rest_count = base_query.filter( - Message.created_at < current_page_first_message.created_at, - Message.id != current_page_first_message.id + Message.created_at < current_page_first_message.created_at, Message.id != current_page_first_message.id ).count() if rest_count > 0: has_more = True - return InfiniteScrollPagination( - data=history_messages, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more) @classmethod - def create_feedback(cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], - rating: Optional[str]) -> MessageFeedback: + def create_feedback( + cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str] + ) -> MessageFeedback: if not user: - raise ValueError('user cannot be None') + raise ValueError("user cannot be None") - message = cls.get_message( - app_model=app_model, - user=user, - message_id=message_id - ) + message = cls.get_message(app_model=app_model, user=user, message_id=message_id) feedback = message.user_feedback if isinstance(user, EndUser) else message.admin_feedback @@ -148,14 +163,14 @@ class MessageService: elif rating and feedback: feedback.rating = rating elif not rating and not feedback: - raise ValueError('rating cannot be None when feedback not exists') + raise ValueError("rating cannot be None when feedback not exists") else: feedback = MessageFeedback( app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, rating=rating, - from_source=('user' if isinstance(user, EndUser) else 'admin'), + from_source=("user" if isinstance(user, EndUser) else "admin"), from_end_user_id=(user.id if isinstance(user, EndUser) else None), from_account_id=(user.id if isinstance(user, Account) else None), ) @@ -167,13 +182,17 @@ class MessageService: @classmethod def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - Message.from_source == ('api' if isinstance(user, EndUser) else 'console'), - Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), - Message.from_account_id == (user.id if isinstance(user, Account) else None), - ).first() + message = ( + db.session.query(Message) + .filter( + Message.id == message_id, + Message.app_id == app_model.id, + Message.from_source == ("api" if isinstance(user, EndUser) else "console"), + Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None), + Message.from_account_id == (user.id if isinstance(user, Account) else None), + ) + .first() + ) if not message: raise MessageNotExistsError() @@ -181,27 +200,22 @@ class MessageService: return message @classmethod - def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Union[Account, EndUser]], - message_id: str, invoke_from: InvokeFrom) -> list[Message]: + def get_suggested_questions_after_answer( + cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom + ) -> list[Message]: if not user: - raise ValueError('user cannot be None') + raise ValueError("user cannot be None") - message = cls.get_message( - app_model=app_model, - user=user, - message_id=message_id - ) + message = cls.get_message(app_model=app_model, user=user, message_id=message_id) conversation = ConversationService.get_conversation( - app_model=app_model, - conversation_id=message.conversation_id, - user=user + app_model=app_model, conversation_id=message.conversation_id, user=user ) if not conversation: raise ConversationNotExistsError() - if conversation.status != 'normal': + if conversation.status != "normal": raise ConversationCompletedError() model_manager = ModelManager() @@ -216,24 +230,23 @@ class MessageService: if workflow is None: return [] - app_config = AdvancedChatAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow - ) + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) if not app_config.additional_features.suggested_questions_after_answer: raise SuggestedQuestionsAfterAnswerDisabledError() model_instance = model_manager.get_default_model_instance( - tenant_id=app_model.tenant_id, - model_type=ModelType.LLM + tenant_id=app_model.tenant_id, model_type=ModelType.LLM ) else: if not conversation.override_model_configs: - app_model_config = db.session.query(AppModelConfig).filter( - AppModelConfig.id == conversation.app_model_config_id, - AppModelConfig.app_id == app_model.id - ).first() + app_model_config = ( + db.session.query(AppModelConfig) + .filter( + AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id + ) + .first() + ) else: conversation_override_model_configs = json.loads(conversation.override_model_configs) app_model_config = AppModelConfig( @@ -249,16 +262,13 @@ class MessageService: model_instance = model_manager.get_model_instance( tenant_id=app_model.tenant_id, - provider=app_model_config.model_dict['provider'], + provider=app_model_config.model_dict["provider"], model_type=ModelType.LLM, - model=app_model_config.model_dict['name'] + model=app_model_config.model_dict["name"], ) # get memory of conversation (read-only) - memory = TokenBufferMemory( - conversation=conversation, - model_instance=model_instance - ) + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) histories = memory.get_history_prompt_text( max_token_limit=3000, @@ -267,18 +277,14 @@ class MessageService: with measure_time() as timer: questions = LLMGenerator.generate_suggested_questions_after_answer( - tenant_id=app_model.tenant_id, - histories=histories + tenant_id=app_model.tenant_id, histories=histories ) # get tracing instance trace_manager = TraceQueueManager(app_id=app_model.id) trace_manager.add_trace_task( TraceTask( - TraceTaskName.SUGGESTED_QUESTION_TRACE, - message_id=message_id, - suggested_question=questions, - timer=timer + TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer ) ) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 80eb72140..e7b9422cf 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -23,7 +23,6 @@ logger = logging.getLogger(__name__) class ModelLoadBalancingService: - def __init__(self) -> None: self.provider_manager = ProviderManager() @@ -46,10 +45,7 @@ class ModelLoadBalancingService: raise ValueError(f"Provider {provider} does not exist.") # Enable model load balancing - provider_configuration.enable_model_load_balancing( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: """ @@ -70,13 +66,11 @@ class ModelLoadBalancingService: raise ValueError(f"Provider {provider} does not exist.") # disable model load balancing - provider_configuration.disable_model_load_balancing( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type)) - def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, model_type: str) \ - -> tuple[bool, list[dict]]: + def get_load_balancing_configs( + self, tenant_id: str, provider: str, model: str, model_type: str + ) -> tuple[bool, list[dict]]: """ Get load balancing configurations. :param tenant_id: workspace id @@ -107,20 +101,24 @@ class ModelLoadBalancingService: is_load_balancing_enabled = True # Get load balancing configurations - load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ + load_balancing_configs = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model - ).order_by(LoadBalancingModelConfig.created_at).all() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .order_by(LoadBalancingModelConfig.created_at) + .all() + ) if provider_configuration.custom_configuration.provider: # check if the inherit configuration exists, # inherit is represented for the provider or model custom credentials inherit_config_exists = False for load_balancing_config in load_balancing_configs: - if load_balancing_config.name == '__inherit__': + if load_balancing_config.name == "__inherit__": inherit_config_exists = True break @@ -133,7 +131,7 @@ class ModelLoadBalancingService: else: # move the inherit configuration to the first for i, load_balancing_config in enumerate(load_balancing_configs[:]): - if load_balancing_config.name == '__inherit__': + if load_balancing_config.name == "__inherit__": inherit_config = load_balancing_configs.pop(i) load_balancing_configs.insert(0, inherit_config) @@ -151,7 +149,7 @@ class ModelLoadBalancingService: provider=provider, model=model, model_type=model_type, - config_id=load_balancing_config.id + config_id=load_balancing_config.id, ) try: @@ -172,32 +170,32 @@ class ModelLoadBalancingService: if variable in credentials: try: credentials[variable] = encrypter.decrypt_token_with_decoding( - credentials.get(variable), - decoding_rsa_key, - decoding_cipher_rsa + credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa ) except ValueError: pass # Obfuscate credentials credentials = provider_configuration.obfuscated_credentials( - credentials=credentials, - credential_form_schemas=credential_schemas.credential_form_schemas + credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas ) - datas.append({ - 'id': load_balancing_config.id, - 'name': load_balancing_config.name, - 'credentials': credentials, - 'enabled': load_balancing_config.enabled, - 'in_cooldown': in_cooldown, - 'ttl': ttl - }) + datas.append( + { + "id": load_balancing_config.id, + "name": load_balancing_config.name, + "credentials": credentials, + "enabled": load_balancing_config.enabled, + "in_cooldown": in_cooldown, + "ttl": ttl, + } + ) return is_load_balancing_enabled, datas - def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str) \ - -> Optional[dict]: + def get_load_balancing_config( + self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str + ) -> Optional[dict]: """ Get load balancing configuration. :param tenant_id: workspace id @@ -219,14 +217,17 @@ class ModelLoadBalancingService: model_type = ModelType.value_of(model_type) # Get load balancing configurations - load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \ + load_balancing_model_config = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model, - LoadBalancingModelConfig.id == config_id - ).first() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + LoadBalancingModelConfig.id == config_id, + ) + .first() + ) if not load_balancing_model_config: return None @@ -244,19 +245,19 @@ class ModelLoadBalancingService: # Obfuscate credentials credentials = provider_configuration.obfuscated_credentials( - credentials=credentials, - credential_form_schemas=credential_schemas.credential_form_schemas + credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas ) return { - 'id': load_balancing_model_config.id, - 'name': load_balancing_model_config.name, - 'credentials': credentials, - 'enabled': load_balancing_model_config.enabled + "id": load_balancing_model_config.id, + "name": load_balancing_model_config.name, + "credentials": credentials, + "enabled": load_balancing_model_config.enabled, } - def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_type: ModelType) \ - -> LoadBalancingModelConfig: + def _init_inherit_config( + self, tenant_id: str, provider: str, model: str, model_type: ModelType + ) -> LoadBalancingModelConfig: """ Initialize the inherit configuration. :param tenant_id: workspace id @@ -271,18 +272,16 @@ class ModelLoadBalancingService: provider_name=provider, model_type=model_type.to_origin_model_type(), model_name=model, - name='__inherit__' + name="__inherit__", ) db.session.add(inherit_config) db.session.commit() return inherit_config - def update_load_balancing_configs(self, tenant_id: str, - provider: str, - model: str, - model_type: str, - configs: list[dict]) -> None: + def update_load_balancing_configs( + self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict] + ) -> None: """ Update load balancing configurations. :param tenant_id: workspace id @@ -304,15 +303,18 @@ class ModelLoadBalancingService: model_type = ModelType.value_of(model_type) if not isinstance(configs, list): - raise ValueError('Invalid load balancing configs') + raise ValueError("Invalid load balancing configs") - current_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \ + current_load_balancing_configs = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model - ).all() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + ) + .all() + ) # id as key, config as value current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs} @@ -320,25 +322,25 @@ class ModelLoadBalancingService: for config in configs: if not isinstance(config, dict): - raise ValueError('Invalid load balancing config') + raise ValueError("Invalid load balancing config") - config_id = config.get('id') - name = config.get('name') - credentials = config.get('credentials') - enabled = config.get('enabled') + config_id = config.get("id") + name = config.get("name") + credentials = config.get("credentials") + enabled = config.get("enabled") if not name: - raise ValueError('Invalid load balancing config name') + raise ValueError("Invalid load balancing config name") if enabled is None: - raise ValueError('Invalid load balancing config enabled') + raise ValueError("Invalid load balancing config enabled") # is config exists if config_id: config_id = str(config_id) if config_id not in current_load_balancing_configs_dict: - raise ValueError('Invalid load balancing config id: {}'.format(config_id)) + raise ValueError("Invalid load balancing config id: {}".format(config_id)) updated_config_ids.add(config_id) @@ -347,11 +349,11 @@ class ModelLoadBalancingService: # check duplicate name for current_load_balancing_config in current_load_balancing_configs: if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name: - raise ValueError('Load balancing config name {} already exists'.format(name)) + raise ValueError("Load balancing config name {} already exists".format(name)) if credentials: if not isinstance(credentials, dict): - raise ValueError('Invalid load balancing config credentials') + raise ValueError("Invalid load balancing config credentials") # validate custom provider config credentials = self._custom_credentials_validate( @@ -361,7 +363,7 @@ class ModelLoadBalancingService: model=model, credentials=credentials, load_balancing_model_config=load_balancing_config, - validate=False + validate=False, ) # update load balancing config @@ -375,19 +377,19 @@ class ModelLoadBalancingService: self._clear_credentials_cache(tenant_id, config_id) else: # create load balancing config - if name == '__inherit__': - raise ValueError('Invalid load balancing config name') + if name == "__inherit__": + raise ValueError("Invalid load balancing config name") # check duplicate name for current_load_balancing_config in current_load_balancing_configs: if current_load_balancing_config.name == name: - raise ValueError('Load balancing config name {} already exists'.format(name)) + raise ValueError("Load balancing config name {} already exists".format(name)) if not credentials: - raise ValueError('Invalid load balancing config credentials') + raise ValueError("Invalid load balancing config credentials") if not isinstance(credentials, dict): - raise ValueError('Invalid load balancing config credentials') + raise ValueError("Invalid load balancing config credentials") # validate custom provider config credentials = self._custom_credentials_validate( @@ -396,7 +398,7 @@ class ModelLoadBalancingService: model_type=model_type, model=model, credentials=credentials, - validate=False + validate=False, ) # create load balancing config @@ -406,7 +408,7 @@ class ModelLoadBalancingService: model_type=model_type.to_origin_model_type(), model_name=model, name=name, - encrypted_config=json.dumps(credentials) + encrypted_config=json.dumps(credentials), ) db.session.add(load_balancing_model_config) @@ -420,12 +422,15 @@ class ModelLoadBalancingService: self._clear_credentials_cache(tenant_id, config_id) - def validate_load_balancing_credentials(self, tenant_id: str, - provider: str, - model: str, - model_type: str, - credentials: dict, - config_id: Optional[str] = None) -> None: + def validate_load_balancing_credentials( + self, + tenant_id: str, + provider: str, + model: str, + model_type: str, + credentials: dict, + config_id: Optional[str] = None, + ) -> None: """ Validate load balancing credentials. :param tenant_id: workspace id @@ -450,14 +455,17 @@ class ModelLoadBalancingService: load_balancing_model_config = None if config_id: # Get load balancing config - load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \ + load_balancing_model_config = ( + db.session.query(LoadBalancingModelConfig) .filter( - LoadBalancingModelConfig.tenant_id == tenant_id, - LoadBalancingModelConfig.provider_name == provider, - LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), - LoadBalancingModelConfig.model_name == model, - LoadBalancingModelConfig.id == config_id - ).first() + LoadBalancingModelConfig.tenant_id == tenant_id, + LoadBalancingModelConfig.provider_name == provider, + LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), + LoadBalancingModelConfig.model_name == model, + LoadBalancingModelConfig.id == config_id, + ) + .first() + ) if not load_balancing_model_config: raise ValueError(f"Load balancing config {config_id} does not exist.") @@ -469,16 +477,19 @@ class ModelLoadBalancingService: model_type=model_type, model=model, credentials=credentials, - load_balancing_model_config=load_balancing_model_config + load_balancing_model_config=load_balancing_model_config, ) - def _custom_credentials_validate(self, tenant_id: str, - provider_configuration: ProviderConfiguration, - model_type: ModelType, - model: str, - credentials: dict, - load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, - validate: bool = True) -> dict: + def _custom_credentials_validate( + self, + tenant_id: str, + provider_configuration: ProviderConfiguration, + model_type: ModelType, + model: str, + credentials: dict, + load_balancing_model_config: Optional[LoadBalancingModelConfig] = None, + validate: bool = True, + ) -> dict: """ Validate custom credentials. :param tenant_id: workspace id @@ -521,12 +532,11 @@ class ModelLoadBalancingService: provider=provider_configuration.provider.provider, model_type=model_type, model=model, - credentials=credentials + credentials=credentials, ) else: credentials = model_provider_factory.provider_credentials_validate( - provider=provider_configuration.provider.provider, - credentials=credentials + provider=provider_configuration.provider.provider, credentials=credentials ) for key, value in credentials.items(): @@ -535,8 +545,9 @@ class ModelLoadBalancingService: return credentials - def _get_credential_schema(self, provider_configuration: ProviderConfiguration) \ - -> ModelCredentialSchema | ProviderCredentialSchema: + def _get_credential_schema( + self, provider_configuration: ProviderConfiguration + ) -> ModelCredentialSchema | ProviderCredentialSchema: """ Get form schemas. :param provider_configuration: provider configuration @@ -558,9 +569,7 @@ class ModelLoadBalancingService: :return: """ provider_model_credentials_cache = ProviderCredentialsCache( - tenant_id=tenant_id, - identity_id=config_id, - cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL + tenant_id=tenant_id, identity_id=config_id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL ) provider_model_credentials_cache.delete() diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 86c059fe9..c0f3c4076 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -73,8 +73,8 @@ class ModelProviderService: system_configuration=SystemConfigurationResponse( enabled=provider_configuration.system_configuration.enabled, current_quota_type=provider_configuration.system_configuration.current_quota_type, - quota_configurations=provider_configuration.system_configuration.quota_configurations - ) + quota_configurations=provider_configuration.system_configuration.quota_configurations, + ), ) provider_responses.append(provider_response) @@ -95,9 +95,9 @@ class ModelProviderService: provider_configurations = self.provider_manager.get_configurations(tenant_id) # Get provider available models - return [ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models( - provider=provider - )] + return [ + ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider) + ] def get_provider_credentials(self, tenant_id: str, provider: str) -> dict: """ @@ -195,13 +195,12 @@ class ModelProviderService: # Get model custom credentials from ProviderModel if exists return provider_configuration.get_custom_model_credentials( - model_type=ModelType.value_of(model_type), - model=model, - obfuscated=True + model_type=ModelType.value_of(model_type), model=model, obfuscated=True ) - def model_credentials_validate(self, tenant_id: str, provider: str, model_type: str, model: str, - credentials: dict) -> None: + def model_credentials_validate( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict + ) -> None: """ validate model credentials. @@ -222,13 +221,12 @@ class ModelProviderService: # Validate model credentials provider_configuration.custom_model_credentials_validate( - model_type=ModelType.value_of(model_type), - model=model, - credentials=credentials + model_type=ModelType.value_of(model_type), model=model, credentials=credentials ) - def save_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str, - credentials: dict) -> None: + def save_model_credentials( + self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict + ) -> None: """ save model credentials. @@ -249,9 +247,7 @@ class ModelProviderService: # Add or update custom model credentials provider_configuration.add_or_update_custom_model_credentials( - model_type=ModelType.value_of(model_type), - model=model, - credentials=credentials + model_type=ModelType.value_of(model_type), model=model, credentials=credentials ) def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None: @@ -273,10 +269,7 @@ class ModelProviderService: raise ValueError(f"Provider {provider} does not exist.") # Remove custom model credentials - provider_configuration.delete_custom_model_credentials( - model_type=ModelType.value_of(model_type), - model=model - ) + provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model) def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]: """ @@ -290,9 +283,7 @@ class ModelProviderService: provider_configurations = self.provider_manager.get_configurations(tenant_id) # Get provider available models - models = provider_configurations.get_models( - model_type=ModelType.value_of(model_type) - ) + models = provider_configurations.get_models(model_type=ModelType.value_of(model_type)) # Group models by provider provider_models = {} @@ -323,16 +314,19 @@ class ModelProviderService: icon_small=first_model.provider.icon_small, icon_large=first_model.provider.icon_large, status=CustomConfigurationStatus.ACTIVE, - models=[ProviderModelWithStatusEntity( - model=model.model, - label=model.label, - model_type=model.model_type, - features=model.features, - fetch_from=model.fetch_from, - model_properties=model.model_properties, - status=model.status, - load_balancing_enabled=model.load_balancing_enabled - ) for model in models] + models=[ + ProviderModelWithStatusEntity( + model=model.model, + label=model.label, + model_type=model.model_type, + features=model.features, + fetch_from=model.fetch_from, + model_properties=model.model_properties, + status=model.status, + load_balancing_enabled=model.load_balancing_enabled, + ) + for model in models + ], ) ) @@ -361,19 +355,13 @@ class ModelProviderService: model_type_instance = cast(LargeLanguageModel, model_type_instance) # fetch credentials - credentials = provider_configuration.get_current_credentials( - model_type=ModelType.LLM, - model=model - ) + credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model) if not credentials: return [] # Call get_parameter_rules method of model instance to get model parameter rules - return model_type_instance.get_parameter_rules( - model=model, - credentials=credentials - ) + return model_type_instance.get_parameter_rules(model=model, credentials=credentials) def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]: """ @@ -384,22 +372,23 @@ class ModelProviderService: :return: """ model_type_enum = ModelType.value_of(model_type) - result = self.provider_manager.get_default_model( - tenant_id=tenant_id, - model_type=model_type_enum - ) + result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum) try: - return DefaultModelResponse( - model=result.model, - model_type=result.model_type, - provider=SimpleProviderEntityResponse( - provider=result.provider.provider, - label=result.provider.label, - icon_small=result.provider.icon_small, - icon_large=result.provider.icon_large, - supported_model_types=result.provider.supported_model_types + return ( + DefaultModelResponse( + model=result.model, + model_type=result.model_type, + provider=SimpleProviderEntityResponse( + provider=result.provider.provider, + label=result.provider.label, + icon_small=result.provider.icon_small, + icon_large=result.provider.icon_large, + supported_model_types=result.provider.supported_model_types, + ), ) - ) if result else None + if result + else None + ) except Exception as e: logger.info(f"get_default_model_of_model_type error: {e}") return None @@ -416,13 +405,12 @@ class ModelProviderService: """ model_type_enum = ModelType.value_of(model_type) self.provider_manager.update_default_model_record( - tenant_id=tenant_id, - model_type=model_type_enum, - provider=provider, - model=model + tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model ) - def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[Optional[bytes], Optional[str]]: + def get_model_provider_icon( + self, provider: str, icon_type: str, lang: str + ) -> tuple[Optional[bytes], Optional[str]]: """ get model provider icon. @@ -434,11 +422,11 @@ class ModelProviderService: provider_instance = model_provider_factory.get_provider_instance(provider) provider_schema = provider_instance.get_provider_schema() - if icon_type.lower() == 'icon_small': + if icon_type.lower() == "icon_small": if not provider_schema.icon_small: raise ValueError(f"Provider {provider} does not have small icon.") - if lang.lower() == 'zh_hans': + if lang.lower() == "zh_hans": file_name = provider_schema.icon_small.zh_Hans else: file_name = provider_schema.icon_small.en_US @@ -446,13 +434,15 @@ class ModelProviderService: if not provider_schema.icon_large: raise ValueError(f"Provider {provider} does not have large icon.") - if lang.lower() == 'zh_hans': + if lang.lower() == "zh_hans": file_name = provider_schema.icon_large.zh_Hans else: file_name = provider_schema.icon_large.en_US root_path = current_app.root_path - provider_instance_path = os.path.dirname(os.path.join(root_path, provider_instance.__class__.__module__.replace('.', '/'))) + provider_instance_path = os.path.dirname( + os.path.join(root_path, provider_instance.__class__.__module__.replace(".", "/")) + ) file_path = os.path.join(provider_instance_path, "_assets") file_path = os.path.join(file_path, file_name) @@ -460,10 +450,10 @@ class ModelProviderService: return None, None mimetype, _ = mimetypes.guess_type(file_path) - mimetype = mimetype or 'application/octet-stream' + mimetype = mimetype or "application/octet-stream" # read binary from file - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: byte_data = f.read() return byte_data, mimetype @@ -509,10 +499,7 @@ class ModelProviderService: raise ValueError(f"Provider {provider} does not exist.") # Enable model - provider_configuration.enable_model( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type)) def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None: """ @@ -533,78 +520,49 @@ class ModelProviderService: raise ValueError(f"Provider {provider} does not exist.") # Enable model - provider_configuration.disable_model( - model=model, - model_type=ModelType.value_of(model_type) - ) + provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type)) def free_quota_submit(self, tenant_id: str, provider: str): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") - api_url = api_base_url + '/api/v1/providers/apply' + api_url = api_base_url + "/api/v1/providers/apply" - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {api_key}" - } - response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider}) + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + response = requests.post(api_url, headers=headers, json={"workspace_id": tenant_id, "provider_name": provider}) if not response.ok: logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") raise ValueError(f"Error: {response.status_code} ") - if response.json()["code"] != 'success': - raise ValueError( - f"error: {response.json()['message']}" - ) + if response.json()["code"] != "success": + raise ValueError(f"error: {response.json()['message']}") rst = response.json() - if rst['type'] == 'redirect': - return { - 'type': rst['type'], - 'redirect_url': rst['redirect_url'] - } + if rst["type"] == "redirect": + return {"type": rst["type"], "redirect_url": rst["redirect_url"]} else: - return { - 'type': rst['type'], - 'result': 'success' - } + return {"type": rst["type"], "result": "success"} def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]): api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY") api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL") - api_url = api_base_url + '/api/v1/providers/qualification-verify' + api_url = api_base_url + "/api/v1/providers/qualification-verify" - headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {api_key}" - } - json_data = {'workspace_id': tenant_id, 'provider_name': provider} + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} + json_data = {"workspace_id": tenant_id, "provider_name": provider} if token: - json_data['token'] = token - response = requests.post(api_url, headers=headers, - json=json_data) + json_data["token"] = token + response = requests.post(api_url, headers=headers, json=json_data) if not response.ok: logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ") raise ValueError(f"Error: {response.status_code} ") rst = response.json() - if rst["code"] != 'success': - raise ValueError( - f"error: {rst['message']}" - ) + if rst["code"] != "success": + raise ValueError(f"error: {rst['message']}") - data = rst['data'] - if data['qualified'] is True: - return { - 'result': 'success', - 'provider_name': provider, - 'flag': True - } + data = rst["data"] + if data["qualified"] is True: + return {"result": "success", "provider_name": provider, "flag": True} else: - return { - 'result': 'success', - 'provider_name': provider, - 'flag': False, - 'reason': data['reason'] - } + return {"result": "success", "provider_name": provider, "flag": False, "reason": data["reason"]} diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py index d472f8cfb..dfb21e767 100644 --- a/api/services/moderation_service.py +++ b/api/services/moderation_service.py @@ -4,17 +4,18 @@ from models.model import App, AppModelConfig class ModerationService: - def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult: app_model_config: AppModelConfig = None - app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + app_model_config = ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + ) if not app_model_config: raise ValueError("app model config not found") - name = app_model_config.sensitive_word_avoidance_dict['type'] - config = app_model_config.sensitive_word_avoidance_dict['config'] + name = app_model_config.sensitive_word_avoidance_dict["type"] + config = app_model_config.sensitive_word_avoidance_dict["config"] moderation = ModerationFactory(name, app_id, app_model.tenant_id, config) return moderation.moderation_for_outputs(text) diff --git a/api/services/operation_service.py b/api/services/operation_service.py index 39f249dc2..8c8b64bcd 100644 --- a/api/services/operation_service.py +++ b/api/services/operation_service.py @@ -4,15 +4,12 @@ import requests class OperationService: - base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL') - secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY') + base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL") + secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY") @classmethod def _send_request(cls, method, endpoint, json=None, params=None): - headers = { - "Content-Type": "application/json", - "Billing-Api-Secret-Key": cls.secret_key - } + headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" response = requests.request(method, url, json=json, params=params, headers=headers) @@ -22,11 +19,11 @@ class OperationService: @classmethod def record_utm(cls, tenant_id: str, utm_info: dict): params = { - 'tenant_id': tenant_id, - 'utm_source': utm_info.get('utm_source', ''), - 'utm_medium': utm_info.get('utm_medium', ''), - 'utm_campaign': utm_info.get('utm_campaign', ''), - 'utm_content': utm_info.get('utm_content', ''), - 'utm_term': utm_info.get('utm_term', '') + "tenant_id": tenant_id, + "utm_source": utm_info.get("utm_source", ""), + "utm_medium": utm_info.get("utm_medium", ""), + "utm_campaign": utm_info.get("utm_campaign", ""), + "utm_content": utm_info.get("utm_content", ""), + "utm_term": utm_info.get("utm_term", ""), } - return cls._send_request('POST', '/tenant_utms', params=params) + return cls._send_request("POST", "/tenant_utms", params=params) diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 7b2edcf7c..0650f2cb2 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -12,19 +12,25 @@ class OpsService: :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config_data: TraceAppConfig = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not trace_config_data: return None # decrypt_token and obfuscated_token tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id - decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(tenant_id, tracing_provider, trace_config_data.tracing_config) - if tracing_provider == 'langfuse' and ('project_key' not in decrypt_tracing_config or not decrypt_tracing_config.get('project_key')): + decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config( + tenant_id, tracing_provider, trace_config_data.tracing_config + ) + if tracing_provider == "langfuse" and ( + "project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key") + ): project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider) - decrypt_tracing_config['project_key'] = project_key + decrypt_tracing_config["project_key"] = project_key decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config) @@ -44,8 +50,10 @@ class OpsService: if tracing_provider not in provider_config_map.keys() and tracing_provider: return {"error": f"Invalid tracing provider: {tracing_provider}"} - config_class, other_keys = provider_config_map[tracing_provider]['config_class'], \ - provider_config_map[tracing_provider]['other_keys'] + config_class, other_keys = ( + provider_config_map[tracing_provider]["config_class"], + provider_config_map[tracing_provider]["other_keys"], + ) default_config_instance = config_class(**tracing_config) for key in other_keys: if key in tracing_config and tracing_config[key] == "": @@ -59,9 +67,11 @@ class OpsService: project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider) # check if trace config already exists - trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config_data: TraceAppConfig = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if trace_config_data: return None @@ -69,8 +79,8 @@ class OpsService: # get tenant id tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config) - if tracing_provider == 'langfuse': - tracing_config['project_key'] = project_key + if tracing_provider == "langfuse": + tracing_config["project_key"] = project_key trace_config_data = TraceAppConfig( app_id=app_id, tracing_provider=tracing_provider, @@ -94,9 +104,11 @@ class OpsService: raise ValueError(f"Invalid tracing provider: {tracing_provider}") # check if trace config already exists - current_trace_config = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + current_trace_config = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not current_trace_config: return None @@ -126,9 +138,11 @@ class OpsService: :param tracing_provider: tracing provider :return: """ - trace_config = db.session.query(TraceAppConfig).filter( - TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider - ).first() + trace_config = ( + db.session.query(TraceAppConfig) + .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) + .first() + ) if not trace_config: return None diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 1c1c5be17..10abf0a76 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -16,7 +16,6 @@ logger = logging.getLogger(__name__) class RecommendedAppService: - builtin_data: Optional[dict] = None @classmethod @@ -27,21 +26,21 @@ class RecommendedAppService: :return: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE - if mode == 'remote': + if mode == "remote": try: result = cls._fetch_recommended_apps_from_dify_official(language) except Exception as e: - logger.warning(f'fetch recommended apps from dify official failed: {e}, switch to built-in.') + logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.") result = cls._fetch_recommended_apps_from_builtin(language) - elif mode == 'db': + elif mode == "db": result = cls._fetch_recommended_apps_from_db(language) - elif mode == 'builtin': + elif mode == "builtin": result = cls._fetch_recommended_apps_from_builtin(language) else: - raise ValueError(f'invalid fetch recommended apps mode: {mode}') + raise ValueError(f"invalid fetch recommended apps mode: {mode}") - if not result.get('recommended_apps') and language != 'en-US': - result = cls._fetch_recommended_apps_from_builtin('en-US') + if not result.get("recommended_apps") and language != "en-US": + result = cls._fetch_recommended_apps_from_builtin("en-US") return result @@ -52,16 +51,18 @@ class RecommendedAppService: :param language: language :return: """ - recommended_apps = db.session.query(RecommendedApp).filter( - RecommendedApp.is_listed == True, - RecommendedApp.language == language - ).all() + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language) + .all() + ) if len(recommended_apps) == 0: - recommended_apps = db.session.query(RecommendedApp).filter( - RecommendedApp.is_listed == True, - RecommendedApp.language == languages[0] - ).all() + recommended_apps = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0]) + .all() + ) categories = set() recommended_apps_result = [] @@ -75,28 +76,28 @@ class RecommendedAppService: continue recommended_app_result = { - 'id': recommended_app.id, - 'app': { - 'id': app.id, - 'name': app.name, - 'mode': app.mode, - 'icon': app.icon, - 'icon_background': app.icon_background + "id": recommended_app.id, + "app": { + "id": app.id, + "name": app.name, + "mode": app.mode, + "icon": app.icon, + "icon_background": app.icon_background, }, - 'app_id': recommended_app.app_id, - 'description': site.description, - 'copyright': site.copyright, - 'privacy_policy': site.privacy_policy, - 'custom_disclaimer': site.custom_disclaimer, - 'category': recommended_app.category, - 'position': recommended_app.position, - 'is_listed': recommended_app.is_listed + "app_id": recommended_app.app_id, + "description": site.description, + "copyright": site.copyright, + "privacy_policy": site.privacy_policy, + "custom_disclaimer": site.custom_disclaimer, + "category": recommended_app.category, + "position": recommended_app.position, + "is_listed": recommended_app.is_listed, } recommended_apps_result.append(recommended_app_result) categories.add(recommended_app.category) # add category to categories - return {'recommended_apps': recommended_apps_result, 'categories': sorted(categories)} + return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} @classmethod def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict: @@ -106,16 +107,16 @@ class RecommendedAppService: :return: """ domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN - url = f'{domain}/apps?language={language}' + url = f"{domain}/apps?language={language}" response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: - raise ValueError(f'fetch recommended apps failed, status code: {response.status_code}') + raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}") result = response.json() if "categories" in result: result["categories"] = sorted(result["categories"]) - + return result @classmethod @@ -126,7 +127,7 @@ class RecommendedAppService: :return: """ builtin_data = cls._get_builtin_data() - return builtin_data.get('recommended_apps', {}).get(language) + return builtin_data.get("recommended_apps", {}).get(language) @classmethod def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]: @@ -136,18 +137,18 @@ class RecommendedAppService: :return: """ mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE - if mode == 'remote': + if mode == "remote": try: result = cls._fetch_recommended_app_detail_from_dify_official(app_id) except Exception as e: - logger.warning(f'fetch recommended app detail from dify official failed: {e}, switch to built-in.') + logger.warning(f"fetch recommended app detail from dify official failed: {e}, switch to built-in.") result = cls._fetch_recommended_app_detail_from_builtin(app_id) - elif mode == 'db': + elif mode == "db": result = cls._fetch_recommended_app_detail_from_db(app_id) - elif mode == 'builtin': + elif mode == "builtin": result = cls._fetch_recommended_app_detail_from_builtin(app_id) else: - raise ValueError(f'invalid fetch recommended app detail mode: {mode}') + raise ValueError(f"invalid fetch recommended app detail mode: {mode}") return result @@ -159,7 +160,7 @@ class RecommendedAppService: :return: """ domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN - url = f'{domain}/apps/{app_id}' + url = f"{domain}/apps/{app_id}" response = requests.get(url, timeout=(3, 10)) if response.status_code != 200: return None @@ -174,10 +175,11 @@ class RecommendedAppService: :return: """ # is in public recommended list - recommended_app = db.session.query(RecommendedApp).filter( - RecommendedApp.is_listed == True, - RecommendedApp.app_id == app_id - ).first() + recommended_app = ( + db.session.query(RecommendedApp) + .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id) + .first() + ) if not recommended_app: return None @@ -188,12 +190,12 @@ class RecommendedAppService: return None return { - 'id': app_model.id, - 'name': app_model.name, - 'icon': app_model.icon, - 'icon_background': app_model.icon_background, - 'mode': app_model.mode, - 'export_data': AppDslService.export_dsl(app_model=app_model) + "id": app_model.id, + "name": app_model.name, + "icon": app_model.icon, + "icon_background": app_model.icon_background, + "mode": app_model.mode, + "export_data": AppDslService.export_dsl(app_model=app_model), } @classmethod @@ -204,7 +206,7 @@ class RecommendedAppService: :return: """ builtin_data = cls._get_builtin_data() - return builtin_data.get('app_details', {}).get(app_id) + return builtin_data.get("app_details", {}).get(app_id) @classmethod def _get_builtin_data(cls) -> dict: @@ -216,7 +218,7 @@ class RecommendedAppService: return cls.builtin_data root_path = current_app.root_path - with open(path.join(root_path, 'constants', 'recommended_apps.json'), encoding='utf-8') as f: + with open(path.join(root_path, "constants", "recommended_apps.json"), encoding="utf-8") as f: json_data = f.read() data = json.loads(json_data) cls.builtin_data = data @@ -229,27 +231,24 @@ class RecommendedAppService: Fetch all recommended apps and export datas :return: """ - templates = { - "recommended_apps": {}, - "app_details": {} - } + templates = {"recommended_apps": {}, "app_details": {}} for language in languages: try: result = cls._fetch_recommended_apps_from_dify_official(language) except Exception as e: - logger.warning(f'fetch recommended apps from dify official failed: {e}, skip.') + logger.warning(f"fetch recommended apps from dify official failed: {e}, skip.") continue - templates['recommended_apps'][language] = result + templates["recommended_apps"][language] = result - for recommended_app in result.get('recommended_apps'): - app_id = recommended_app.get('app_id') + for recommended_app in result.get("recommended_apps"): + app_id = recommended_app.get("app_id") # get app detail app_detail = cls._fetch_recommended_app_detail_from_dify_official(app_id) if not app_detail: continue - templates['app_details'][app_id] = app_detail + templates["app_details"][app_id] = app_detail return templates diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index f1113c150..9fe3cecce 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -10,46 +10,48 @@ from services.message_service import MessageService class SavedMessageService: @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int) -> InfiniteScrollPagination: - saved_messages = db.session.query(SavedMessage).filter( - SavedMessage.app_id == app_model.id, - SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - SavedMessage.created_by == user.id - ).order_by(SavedMessage.created_at.desc()).all() + def pagination_by_last_id( + cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int + ) -> InfiniteScrollPagination: + saved_messages = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .order_by(SavedMessage.created_at.desc()) + .all() + ) message_ids = [sm.message_id for sm in saved_messages] return MessageService.pagination_by_last_id( - app_model=app_model, - user=user, - last_id=last_id, - limit=limit, - include_ids=message_ids + app_model=app_model, user=user, last_id=last_id, limit=limit, include_ids=message_ids ) @classmethod def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): - saved_message = db.session.query(SavedMessage).filter( - SavedMessage.app_id == app_model.id, - SavedMessage.message_id == message_id, - SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - SavedMessage.created_by == user.id - ).first() + saved_message = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.message_id == message_id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .first() + ) if saved_message: return - message = MessageService.get_message( - app_model=app_model, - user=user, - message_id=message_id - ) + message = MessageService.get_message(app_model=app_model, user=user, message_id=message_id) saved_message = SavedMessage( app_id=app_model.id, message_id=message.id, - created_by_role='account' if isinstance(user, Account) else 'end_user', - created_by=user.id + created_by_role="account" if isinstance(user, Account) else "end_user", + created_by=user.id, ) db.session.add(saved_message) @@ -57,12 +59,16 @@ class SavedMessageService: @classmethod def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): - saved_message = db.session.query(SavedMessage).filter( - SavedMessage.app_id == app_model.id, - SavedMessage.message_id == message_id, - SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - SavedMessage.created_by == user.id - ).first() + saved_message = ( + db.session.query(SavedMessage) + .filter( + SavedMessage.app_id == app_model.id, + SavedMessage.message_id == message_id, + SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + SavedMessage.created_by == user.id, + ) + .first() + ) if not saved_message: return diff --git a/api/services/tag_service.py b/api/services/tag_service.py index d6eba38fb..0c17485a9 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -12,38 +12,32 @@ from models.model import App, Tag, TagBinding class TagService: @staticmethod def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list: - query = db.session.query( - Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label('binding_count') - ).outerjoin( - TagBinding, Tag.id == TagBinding.tag_id - ).filter( - Tag.type == tag_type, - Tag.tenant_id == current_tenant_id + query = ( + db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count")) + .outerjoin(TagBinding, Tag.id == TagBinding.tag_id) + .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id) ) if keyword: - query = query.filter(db.and_(Tag.name.ilike(f'%{keyword}%'))) - query = query.group_by( - Tag.id - ) + query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%"))) + query = query.group_by(Tag.id) results = query.order_by(Tag.created_at.desc()).all() return results @staticmethod def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list: - tags = db.session.query(Tag).filter( - Tag.id.in_(tag_ids), - Tag.tenant_id == current_tenant_id, - Tag.type == tag_type - ).all() + tags = ( + db.session.query(Tag) + .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + .all() + ) if not tags: return [] tag_ids = [tag.id for tag in tags] - tag_bindings = db.session.query( - TagBinding.target_id - ).filter( - TagBinding.tag_id.in_(tag_ids), - TagBinding.tenant_id == current_tenant_id - ).all() + tag_bindings = ( + db.session.query(TagBinding.target_id) + .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id) + .all() + ) if not tag_bindings: return [] results = [tag_binding.target_id for tag_binding in tag_bindings] @@ -51,27 +45,28 @@ class TagService: @staticmethod def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list: - tags = db.session.query(Tag).join( - TagBinding, - Tag.id == TagBinding.tag_id - ).filter( - TagBinding.target_id == target_id, - TagBinding.tenant_id == current_tenant_id, - Tag.tenant_id == current_tenant_id, - Tag.type == tag_type - ).all() + tags = ( + db.session.query(Tag) + .join(TagBinding, Tag.id == TagBinding.tag_id) + .filter( + TagBinding.target_id == target_id, + TagBinding.tenant_id == current_tenant_id, + Tag.tenant_id == current_tenant_id, + Tag.type == tag_type, + ) + .all() + ) return tags if tags else [] - @staticmethod def save_tags(args: dict) -> Tag: tag = Tag( id=str(uuid.uuid4()), - name=args['name'], - type=args['type'], + name=args["name"], + type=args["type"], created_by=current_user.id, - tenant_id=current_user.current_tenant_id + tenant_id=current_user.current_tenant_id, ) db.session.add(tag) db.session.commit() @@ -82,7 +77,7 @@ class TagService: tag = db.session.query(Tag).filter(Tag.id == tag_id).first() if not tag: raise NotFound("Tag not found") - tag.name = args['name'] + tag.name = args["name"] db.session.commit() return tag @@ -107,20 +102,21 @@ class TagService: @staticmethod def save_tag_binding(args): # check if target exists - TagService.check_target_exists(args['type'], args['target_id']) + TagService.check_target_exists(args["type"], args["target_id"]) # save tag binding - for tag_id in args['tag_ids']: - tag_binding = db.session.query(TagBinding).filter( - TagBinding.tag_id == tag_id, - TagBinding.target_id == args['target_id'] - ).first() + for tag_id in args["tag_ids"]: + tag_binding = ( + db.session.query(TagBinding) + .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"]) + .first() + ) if tag_binding: continue new_tag_binding = TagBinding( tag_id=tag_id, - target_id=args['target_id'], + target_id=args["target_id"], tenant_id=current_user.current_tenant_id, - created_by=current_user.id + created_by=current_user.id, ) db.session.add(new_tag_binding) db.session.commit() @@ -128,34 +124,34 @@ class TagService: @staticmethod def delete_tag_binding(args): # check if target exists - TagService.check_target_exists(args['type'], args['target_id']) + TagService.check_target_exists(args["type"], args["target_id"]) # delete tag binding - tag_bindings = db.session.query(TagBinding).filter( - TagBinding.target_id == args['target_id'], - TagBinding.tag_id == (args['tag_id']) - ).first() + tag_bindings = ( + db.session.query(TagBinding) + .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"])) + .first() + ) if tag_bindings: db.session.delete(tag_bindings) db.session.commit() - - @staticmethod def check_target_exists(type: str, target_id: str): - if type == 'knowledge': - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == current_user.current_tenant_id, - Dataset.id == target_id - ).first() + if type == "knowledge": + dataset = ( + db.session.query(Dataset) + .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id) + .first() + ) if not dataset: raise NotFound("Dataset not found") - elif type == 'app': - app = db.session.query(App).filter( - App.tenant_id == current_user.current_tenant_id, - App.id == target_id - ).first() + elif type == "app": + app = ( + db.session.query(App) + .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id) + .first() + ) if not app: raise NotFound("App not found") else: raise NotFound("Invalid binding type") - diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index ecc065d52..3ded9c098 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -29,111 +29,107 @@ class ApiToolManageService: @staticmethod def parser_api_schema(schema: str) -> list[ApiToolBundle]: """ - parse api schema to tool bundle + parse api schema to tool bundle """ try: warnings = {} try: tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings) except Exception as e: - raise ValueError(f'invalid schema: {str(e)}') - + raise ValueError(f"invalid schema: {str(e)}") + credentials_schema = [ ToolProviderCredentials( - name='auth_type', + name="auth_type", type=ToolProviderCredentials.CredentialsType.SELECT, required=True, - default='none', + default="none", options=[ - ToolCredentialsOption(value='none', label=I18nObject( - en_US='None', - zh_Hans='无' - )), - ToolCredentialsOption(value='api_key', label=I18nObject( - en_US='Api Key', - zh_Hans='Api Key' - )), + ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")), + ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")), ], - placeholder=I18nObject( - en_US='Select auth type', - zh_Hans='选择认证方式' - ) + placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"), ), ToolProviderCredentials( - name='api_key_header', + name="api_key_header", type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, required=False, - placeholder=I18nObject( - en_US='Enter api key header', - zh_Hans='输入 api key header,如:X-API-KEY' - ), - default='api_key', - help=I18nObject( - en_US='HTTP header name for api key', - zh_Hans='HTTP 头部字段名,用于传递 api key' - ) + placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"), + default="api_key", + help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"), ), ToolProviderCredentials( - name='api_key_value', + name="api_key_value", type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, required=False, - placeholder=I18nObject( - en_US='Enter api key', - zh_Hans='输入 api key' - ), - default='' + placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"), + default="", ), ] - return jsonable_encoder({ - 'schema_type': schema_type, - 'parameters_schema': tool_bundles, - 'credentials_schema': credentials_schema, - 'warning': warnings - }) + return jsonable_encoder( + { + "schema_type": schema_type, + "parameters_schema": tool_bundles, + "credentials_schema": credentials_schema, + "warning": warnings, + } + ) except Exception as e: - raise ValueError(f'invalid schema: {str(e)}') + raise ValueError(f"invalid schema: {str(e)}") @staticmethod def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]: """ - convert schema to tool bundles + convert schema to tool bundles - :return: the list of tool bundles, description + :return: the list of tool bundles, description """ try: tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info) return tool_bundles except Exception as e: - raise ValueError(f'invalid schema: {str(e)}') + raise ValueError(f"invalid schema: {str(e)}") @staticmethod def create_api_tool_provider( - user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict, - schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str] + user_id: str, + tenant_id: str, + provider_name: str, + icon: dict, + credentials: dict, + schema_type: str, + schema: str, + privacy_policy: str, + custom_disclaimer: str, + labels: list[str], ): """ - create api tool provider + create api tool provider """ if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f'invalid schema type {schema}') - + raise ValueError(f"invalid schema type {schema}") + # check if the provider exists - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) if provider is not None: - raise ValueError(f'provider {provider_name} already exists') + raise ValueError(f"provider {provider_name} already exists") # parse openapi to tool bundle extra_info = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - + if len(tool_bundles) > 100: - raise ValueError('the number of apis should be less than 100') + raise ValueError("the number of apis should be less than 100") # create db provider db_provider = ApiToolProvider( @@ -142,19 +138,19 @@ class ApiToolManageService: name=provider_name, icon=json.dumps(icon), schema=schema, - description=extra_info.get('description', ''), + description=extra_info.get("description", ""), schema_type_str=schema_type, tools_str=json.dumps(jsonable_encoder(tool_bundles)), credentials_str={}, privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer + custom_disclaimer=custom_disclaimer, ) - if 'auth_type' not in credentials: - raise ValueError('auth_type is required') + if "auth_type" not in credentials: + raise ValueError("auth_type is required") # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) @@ -172,14 +168,12 @@ class ApiToolManageService: # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def get_api_tool_provider_remote_schema( - user_id: str, tenant_id: str, url: str - ): + def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str): """ - get api tool provider remote schema + get api tool provider remote schema """ headers = { "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0", @@ -189,84 +183,98 @@ class ApiToolManageService: try: response = get(url, headers=headers, timeout=10) if response.status_code != 200: - raise ValueError(f'Got status code {response.status_code}') + raise ValueError(f"Got status code {response.status_code}") schema = response.text # try to parse schema, avoid SSRF attack ApiToolManageService.parser_api_schema(schema) except Exception as e: logger.error(f"parse api schema error: {str(e)}") - raise ValueError('invalid schema, please check the url you provided') - - return { - 'schema': schema - } + raise ValueError("invalid schema, please check the url you provided") + + return {"schema": schema} @staticmethod - def list_api_tool_provider_tools( - user_id: str, tenant_id: str, provider: str - ) -> list[UserTool]: + def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: """ - list api tool provider tools + list api tool provider tools """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider}') - + raise ValueError(f"you have not added provider {provider}") + controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(controller) - + return [ ToolTransformService.tool_to_user_tool( tool_bundle, labels=labels, - ) for tool_bundle in provider.tools + ) + for tool_bundle in provider.tools ] @staticmethod def update_api_tool_provider( - user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict, - schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str] + user_id: str, + tenant_id: str, + provider_name: str, + original_provider: str, + icon: dict, + credentials: dict, + schema_type: str, + schema: str, + privacy_policy: str, + custom_disclaimer: str, + labels: list[str], ): """ - update api tool provider + update api tool provider """ if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f'invalid schema type {schema}') - + raise ValueError(f"invalid schema type {schema}") + # check if the provider exists - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == original_provider, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == original_provider, + ) + .first() + ) if provider is None: - raise ValueError(f'api provider {provider_name} does not exists') + raise ValueError(f"api provider {provider_name} does not exists") # parse openapi to tool bundle extra_info = {} # extra info like description will be set here tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) - + # update db provider provider.name = provider_name provider.icon = json.dumps(icon) provider.schema = schema - provider.description = extra_info.get('description', '') + provider.description = extra_info.get("description", "") provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) provider.privacy_policy = privacy_policy provider.custom_disclaimer = custom_disclaimer - if 'auth_type' not in credentials: - raise ValueError('auth_type is required') + if "auth_type" not in credentials: + raise ValueError("auth_type is required") # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity provider_controller = ApiToolProviderController.from_db(provider, auth_type) @@ -295,84 +303,91 @@ class ApiToolManageService: # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def delete_api_tool_provider( - user_id: str, tenant_id: str, provider_name: str - ): + def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str): """ - delete tool provider + delete tool provider """ - provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ).first() + provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider_name}') - + raise ValueError(f"you have not added provider {provider_name}") + db.session.delete(provider) db.session.commit() - return { 'result': 'success' } - + return {"result": "success"} + @staticmethod - def get_api_tool_provider( - user_id: str, tenant_id: str, provider: str - ): + def get_api_tool_provider(user_id: str, tenant_id: str, provider: str): """ - get api tool provider + get api tool provider """ return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id) - + @staticmethod def test_api_tool_preview( - tenant_id: str, + tenant_id: str, provider_name: str, - tool_name: str, - credentials: dict, - parameters: dict, - schema_type: str, - schema: str + tool_name: str, + credentials: dict, + parameters: dict, + schema_type: str, + schema: str, ): """ - test api tool before adding api tool provider + test api tool before adding api tool provider """ if schema_type not in [member.value for member in ApiProviderSchemaType]: - raise ValueError(f'invalid schema type {schema_type}') - + raise ValueError(f"invalid schema type {schema_type}") + try: tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema) except Exception as e: - raise ValueError('invalid schema') - + raise ValueError("invalid schema") + # get tool bundle tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None) if tool_bundle is None: - raise ValueError(f'invalid tool name {tool_name}') - - db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider_name, - ).first() + raise ValueError(f"invalid tool name {tool_name}") + + db_provider: ApiToolProvider = ( + db.session.query(ApiToolProvider) + .filter( + ApiToolProvider.tenant_id == tenant_id, + ApiToolProvider.name == provider_name, + ) + .first() + ) if not db_provider: # create a fake db provider db_provider = ApiToolProvider( - tenant_id='', user_id='', name='', icon='', + tenant_id="", + user_id="", + name="", + icon="", schema=schema, - description='', + description="", schema_type_str=ApiProviderSchemaType.OPENAPI.value, tools_str=json.dumps(jsonable_encoder(tool_bundles)), credentials_str=json.dumps(credentials), ) - if 'auth_type' not in credentials: - raise ValueError('auth_type is required') + if "auth_type" not in credentials: + raise ValueError("auth_type is required") # get auth type, none or api key - auth_type = ApiProviderAuthType.value_of(credentials['auth_type']) + auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) # create provider entity provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) @@ -381,10 +396,7 @@ class ApiToolManageService: # decrypt credentials if db_provider.id: - tool_configuration = ToolConfigurationManager( - tenant_id=tenant_id, - provider_controller=provider_controller - ) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) # check if the credential has changed, save the original credential masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) @@ -396,27 +408,27 @@ class ApiToolManageService: provider_controller.validate_credentials_format(credentials) # get tool tool = provider_controller.get_tool(tool_name) - tool = tool.fork_tool_runtime(runtime={ - 'credentials': credentials, - 'tenant_id': tenant_id, - }) + tool = tool.fork_tool_runtime( + runtime={ + "credentials": credentials, + "tenant_id": tenant_id, + } + ) result = tool.validate_credentials(credentials, parameters) except Exception as e: - return { 'error': str(e) } - - return { 'result': result or 'empty response' } - + return {"error": str(e)} + + return {"result": result or "empty response"} + @staticmethod - def list_api_tools( - user_id: str, tenant_id: str - ) -> list[UserToolProvider]: + def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: """ - list api tools + list api tools """ # get all api providers - db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter( - ApiToolProvider.tenant_id == tenant_id - ).all() or [] + db_providers: list[ApiToolProvider] = ( + db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or [] + ) result: list[UserToolProvider] = [] @@ -425,26 +437,21 @@ class ApiToolManageService: provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) labels = ToolLabelManager.get_tool_labels(provider_controller) user_provider = ToolTransformService.api_provider_to_user_provider( - provider_controller, - db_provider=provider, - decrypt_credentials=True + provider_controller, db_provider=provider, decrypt_credentials=True ) user_provider.labels = labels # add icon ToolTransformService.repack_provider(user_provider) - tools = provider_controller.get_tools( - user_id=user_id, tenant_id=tenant_id - ) + tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) for tool in tools: - user_provider.tools.append(ToolTransformService.tool_to_user_tool( - tenant_id=tenant_id, - tool=tool, - credentials=user_provider.original_credentials, - labels=labels - )) + user_provider.tools.append( + ToolTransformService.tool_to_user_tool( + tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels + ) + ) result.append(user_provider) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index ebadbd9be..dc8cebb58 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -20,21 +20,25 @@ logger = logging.getLogger(__name__) class BuiltinToolManageService: @staticmethod - def list_builtin_tool_provider_tools( - user_id: str, tenant_id: str, provider: str - ) -> list[UserTool]: + def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: """ - list builtin tool provider tools + list builtin tool provider tools """ provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) tools = provider_controller.get_tools() - tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + tool_provider_configurations = ToolConfigurationManager( + tenant_id=tenant_id, provider_controller=provider_controller + ) # check if user has added the provider - builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, - ).first() + builtin_provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + .first() + ) credentials = {} if builtin_provider is not None: @@ -44,47 +48,47 @@ class BuiltinToolManageService: result = [] for tool in tools: - result.append(ToolTransformService.tool_to_user_tool( - tool=tool, - credentials=credentials, - tenant_id=tenant_id, - labels=ToolLabelManager.get_tool_labels(provider_controller) - )) + result.append( + ToolTransformService.tool_to_user_tool( + tool=tool, + credentials=credentials, + tenant_id=tenant_id, + labels=ToolLabelManager.get_tool_labels(provider_controller), + ) + ) return result @staticmethod - def list_builtin_provider_credentials_schema( - provider_name - ): + def list_builtin_provider_credentials_schema(provider_name): """ - list builtin provider credentials schema + list builtin provider credentials schema - :return: the list of tool providers + :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name) - return jsonable_encoder([ - v for _, v in (provider.credentials_schema or {}).items() - ]) + return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()]) @staticmethod - def update_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str, credentials: dict - ): + def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): """ - update builtin tool provider + update builtin tool provider """ # get if the provider exists - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ).first() + provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ) + .first() + ) try: # get provider provider_controller = ToolManager.get_builtin_provider(provider_name) if not provider_controller.need_credentials: - raise ValueError(f'provider {provider_name} does not need credentials') + raise ValueError(f"provider {provider_name} does not need credentials") tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) # get original credentials if exists if provider is not None: @@ -121,19 +125,21 @@ class BuiltinToolManageService: # delete cache tool_configuration.delete_tool_credentials_cache() - return {'result': 'success'} + return {"result": "success"} @staticmethod - def get_builtin_tool_provider_credentials( - user_id: str, tenant_id: str, provider: str - ): + def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str): """ - get builtin tool provider credentials + get builtin tool provider credentials """ - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, - ).first() + provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ) + .first() + ) if provider is None: return {} @@ -145,19 +151,21 @@ class BuiltinToolManageService: return credentials @staticmethod - def delete_builtin_tool_provider( - user_id: str, tenant_id: str, provider_name: str - ): + def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str): """ - delete tool provider + delete tool provider """ - provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ).first() + provider: BuiltinToolProvider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, + ) + .first() + ) if provider is None: - raise ValueError(f'you have not added provider {provider_name}') + raise ValueError(f"you have not added provider {provider_name}") db.session.delete(provider) db.session.commit() @@ -167,38 +175,36 @@ class BuiltinToolManageService: tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration.delete_tool_credentials_cache() - return {'result': 'success'} + return {"result": "success"} @staticmethod - def get_builtin_tool_provider_icon( - provider: str - ): + def get_builtin_tool_provider_icon(provider: str): """ - get tool provider icon and it's mimetype + get tool provider icon and it's mimetype """ icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider) - with open(icon_path, 'rb') as f: + with open(icon_path, "rb") as f: icon_bytes = f.read() return icon_bytes, mime_type @staticmethod - def list_builtin_tools( - user_id: str, tenant_id: str - ) -> list[UserToolProvider]: + def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: """ - list builtin tools + list builtin tools """ # get all builtin providers provider_controllers = ToolManager.list_builtin_providers() # get all user added providers - db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter( - BuiltinToolProvider.tenant_id == tenant_id - ).all() or [] + db_providers: list[BuiltinToolProvider] = ( + db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or [] + ) # find provider - find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) + find_provider = lambda provider: next( + filter(lambda db_provider: db_provider.provider == provider, db_providers), None + ) result: list[UserToolProvider] = [] @@ -209,7 +215,7 @@ class BuiltinToolManageService: include_set=dify_config.POSITION_TOOL_INCLUDES_SET, exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, data=provider_controller, - name_func=lambda x: x.identity.name + name_func=lambda x: x.identity.name, ): continue @@ -217,7 +223,7 @@ class BuiltinToolManageService: user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, db_provider=find_provider(provider_controller.identity.name), - decrypt_credentials=True + decrypt_credentials=True, ) # add icon @@ -225,12 +231,14 @@ class BuiltinToolManageService: tools = provider_controller.get_tools() for tool in tools: - user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool( - tenant_id=tenant_id, - tool=tool, - credentials=user_builtin_provider.original_credentials, - labels=ToolLabelManager.get_tool_labels(provider_controller) - )) + user_builtin_provider.tools.append( + ToolTransformService.tool_to_user_tool( + tenant_id=tenant_id, + tool=tool, + credentials=user_builtin_provider.original_credentials, + labels=ToolLabelManager.get_tool_labels(provider_controller), + ) + ) result.append(user_builtin_provider) except Exception as e: diff --git a/api/services/tools/tool_labels_service.py b/api/services/tools/tool_labels_service.py index 8a6aa025f..35e58b5ad 100644 --- a/api/services/tools/tool_labels_service.py +++ b/api/services/tools/tool_labels_service.py @@ -5,4 +5,4 @@ from core.tools.entities.values import default_tool_labels class ToolLabelsService: @classmethod def list_tool_labels(cls) -> list[ToolLabel]: - return default_tool_labels \ No newline at end of file + return default_tool_labels diff --git a/api/services/tools/tools_manage_service.py b/api/services/tools/tools_manage_service.py index 76d2f53ae..1c67f7648 100644 --- a/api/services/tools/tools_manage_service.py +++ b/api/services/tools/tools_manage_service.py @@ -11,13 +11,11 @@ class ToolCommonService: @staticmethod def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None): """ - list tool providers + list tool providers - :return: the list of tool providers + :return: the list of tool providers """ - providers = ToolManager.user_list_providers( - user_id, tenant_id, typ - ) + providers = ToolManager.user_list_providers(user_id, tenant_id, typ) # add icon for provider in providers: @@ -26,4 +24,3 @@ class ToolCommonService: result = [provider.to_dict() for provider in providers] return result - \ No newline at end of file diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index cfce3fbd0..6fb0f2f51 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -22,46 +22,39 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi logger = logging.getLogger(__name__) + class ToolTransformService: @staticmethod def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]: """ - get tool provider icon url + get tool provider icon url """ - url_prefix = (dify_config.CONSOLE_API_URL - + "/console/api/workspaces/current/tool-provider/") - + url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/" + if provider_type == ToolProviderType.BUILT_IN.value: - return url_prefix + 'builtin/' + provider_name + '/icon' + return url_prefix + "builtin/" + provider_name + "/icon" elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]: try: return json.loads(icon) except: - return { - "background": "#252525", - "content": "\ud83d\ude01" - } - - return '' - + return {"background": "#252525", "content": "\ud83d\ude01"} + + return "" + @staticmethod def repack_provider(provider: Union[dict, UserToolProvider]): """ - repack provider + repack provider - :param provider: the provider dict + :param provider: the provider dict """ - if isinstance(provider, dict) and 'icon' in provider: - provider['icon'] = ToolTransformService.get_tool_provider_icon_url( - provider_type=provider['type'], - provider_name=provider['name'], - icon=provider['icon'] + if isinstance(provider, dict) and "icon" in provider: + provider["icon"] = ToolTransformService.get_tool_provider_icon_url( + provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"] ) elif isinstance(provider, UserToolProvider): provider.icon = ToolTransformService.get_tool_provider_icon_url( - provider_type=provider.type.value, - provider_name=provider.name, - icon=provider.icon + provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon ) @staticmethod @@ -92,14 +85,13 @@ class ToolTransformService: masked_credentials={}, is_team_authorization=False, tools=[], - labels=provider_controller.tool_labels + labels=provider_controller.tool_labels, ) # get credentials schema schema = provider_controller.get_credentials_schema() for name, value in schema.items(): - result.masked_credentials[name] = \ - ToolProviderCredentials.CredentialsType.default(value.type) + result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type) # check if the provider need credentials if not provider_controller.need_credentials: @@ -113,8 +105,7 @@ class ToolTransformService: # init tool configuration tool_configuration = ToolConfigurationManager( - tenant_id=db_provider.tenant_id, - provider_controller=provider_controller + tenant_id=db_provider.tenant_id, provider_controller=provider_controller ) # decrypt the credentials and mask the credentials decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) @@ -124,7 +115,7 @@ class ToolTransformService: result.original_credentials = decrypted_credentials return result - + @staticmethod def api_provider_to_controller( db_provider: ApiToolProvider, @@ -135,25 +126,23 @@ class ToolTransformService: # package tool provider controller controller = ApiToolProviderController.from_db( db_provider=db_provider, - auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else - ApiProviderAuthType.NONE + auth_type=ApiProviderAuthType.API_KEY + if db_provider.credentials["auth_type"] == "api_key" + else ApiProviderAuthType.NONE, ) return controller - + @staticmethod - def workflow_provider_to_controller( - db_provider: WorkflowToolProvider - ) -> WorkflowToolProviderController: + def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController: """ convert provider controller to provider """ return WorkflowToolProviderController.from_db(db_provider) - + @staticmethod def workflow_provider_to_user_provider( - provider_controller: WorkflowToolProviderController, - labels: list[str] = None + provider_controller: WorkflowToolProviderController, labels: list[str] = None ): """ convert provider controller to user provider @@ -175,7 +164,7 @@ class ToolTransformService: masked_credentials={}, is_team_authorization=True, tools=[], - labels=labels or [] + labels=labels or [], ) @staticmethod @@ -183,16 +172,16 @@ class ToolTransformService: provider_controller: ApiToolProviderController, db_provider: ApiToolProvider, decrypt_credentials: bool = True, - labels: list[str] = None + labels: list[str] = None, ) -> UserToolProvider: """ convert provider controller to user provider """ - username = 'Anonymous' + username = "Anonymous" try: username = db_provider.user.name except Exception as e: - logger.error(f'failed to get user name for api provider {db_provider.id}: {str(e)}') + logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}") # add provider into providers credentials = db_provider.credentials result = UserToolProvider( @@ -212,14 +201,13 @@ class ToolTransformService: masked_credentials={}, is_team_authorization=True, tools=[], - labels=labels or [] + labels=labels or [], ) if decrypt_credentials: # init tool configuration tool_configuration = ToolConfigurationManager( - tenant_id=db_provider.tenant_id, - provider_controller=provider_controller + tenant_id=db_provider.tenant_id, provider_controller=provider_controller ) # decrypt the credentials and mask the credentials @@ -229,23 +217,25 @@ class ToolTransformService: result.masked_credentials = masked_credentials return result - + @staticmethod def tool_to_user_tool( - tool: Union[ApiToolBundle, WorkflowTool, Tool], - credentials: dict = None, + tool: Union[ApiToolBundle, WorkflowTool, Tool], + credentials: dict = None, tenant_id: str = None, - labels: list[str] = None + labels: list[str] = None, ) -> UserTool: """ convert tool to user tool """ if isinstance(tool, Tool): # fork tool runtime - tool = tool.fork_tool_runtime(runtime={ - 'credentials': credentials, - 'tenant_id': tenant_id, - }) + tool = tool.fork_tool_runtime( + runtime={ + "credentials": credentials, + "tenant_id": tenant_id, + } + ) # get tool parameters parameters = tool.parameters or [] @@ -270,20 +260,14 @@ class ToolTransformService: label=tool.identity.label, description=tool.description.human, parameters=current_parameters, - labels=labels + labels=labels, ) if isinstance(tool, ApiToolBundle): return UserTool( author=tool.author, name=tool.operation_id, - label=I18nObject( - en_US=tool.operation_id, - zh_Hans=tool.operation_id - ), - description=I18nObject( - en_US=tool.summary or '', - zh_Hans=tool.summary or '' - ), + label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id), + description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""), parameters=tool.parameters, - labels=labels - ) \ No newline at end of file + labels=labels, + ) diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 185483a71..3830e7533 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -19,10 +19,21 @@ class WorkflowToolManageService: """ Service class for managing workflow tools. """ + @classmethod - def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str, name: str, - label: str, icon: dict, description: str, - parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict: + def create_workflow_tool( + cls, + user_id: str, + tenant_id: str, + workflow_app_id: str, + name: str, + label: str, + icon: dict, + description: str, + parameters: list[dict], + privacy_policy: str = "", + labels: list[str] = None, + ) -> dict: """ Create a workflow tool. :param user_id: the user id @@ -38,27 +49,28 @@ class WorkflowToolManageService: WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) # check if the name is unique - existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - # name or app_id - or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id) - ).first() + existing_workflow_tool_provider = ( + db.session.query(WorkflowToolProvider) + .filter( + WorkflowToolProvider.tenant_id == tenant_id, + # name or app_id + or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id), + ) + .first() + ) if existing_workflow_tool_provider is not None: - raise ValueError(f'Tool with name {name} or app_id {workflow_app_id} already exists') - - app: App = db.session.query(App).filter( - App.id == workflow_app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists") + + app: App = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first() if app is None: - raise ValueError(f'App {workflow_app_id} not found') - + raise ValueError(f"App {workflow_app_id} not found") + workflow: Workflow = app.workflow if workflow is None: - raise ValueError(f'Workflow not found for app {workflow_app_id}') - + raise ValueError(f"Workflow not found for app {workflow_app_id}") + workflow_tool_provider = WorkflowToolProvider( tenant_id=tenant_id, user_id=user_id, @@ -76,19 +88,26 @@ class WorkflowToolManageService: WorkflowToolProviderController.from_db(workflow_tool_provider) except Exception as e: raise ValueError(str(e)) - + db.session.add(workflow_tool_provider) db.session.commit() - return { - 'result': 'success' - } - + return {"result": "success"} @classmethod - def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str, - name: str, label: str, icon: dict, description: str, - parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict: + def update_workflow_tool( + cls, + user_id: str, + tenant_id: str, + workflow_tool_id: str, + name: str, + label: str, + icon: dict, + description: str, + parameters: list[dict], + privacy_policy: str = "", + labels: list[str] = None, + ) -> dict: """ Update a workflow tool. :param user_id: the user id @@ -106,35 +125,39 @@ class WorkflowToolManageService: WorkflowToolConfigurationUtils.check_parameter_configurations(parameters) # check if the name is unique - existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.name == name, - WorkflowToolProvider.id != workflow_tool_id - ).first() + existing_workflow_tool_provider = ( + db.session.query(WorkflowToolProvider) + .filter( + WorkflowToolProvider.tenant_id == tenant_id, + WorkflowToolProvider.name == name, + WorkflowToolProvider.id != workflow_tool_id, + ) + .first() + ) if existing_workflow_tool_provider is not None: - raise ValueError(f'Tool with name {name} already exists') - - workflow_tool_provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id - ).first() + raise ValueError(f"Tool with name {name} already exists") + + workflow_tool_provider: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) if workflow_tool_provider is None: - raise ValueError(f'Tool {workflow_tool_id} not found') - - app: App = db.session.query(App).filter( - App.id == workflow_tool_provider.app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool {workflow_tool_id} not found") + + app: App = ( + db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first() + ) if app is None: - raise ValueError(f'App {workflow_tool_provider.app_id} not found') - + raise ValueError(f"App {workflow_tool_provider.app_id} not found") + workflow: Workflow = app.workflow if workflow is None: - raise ValueError(f'Workflow not found for app {workflow_tool_provider.app_id}') - + raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}") + workflow_tool_provider.name = name workflow_tool_provider.label = label workflow_tool_provider.icon = json.dumps(icon) @@ -154,13 +177,10 @@ class WorkflowToolManageService: if labels is not None: ToolLabelManager.update_tool_labels( - ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), - labels + ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels ) - return { - 'result': 'success' - } + return {"result": "success"} @classmethod def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]: @@ -170,9 +190,7 @@ class WorkflowToolManageService: :param tenant_id: the tenant id :return: the list of tools """ - db_tools = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id - ).all() + db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() tools = [] for provider in db_tools: @@ -188,14 +206,12 @@ class WorkflowToolManageService: for tool in tools: user_tool_provider = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=tool, - labels=labels.get(tool.provider_id, []) + provider_controller=tool, labels=labels.get(tool.provider_id, []) ) ToolTransformService.repack_provider(user_tool_provider) user_tool_provider.tools = [ ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=labels.get(tool.provider_id, []) + tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, []) ) ] result.append(user_tool_provider) @@ -211,15 +227,12 @@ class WorkflowToolManageService: :param workflow_app_id: the workflow app id """ db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id + WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id ).delete() db.session.commit() - return { - 'result': 'success' - } + return {"result": "success"} @classmethod def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict: @@ -230,40 +243,37 @@ class WorkflowToolManageService: :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id - ).first() + db_tool: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) if db_tool is None: - raise ValueError(f'Tool {workflow_tool_id} not found') - - workflow_app: App = db.session.query(App).filter( - App.id == db_tool.app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool {workflow_tool_id} not found") + + workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() if workflow_app is None: - raise ValueError(f'App {db_tool.app_id} not found') + raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) return { - 'name': db_tool.name, - 'label': db_tool.label, - 'workflow_tool_id': db_tool.id, - 'workflow_app_id': db_tool.app_id, - 'icon': json.loads(db_tool.icon), - 'description': db_tool.description, - 'parameters': jsonable_encoder(db_tool.parameter_configurations), - 'tool': ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=ToolLabelManager.get_tool_labels(tool) + "name": db_tool.name, + "label": db_tool.label, + "workflow_tool_id": db_tool.id, + "workflow_app_id": db_tool.app_id, + "icon": json.loads(db_tool.icon), + "description": db_tool.description, + "parameters": jsonable_encoder(db_tool.parameter_configurations), + "tool": ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) ), - 'synced': workflow_app.workflow.version == db_tool.version, - 'privacy_policy': db_tool.privacy_policy, + "synced": workflow_app.workflow.version == db_tool.version, + "privacy_policy": db_tool.privacy_policy, } - + @classmethod def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict: """ @@ -273,40 +283,37 @@ class WorkflowToolManageService: :param workflow_app_id: the workflow app id :return: the tool """ - db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.app_id == workflow_app_id - ).first() + db_tool: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id) + .first() + ) if db_tool is None: - raise ValueError(f'Tool {workflow_app_id} not found') - - workflow_app: App = db.session.query(App).filter( - App.id == db_tool.app_id, - App.tenant_id == tenant_id - ).first() + raise ValueError(f"Tool {workflow_app_id} not found") + + workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first() if workflow_app is None: - raise ValueError(f'App {db_tool.app_id} not found') + raise ValueError(f"App {db_tool.app_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) return { - 'name': db_tool.name, - 'label': db_tool.label, - 'workflow_tool_id': db_tool.id, - 'workflow_app_id': db_tool.app_id, - 'icon': json.loads(db_tool.icon), - 'description': db_tool.description, - 'parameters': jsonable_encoder(db_tool.parameter_configurations), - 'tool': ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=ToolLabelManager.get_tool_labels(tool) + "name": db_tool.name, + "label": db_tool.label, + "workflow_tool_id": db_tool.id, + "workflow_app_id": db_tool.app_id, + "icon": json.loads(db_tool.icon), + "description": db_tool.description, + "parameters": jsonable_encoder(db_tool.parameter_configurations), + "tool": ToolTransformService.tool_to_user_tool( + tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) ), - 'synced': workflow_app.workflow.version == db_tool.version, - 'privacy_policy': db_tool.privacy_policy + "synced": workflow_app.workflow.version == db_tool.version, + "privacy_policy": db_tool.privacy_policy, } - + @classmethod def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]: """ @@ -316,19 +323,19 @@ class WorkflowToolManageService: :param workflow_app_id: the workflow app id :return: the list of tools """ - db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter( - WorkflowToolProvider.tenant_id == tenant_id, - WorkflowToolProvider.id == workflow_tool_id - ).first() + db_tool: WorkflowToolProvider = ( + db.session.query(WorkflowToolProvider) + .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id) + .first() + ) if db_tool is None: - raise ValueError(f'Tool {workflow_tool_id} not found') + raise ValueError(f"Tool {workflow_tool_id} not found") tool = ToolTransformService.workflow_provider_to_controller(db_tool) return [ ToolTransformService.tool_to_user_tool( - tool.get_tools(user_id, tenant_id)[0], - labels=ToolLabelManager.get_tool_labels(tool) + tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool) ) - ] \ No newline at end of file + ] diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 232d29432..3c6735133 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -7,10 +7,10 @@ from models.dataset import Dataset, DocumentSegment class VectorService: - @classmethod - def create_segments_vector(cls, keywords_list: Optional[list[list[str]]], - segments: list[DocumentSegment], dataset: Dataset): + def create_segments_vector( + cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset + ): documents = [] for segment in segments: document = Document( @@ -20,14 +20,12 @@ class VectorService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) documents.append(document) - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": # save vector index - vector = Vector( - dataset=dataset - ) + vector = Vector(dataset=dataset) vector.add_texts(documents, duplicate_check=True) # save keyword index @@ -50,13 +48,11 @@ class VectorService: "doc_hash": segment.index_node_hash, "document_id": segment.document_id, "dataset_id": segment.dataset_id, - } + }, ) - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": # update vector index - vector = Vector( - dataset=dataset - ) + vector = Vector(dataset=dataset) vector.delete_by_ids([segment.index_node_id]) vector.add_texts([document], duplicate_check=True) diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index 269a04813..d7ccc964c 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -11,18 +11,29 @@ from services.conversation_service import ConversationService class WebConversationService: @classmethod - def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]], - last_id: Optional[str], limit: int, invoke_from: InvokeFrom, - pinned: Optional[bool] = None, - sort_by='-updated_at') -> InfiniteScrollPagination: + def pagination_by_last_id( + cls, + app_model: App, + user: Optional[Union[Account, EndUser]], + last_id: Optional[str], + limit: int, + invoke_from: InvokeFrom, + pinned: Optional[bool] = None, + sort_by="-updated_at", + ) -> InfiniteScrollPagination: include_ids = None exclude_ids = None if pinned is not None: - pinned_conversations = db.session.query(PinnedConversation).filter( - PinnedConversation.app_id == app_model.id, - PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - PinnedConversation.created_by == user.id - ).order_by(PinnedConversation.created_at.desc()).all() + pinned_conversations = ( + db.session.query(PinnedConversation) + .filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .order_by(PinnedConversation.created_at.desc()) + .all() + ) pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations] if pinned: include_ids = pinned_conversation_ids @@ -37,32 +48,34 @@ class WebConversationService: invoke_from=invoke_from, include_ids=include_ids, exclude_ids=exclude_ids, - sort_by=sort_by + sort_by=sort_by, ) @classmethod def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): - pinned_conversation = db.session.query(PinnedConversation).filter( - PinnedConversation.app_id == app_model.id, - PinnedConversation.conversation_id == conversation_id, - PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - PinnedConversation.created_by == user.id - ).first() + pinned_conversation = ( + db.session.query(PinnedConversation) + .filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.conversation_id == conversation_id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .first() + ) if pinned_conversation: return conversation = ConversationService.get_conversation( - app_model=app_model, - conversation_id=conversation_id, - user=user + app_model=app_model, conversation_id=conversation_id, user=user ) pinned_conversation = PinnedConversation( app_id=app_model.id, conversation_id=conversation.id, - created_by_role='account' if isinstance(user, Account) else 'end_user', - created_by=user.id + created_by_role="account" if isinstance(user, Account) else "end_user", + created_by=user.id, ) db.session.add(pinned_conversation) @@ -70,12 +83,16 @@ class WebConversationService: @classmethod def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]): - pinned_conversation = db.session.query(PinnedConversation).filter( - PinnedConversation.app_id == app_model.id, - PinnedConversation.conversation_id == conversation_id, - PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'), - PinnedConversation.created_by == user.id - ).first() + pinned_conversation = ( + db.session.query(PinnedConversation) + .filter( + PinnedConversation.app_id == app_model.id, + PinnedConversation.conversation_id == conversation_id, + PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), + PinnedConversation.created_by == user.id, + ) + .first() + ) if not pinned_conversation: return diff --git a/api/services/website_service.py b/api/services/website_service.py index c166b0123..6dff35d63 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -11,161 +11,126 @@ from services.auth.api_key_auth_service import ApiKeyAuthService class WebsiteService: - @classmethod def document_create_args_validate(cls, args: dict): - if 'url' not in args or not args['url']: - raise ValueError('url is required') - if 'options' not in args or not args['options']: - raise ValueError('options is required') - if 'limit' not in args['options'] or not args['options']['limit']: - raise ValueError('limit is required') + if "url" not in args or not args["url"]: + raise ValueError("url is required") + if "options" not in args or not args["options"]: + raise ValueError("options is required") + if "limit" not in args["options"] or not args["options"]["limit"]: + raise ValueError("limit is required") @classmethod def crawl_url(cls, args: dict) -> dict: - provider = args.get('provider') - url = args.get('url') - options = args.get('options') - credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, - 'website', - provider) - if provider == 'firecrawl': + provider = args.get("provider") + url = args.get("url") + options = args.get("options") + credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) + if provider == "firecrawl": # decrypt api_key api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, - token=credentials.get('config').get('api_key') + tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) - crawl_sub_pages = options.get('crawl_sub_pages', False) - only_main_content = options.get('only_main_content', False) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) + crawl_sub_pages = options.get("crawl_sub_pages", False) + only_main_content = options.get("only_main_content", False) if not crawl_sub_pages: params = { - 'crawlerOptions': { + "crawlerOptions": { "includes": [], "excludes": [], "generateImgAltText": True, "limit": 1, - 'returnOnlyUrls': False, - 'pageOptions': { - 'onlyMainContent': only_main_content, - "includeHtml": False - } + "returnOnlyUrls": False, + "pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}, } } else: - includes = options.get('includes').split(',') if options.get('includes') else [] - excludes = options.get('excludes').split(',') if options.get('excludes') else [] + includes = options.get("includes").split(",") if options.get("includes") else [] + excludes = options.get("excludes").split(",") if options.get("excludes") else [] params = { - 'crawlerOptions': { + "crawlerOptions": { "includes": includes if includes else [], "excludes": excludes if excludes else [], "generateImgAltText": True, - "limit": options.get('limit', 1), - 'returnOnlyUrls': False, - 'pageOptions': { - 'onlyMainContent': only_main_content, - "includeHtml": False - } + "limit": options.get("limit", 1), + "returnOnlyUrls": False, + "pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}, } } - if options.get('max_depth'): - params['crawlerOptions']['maxDepth'] = options.get('max_depth') + if options.get("max_depth"): + params["crawlerOptions"]["maxDepth"] = options.get("max_depth") job_id = firecrawl_app.crawl_url(url, params) - website_crawl_time_cache_key = f'website_crawl_{job_id}' + website_crawl_time_cache_key = f"website_crawl_{job_id}" time = str(datetime.datetime.now().timestamp()) redis_client.setex(website_crawl_time_cache_key, 3600, time) - return { - 'status': 'active', - 'job_id': job_id - } + return {"status": "active", "job_id": job_id} else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") @classmethod def get_crawl_status(cls, job_id: str, provider: str) -> dict: - credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, - 'website', - provider) - if provider == 'firecrawl': + credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) + if provider == "firecrawl": # decrypt api_key api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, - token=credentials.get('config').get('api_key') + tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) result = firecrawl_app.check_crawl_status(job_id) crawl_status_data = { - 'status': result.get('status', 'active'), - 'job_id': job_id, - 'total': result.get('total', 0), - 'current': result.get('current', 0), - 'data': result.get('data', []) + "status": result.get("status", "active"), + "job_id": job_id, + "total": result.get("total", 0), + "current": result.get("current", 0), + "data": result.get("data", []), } - if crawl_status_data['status'] == 'completed': - website_crawl_time_cache_key = f'website_crawl_{job_id}' + if crawl_status_data["status"] == "completed": + website_crawl_time_cache_key = f"website_crawl_{job_id}" start_time = redis_client.get(website_crawl_time_cache_key) if start_time: end_time = datetime.datetime.now().timestamp() time_consuming = abs(end_time - float(start_time)) - crawl_status_data['time_consuming'] = f"{time_consuming:.2f}" + crawl_status_data["time_consuming"] = f"{time_consuming:.2f}" redis_client.delete(website_crawl_time_cache_key) else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") return crawl_status_data @classmethod def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None: - credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, - 'website', - provider) - if provider == 'firecrawl': - file_key = 'website_files/' + job_id + '.txt' + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) + if provider == "firecrawl": + file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): data = storage.load_once(file_key) if data: - data = json.loads(data.decode('utf-8')) + data = json.loads(data.decode("utf-8")) else: # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=tenant_id, - token=credentials.get('config').get('api_key') - ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) result = firecrawl_app.check_crawl_status(job_id) - if result.get('status') != 'completed': - raise ValueError('Crawl job is not completed') - data = result.get('data') + if result.get("status") != "completed": + raise ValueError("Crawl job is not completed") + data = result.get("data") if data: for item in data: - if item.get('source_url') == url: + if item.get("source_url") == url: return item return None else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") @classmethod def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None: - credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, - 'website', - provider) - if provider == 'firecrawl': + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) + if provider == "firecrawl": # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=tenant_id, - token=credentials.get('config').get('api_key') - ) - firecrawl_app = FirecrawlApp(api_key=api_key, - base_url=credentials.get('config').get('base_url', None)) - params = { - 'pageOptions': { - 'onlyMainContent': only_main_content, - "includeHtml": False - } - } + api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) + params = {"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}} result = firecrawl_app.scrape_url(url, params) return result else: - raise ValueError('Invalid provider') + raise ValueError("Invalid provider") diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index c4d3d2763..b4f0882a3 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -10,7 +10,6 @@ from models.workflow import WorkflowAppLog, WorkflowRun, WorkflowRunStatus class WorkflowAppService: - def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Pagination: """ Get paginate workflow app logs @@ -18,20 +17,14 @@ class WorkflowAppService: :param args: request args :return: """ - query = ( - db.select(WorkflowAppLog) - .where( - WorkflowAppLog.tenant_id == app_model.tenant_id, - WorkflowAppLog.app_id == app_model.id - ) + query = db.select(WorkflowAppLog).where( + WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id ) - status = WorkflowRunStatus.value_of(args.get('status')) if args.get('status') else None - keyword = args['keyword'] + status = WorkflowRunStatus.value_of(args.get("status")) if args.get("status") else None + keyword = args["keyword"] if keyword or status: - query = query.join( - WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id - ) + query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id) if keyword: keyword_like_val = f"%{args['keyword'][:30]}%" @@ -39,7 +32,7 @@ class WorkflowAppService: WorkflowRun.inputs.ilike(keyword_like_val), WorkflowRun.outputs.ilike(keyword_like_val), # filter keyword by end user session id if created by end user role - and_(WorkflowRun.created_by_role == 'end_user', EndUser.session_id.ilike(keyword_like_val)) + and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)), ] # filter keyword by workflow run id @@ -49,23 +42,16 @@ class WorkflowAppService: query = query.outerjoin( EndUser, - and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value) + and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value), ).filter(or_(*keyword_conditions)) if status: # join with workflow_run and filter by status - query = query.filter( - WorkflowRun.status == status.value - ) + query = query.filter(WorkflowRun.status == status.value) query = query.order_by(WorkflowAppLog.created_at.desc()) - pagination = db.paginate( - query, - page=args['page'], - per_page=args['limit'], - error_out=False - ) + pagination = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) return pagination diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index ccce38ada..b7b3abeaa 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -18,6 +18,7 @@ class WorkflowRunService: :param app_model: app model :param args: request args """ + class WorkflowWithMessage: message_id: str conversation_id: str @@ -33,9 +34,7 @@ class WorkflowRunService: with_message_workflow_runs = [] for workflow_run in pagination.data: message = workflow_run.message - with_message_workflow_run = WorkflowWithMessage( - workflow_run=workflow_run - ) + with_message_workflow_run = WorkflowWithMessage(workflow_run=workflow_run) if message: with_message_workflow_run.message_id = message.id with_message_workflow_run.conversation_id = message.conversation_id @@ -53,26 +52,30 @@ class WorkflowRunService: :param app_model: app model :param args: request args """ - limit = int(args.get('limit', 20)) + limit = int(args.get("limit", 20)) base_query = db.session.query(WorkflowRun).filter( WorkflowRun.tenant_id == app_model.tenant_id, WorkflowRun.app_id == app_model.id, - WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value + WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value, ) - if args.get('last_id'): + if args.get("last_id"): last_workflow_run = base_query.filter( - WorkflowRun.id == args.get('last_id'), + WorkflowRun.id == args.get("last_id"), ).first() if not last_workflow_run: - raise ValueError('Last workflow run not exists') + raise ValueError("Last workflow run not exists") - workflow_runs = base_query.filter( - WorkflowRun.created_at < last_workflow_run.created_at, - WorkflowRun.id != last_workflow_run.id - ).order_by(WorkflowRun.created_at.desc()).limit(limit).all() + workflow_runs = ( + base_query.filter( + WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id + ) + .order_by(WorkflowRun.created_at.desc()) + .limit(limit) + .all() + ) else: workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() @@ -81,17 +84,13 @@ class WorkflowRunService: current_page_first_workflow_run = workflow_runs[-1] rest_count = base_query.filter( WorkflowRun.created_at < current_page_first_workflow_run.created_at, - WorkflowRun.id != current_page_first_workflow_run.id + WorkflowRun.id != current_page_first_workflow_run.id, ).count() if rest_count > 0: has_more = True - return InfiniteScrollPagination( - data=workflow_runs, - limit=limit, - has_more=has_more - ) + return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun: """ @@ -100,11 +99,15 @@ class WorkflowRunService: :param app_model: app model :param run_id: workflow run id """ - workflow_run = db.session.query(WorkflowRun).filter( - WorkflowRun.tenant_id == app_model.tenant_id, - WorkflowRun.app_id == app_model.id, - WorkflowRun.id == run_id, - ).first() + workflow_run = ( + db.session.query(WorkflowRun) + .filter( + WorkflowRun.tenant_id == app_model.tenant_id, + WorkflowRun.app_id == app_model.id, + WorkflowRun.id == run_id, + ) + .first() + ) return workflow_run @@ -117,12 +120,17 @@ class WorkflowRunService: if not workflow_run: return [] - node_executions = db.session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.tenant_id == app_model.tenant_id, - WorkflowNodeExecution.app_id == app_model.id, - WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, - WorkflowNodeExecution.workflow_run_id == run_id, - ).order_by(WorkflowNodeExecution.index.desc()).all() + node_executions = ( + db.session.query(WorkflowNodeExecution) + .filter( + WorkflowNodeExecution.tenant_id == app_model.tenant_id, + WorkflowNodeExecution.app_id == app_model.id, + WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, + WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, + WorkflowNodeExecution.workflow_run_id == run_id, + ) + .order_by(WorkflowNodeExecution.index.desc()) + .all() + ) return node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index c593b66f3..4c3ded14a 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -37,11 +37,13 @@ class WorkflowService: Get draft workflow """ # fetch draft workflow by app_model - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.version == 'draft' - ).first() + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft" + ) + .first() + ) # return draft workflow return workflow @@ -55,11 +57,15 @@ class WorkflowService: return None # fetch published workflow by workflow_id - workflow = db.session.query(Workflow).filter( - Workflow.tenant_id == app_model.tenant_id, - Workflow.app_id == app_model.id, - Workflow.id == app_model.workflow_id - ).first() + workflow = ( + db.session.query(Workflow) + .filter( + Workflow.tenant_id == app_model.tenant_id, + Workflow.app_id == app_model.id, + Workflow.id == app_model.workflow_id, + ) + .first() + ) return workflow @@ -85,10 +91,7 @@ class WorkflowService: raise WorkflowHashNotEqualError() # validate features structure - self.validate_features_structure( - app_model=app_model, - features=features - ) + self.validate_features_structure(app_model=app_model, features=features) # create draft workflow if not found if not workflow: @@ -96,7 +99,7 @@ class WorkflowService: tenant_id=app_model.tenant_id, app_id=app_model.id, type=WorkflowType.from_app_mode(app_model.mode).value, - version='draft', + version="draft", graph=json.dumps(graph), features=json.dumps(features), created_by=account.id, @@ -122,9 +125,7 @@ class WorkflowService: # return draft workflow return workflow - def publish_workflow(self, app_model: App, - account: Account, - draft_workflow: Optional[Workflow] = None) -> Workflow: + def publish_workflow(self, app_model: App, account: Account, draft_workflow: Optional[Workflow] = None) -> Workflow: """ Publish workflow from draft @@ -137,7 +138,7 @@ class WorkflowService: draft_workflow = self.get_draft_workflow(app_model=app_model) if not draft_workflow: - raise ValueError('No valid workflow found.') + raise ValueError("No valid workflow found.") # create new workflow workflow = Workflow( @@ -187,17 +188,16 @@ class WorkflowService: workflow_engine_manager = WorkflowEngineManager() return workflow_engine_manager.get_default_config(node_type, filters) - def run_draft_workflow_node(self, app_model: App, - node_id: str, - user_inputs: dict, - account: Account) -> WorkflowNodeExecution: + def run_draft_workflow_node( + self, app_model: App, node_id: str, user_inputs: dict, account: Account + ) -> WorkflowNodeExecution: """ Run draft workflow node """ # fetch draft workflow by app_model draft_workflow = self.get_draft_workflow(app_model=app_model) if not draft_workflow: - raise ValueError('Workflow not initialized') + raise ValueError("Workflow not initialized") # run draft workflow node workflow_engine_manager = WorkflowEngineManager() @@ -226,7 +226,7 @@ class WorkflowService: created_by_role=CreatedByRole.ACCOUNT.value, created_by=account.id, created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None) + finished_at=datetime.now(timezone.utc).replace(tzinfo=None), ) db.session.add(workflow_node_execution) db.session.commit() @@ -247,14 +247,15 @@ class WorkflowService: inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None, process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None, outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None, - execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata)) - if node_run_result.metadata else None), + execution_metadata=( + json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None + ), status=WorkflowNodeExecutionStatus.SUCCEEDED.value, elapsed_time=time.perf_counter() - start_at, created_by_role=CreatedByRole.ACCOUNT.value, created_by=account.id, created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None) + finished_at=datetime.now(timezone.utc).replace(tzinfo=None), ) else: # create workflow node execution @@ -273,7 +274,7 @@ class WorkflowService: created_by_role=CreatedByRole.ACCOUNT.value, created_by=account.id, created_at=datetime.now(timezone.utc).replace(tzinfo=None), - finished_at=datetime.now(timezone.utc).replace(tzinfo=None) + finished_at=datetime.now(timezone.utc).replace(tzinfo=None), ) db.session.add(workflow_node_execution) @@ -295,16 +296,16 @@ class WorkflowService: workflow_converter = WorkflowConverter() if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]: - raise ValueError(f'Current App mode: {app_model.mode} is not supported convert to workflow.') + raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.") # convert to workflow new_app = workflow_converter.convert_to_workflow( app_model=app_model, account=account, - name=args.get('name'), - icon_type=args.get('icon_type'), - icon=args.get('icon'), - icon_background=args.get('icon_background'), + name=args.get("name"), + icon_type=args.get("icon_type"), + icon=args.get("icon"), + icon_background=args.get("icon_background"), ) return new_app @@ -312,15 +313,11 @@ class WorkflowService: def validate_features_structure(self, app_model: App, features: dict) -> dict: if app_model.mode == AppMode.ADVANCED_CHAT.value: return AdvancedChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=features, - only_structure_validate=True + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) elif app_model.mode == AppMode.WORKFLOW.value: return WorkflowAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, - config=features, - only_structure_validate=True + tenant_id=app_model.tenant_id, config=features, only_structure_validate=True ) else: raise ValueError(f"Invalid app mode: {app_model.mode}") diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 2bcbe5c6f..8fcb12b1c 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,4 +1,3 @@ - from flask_login import current_user from configs import dify_config @@ -14,34 +13,40 @@ class WorkspaceService: if not tenant: return None tenant_info = { - 'id': tenant.id, - 'name': tenant.name, - 'plan': tenant.plan, - 'status': tenant.status, - 'created_at': tenant.created_at, - 'in_trail': True, - 'trial_end_reason': None, - 'role': 'normal', + "id": tenant.id, + "name": tenant.name, + "plan": tenant.plan, + "status": tenant.status, + "created_at": tenant.created_at, + "in_trail": True, + "trial_end_reason": None, + "role": "normal", } # Get role of user - tenant_account_join = db.session.query(TenantAccountJoin).filter( - TenantAccountJoin.tenant_id == tenant.id, - TenantAccountJoin.account_id == current_user.id - ).first() - tenant_info['role'] = tenant_account_join.role + tenant_account_join = ( + db.session.query(TenantAccountJoin) + .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id) + .first() + ) + tenant_info["role"] = tenant_account_join.role - can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo + can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo - if can_replace_logo and TenantService.has_roles(tenant, - [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]): + if can_replace_logo and TenantService.has_roles( + tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN] + ): base_url = dify_config.FILES_URL - replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None - remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False) + replace_webapp_logo = ( + f"{base_url}/files/workspaces/{tenant.id}/webapp-logo" + if tenant.custom_config_dict.get("replace_webapp_logo") + else None + ) + remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False) - tenant_info['custom_config'] = { - 'remove_webapp_brand': remove_webapp_brand, - 'replace_webapp_logo': replace_webapp_logo, + tenant_info["custom_config"] = { + "remove_webapp_brand": remove_webapp_brand, + "replace_webapp_logo": replace_webapp_logo, } return tenant_info