mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-11-30 02:08:37 +08:00
chore(api/services): apply ruff reformatting (#7599)
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
979422cdc6
commit
17fd773a30
@ -74,7 +74,6 @@ exclude = [
|
||||
"controllers/**/*.py",
|
||||
"models/**/*.py",
|
||||
"migrations/**/*",
|
||||
"services/**/*.py",
|
||||
]
|
||||
|
||||
[tool.pytest_env]
|
||||
|
@ -1,3 +1,3 @@
|
||||
from . import errors
|
||||
|
||||
__all__ = ['errors']
|
||||
__all__ = ["errors"]
|
||||
|
@ -39,12 +39,7 @@ from tasks.mail_reset_password_task import send_reset_password_mail_task
|
||||
|
||||
|
||||
class AccountService:
|
||||
|
||||
reset_password_rate_limiter = RateLimiter(
|
||||
prefix="reset_password_rate_limit",
|
||||
max_attempts=5,
|
||||
time_window=60 * 60
|
||||
)
|
||||
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=5, time_window=60 * 60)
|
||||
|
||||
@staticmethod
|
||||
def load_user(user_id: str) -> None | Account:
|
||||
@ -55,12 +50,15 @@ class AccountService:
|
||||
if account.status in [AccountStatus.BANNED.value, AccountStatus.CLOSED.value]:
|
||||
raise Unauthorized("Account is banned or closed.")
|
||||
|
||||
current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
|
||||
current_tenant: TenantAccountJoin = TenantAccountJoin.query.filter_by(
|
||||
account_id=account.id, current=True
|
||||
).first()
|
||||
if current_tenant:
|
||||
account.current_tenant_id = current_tenant.tenant_id
|
||||
else:
|
||||
available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \
|
||||
.order_by(TenantAccountJoin.id.asc()).first()
|
||||
available_ta = (
|
||||
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
|
||||
)
|
||||
if not available_ta:
|
||||
return None
|
||||
|
||||
@ -74,14 +72,13 @@ class AccountService:
|
||||
|
||||
return account
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_account_jwt_token(account, *, exp: timedelta = timedelta(days=30)):
|
||||
payload = {
|
||||
"user_id": account.id,
|
||||
"exp": datetime.now(timezone.utc).replace(tzinfo=None) + exp,
|
||||
"iss": dify_config.EDITION,
|
||||
"sub": 'Console API Passport',
|
||||
"sub": "Console API Passport",
|
||||
}
|
||||
|
||||
token = PassportService().issue(payload)
|
||||
@ -93,10 +90,10 @@ class AccountService:
|
||||
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
if not account:
|
||||
raise AccountLoginError('Invalid email or password.')
|
||||
raise AccountLoginError("Invalid email or password.")
|
||||
|
||||
if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value:
|
||||
raise AccountLoginError('Account is banned or closed.')
|
||||
raise AccountLoginError("Account is banned or closed.")
|
||||
|
||||
if account.status == AccountStatus.PENDING.value:
|
||||
account.status = AccountStatus.ACTIVE.value
|
||||
@ -104,7 +101,7 @@ class AccountService:
|
||||
db.session.commit()
|
||||
|
||||
if account.password is None or not compare_password(password, account.password, account.password_salt):
|
||||
raise AccountLoginError('Invalid email or password.')
|
||||
raise AccountLoginError("Invalid email or password.")
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
@ -129,11 +126,9 @@ class AccountService:
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def create_account(email: str,
|
||||
name: str,
|
||||
interface_language: str,
|
||||
password: Optional[str] = None,
|
||||
interface_theme: str = 'light') -> Account:
|
||||
def create_account(
|
||||
email: str, name: str, interface_language: str, password: Optional[str] = None, interface_theme: str = "light"
|
||||
) -> Account:
|
||||
"""create account"""
|
||||
account = Account()
|
||||
account.email = email
|
||||
@ -155,7 +150,7 @@ class AccountService:
|
||||
account.interface_theme = interface_theme
|
||||
|
||||
# Set timezone based on language
|
||||
account.timezone = language_timezone_mapping.get(interface_language, 'UTC')
|
||||
account.timezone = language_timezone_mapping.get(interface_language, "UTC")
|
||||
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
@ -166,8 +161,9 @@ class AccountService:
|
||||
"""Link account integrate"""
|
||||
try:
|
||||
# Query whether there is an existing binding record for the same provider
|
||||
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(account_id=account.id,
|
||||
provider=provider).first()
|
||||
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(
|
||||
account_id=account.id, provider=provider
|
||||
).first()
|
||||
|
||||
if account_integrate:
|
||||
# If it exists, update the record
|
||||
@ -176,15 +172,16 @@ class AccountService:
|
||||
account_integrate.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
else:
|
||||
# If it does not exist, create a new record
|
||||
account_integrate = AccountIntegrate(account_id=account.id, provider=provider, open_id=open_id,
|
||||
encrypted_token="")
|
||||
account_integrate = AccountIntegrate(
|
||||
account_id=account.id, provider=provider, open_id=open_id, encrypted_token=""
|
||||
)
|
||||
db.session.add(account_integrate)
|
||||
|
||||
db.session.commit()
|
||||
logging.info(f'Account {account.id} linked {provider} account {open_id}.')
|
||||
logging.info(f"Account {account.id} linked {provider} account {open_id}.")
|
||||
except Exception as e:
|
||||
logging.exception(f'Failed to link {provider} account {open_id} to Account {account.id}')
|
||||
raise LinkAccountIntegrateError('Failed to link account.') from e
|
||||
logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}")
|
||||
raise LinkAccountIntegrateError("Failed to link account.") from e
|
||||
|
||||
@staticmethod
|
||||
def close_account(account: Account) -> None:
|
||||
@ -218,7 +215,7 @@ class AccountService:
|
||||
AccountService.update_last_login(account, ip_address=ip_address)
|
||||
exp = timedelta(days=30)
|
||||
token = AccountService.get_account_jwt_token(account, exp=exp)
|
||||
redis_client.set(_get_login_cache_key(account_id=account.id, token=token), '1', ex=int(exp.total_seconds()))
|
||||
redis_client.set(_get_login_cache_key(account_id=account.id, token=token), "1", ex=int(exp.total_seconds()))
|
||||
return token
|
||||
|
||||
@staticmethod
|
||||
@ -236,22 +233,18 @@ class AccountService:
|
||||
if cls.reset_password_rate_limiter.is_rate_limited(account.email):
|
||||
raise RateLimitExceededError(f"Rate limit exceeded for email: {account.email}. Please try again later.")
|
||||
|
||||
token = TokenManager.generate_token(account, 'reset_password')
|
||||
send_reset_password_mail_task.delay(
|
||||
language=account.interface_language,
|
||||
to=account.email,
|
||||
token=token
|
||||
)
|
||||
token = TokenManager.generate_token(account, "reset_password")
|
||||
send_reset_password_mail_task.delay(language=account.interface_language, to=account.email, token=token)
|
||||
cls.reset_password_rate_limiter.increment_rate_limit(account.email)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def revoke_reset_password_token(cls, token: str):
|
||||
TokenManager.revoke_token(token, 'reset_password')
|
||||
TokenManager.revoke_token(token, "reset_password")
|
||||
|
||||
@classmethod
|
||||
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
|
||||
return TokenManager.get_token_data(token, 'reset_password')
|
||||
return TokenManager.get_token_data(token, "reset_password")
|
||||
|
||||
|
||||
def _get_login_cache_key(*, account_id: str, token: str):
|
||||
@ -259,7 +252,6 @@ def _get_login_cache_key(*, account_id: str, token: str):
|
||||
|
||||
|
||||
class TenantService:
|
||||
|
||||
@staticmethod
|
||||
def create_tenant(name: str) -> Tenant:
|
||||
"""Create tenant"""
|
||||
@ -275,31 +267,28 @@ class TenantService:
|
||||
@staticmethod
|
||||
def create_owner_tenant_if_not_exist(account: Account):
|
||||
"""Create owner tenant if not exist"""
|
||||
available_ta = TenantAccountJoin.query.filter_by(account_id=account.id) \
|
||||
.order_by(TenantAccountJoin.id.asc()).first()
|
||||
available_ta = (
|
||||
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
|
||||
)
|
||||
|
||||
if available_ta:
|
||||
return
|
||||
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
TenantService.create_tenant_member(tenant, account, role='owner')
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
db.session.commit()
|
||||
tenant_was_created.send(tenant)
|
||||
|
||||
@staticmethod
|
||||
def create_tenant_member(tenant: Tenant, account: Account, role: str = 'normal') -> TenantAccountJoin:
|
||||
def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin:
|
||||
"""Create tenant member"""
|
||||
if role == TenantAccountJoinRole.OWNER.value:
|
||||
if TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER]):
|
||||
logging.error(f'Tenant {tenant.id} has already an owner.')
|
||||
raise Exception('Tenant already has an owner.')
|
||||
logging.error(f"Tenant {tenant.id} has already an owner.")
|
||||
raise Exception("Tenant already has an owner.")
|
||||
|
||||
ta = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=role
|
||||
)
|
||||
ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role)
|
||||
db.session.add(ta)
|
||||
db.session.commit()
|
||||
return ta
|
||||
@ -307,9 +296,12 @@ class TenantService:
|
||||
@staticmethod
|
||||
def get_join_tenants(account: Account) -> list[Tenant]:
|
||||
"""Get account join tenants"""
|
||||
return db.session.query(Tenant).join(
|
||||
TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id
|
||||
).filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL).all()
|
||||
return (
|
||||
db.session.query(Tenant)
|
||||
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
|
||||
.all()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_current_tenant_by_account(account: Account):
|
||||
@ -333,16 +325,23 @@ class TenantService:
|
||||
if tenant_id is None:
|
||||
raise ValueError("Tenant ID must be provided.")
|
||||
|
||||
tenant_account_join = db.session.query(TenantAccountJoin).join(Tenant, TenantAccountJoin.tenant_id == Tenant.id).filter(
|
||||
TenantAccountJoin.account_id == account.id,
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
Tenant.status == TenantStatus.NORMAL,
|
||||
).first()
|
||||
tenant_account_join = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.join(Tenant, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.filter(
|
||||
TenantAccountJoin.account_id == account.id,
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
Tenant.status == TenantStatus.NORMAL,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tenant_account_join:
|
||||
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
|
||||
else:
|
||||
TenantAccountJoin.query.filter(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id).update({'current': False})
|
||||
TenantAccountJoin.query.filter(
|
||||
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id
|
||||
).update({"current": False})
|
||||
tenant_account_join.current = True
|
||||
# Set the current tenant for the account
|
||||
account.current_tenant_id = tenant_account_join.tenant_id
|
||||
@ -354,9 +353,7 @@ class TenantService:
|
||||
query = (
|
||||
db.session.query(Account, TenantAccountJoin.role)
|
||||
.select_from(Account)
|
||||
.join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
)
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||
.filter(TenantAccountJoin.tenant_id == tenant.id)
|
||||
)
|
||||
|
||||
@ -375,11 +372,9 @@ class TenantService:
|
||||
query = (
|
||||
db.session.query(Account, TenantAccountJoin.role)
|
||||
.select_from(Account)
|
||||
.join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
)
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||
.filter(TenantAccountJoin.tenant_id == tenant.id)
|
||||
.filter(TenantAccountJoin.role == 'dataset_operator')
|
||||
.filter(TenantAccountJoin.role == "dataset_operator")
|
||||
)
|
||||
|
||||
# Initialize an empty list to store the updated accounts
|
||||
@ -395,20 +390,25 @@ class TenantService:
|
||||
def has_roles(tenant: Tenant, roles: list[TenantAccountJoinRole]) -> bool:
|
||||
"""Check if user has any of the given roles for a tenant"""
|
||||
if not all(isinstance(role, TenantAccountJoinRole) for role in roles):
|
||||
raise ValueError('all roles must be TenantAccountJoinRole')
|
||||
raise ValueError("all roles must be TenantAccountJoinRole")
|
||||
|
||||
return db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.tenant_id == tenant.id,
|
||||
TenantAccountJoin.role.in_([role.value for role in roles])
|
||||
).first() is not None
|
||||
return (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter(
|
||||
TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles])
|
||||
)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_role(account: Account, tenant: Tenant) -> Optional[TenantAccountJoinRole]:
|
||||
"""Get the role of the current account for a given tenant"""
|
||||
join = db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.tenant_id == tenant.id,
|
||||
TenantAccountJoin.account_id == account.id
|
||||
).first()
|
||||
join = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
|
||||
.first()
|
||||
)
|
||||
return join.role if join else None
|
||||
|
||||
@staticmethod
|
||||
@ -420,29 +420,26 @@ class TenantService:
|
||||
def check_member_permission(tenant: Tenant, operator: Account, member: Account, action: str) -> None:
|
||||
"""Check member permission"""
|
||||
perms = {
|
||||
'add': [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
|
||||
'remove': [TenantAccountRole.OWNER],
|
||||
'update': [TenantAccountRole.OWNER]
|
||||
"add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
|
||||
"remove": [TenantAccountRole.OWNER],
|
||||
"update": [TenantAccountRole.OWNER],
|
||||
}
|
||||
if action not in ['add', 'remove', 'update']:
|
||||
if action not in ["add", "remove", "update"]:
|
||||
raise InvalidActionError("Invalid action.")
|
||||
|
||||
if member:
|
||||
if operator.id == member.id:
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
|
||||
ta_operator = TenantAccountJoin.query.filter_by(
|
||||
tenant_id=tenant.id,
|
||||
account_id=operator.id
|
||||
).first()
|
||||
ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first()
|
||||
|
||||
if not ta_operator or ta_operator.role not in perms[action]:
|
||||
raise NoPermissionError(f'No permission to {action} member.')
|
||||
raise NoPermissionError(f"No permission to {action} member.")
|
||||
|
||||
@staticmethod
|
||||
def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
|
||||
"""Remove member from tenant"""
|
||||
if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, 'remove'):
|
||||
if operator.id == account.id and TenantService.check_member_permission(tenant, operator, account, "remove"):
|
||||
raise CannotOperateSelfError("Cannot operate self.")
|
||||
|
||||
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
@ -455,23 +452,17 @@ class TenantService:
|
||||
@staticmethod
|
||||
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
|
||||
"""Update member role"""
|
||||
TenantService.check_member_permission(tenant, operator, member, 'update')
|
||||
TenantService.check_member_permission(tenant, operator, member, "update")
|
||||
|
||||
target_member_join = TenantAccountJoin.query.filter_by(
|
||||
tenant_id=tenant.id,
|
||||
account_id=member.id
|
||||
).first()
|
||||
target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first()
|
||||
|
||||
if target_member_join.role == new_role:
|
||||
raise RoleAlreadyAssignedError("The provided role is already assigned to the member.")
|
||||
|
||||
if new_role == 'owner':
|
||||
if new_role == "owner":
|
||||
# Find the current owner and change their role to 'admin'
|
||||
current_owner_join = TenantAccountJoin.query.filter_by(
|
||||
tenant_id=tenant.id,
|
||||
role='owner'
|
||||
).first()
|
||||
current_owner_join.role = 'admin'
|
||||
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
|
||||
current_owner_join.role = "admin"
|
||||
|
||||
# Update the role of the target member
|
||||
target_member_join.role = new_role
|
||||
@ -480,8 +471,8 @@ class TenantService:
|
||||
@staticmethod
|
||||
def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
|
||||
"""Dissolve tenant"""
|
||||
if not TenantService.check_member_permission(tenant, operator, operator, 'remove'):
|
||||
raise NoPermissionError('No permission to dissolve tenant.')
|
||||
if not TenantService.check_member_permission(tenant, operator, operator, "remove"):
|
||||
raise NoPermissionError("No permission to dissolve tenant.")
|
||||
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
|
||||
db.session.delete(tenant)
|
||||
db.session.commit()
|
||||
@ -494,10 +485,9 @@ class TenantService:
|
||||
|
||||
|
||||
class RegisterService:
|
||||
|
||||
@classmethod
|
||||
def _get_invitation_token_key(cls, token: str) -> str:
|
||||
return f'member_invite:token:{token}'
|
||||
return f"member_invite:token:{token}"
|
||||
|
||||
@classmethod
|
||||
def setup(cls, email: str, name: str, password: str, ip_address: str) -> None:
|
||||
@ -523,9 +513,7 @@ class RegisterService:
|
||||
|
||||
TenantService.create_owner_tenant_if_not_exist(account)
|
||||
|
||||
dify_setup = DifySetup(
|
||||
version=dify_config.CURRENT_VERSION
|
||||
)
|
||||
dify_setup = DifySetup(version=dify_config.CURRENT_VERSION)
|
||||
db.session.add(dify_setup)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
@ -535,34 +523,35 @@ class RegisterService:
|
||||
db.session.query(Tenant).delete()
|
||||
db.session.commit()
|
||||
|
||||
logging.exception(f'Setup failed: {e}')
|
||||
raise ValueError(f'Setup failed: {e}')
|
||||
logging.exception(f"Setup failed: {e}")
|
||||
raise ValueError(f"Setup failed: {e}")
|
||||
|
||||
@classmethod
|
||||
def register(cls, email, name,
|
||||
password: Optional[str] = None,
|
||||
open_id: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
status: Optional[AccountStatus] = None) -> Account:
|
||||
def register(
|
||||
cls,
|
||||
email,
|
||||
name,
|
||||
password: Optional[str] = None,
|
||||
open_id: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
language: Optional[str] = None,
|
||||
status: Optional[AccountStatus] = None,
|
||||
) -> Account:
|
||||
db.session.begin_nested()
|
||||
"""Register account"""
|
||||
try:
|
||||
account = AccountService.create_account(
|
||||
email=email,
|
||||
name=name,
|
||||
interface_language=language if language else languages[0],
|
||||
password=password
|
||||
email=email, name=name, interface_language=language if language else languages[0], password=password
|
||||
)
|
||||
account.status = AccountStatus.ACTIVE.value if not status else status.value
|
||||
account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
if open_id is not None or provider is not None:
|
||||
AccountService.link_account_integrate(provider, open_id, account)
|
||||
if dify_config.EDITION != 'SELF_HOSTED':
|
||||
if dify_config.EDITION != "SELF_HOSTED":
|
||||
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
|
||||
|
||||
TenantService.create_tenant_member(tenant, account, role='owner')
|
||||
TenantService.create_tenant_member(tenant, account, role="owner")
|
||||
account.current_tenant = tenant
|
||||
|
||||
tenant_was_created.send(tenant)
|
||||
@ -570,30 +559,29 @@ class RegisterService:
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logging.error(f'Register failed: {e}')
|
||||
raise AccountRegisterError(f'Registration failed: {e}') from e
|
||||
logging.error(f"Register failed: {e}")
|
||||
raise AccountRegisterError(f"Registration failed: {e}") from e
|
||||
|
||||
return account
|
||||
|
||||
@classmethod
|
||||
def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal', inviter: Account = None) -> str:
|
||||
def invite_new_member(
|
||||
cls, tenant: Tenant, email: str, language: str, role: str = "normal", inviter: Account = None
|
||||
) -> str:
|
||||
"""Invite new member"""
|
||||
account = Account.query.filter_by(email=email).first()
|
||||
|
||||
if not account:
|
||||
TenantService.check_member_permission(tenant, inviter, None, 'add')
|
||||
name = email.split('@')[0]
|
||||
TenantService.check_member_permission(tenant, inviter, None, "add")
|
||||
name = email.split("@")[0]
|
||||
|
||||
account = cls.register(email=email, name=name, language=language, status=AccountStatus.PENDING)
|
||||
# Create new tenant member for invited tenant
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
TenantService.switch_tenant(account, tenant.id)
|
||||
else:
|
||||
TenantService.check_member_permission(tenant, inviter, account, 'add')
|
||||
ta = TenantAccountJoin.query.filter_by(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id
|
||||
).first()
|
||||
TenantService.check_member_permission(tenant, inviter, account, "add")
|
||||
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
|
||||
|
||||
if not ta:
|
||||
TenantService.create_tenant_member(tenant, account, role)
|
||||
@ -609,7 +597,7 @@ class RegisterService:
|
||||
language=account.interface_language,
|
||||
to=email,
|
||||
token=token,
|
||||
inviter_name=inviter.name if inviter else 'Dify',
|
||||
inviter_name=inviter.name if inviter else "Dify",
|
||||
workspace_name=tenant.name,
|
||||
)
|
||||
|
||||
@ -619,23 +607,19 @@ class RegisterService:
|
||||
def generate_invite_token(cls, tenant: Tenant, account: Account) -> str:
|
||||
token = str(uuid.uuid4())
|
||||
invitation_data = {
|
||||
'account_id': account.id,
|
||||
'email': account.email,
|
||||
'workspace_id': tenant.id,
|
||||
"account_id": account.id,
|
||||
"email": account.email,
|
||||
"workspace_id": tenant.id,
|
||||
}
|
||||
expiryHours = dify_config.INVITE_EXPIRY_HOURS
|
||||
redis_client.setex(
|
||||
cls._get_invitation_token_key(token),
|
||||
expiryHours * 60 * 60,
|
||||
json.dumps(invitation_data)
|
||||
)
|
||||
redis_client.setex(cls._get_invitation_token_key(token), expiryHours * 60 * 60, json.dumps(invitation_data))
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def revoke_token(cls, workspace_id: str, email: str, token: str):
|
||||
if workspace_id and email:
|
||||
email_hash = sha256(email.encode()).hexdigest()
|
||||
cache_key = 'member_invite_token:{}, {}:{}'.format(workspace_id, email_hash, token)
|
||||
cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token)
|
||||
redis_client.delete(cache_key)
|
||||
else:
|
||||
redis_client.delete(cls._get_invitation_token_key(token))
|
||||
@ -646,17 +630,21 @@ class RegisterService:
|
||||
if not invitation_data:
|
||||
return None
|
||||
|
||||
tenant = db.session.query(Tenant).filter(
|
||||
Tenant.id == invitation_data['workspace_id'],
|
||||
Tenant.status == 'normal'
|
||||
).first()
|
||||
tenant = (
|
||||
db.session.query(Tenant)
|
||||
.filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tenant:
|
||||
return None
|
||||
|
||||
tenant_account = db.session.query(Account, TenantAccountJoin.role).join(
|
||||
TenantAccountJoin, Account.id == TenantAccountJoin.account_id
|
||||
).filter(Account.email == invitation_data['email'], TenantAccountJoin.tenant_id == tenant.id).first()
|
||||
tenant_account = (
|
||||
db.session.query(Account, TenantAccountJoin.role)
|
||||
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
|
||||
.filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tenant_account:
|
||||
return None
|
||||
@ -665,29 +653,29 @@ class RegisterService:
|
||||
if not account:
|
||||
return None
|
||||
|
||||
if invitation_data['account_id'] != str(account.id):
|
||||
if invitation_data["account_id"] != str(account.id):
|
||||
return None
|
||||
|
||||
return {
|
||||
'account': account,
|
||||
'data': invitation_data,
|
||||
'tenant': tenant,
|
||||
"account": account,
|
||||
"data": invitation_data,
|
||||
"tenant": tenant,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_invitation_by_token(cls, token: str, workspace_id: str, email: str) -> Optional[dict[str, str]]:
|
||||
if workspace_id is not None and email is not None:
|
||||
email_hash = sha256(email.encode()).hexdigest()
|
||||
cache_key = f'member_invite_token:{workspace_id}, {email_hash}:{token}'
|
||||
cache_key = f"member_invite_token:{workspace_id}, {email_hash}:{token}"
|
||||
account_id = redis_client.get(cache_key)
|
||||
|
||||
if not account_id:
|
||||
return None
|
||||
|
||||
return {
|
||||
'account_id': account_id.decode('utf-8'),
|
||||
'email': email,
|
||||
'workspace_id': workspace_id,
|
||||
"account_id": account_id.decode("utf-8"),
|
||||
"email": email,
|
||||
"workspace_id": workspace_id,
|
||||
}
|
||||
else:
|
||||
data = redis_client.get(cls._get_invitation_token_key(token))
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
import copy
|
||||
|
||||
from core.prompt.prompt_templates.advanced_prompt_templates import (
|
||||
@ -17,59 +16,78 @@ from models.model import AppMode
|
||||
|
||||
|
||||
class AdvancedPromptTemplateService:
|
||||
|
||||
@classmethod
|
||||
def get_prompt(cls, args: dict) -> dict:
|
||||
app_mode = args['app_mode']
|
||||
model_mode = args['model_mode']
|
||||
model_name = args['model_name']
|
||||
has_context = args['has_context']
|
||||
app_mode = args["app_mode"]
|
||||
model_mode = args["model_mode"]
|
||||
model_name = args["model_name"]
|
||||
has_context = args["has_context"]
|
||||
|
||||
if 'baichuan' in model_name.lower():
|
||||
if "baichuan" in model_name.lower():
|
||||
return cls.get_baichuan_prompt(app_mode, model_mode, has_context)
|
||||
else:
|
||||
return cls.get_common_prompt(app_mode, model_mode, has_context)
|
||||
|
||||
@classmethod
|
||||
def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
|
||||
def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
|
||||
context_prompt = copy.deepcopy(CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT.value:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
|
||||
)
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
|
||||
elif app_mode == AppMode.COMPLETION.value:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt)
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, context_prompt
|
||||
)
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt)
|
||||
|
||||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, context_prompt
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_completion_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
|
||||
if has_context == 'true':
|
||||
prompt_template['completion_prompt_config']['prompt']['text'] = context + prompt_template['completion_prompt_config']['prompt']['text']
|
||||
|
||||
if has_context == "true":
|
||||
prompt_template["completion_prompt_config"]["prompt"]["text"] = (
|
||||
context + prompt_template["completion_prompt_config"]["prompt"]["text"]
|
||||
)
|
||||
|
||||
return prompt_template
|
||||
|
||||
@classmethod
|
||||
def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) -> dict:
|
||||
if has_context == 'true':
|
||||
prompt_template['chat_prompt_config']['prompt'][0]['text'] = context + prompt_template['chat_prompt_config']['prompt'][0]['text']
|
||||
|
||||
if has_context == "true":
|
||||
prompt_template["chat_prompt_config"]["prompt"][0]["text"] = (
|
||||
context + prompt_template["chat_prompt_config"]["prompt"][0]["text"]
|
||||
)
|
||||
|
||||
return prompt_template
|
||||
|
||||
@classmethod
|
||||
def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict:
|
||||
def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict:
|
||||
baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT)
|
||||
|
||||
if app_mode == AppMode.CHAT.value:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt
|
||||
)
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
|
||||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
|
||||
)
|
||||
elif app_mode == AppMode.COMPLETION.value:
|
||||
if model_mode == "completion":
|
||||
return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt)
|
||||
return cls.get_completion_prompt(
|
||||
copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG),
|
||||
has_context,
|
||||
baichuan_context_prompt,
|
||||
)
|
||||
elif model_mode == "chat":
|
||||
return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt)
|
||||
return cls.get_chat_prompt(
|
||||
copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt
|
||||
)
|
||||
|
@ -10,59 +10,65 @@ from models.model import App, Conversation, EndUser, Message, MessageAgentThough
|
||||
|
||||
class AgentService:
|
||||
@classmethod
|
||||
def get_agent_logs(cls, app_model: App,
|
||||
conversation_id: str,
|
||||
message_id: str) -> dict:
|
||||
def get_agent_logs(cls, app_model: App, conversation_id: str, message_id: str) -> dict:
|
||||
"""
|
||||
Service to get agent logs
|
||||
"""
|
||||
conversation: Conversation = db.session.query(Conversation).filter(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
).first()
|
||||
conversation: Conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise ValueError(f"Conversation not found: {conversation_id}")
|
||||
|
||||
message: Message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.conversation_id == conversation_id,
|
||||
).first()
|
||||
message: Message = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.id == message_id,
|
||||
Message.conversation_id == conversation_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise ValueError(f"Message not found: {message_id}")
|
||||
|
||||
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
|
||||
if conversation.from_end_user_id:
|
||||
# only select name field
|
||||
executor = db.session.query(EndUser, EndUser.name).filter(
|
||||
EndUser.id == conversation.from_end_user_id
|
||||
).first()
|
||||
executor = (
|
||||
db.session.query(EndUser, EndUser.name).filter(EndUser.id == conversation.from_end_user_id).first()
|
||||
)
|
||||
else:
|
||||
executor = db.session.query(Account, Account.name).filter(
|
||||
Account.id == conversation.from_account_id
|
||||
).first()
|
||||
|
||||
executor = (
|
||||
db.session.query(Account, Account.name).filter(Account.id == conversation.from_account_id).first()
|
||||
)
|
||||
|
||||
if executor:
|
||||
executor = executor.name
|
||||
else:
|
||||
executor = 'Unknown'
|
||||
executor = "Unknown"
|
||||
|
||||
timezone = pytz.timezone(current_user.timezone)
|
||||
|
||||
result = {
|
||||
'meta': {
|
||||
'status': 'success',
|
||||
'executor': executor,
|
||||
'start_time': message.created_at.astimezone(timezone).isoformat(),
|
||||
'elapsed_time': message.provider_response_latency,
|
||||
'total_tokens': message.answer_tokens + message.message_tokens,
|
||||
'agent_mode': app_model.app_model_config.agent_mode_dict.get('strategy', 'react'),
|
||||
'iterations': len(agent_thoughts),
|
||||
"meta": {
|
||||
"status": "success",
|
||||
"executor": executor,
|
||||
"start_time": message.created_at.astimezone(timezone).isoformat(),
|
||||
"elapsed_time": message.provider_response_latency,
|
||||
"total_tokens": message.answer_tokens + message.message_tokens,
|
||||
"agent_mode": app_model.app_model_config.agent_mode_dict.get("strategy", "react"),
|
||||
"iterations": len(agent_thoughts),
|
||||
},
|
||||
'iterations': [],
|
||||
'files': message.files,
|
||||
"iterations": [],
|
||||
"files": message.files,
|
||||
}
|
||||
|
||||
agent_config = AgentConfigManager.convert(app_model.app_model_config.to_dict())
|
||||
@ -86,12 +92,12 @@ class AgentService:
|
||||
tool_input = tool_inputs.get(tool_name, {})
|
||||
tool_output = tool_outputs.get(tool_name, {})
|
||||
tool_meta_data = tool_meta.get(tool_name, {})
|
||||
tool_config = tool_meta_data.get('tool_config', {})
|
||||
if tool_config.get('tool_provider_type', '') != 'dataset-retrieval':
|
||||
tool_config = tool_meta_data.get("tool_config", {})
|
||||
if tool_config.get("tool_provider_type", "") != "dataset-retrieval":
|
||||
tool_icon = ToolManager.get_tool_icon(
|
||||
tenant_id=app_model.tenant_id,
|
||||
provider_type=tool_config.get('tool_provider_type', ''),
|
||||
provider_id=tool_config.get('tool_provider', ''),
|
||||
provider_type=tool_config.get("tool_provider_type", ""),
|
||||
provider_id=tool_config.get("tool_provider", ""),
|
||||
)
|
||||
if not tool_icon:
|
||||
tool_entity = find_agent_tool(tool_name)
|
||||
@ -102,30 +108,34 @@ class AgentService:
|
||||
provider_id=tool_entity.provider_id,
|
||||
)
|
||||
else:
|
||||
tool_icon = ''
|
||||
tool_icon = ""
|
||||
|
||||
tool_calls.append({
|
||||
'status': 'success' if not tool_meta_data.get('error') else 'error',
|
||||
'error': tool_meta_data.get('error'),
|
||||
'time_cost': tool_meta_data.get('time_cost', 0),
|
||||
'tool_name': tool_name,
|
||||
'tool_label': tool_label,
|
||||
'tool_input': tool_input,
|
||||
'tool_output': tool_output,
|
||||
'tool_parameters': tool_meta_data.get('tool_parameters', {}),
|
||||
'tool_icon': tool_icon,
|
||||
})
|
||||
tool_calls.append(
|
||||
{
|
||||
"status": "success" if not tool_meta_data.get("error") else "error",
|
||||
"error": tool_meta_data.get("error"),
|
||||
"time_cost": tool_meta_data.get("time_cost", 0),
|
||||
"tool_name": tool_name,
|
||||
"tool_label": tool_label,
|
||||
"tool_input": tool_input,
|
||||
"tool_output": tool_output,
|
||||
"tool_parameters": tool_meta_data.get("tool_parameters", {}),
|
||||
"tool_icon": tool_icon,
|
||||
}
|
||||
)
|
||||
|
||||
result['iterations'].append({
|
||||
'tokens': agent_thought.tokens,
|
||||
'tool_calls': tool_calls,
|
||||
'tool_raw': {
|
||||
'inputs': agent_thought.tool_input,
|
||||
'outputs': agent_thought.observation,
|
||||
},
|
||||
'thought': agent_thought.thought,
|
||||
'created_at': agent_thought.created_at.isoformat(),
|
||||
'files': agent_thought.files,
|
||||
})
|
||||
result["iterations"].append(
|
||||
{
|
||||
"tokens": agent_thought.tokens,
|
||||
"tool_calls": tool_calls,
|
||||
"tool_raw": {
|
||||
"inputs": agent_thought.tool_input,
|
||||
"outputs": agent_thought.observation,
|
||||
},
|
||||
"thought": agent_thought.thought,
|
||||
"created_at": agent_thought.created_at.isoformat(),
|
||||
"files": agent_thought.files,
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
return result
|
||||
|
@ -23,21 +23,18 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
if args.get('message_id'):
|
||||
message_id = str(args['message_id'])
|
||||
if args.get("message_id"):
|
||||
message_id = str(args["message_id"])
|
||||
# get message info
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app.id
|
||||
).first()
|
||||
message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app.id).first()
|
||||
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
@ -45,159 +42,166 @@ class AppAnnotationService:
|
||||
annotation = message.annotation
|
||||
# save the message annotation
|
||||
if annotation:
|
||||
annotation.content = args['answer']
|
||||
annotation.question = args['question']
|
||||
annotation.content = args["answer"]
|
||||
annotation.question = args["question"]
|
||||
else:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
content=args['answer'],
|
||||
question=args['question'],
|
||||
account_id=current_user.id
|
||||
content=args["answer"],
|
||||
question=args["question"],
|
||||
account_id=current_user.id,
|
||||
)
|
||||
else:
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
content=args['answer'],
|
||||
question=args['question'],
|
||||
account_id=current_user.id
|
||||
app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
|
||||
)
|
||||
db.session.add(annotation)
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id).first()
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id,
|
||||
app_id, annotation_setting.collection_binding_id)
|
||||
add_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
args["question"],
|
||||
current_user.current_tenant_id,
|
||||
app_id,
|
||||
annotation_setting.collection_binding_id,
|
||||
)
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
|
||||
enable_app_annotation_key = 'enable_app_annotation_{}'.format(str(app_id))
|
||||
enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id))
|
||||
cache_result = redis_client.get(enable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
return {
|
||||
'job_id': cache_result,
|
||||
'job_status': 'processing'
|
||||
}
|
||||
return {"job_id": cache_result, "job_status": "processing"}
|
||||
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
enable_app_annotation_job_key = 'enable_app_annotation_job_{}'.format(str(job_id))
|
||||
enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id))
|
||||
# send batch add segments task
|
||||
redis_client.setnx(enable_app_annotation_job_key, 'waiting')
|
||||
enable_annotation_reply_task.delay(str(job_id), app_id, current_user.id, current_user.current_tenant_id,
|
||||
args['score_threshold'],
|
||||
args['embedding_provider_name'], args['embedding_model_name'])
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': 'waiting'
|
||||
}
|
||||
redis_client.setnx(enable_app_annotation_job_key, "waiting")
|
||||
enable_annotation_reply_task.delay(
|
||||
str(job_id),
|
||||
app_id,
|
||||
current_user.id,
|
||||
current_user.current_tenant_id,
|
||||
args["score_threshold"],
|
||||
args["embedding_provider_name"],
|
||||
args["embedding_model_name"],
|
||||
)
|
||||
return {"job_id": job_id, "job_status": "waiting"}
|
||||
|
||||
@classmethod
|
||||
def disable_app_annotation(cls, app_id: str) -> dict:
|
||||
disable_app_annotation_key = 'disable_app_annotation_{}'.format(str(app_id))
|
||||
disable_app_annotation_key = "disable_app_annotation_{}".format(str(app_id))
|
||||
cache_result = redis_client.get(disable_app_annotation_key)
|
||||
if cache_result is not None:
|
||||
return {
|
||||
'job_id': cache_result,
|
||||
'job_status': 'processing'
|
||||
}
|
||||
return {"job_id": cache_result, "job_status": "processing"}
|
||||
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
disable_app_annotation_job_key = 'disable_app_annotation_job_{}'.format(str(job_id))
|
||||
disable_app_annotation_job_key = "disable_app_annotation_job_{}".format(str(job_id))
|
||||
# send batch add segments task
|
||||
redis_client.setnx(disable_app_annotation_job_key, 'waiting')
|
||||
redis_client.setnx(disable_app_annotation_job_key, "waiting")
|
||||
disable_annotation_reply_task.delay(str(job_id), app_id, current_user.current_tenant_id)
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': 'waiting'
|
||||
}
|
||||
return {"job_id": job_id, "job_status": "waiting"}
|
||||
|
||||
@classmethod
|
||||
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
if keyword:
|
||||
annotations = (db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.filter(
|
||||
or_(
|
||||
MessageAnnotation.question.ilike('%{}%'.format(keyword)),
|
||||
MessageAnnotation.content.ilike('%{}%'.format(keyword))
|
||||
annotations = (
|
||||
db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.filter(
|
||||
or_(
|
||||
MessageAnnotation.question.ilike("%{}%".format(keyword)),
|
||||
MessageAnnotation.content.ilike("%{}%".format(keyword)),
|
||||
)
|
||||
)
|
||||
.order_by(MessageAnnotation.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
)
|
||||
.order_by(MessageAnnotation.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
|
||||
else:
|
||||
annotations = (db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.order_by(MessageAnnotation.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
|
||||
annotations = (
|
||||
db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.order_by(MessageAnnotation.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
)
|
||||
return annotations.items, annotations.total
|
||||
|
||||
@classmethod
|
||||
def export_annotation_list_by_app_id(cls, app_id: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
annotations = (db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.order_by(MessageAnnotation.created_at.desc()).all())
|
||||
annotations = (
|
||||
db.session.query(MessageAnnotation)
|
||||
.filter(MessageAnnotation.app_id == app_id)
|
||||
.order_by(MessageAnnotation.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return annotations
|
||||
|
||||
@classmethod
|
||||
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation = MessageAnnotation(
|
||||
app_id=app.id,
|
||||
content=args['answer'],
|
||||
question=args['question'],
|
||||
account_id=current_user.id
|
||||
app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
|
||||
)
|
||||
db.session.add(annotation)
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id).first()
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
if annotation_setting:
|
||||
add_annotation_to_index_task.delay(annotation.id, args['question'], current_user.current_tenant_id,
|
||||
app_id, annotation_setting.collection_binding_id)
|
||||
add_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
args["question"],
|
||||
current_user.current_tenant_id,
|
||||
app_id,
|
||||
annotation_setting.collection_binding_id,
|
||||
)
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
@ -207,30 +211,34 @@ class AppAnnotationService:
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
annotation.content = args['answer']
|
||||
annotation.question = args['question']
|
||||
annotation.content = args["answer"]
|
||||
annotation.question = args["question"]
|
||||
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , add annotation to index
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id
|
||||
).first()
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
|
||||
if app_annotation_setting:
|
||||
update_annotation_to_index_task.delay(annotation.id, annotation.question,
|
||||
current_user.current_tenant_id,
|
||||
app_id, app_annotation_setting.collection_binding_id)
|
||||
update_annotation_to_index_task.delay(
|
||||
annotation.id,
|
||||
annotation.question,
|
||||
current_user.current_tenant_id,
|
||||
app_id,
|
||||
app_annotation_setting.collection_binding_id,
|
||||
)
|
||||
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def delete_app_annotation(cls, app_id: str, annotation_id: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
@ -242,33 +250,34 @@ class AppAnnotationService:
|
||||
|
||||
db.session.delete(annotation)
|
||||
|
||||
annotation_hit_histories = (db.session.query(AppAnnotationHitHistory)
|
||||
.filter(AppAnnotationHitHistory.annotation_id == annotation_id)
|
||||
.all()
|
||||
)
|
||||
annotation_hit_histories = (
|
||||
db.session.query(AppAnnotationHitHistory)
|
||||
.filter(AppAnnotationHitHistory.annotation_id == annotation_id)
|
||||
.all()
|
||||
)
|
||||
if annotation_hit_histories:
|
||||
for annotation_hit_history in annotation_hit_histories:
|
||||
db.session.delete(annotation_hit_history)
|
||||
|
||||
db.session.commit()
|
||||
# if annotation reply is enabled , delete annotation index
|
||||
app_annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id
|
||||
).first()
|
||||
app_annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
|
||||
if app_annotation_setting:
|
||||
delete_annotation_index_task.delay(annotation.id, app_id,
|
||||
current_user.current_tenant_id,
|
||||
app_annotation_setting.collection_binding_id)
|
||||
delete_annotation_index_task.delay(
|
||||
annotation.id, app_id, current_user.current_tenant_id, app_annotation_setting.collection_binding_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def batch_import_app_annotations(cls, app_id, file: FileStorage) -> dict:
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
@ -278,10 +287,7 @@ class AppAnnotationService:
|
||||
df = pd.read_csv(file)
|
||||
result = []
|
||||
for index, row in df.iterrows():
|
||||
content = {
|
||||
'question': row[0],
|
||||
'answer': row[1]
|
||||
}
|
||||
content = {"question": row[0], "answer": row[1]}
|
||||
result.append(content)
|
||||
if len(result) == 0:
|
||||
raise ValueError("The CSV file is empty.")
|
||||
@ -293,28 +299,24 @@ class AppAnnotationService:
|
||||
raise ValueError("The number of annotations exceeds the limit of your subscription.")
|
||||
# async job
|
||||
job_id = str(uuid.uuid4())
|
||||
indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id))
|
||||
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
|
||||
# send batch add segments task
|
||||
redis_client.setnx(indexing_cache_key, 'waiting')
|
||||
batch_import_annotations_task.delay(str(job_id), result, app_id,
|
||||
current_user.current_tenant_id, current_user.id)
|
||||
redis_client.setnx(indexing_cache_key, "waiting")
|
||||
batch_import_annotations_task.delay(
|
||||
str(job_id), result, app_id, current_user.current_tenant_id, current_user.id
|
||||
)
|
||||
except Exception as e:
|
||||
return {
|
||||
'error_msg': str(e)
|
||||
}
|
||||
return {
|
||||
'job_id': job_id,
|
||||
'job_status': 'waiting'
|
||||
}
|
||||
return {"error_msg": str(e)}
|
||||
return {"job_id": job_id, "job_status": "waiting"}
|
||||
|
||||
@classmethod
|
||||
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
@ -324,12 +326,15 @@ class AppAnnotationService:
|
||||
if not annotation:
|
||||
raise NotFound("Annotation not found")
|
||||
|
||||
annotation_hit_histories = (db.session.query(AppAnnotationHitHistory)
|
||||
.filter(AppAnnotationHitHistory.app_id == app_id,
|
||||
AppAnnotationHitHistory.annotation_id == annotation_id,
|
||||
)
|
||||
.order_by(AppAnnotationHitHistory.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False))
|
||||
annotation_hit_histories = (
|
||||
db.session.query(AppAnnotationHitHistory)
|
||||
.filter(
|
||||
AppAnnotationHitHistory.app_id == app_id,
|
||||
AppAnnotationHitHistory.annotation_id == annotation_id,
|
||||
)
|
||||
.order_by(AppAnnotationHitHistory.created_at.desc())
|
||||
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
)
|
||||
return annotation_hit_histories.items, annotation_hit_histories.total
|
||||
|
||||
@classmethod
|
||||
@ -341,15 +346,21 @@ class AppAnnotationService:
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def add_annotation_history(cls, annotation_id: str, app_id: str, annotation_question: str,
|
||||
annotation_content: str, query: str, user_id: str,
|
||||
message_id: str, from_source: str, score: float):
|
||||
def add_annotation_history(
|
||||
cls,
|
||||
annotation_id: str,
|
||||
app_id: str,
|
||||
annotation_question: str,
|
||||
annotation_content: str,
|
||||
query: str,
|
||||
user_id: str,
|
||||
message_id: str,
|
||||
from_source: str,
|
||||
score: float,
|
||||
):
|
||||
# add hit count to annotation
|
||||
db.session.query(MessageAnnotation).filter(
|
||||
MessageAnnotation.id == annotation_id
|
||||
).update(
|
||||
{MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1},
|
||||
synchronize_session=False
|
||||
db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).update(
|
||||
{MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False
|
||||
)
|
||||
|
||||
annotation_hit_history = AppAnnotationHitHistory(
|
||||
@ -361,7 +372,7 @@ class AppAnnotationService:
|
||||
score=score,
|
||||
message_id=message_id,
|
||||
annotation_question=annotation_question,
|
||||
annotation_content=annotation_content
|
||||
annotation_content=annotation_content,
|
||||
)
|
||||
db.session.add(annotation_hit_history)
|
||||
db.session.commit()
|
||||
@ -369,17 +380,18 @@ class AppAnnotationService:
|
||||
@classmethod
|
||||
def get_app_annotation_setting_by_app_id(cls, app_id: str):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id).first()
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
|
||||
)
|
||||
if annotation_setting:
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
return {
|
||||
@ -388,32 +400,34 @@ class AppAnnotationService:
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name
|
||||
}
|
||||
"embedding_model_name": collection_binding_detail.model_name,
|
||||
},
|
||||
}
|
||||
return {
|
||||
"enabled": False
|
||||
}
|
||||
return {"enabled": False}
|
||||
|
||||
@classmethod
|
||||
def update_app_annotation_setting(cls, app_id: str, annotation_setting_id: str, args: dict):
|
||||
# get app info
|
||||
app = db.session.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.status == 'normal'
|
||||
).first()
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
|
||||
.first()
|
||||
)
|
||||
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
|
||||
annotation_setting = db.session.query(AppAnnotationSetting).filter(
|
||||
AppAnnotationSetting.app_id == app_id,
|
||||
AppAnnotationSetting.id == annotation_setting_id,
|
||||
).first()
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting)
|
||||
.filter(
|
||||
AppAnnotationSetting.app_id == app_id,
|
||||
AppAnnotationSetting.id == annotation_setting_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not annotation_setting:
|
||||
raise NotFound("App annotation not found")
|
||||
annotation_setting.score_threshold = args['score_threshold']
|
||||
annotation_setting.score_threshold = args["score_threshold"]
|
||||
annotation_setting.updated_user_id = current_user.id
|
||||
annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
db.session.add(annotation_setting)
|
||||
@ -427,6 +441,6 @@ class AppAnnotationService:
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name
|
||||
}
|
||||
"embedding_model_name": collection_binding_detail.model_name,
|
||||
},
|
||||
}
|
||||
|
@ -5,13 +5,14 @@ from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
|
||||
|
||||
class APIBasedExtensionService:
|
||||
|
||||
@staticmethod
|
||||
def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]:
|
||||
extension_list = db.session.query(APIBasedExtension) \
|
||||
.filter_by(tenant_id=tenant_id) \
|
||||
.order_by(APIBasedExtension.created_at.desc()) \
|
||||
.all()
|
||||
extension_list = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter_by(tenant_id=tenant_id)
|
||||
.order_by(APIBasedExtension.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
for extension in extension_list:
|
||||
extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
|
||||
@ -35,10 +36,12 @@ class APIBasedExtensionService:
|
||||
|
||||
@staticmethod
|
||||
def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
|
||||
extension = db.session.query(APIBasedExtension) \
|
||||
.filter_by(tenant_id=tenant_id) \
|
||||
.filter_by(id=api_based_extension_id) \
|
||||
extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter_by(tenant_id=tenant_id)
|
||||
.filter_by(id=api_based_extension_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not extension:
|
||||
raise ValueError("API based extension is not found")
|
||||
@ -55,20 +58,24 @@ class APIBasedExtensionService:
|
||||
|
||||
if not extension_data.id:
|
||||
# case one: check new data, name must be unique
|
||||
is_name_existed = db.session.query(APIBasedExtension) \
|
||||
.filter_by(tenant_id=extension_data.tenant_id) \
|
||||
.filter_by(name=extension_data.name) \
|
||||
is_name_existed = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter_by(tenant_id=extension_data.tenant_id)
|
||||
.filter_by(name=extension_data.name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if is_name_existed:
|
||||
raise ValueError("name must be unique, it is already existed")
|
||||
else:
|
||||
# case two: check existing data, name must be unique
|
||||
is_name_existed = db.session.query(APIBasedExtension) \
|
||||
.filter_by(tenant_id=extension_data.tenant_id) \
|
||||
.filter_by(name=extension_data.name) \
|
||||
.filter(APIBasedExtension.id != extension_data.id) \
|
||||
is_name_existed = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.filter_by(tenant_id=extension_data.tenant_id)
|
||||
.filter_by(name=extension_data.name)
|
||||
.filter(APIBasedExtension.id != extension_data.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if is_name_existed:
|
||||
raise ValueError("name must be unique, it is already existed")
|
||||
@ -92,7 +99,7 @@ class APIBasedExtensionService:
|
||||
try:
|
||||
client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
|
||||
resp = client.request(point=APIBasedExtensionPoint.PING, params={})
|
||||
if resp.get('result') != 'pong':
|
||||
if resp.get("result") != "pong":
|
||||
raise ValueError(resp)
|
||||
except Exception as e:
|
||||
raise ValueError("connection error: {}".format(e))
|
||||
|
@ -75,43 +75,44 @@ class AppDslService:
|
||||
# check or repair dsl version
|
||||
import_data = cls._check_or_fix_dsl(import_data)
|
||||
|
||||
app_data = import_data.get('app')
|
||||
app_data = import_data.get("app")
|
||||
if not app_data:
|
||||
raise ValueError("Missing app in data argument")
|
||||
|
||||
# get app basic info
|
||||
name = args.get("name") if args.get("name") else app_data.get('name')
|
||||
description = args.get("description") if args.get("description") else app_data.get('description', '')
|
||||
icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get('icon_type')
|
||||
icon = args.get("icon") if args.get("icon") else app_data.get('icon')
|
||||
icon_background = args.get("icon_background") if args.get("icon_background") \
|
||||
else app_data.get('icon_background')
|
||||
name = args.get("name") if args.get("name") else app_data.get("name")
|
||||
description = args.get("description") if args.get("description") else app_data.get("description", "")
|
||||
icon_type = args.get("icon_type") if args.get("icon_type") else app_data.get("icon_type")
|
||||
icon = args.get("icon") if args.get("icon") else app_data.get("icon")
|
||||
icon_background = (
|
||||
args.get("icon_background") if args.get("icon_background") else app_data.get("icon_background")
|
||||
)
|
||||
|
||||
# import dsl and create app
|
||||
app_mode = AppMode.value_of(app_data.get('mode'))
|
||||
app_mode = AppMode.value_of(app_data.get("mode"))
|
||||
if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
|
||||
app = cls._import_and_create_new_workflow_based_app(
|
||||
tenant_id=tenant_id,
|
||||
app_mode=app_mode,
|
||||
workflow_data=import_data.get('workflow'),
|
||||
workflow_data=import_data.get("workflow"),
|
||||
account=account,
|
||||
name=name,
|
||||
description=description,
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background
|
||||
icon_background=icon_background,
|
||||
)
|
||||
elif app_mode in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]:
|
||||
app = cls._import_and_create_new_model_config_based_app(
|
||||
tenant_id=tenant_id,
|
||||
app_mode=app_mode,
|
||||
model_config_data=import_data.get('model_config'),
|
||||
model_config_data=import_data.get("model_config"),
|
||||
account=account,
|
||||
name=name,
|
||||
description=description,
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background
|
||||
icon_background=icon_background,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid app mode")
|
||||
@ -134,27 +135,26 @@ class AppDslService:
|
||||
# check or repair dsl version
|
||||
import_data = cls._check_or_fix_dsl(import_data)
|
||||
|
||||
app_data = import_data.get('app')
|
||||
app_data = import_data.get("app")
|
||||
if not app_data:
|
||||
raise ValueError("Missing app in data argument")
|
||||
|
||||
# import dsl and overwrite app
|
||||
app_mode = AppMode.value_of(app_data.get('mode'))
|
||||
app_mode = AppMode.value_of(app_data.get("mode"))
|
||||
if app_mode not in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
|
||||
raise ValueError("Only support import workflow in advanced-chat or workflow app.")
|
||||
|
||||
if app_data.get('mode') != app_model.mode:
|
||||
raise ValueError(
|
||||
f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}")
|
||||
if app_data.get("mode") != app_model.mode:
|
||||
raise ValueError(f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}")
|
||||
|
||||
return cls._import_and_overwrite_workflow_based_app(
|
||||
app_model=app_model,
|
||||
workflow_data=import_data.get('workflow'),
|
||||
workflow_data=import_data.get("workflow"),
|
||||
account=account,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def export_dsl(cls, app_model: App, include_secret:bool = False) -> str:
|
||||
def export_dsl(cls, app_model: App, include_secret: bool = False) -> str:
|
||||
"""
|
||||
Export app
|
||||
:param app_model: App instance
|
||||
@ -168,14 +168,16 @@ class AppDslService:
|
||||
"app": {
|
||||
"name": app_model.name,
|
||||
"mode": app_model.mode,
|
||||
"icon": '🤖' if app_model.icon_type == 'image' else app_model.icon,
|
||||
"icon_background": '#FFEAD5' if app_model.icon_type == 'image' else app_model.icon_background,
|
||||
"description": app_model.description
|
||||
}
|
||||
"icon": "🤖" if app_model.icon_type == "image" else app_model.icon,
|
||||
"icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background,
|
||||
"description": app_model.description,
|
||||
},
|
||||
}
|
||||
|
||||
if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
|
||||
cls._append_workflow_export_data(export_data=export_data, app_model=app_model, include_secret=include_secret)
|
||||
cls._append_workflow_export_data(
|
||||
export_data=export_data, app_model=app_model, include_secret=include_secret
|
||||
)
|
||||
else:
|
||||
cls._append_model_config_export_data(export_data, app_model)
|
||||
|
||||
@ -188,31 +190,35 @@ class AppDslService:
|
||||
|
||||
:param import_data: import data
|
||||
"""
|
||||
if not import_data.get('version'):
|
||||
import_data['version'] = "0.1.0"
|
||||
if not import_data.get("version"):
|
||||
import_data["version"] = "0.1.0"
|
||||
|
||||
if not import_data.get('kind') or import_data.get('kind') != "app":
|
||||
import_data['kind'] = "app"
|
||||
if not import_data.get("kind") or import_data.get("kind") != "app":
|
||||
import_data["kind"] = "app"
|
||||
|
||||
if import_data.get('version') != current_dsl_version:
|
||||
if import_data.get("version") != current_dsl_version:
|
||||
# Currently only one DSL version, so no difference checks or compatibility fixes will be performed.
|
||||
logger.warning(f"DSL version {import_data.get('version')} is not compatible "
|
||||
f"with current version {current_dsl_version}, related to "
|
||||
f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}.")
|
||||
logger.warning(
|
||||
f"DSL version {import_data.get('version')} is not compatible "
|
||||
f"with current version {current_dsl_version}, related to "
|
||||
f"Dify version {dsl_to_dify_version_mapping.get(current_dsl_version)}."
|
||||
)
|
||||
|
||||
return import_data
|
||||
|
||||
@classmethod
|
||||
def _import_and_create_new_workflow_based_app(cls,
|
||||
tenant_id: str,
|
||||
app_mode: AppMode,
|
||||
workflow_data: dict,
|
||||
account: Account,
|
||||
name: str,
|
||||
description: str,
|
||||
icon_type: str,
|
||||
icon: str,
|
||||
icon_background: str) -> App:
|
||||
def _import_and_create_new_workflow_based_app(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
app_mode: AppMode,
|
||||
workflow_data: dict,
|
||||
account: Account,
|
||||
name: str,
|
||||
description: str,
|
||||
icon_type: str,
|
||||
icon: str,
|
||||
icon_background: str,
|
||||
) -> App:
|
||||
"""
|
||||
Import app dsl and create new workflow based app
|
||||
|
||||
@ -227,8 +233,7 @@ class AppDslService:
|
||||
:param icon_background: app icon background
|
||||
"""
|
||||
if not workflow_data:
|
||||
raise ValueError("Missing workflow in data argument "
|
||||
"when app mode is advanced-chat or workflow")
|
||||
raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow")
|
||||
|
||||
app = cls._create_app(
|
||||
tenant_id=tenant_id,
|
||||
@ -238,37 +243,32 @@ class AppDslService:
|
||||
description=description,
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background
|
||||
icon_background=icon_background,
|
||||
)
|
||||
|
||||
# init draft workflow
|
||||
environment_variables_list = workflow_data.get('environment_variables') or []
|
||||
environment_variables_list = workflow_data.get("environment_variables") or []
|
||||
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
||||
conversation_variables_list = workflow_data.get('conversation_variables') or []
|
||||
conversation_variables_list = workflow_data.get("conversation_variables") or []
|
||||
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
|
||||
workflow_service = WorkflowService()
|
||||
draft_workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app,
|
||||
graph=workflow_data.get('graph', {}),
|
||||
features=workflow_data.get('../core/app/features', {}),
|
||||
graph=workflow_data.get("graph", {}),
|
||||
features=workflow_data.get("../core/app/features", {}),
|
||||
unique_hash=None,
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
workflow_service.publish_workflow(
|
||||
app_model=app,
|
||||
account=account,
|
||||
draft_workflow=draft_workflow
|
||||
)
|
||||
workflow_service.publish_workflow(app_model=app, account=account, draft_workflow=draft_workflow)
|
||||
|
||||
return app
|
||||
|
||||
@classmethod
|
||||
def _import_and_overwrite_workflow_based_app(cls,
|
||||
app_model: App,
|
||||
workflow_data: dict,
|
||||
account: Account) -> Workflow:
|
||||
def _import_and_overwrite_workflow_based_app(
|
||||
cls, app_model: App, workflow_data: dict, account: Account
|
||||
) -> Workflow:
|
||||
"""
|
||||
Import app dsl and overwrite workflow based app
|
||||
|
||||
@ -277,8 +277,7 @@ class AppDslService:
|
||||
:param account: Account instance
|
||||
"""
|
||||
if not workflow_data:
|
||||
raise ValueError("Missing workflow in data argument "
|
||||
"when app mode is advanced-chat or workflow")
|
||||
raise ValueError("Missing workflow in data argument " "when app mode is advanced-chat or workflow")
|
||||
|
||||
# fetch draft workflow by app_model
|
||||
workflow_service = WorkflowService()
|
||||
@ -289,14 +288,14 @@ class AppDslService:
|
||||
unique_hash = None
|
||||
|
||||
# sync draft workflow
|
||||
environment_variables_list = workflow_data.get('environment_variables') or []
|
||||
environment_variables_list = workflow_data.get("environment_variables") or []
|
||||
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
||||
conversation_variables_list = workflow_data.get('conversation_variables') or []
|
||||
conversation_variables_list = workflow_data.get("conversation_variables") or []
|
||||
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
|
||||
draft_workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=workflow_data.get('graph', {}),
|
||||
features=workflow_data.get('features', {}),
|
||||
graph=workflow_data.get("graph", {}),
|
||||
features=workflow_data.get("features", {}),
|
||||
unique_hash=unique_hash,
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
@ -306,16 +305,18 @@ class AppDslService:
|
||||
return draft_workflow
|
||||
|
||||
@classmethod
|
||||
def _import_and_create_new_model_config_based_app(cls,
|
||||
tenant_id: str,
|
||||
app_mode: AppMode,
|
||||
model_config_data: dict,
|
||||
account: Account,
|
||||
name: str,
|
||||
description: str,
|
||||
icon_type: str,
|
||||
icon: str,
|
||||
icon_background: str) -> App:
|
||||
def _import_and_create_new_model_config_based_app(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
app_mode: AppMode,
|
||||
model_config_data: dict,
|
||||
account: Account,
|
||||
name: str,
|
||||
description: str,
|
||||
icon_type: str,
|
||||
icon: str,
|
||||
icon_background: str,
|
||||
) -> App:
|
||||
"""
|
||||
Import app dsl and create new model config based app
|
||||
|
||||
@ -329,8 +330,7 @@ class AppDslService:
|
||||
:param icon_background: app icon background
|
||||
"""
|
||||
if not model_config_data:
|
||||
raise ValueError("Missing model_config in data argument "
|
||||
"when app mode is chat, agent-chat or completion")
|
||||
raise ValueError("Missing model_config in data argument " "when app mode is chat, agent-chat or completion")
|
||||
|
||||
app = cls._create_app(
|
||||
tenant_id=tenant_id,
|
||||
@ -340,7 +340,7 @@ class AppDslService:
|
||||
description=description,
|
||||
icon_type=icon_type,
|
||||
icon=icon,
|
||||
icon_background=icon_background
|
||||
icon_background=icon_background,
|
||||
)
|
||||
|
||||
app_model_config = AppModelConfig()
|
||||
@ -352,23 +352,22 @@ class AppDslService:
|
||||
|
||||
app.app_model_config_id = app_model_config.id
|
||||
|
||||
app_model_config_was_updated.send(
|
||||
app,
|
||||
app_model_config=app_model_config
|
||||
)
|
||||
app_model_config_was_updated.send(app, app_model_config=app_model_config)
|
||||
|
||||
return app
|
||||
|
||||
@classmethod
|
||||
def _create_app(cls,
|
||||
tenant_id: str,
|
||||
app_mode: AppMode,
|
||||
account: Account,
|
||||
name: str,
|
||||
description: str,
|
||||
icon_type: str,
|
||||
icon: str,
|
||||
icon_background: str) -> App:
|
||||
def _create_app(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
app_mode: AppMode,
|
||||
account: Account,
|
||||
name: str,
|
||||
description: str,
|
||||
icon_type: str,
|
||||
icon: str,
|
||||
icon_background: str,
|
||||
) -> App:
|
||||
"""
|
||||
Create new app
|
||||
|
||||
@ -390,7 +389,7 @@ class AppDslService:
|
||||
icon=icon,
|
||||
icon_background=icon_background,
|
||||
enable_site=True,
|
||||
enable_api=True
|
||||
enable_api=True,
|
||||
)
|
||||
|
||||
db.session.add(app)
|
||||
@ -412,7 +411,7 @@ class AppDslService:
|
||||
if not workflow:
|
||||
raise ValueError("Missing draft workflow configuration, please check.")
|
||||
|
||||
export_data['workflow'] = workflow.to_dict(include_secret=include_secret)
|
||||
export_data["workflow"] = workflow.to_dict(include_secret=include_secret)
|
||||
|
||||
@classmethod
|
||||
def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None:
|
||||
@ -425,4 +424,4 @@ class AppDslService:
|
||||
if not app_model_config:
|
||||
raise ValueError("Missing app configuration, please check.")
|
||||
|
||||
export_data['model_config'] = app_model_config.to_dict()
|
||||
export_data["model_config"] = app_model_config.to_dict()
|
||||
|
@ -14,14 +14,15 @@ from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
class AppGenerateService:
|
||||
|
||||
@classmethod
|
||||
def generate(cls, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
def generate(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
args: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
):
|
||||
"""
|
||||
App Content Generate
|
||||
:param app_model: app model
|
||||
@ -37,51 +38,54 @@ class AppGenerateService:
|
||||
try:
|
||||
request_id = rate_limit.enter(request_id)
|
||||
if app_model.mode == AppMode.COMPLETION.value:
|
||||
return rate_limit.generate(CompletionAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
), request_id)
|
||||
return rate_limit.generate(
|
||||
CompletionAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
|
||||
return rate_limit.generate(AgentChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
), request_id)
|
||||
return rate_limit.generate(
|
||||
AgentChatAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.CHAT.value:
|
||||
return rate_limit.generate(ChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
), request_id)
|
||||
return rate_limit.generate(
|
||||
ChatAppGenerator().generate(
|
||||
app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
workflow = cls._get_workflow(app_model, invoke_from)
|
||||
return rate_limit.generate(AdvancedChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
), request_id)
|
||||
return rate_limit.generate(
|
||||
AdvancedChatAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming,
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
workflow = cls._get_workflow(app_model, invoke_from)
|
||||
return rate_limit.generate(WorkflowAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
), request_id)
|
||||
return rate_limit.generate(
|
||||
WorkflowAppGenerator().generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming,
|
||||
),
|
||||
request_id,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Invalid app mode {app_model.mode}')
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
finally:
|
||||
if not streaming:
|
||||
rate_limit.exit(request_id)
|
||||
@ -94,38 +98,31 @@ class AppGenerateService:
|
||||
return max_active_requests
|
||||
|
||||
@classmethod
|
||||
def generate_single_iteration(cls, app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
node_id: str,
|
||||
args: Any,
|
||||
streaming: bool = True):
|
||||
def generate_single_iteration(
|
||||
cls, app_model: App, user: Union[Account, EndUser], node_id: str, args: Any, streaming: bool = True
|
||||
):
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return AdvancedChatAppGenerator().single_iteration_generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args,
|
||||
stream=streaming
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
|
||||
return WorkflowAppGenerator().single_iteration_generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args,
|
||||
stream=streaming
|
||||
app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Invalid app mode {app_model.mode}')
|
||||
raise ValueError(f"Invalid app mode {app_model.mode}")
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
|
||||
message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
|
||||
-> Union[dict, Generator]:
|
||||
def generate_more_like_this(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Union[Account, EndUser],
|
||||
message_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[dict, Generator]:
|
||||
"""
|
||||
Generate more like this
|
||||
:param app_model: app model
|
||||
@ -136,11 +133,7 @@ class AppGenerateService:
|
||||
:return:
|
||||
"""
|
||||
return CompletionAppGenerator().generate_more_like_this(
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
stream=streaming
|
||||
app_model=app_model, message_id=message_id, user=user, invoke_from=invoke_from, stream=streaming
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -157,12 +150,12 @@ class AppGenerateService:
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
raise ValueError("Workflow not initialized")
|
||||
else:
|
||||
# fetch published workflow by app_model
|
||||
workflow = workflow_service.get_published_workflow(app_model=app_model)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError('Workflow not published')
|
||||
raise ValueError("Workflow not published")
|
||||
|
||||
return workflow
|
||||
|
@ -5,7 +5,6 @@ from models.model import AppMode
|
||||
|
||||
|
||||
class AppModelConfigService:
|
||||
|
||||
@classmethod
|
||||
def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> dict:
|
||||
if app_mode == AppMode.CHAT:
|
||||
|
@ -33,27 +33,22 @@ class AppService:
|
||||
:param args: request args
|
||||
:return:
|
||||
"""
|
||||
filters = [
|
||||
App.tenant_id == tenant_id,
|
||||
App.is_universal == False
|
||||
]
|
||||
filters = [App.tenant_id == tenant_id, App.is_universal == False]
|
||||
|
||||
if args['mode'] == 'workflow':
|
||||
if args["mode"] == "workflow":
|
||||
filters.append(App.mode.in_([AppMode.WORKFLOW.value, AppMode.COMPLETION.value]))
|
||||
elif args['mode'] == 'chat':
|
||||
elif args["mode"] == "chat":
|
||||
filters.append(App.mode.in_([AppMode.CHAT.value, AppMode.ADVANCED_CHAT.value]))
|
||||
elif args['mode'] == 'agent-chat':
|
||||
elif args["mode"] == "agent-chat":
|
||||
filters.append(App.mode == AppMode.AGENT_CHAT.value)
|
||||
elif args['mode'] == 'channel':
|
||||
elif args["mode"] == "channel":
|
||||
filters.append(App.mode == AppMode.CHANNEL.value)
|
||||
|
||||
if args.get('name'):
|
||||
name = args['name'][:30]
|
||||
filters.append(App.name.ilike(f'%{name}%'))
|
||||
if args.get('tag_ids'):
|
||||
target_ids = TagService.get_target_ids_by_tag_ids('app',
|
||||
tenant_id,
|
||||
args['tag_ids'])
|
||||
if args.get("name"):
|
||||
name = args["name"][:30]
|
||||
filters.append(App.name.ilike(f"%{name}%"))
|
||||
if args.get("tag_ids"):
|
||||
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"])
|
||||
if target_ids:
|
||||
filters.append(App.id.in_(target_ids))
|
||||
else:
|
||||
@ -61,9 +56,9 @@ class AppService:
|
||||
|
||||
app_models = db.paginate(
|
||||
db.select(App).where(*filters).order_by(App.created_at.desc()),
|
||||
page=args['page'],
|
||||
per_page=args['limit'],
|
||||
error_out=False
|
||||
page=args["page"],
|
||||
per_page=args["limit"],
|
||||
error_out=False,
|
||||
)
|
||||
|
||||
return app_models
|
||||
@ -75,21 +70,20 @@ class AppService:
|
||||
:param args: request args
|
||||
:param account: Account instance
|
||||
"""
|
||||
app_mode = AppMode.value_of(args['mode'])
|
||||
app_mode = AppMode.value_of(args["mode"])
|
||||
app_template = default_app_templates[app_mode]
|
||||
|
||||
# get model config
|
||||
default_model_config = app_template.get('model_config')
|
||||
default_model_config = app_template.get("model_config")
|
||||
default_model_config = default_model_config.copy() if default_model_config else None
|
||||
if default_model_config and 'model' in default_model_config:
|
||||
if default_model_config and "model" in default_model_config:
|
||||
# get model provider
|
||||
model_manager = ModelManager()
|
||||
|
||||
# get default model instance
|
||||
try:
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=account.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
tenant_id=account.current_tenant_id, model_type=ModelType.LLM
|
||||
)
|
||||
except (ProviderTokenNotInitError, LLMBadRequestError):
|
||||
model_instance = None
|
||||
@ -98,39 +92,41 @@ class AppService:
|
||||
model_instance = None
|
||||
|
||||
if model_instance:
|
||||
if model_instance.model == default_model_config['model']['name'] and model_instance.provider == default_model_config['model']['provider']:
|
||||
default_model_dict = default_model_config['model']
|
||||
if (
|
||||
model_instance.model == default_model_config["model"]["name"]
|
||||
and model_instance.provider == default_model_config["model"]["provider"]
|
||||
):
|
||||
default_model_dict = default_model_config["model"]
|
||||
else:
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
|
||||
default_model_dict = {
|
||||
'provider': model_instance.provider,
|
||||
'name': model_instance.model,
|
||||
'mode': model_schema.model_properties.get(ModelPropertyKey.MODE),
|
||||
'completion_params': {}
|
||||
"provider": model_instance.provider,
|
||||
"name": model_instance.model,
|
||||
"mode": model_schema.model_properties.get(ModelPropertyKey.MODE),
|
||||
"completion_params": {},
|
||||
}
|
||||
else:
|
||||
provider, model = model_manager.get_default_provider_model_name(
|
||||
tenant_id=account.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
tenant_id=account.current_tenant_id, model_type=ModelType.LLM
|
||||
)
|
||||
default_model_config['model']['provider'] = provider
|
||||
default_model_config['model']['name'] = model
|
||||
default_model_dict = default_model_config['model']
|
||||
default_model_config["model"]["provider"] = provider
|
||||
default_model_config["model"]["name"] = model
|
||||
default_model_dict = default_model_config["model"]
|
||||
|
||||
default_model_config['model'] = json.dumps(default_model_dict)
|
||||
default_model_config["model"] = json.dumps(default_model_dict)
|
||||
|
||||
app = App(**app_template['app'])
|
||||
app.name = args['name']
|
||||
app.description = args.get('description', '')
|
||||
app.mode = args['mode']
|
||||
app.icon_type = args.get('icon_type', 'emoji')
|
||||
app.icon = args['icon']
|
||||
app.icon_background = args['icon_background']
|
||||
app = App(**app_template["app"])
|
||||
app.name = args["name"]
|
||||
app.description = args.get("description", "")
|
||||
app.mode = args["mode"]
|
||||
app.icon_type = args.get("icon_type", "emoji")
|
||||
app.icon = args["icon"]
|
||||
app.icon_background = args["icon_background"]
|
||||
app.tenant_id = tenant_id
|
||||
app.api_rph = args.get('api_rph', 0)
|
||||
app.api_rpm = args.get('api_rpm', 0)
|
||||
app.api_rph = args.get("api_rph", 0)
|
||||
app.api_rpm = args.get("api_rpm", 0)
|
||||
|
||||
db.session.add(app)
|
||||
db.session.flush()
|
||||
@ -158,7 +154,7 @@ class AppService:
|
||||
model_config: AppModelConfig = app.app_model_config
|
||||
agent_mode = model_config.agent_mode_dict
|
||||
# decrypt agent tool parameters if it's secret-input
|
||||
for tool in agent_mode.get('tools') or []:
|
||||
for tool in agent_mode.get("tools") or []:
|
||||
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
|
||||
continue
|
||||
agent_tool_entity = AgentToolEntity(**tool)
|
||||
@ -174,7 +170,7 @@ class AppService:
|
||||
tool_runtime=tool_runtime,
|
||||
provider_name=agent_tool_entity.provider_id,
|
||||
provider_type=agent_tool_entity.provider_type,
|
||||
identity_id=f'AGENT.{app.id}'
|
||||
identity_id=f"AGENT.{app.id}",
|
||||
)
|
||||
|
||||
# get decrypted parameters
|
||||
@ -185,7 +181,7 @@ class AppService:
|
||||
masked_parameter = {}
|
||||
|
||||
# override tool parameters
|
||||
tool['tool_parameters'] = masked_parameter
|
||||
tool["tool_parameters"] = masked_parameter
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
@ -215,12 +211,12 @@ class AppService:
|
||||
:param args: request args
|
||||
:return: App instance
|
||||
"""
|
||||
app.name = args.get('name')
|
||||
app.description = args.get('description', '')
|
||||
app.max_active_requests = args.get('max_active_requests')
|
||||
app.icon_type = args.get('icon_type', 'emoji')
|
||||
app.icon = args.get('icon')
|
||||
app.icon_background = args.get('icon_background')
|
||||
app.name = args.get("name")
|
||||
app.description = args.get("description", "")
|
||||
app.max_active_requests = args.get("max_active_requests")
|
||||
app.icon_type = args.get("icon_type", "emoji")
|
||||
app.icon = args.get("icon")
|
||||
app.icon_background = args.get("icon_background")
|
||||
app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
db.session.commit()
|
||||
|
||||
@ -298,10 +294,7 @@ class AppService:
|
||||
db.session.commit()
|
||||
|
||||
# Trigger asynchronous deletion of app and related data
|
||||
remove_app_and_related_data_task.delay(
|
||||
tenant_id=app.tenant_id,
|
||||
app_id=app.id
|
||||
)
|
||||
remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id)
|
||||
|
||||
def get_app_meta(self, app_model: App) -> dict:
|
||||
"""
|
||||
@ -311,9 +304,7 @@ class AppService:
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
|
||||
meta = {
|
||||
'tool_icons': {}
|
||||
}
|
||||
meta = {"tool_icons": {}}
|
||||
|
||||
if app_mode in [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]:
|
||||
workflow = app_model.workflow
|
||||
@ -321,17 +312,19 @@ class AppService:
|
||||
return meta
|
||||
|
||||
graph = workflow.graph_dict
|
||||
nodes = graph.get('nodes', [])
|
||||
nodes = graph.get("nodes", [])
|
||||
tools = []
|
||||
for node in nodes:
|
||||
if node.get('data', {}).get('type') == 'tool':
|
||||
node_data = node.get('data', {})
|
||||
tools.append({
|
||||
'provider_type': node_data.get('provider_type'),
|
||||
'provider_id': node_data.get('provider_id'),
|
||||
'tool_name': node_data.get('tool_name'),
|
||||
'tool_parameters': {}
|
||||
})
|
||||
if node.get("data", {}).get("type") == "tool":
|
||||
node_data = node.get("data", {})
|
||||
tools.append(
|
||||
{
|
||||
"provider_type": node_data.get("provider_type"),
|
||||
"provider_id": node_data.get("provider_id"),
|
||||
"tool_name": node_data.get("tool_name"),
|
||||
"tool_parameters": {},
|
||||
}
|
||||
)
|
||||
else:
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
@ -341,30 +334,26 @@ class AppService:
|
||||
agent_config = app_model_config.agent_mode_dict or {}
|
||||
|
||||
# get all tools
|
||||
tools = agent_config.get('tools', [])
|
||||
tools = agent_config.get("tools", [])
|
||||
|
||||
url_prefix = (dify_config.CONSOLE_API_URL
|
||||
+ "/console/api/workspaces/current/tool-provider/builtin/")
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/"
|
||||
|
||||
for tool in tools:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
# current tool standard
|
||||
provider_type = tool.get('provider_type')
|
||||
provider_id = tool.get('provider_id')
|
||||
tool_name = tool.get('tool_name')
|
||||
if provider_type == 'builtin':
|
||||
meta['tool_icons'][tool_name] = url_prefix + provider_id + '/icon'
|
||||
elif provider_type == 'api':
|
||||
provider_type = tool.get("provider_type")
|
||||
provider_id = tool.get("provider_id")
|
||||
tool_name = tool.get("tool_name")
|
||||
if provider_type == "builtin":
|
||||
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
|
||||
elif provider_type == "api":
|
||||
try:
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.id == provider_id
|
||||
).first()
|
||||
meta['tool_icons'][tool_name] = json.loads(provider.icon)
|
||||
provider: ApiToolProvider = (
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first()
|
||||
)
|
||||
meta["tool_icons"][tool_name] = json.loads(provider.icon)
|
||||
except:
|
||||
meta['tool_icons'][tool_name] = {
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
return meta
|
||||
|
@ -17,7 +17,7 @@ from services.errors.audio import (
|
||||
|
||||
FILE_SIZE = 30
|
||||
FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024
|
||||
ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm', 'amr']
|
||||
ALLOWED_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm", "amr"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -31,19 +31,19 @@ class AudioService:
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
if 'speech_to_text' not in features_dict or not features_dict['speech_to_text'].get('enabled'):
|
||||
if "speech_to_text" not in features_dict or not features_dict["speech_to_text"].get("enabled"):
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
else:
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
if not app_model_config.speech_to_text_dict["enabled"]:
|
||||
raise ValueError("Speech to text is not enabled")
|
||||
|
||||
if file is None:
|
||||
raise NoAudioUploadedServiceError()
|
||||
|
||||
extension = file.mimetype
|
||||
if extension not in [f'audio/{ext}' for ext in ALLOWED_EXTENSIONS]:
|
||||
if extension not in [f"audio/{ext}" for ext in ALLOWED_EXTENSIONS]:
|
||||
raise UnsupportedAudioTypeServiceError()
|
||||
|
||||
file_content = file.read()
|
||||
@ -55,20 +55,25 @@ class AudioService:
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=app_model.tenant_id,
|
||||
model_type=ModelType.SPEECH2TEXT
|
||||
tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT
|
||||
)
|
||||
if model_instance is None:
|
||||
raise ProviderNotSupportSpeechToTextServiceError()
|
||||
|
||||
buffer = io.BytesIO(file_content)
|
||||
buffer.name = 'temp.mp3'
|
||||
buffer.name = "temp.mp3"
|
||||
|
||||
return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)}
|
||||
|
||||
@classmethod
|
||||
def transcript_tts(cls, app_model: App, text: Optional[str] = None,
|
||||
voice: Optional[str] = None, end_user: Optional[str] = None, message_id: Optional[str] = None):
|
||||
def transcript_tts(
|
||||
cls,
|
||||
app_model: App,
|
||||
text: Optional[str] = None,
|
||||
voice: Optional[str] = None,
|
||||
end_user: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
):
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
@ -84,65 +89,56 @@ class AudioService:
|
||||
raise ValueError("TTS is not enabled")
|
||||
|
||||
features_dict = workflow.features_dict
|
||||
if 'text_to_speech' not in features_dict or not features_dict['text_to_speech'].get('enabled'):
|
||||
if "text_to_speech" not in features_dict or not features_dict["text_to_speech"].get("enabled"):
|
||||
raise ValueError("TTS is not enabled")
|
||||
|
||||
voice = features_dict['text_to_speech'].get('voice') if voice is None else voice
|
||||
voice = features_dict["text_to_speech"].get("voice") if voice is None else voice
|
||||
else:
|
||||
text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
|
||||
|
||||
if not text_to_speech_dict.get('enabled'):
|
||||
if not text_to_speech_dict.get("enabled"):
|
||||
raise ValueError("TTS is not enabled")
|
||||
|
||||
voice = text_to_speech_dict.get('voice') if voice is None else voice
|
||||
voice = text_to_speech_dict.get("voice") if voice is None else voice
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=app_model.tenant_id,
|
||||
model_type=ModelType.TTS
|
||||
tenant_id=app_model.tenant_id, model_type=ModelType.TTS
|
||||
)
|
||||
try:
|
||||
if not voice:
|
||||
voices = model_instance.get_tts_voices()
|
||||
if voices:
|
||||
voice = voices[0].get('value')
|
||||
voice = voices[0].get("value")
|
||||
else:
|
||||
raise ValueError("Sorry, no voice available.")
|
||||
|
||||
return model_instance.invoke_tts(
|
||||
content_text=text_content.strip(),
|
||||
user=end_user,
|
||||
tenant_id=app_model.tenant_id,
|
||||
voice=voice
|
||||
content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
if message_id:
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id
|
||||
).first()
|
||||
if message.answer == '' and message.status == 'normal':
|
||||
message = db.session.query(Message).filter(Message.id == message_id).first()
|
||||
if message.answer == "" and message.status == "normal":
|
||||
return None
|
||||
|
||||
else:
|
||||
response = invoke_tts(message.answer, app_model=app_model, voice=voice)
|
||||
if isinstance(response, Generator):
|
||||
return Response(stream_with_context(response), content_type='audio/mpeg')
|
||||
return Response(stream_with_context(response), content_type="audio/mpeg")
|
||||
return response
|
||||
else:
|
||||
response = invoke_tts(text, app_model, voice)
|
||||
if isinstance(response, Generator):
|
||||
return Response(stream_with_context(response), content_type='audio/mpeg')
|
||||
return Response(stream_with_context(response), content_type="audio/mpeg")
|
||||
return response
|
||||
|
||||
@classmethod
|
||||
def transcript_tts_voices(cls, tenant_id: str, language: str):
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.TTS
|
||||
)
|
||||
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS)
|
||||
if model_instance is None:
|
||||
raise ProviderNotSupportTextToSpeechServiceError()
|
||||
|
||||
|
@ -1,14 +1,12 @@
|
||||
|
||||
from services.auth.firecrawl import FirecrawlAuth
|
||||
|
||||
|
||||
class ApiKeyAuthFactory:
|
||||
|
||||
def __init__(self, provider: str, credentials: dict):
|
||||
if provider == 'firecrawl':
|
||||
if provider == "firecrawl":
|
||||
self.auth = FirecrawlAuth(credentials)
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
def validate_credentials(self):
|
||||
return self.auth.validate_credentials()
|
||||
|
@ -7,39 +7,43 @@ from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
|
||||
|
||||
class ApiKeyAuthService:
|
||||
|
||||
@staticmethod
|
||||
def get_provider_auth_list(tenant_id: str) -> list:
|
||||
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
).all()
|
||||
data_source_api_key_bindings = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
|
||||
.all()
|
||||
)
|
||||
return data_source_api_key_bindings
|
||||
|
||||
@staticmethod
|
||||
def create_provider_auth(tenant_id: str, args: dict):
|
||||
auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials()
|
||||
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
|
||||
if auth_result:
|
||||
# Encrypt the api key
|
||||
api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key'])
|
||||
args['credentials']['config']['api_key'] = api_key
|
||||
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
|
||||
args["credentials"]["config"]["api_key"] = api_key
|
||||
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
|
||||
data_source_api_key_binding.tenant_id = tenant_id
|
||||
data_source_api_key_binding.category = args['category']
|
||||
data_source_api_key_binding.provider = args['provider']
|
||||
data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False)
|
||||
data_source_api_key_binding.category = args["category"]
|
||||
data_source_api_key_binding.provider = args["provider"]
|
||||
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
|
||||
db.session.add(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_auth_credentials(tenant_id: str, category: str, provider: str):
|
||||
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.category == category,
|
||||
DataSourceApiKeyAuthBinding.provider == provider,
|
||||
DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
).first()
|
||||
data_source_api_key_bindings = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.category == category,
|
||||
DataSourceApiKeyAuthBinding.provider == provider,
|
||||
DataSourceApiKeyAuthBinding.disabled.is_(False),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not data_source_api_key_bindings:
|
||||
return None
|
||||
credentials = json.loads(data_source_api_key_bindings.credentials)
|
||||
@ -47,24 +51,24 @@ class ApiKeyAuthService:
|
||||
|
||||
@staticmethod
|
||||
def delete_provider_auth(tenant_id: str, binding_id: str):
|
||||
data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.id == binding_id
|
||||
).first()
|
||||
data_source_api_key_binding = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
|
||||
.first()
|
||||
)
|
||||
if data_source_api_key_binding:
|
||||
db.session.delete(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def validate_api_key_auth_args(cls, args):
|
||||
if 'category' not in args or not args['category']:
|
||||
raise ValueError('category is required')
|
||||
if 'provider' not in args or not args['provider']:
|
||||
raise ValueError('provider is required')
|
||||
if 'credentials' not in args or not args['credentials']:
|
||||
raise ValueError('credentials is required')
|
||||
if not isinstance(args['credentials'], dict):
|
||||
raise ValueError('credentials must be a dictionary')
|
||||
if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']:
|
||||
raise ValueError('auth_type is required')
|
||||
|
||||
if "category" not in args or not args["category"]:
|
||||
raise ValueError("category is required")
|
||||
if "provider" not in args or not args["provider"]:
|
||||
raise ValueError("provider is required")
|
||||
if "credentials" not in args or not args["credentials"]:
|
||||
raise ValueError("credentials is required")
|
||||
if not isinstance(args["credentials"], dict):
|
||||
raise ValueError("credentials must be a dictionary")
|
||||
if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
|
||||
raise ValueError("auth_type is required")
|
||||
|
@ -8,49 +8,40 @@ from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
class FirecrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get('auth_type')
|
||||
if auth_type != 'bearer':
|
||||
raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer')
|
||||
self.api_key = credentials.get('config').get('api_key', None)
|
||||
self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev')
|
||||
auth_type = credentials.get("auth_type")
|
||||
if auth_type != "bearer":
|
||||
raise ValueError("Invalid auth type, Firecrawl auth type must be Bearer")
|
||||
self.api_key = credentials.get("config").get("api_key", None)
|
||||
self.base_url = credentials.get("config").get("base_url", "https://api.firecrawl.dev")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError('No API key provided')
|
||||
raise ValueError("No API key provided")
|
||||
|
||||
def validate_credentials(self):
|
||||
headers = self._prepare_headers()
|
||||
options = {
|
||||
'url': 'https://example.com',
|
||||
'crawlerOptions': {
|
||||
'excludes': [],
|
||||
'includes': [],
|
||||
'limit': 1
|
||||
},
|
||||
'pageOptions': {
|
||||
'onlyMainContent': True
|
||||
}
|
||||
"url": "https://example.com",
|
||||
"crawlerOptions": {"excludes": [], "includes": [], "limit": 1},
|
||||
"pageOptions": {"onlyMainContent": True},
|
||||
}
|
||||
response = self._post_request(f'{self.base_url}/v0/crawl', options, headers)
|
||||
response = self._post_request(f"{self.base_url}/v0/crawl", options, headers)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
self._handle_error(response)
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return requests.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in [402, 409, 500]:
|
||||
error_message = response.json().get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
|
||||
error_message = response.json().get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
else:
|
||||
if response.text:
|
||||
error_message = json.loads(response.text).get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
|
||||
raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}')
|
||||
error_message = json.loads(response.text).get("error", "Unknown error occurred")
|
||||
raise Exception(f"Failed to authorize. Status code: {response.status_code}. Error: {error_message}")
|
||||
raise Exception(f"Unexpected error occurred while trying to authorize. Status code: {response.status_code}")
|
||||
|
@ -7,58 +7,40 @@ from models.account import TenantAccountJoin, TenantAccountRole
|
||||
|
||||
|
||||
class BillingService:
|
||||
base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL')
|
||||
secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY')
|
||||
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
|
||||
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
|
||||
|
||||
@classmethod
|
||||
def get_info(cls, tenant_id: str):
|
||||
params = {'tenant_id': tenant_id}
|
||||
params = {"tenant_id": tenant_id}
|
||||
|
||||
billing_info = cls._send_request('GET', '/subscription/info', params=params)
|
||||
billing_info = cls._send_request("GET", "/subscription/info", params=params)
|
||||
|
||||
return billing_info
|
||||
|
||||
@classmethod
|
||||
def get_subscription(cls, plan: str,
|
||||
interval: str,
|
||||
prefilled_email: str = '',
|
||||
tenant_id: str = ''):
|
||||
params = {
|
||||
'plan': plan,
|
||||
'interval': interval,
|
||||
'prefilled_email': prefilled_email,
|
||||
'tenant_id': tenant_id
|
||||
}
|
||||
return cls._send_request('GET', '/subscription/payment-link', params=params)
|
||||
def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""):
|
||||
params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id}
|
||||
return cls._send_request("GET", "/subscription/payment-link", params=params)
|
||||
|
||||
@classmethod
|
||||
def get_model_provider_payment_link(cls,
|
||||
provider_name: str,
|
||||
tenant_id: str,
|
||||
account_id: str,
|
||||
prefilled_email: str):
|
||||
def get_model_provider_payment_link(cls, provider_name: str, tenant_id: str, account_id: str, prefilled_email: str):
|
||||
params = {
|
||||
'provider_name': provider_name,
|
||||
'tenant_id': tenant_id,
|
||||
'account_id': account_id,
|
||||
'prefilled_email': prefilled_email
|
||||
"provider_name": provider_name,
|
||||
"tenant_id": tenant_id,
|
||||
"account_id": account_id,
|
||||
"prefilled_email": prefilled_email,
|
||||
}
|
||||
return cls._send_request('GET', '/model-provider/payment-link', params=params)
|
||||
return cls._send_request("GET", "/model-provider/payment-link", params=params)
|
||||
|
||||
@classmethod
|
||||
def get_invoices(cls, prefilled_email: str = '', tenant_id: str = ''):
|
||||
params = {
|
||||
'prefilled_email': prefilled_email,
|
||||
'tenant_id': tenant_id
|
||||
}
|
||||
return cls._send_request('GET', '/invoices', params=params)
|
||||
def get_invoices(cls, prefilled_email: str = "", tenant_id: str = ""):
|
||||
params = {"prefilled_email": prefilled_email, "tenant_id": tenant_id}
|
||||
return cls._send_request("GET", "/invoices", params=params)
|
||||
|
||||
@classmethod
|
||||
def _send_request(cls, method, endpoint, json=None, params=None):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Billing-Api-Secret-Key": cls.secret_key
|
||||
}
|
||||
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
|
||||
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = requests.request(method, url, json=json, params=params, headers=headers)
|
||||
@ -69,10 +51,11 @@ class BillingService:
|
||||
def is_tenant_owner_or_admin(current_user):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
join = db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
TenantAccountJoin.account_id == current_user.id
|
||||
).first()
|
||||
join = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not TenantAccountRole.is_privileged_role(join.role):
|
||||
raise ValueError('Only team owner or team admin can perform this action')
|
||||
raise ValueError("Only team owner or team admin can perform this action")
|
||||
|
@ -2,12 +2,15 @@ from extensions.ext_code_based_extension import code_based_extension
|
||||
|
||||
|
||||
class CodeBasedExtensionService:
|
||||
|
||||
@staticmethod
|
||||
def get_code_based_extension(module: str) -> list[dict]:
|
||||
module_extensions = code_based_extension.module_extensions(module)
|
||||
return [{
|
||||
'name': module_extension.name,
|
||||
'label': module_extension.label,
|
||||
'form_schema': module_extension.form_schema
|
||||
} for module_extension in module_extensions if not module_extension.builtin]
|
||||
return [
|
||||
{
|
||||
"name": module_extension.name,
|
||||
"label": module_extension.label,
|
||||
"form_schema": module_extension.form_schema,
|
||||
}
|
||||
for module_extension in module_extensions
|
||||
if not module_extension.builtin
|
||||
]
|
||||
|
@ -15,22 +15,27 @@ from services.errors.message import MessageNotExistsError
|
||||
|
||||
class ConversationService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str], limit: int,
|
||||
invoke_from: InvokeFrom,
|
||||
include_ids: Optional[list] = None,
|
||||
exclude_ids: Optional[list] = None,
|
||||
sort_by: str = '-updated_at') -> InfiniteScrollPagination:
|
||||
def pagination_by_last_id(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str],
|
||||
limit: int,
|
||||
invoke_from: InvokeFrom,
|
||||
include_ids: Optional[list] = None,
|
||||
exclude_ids: Optional[list] = None,
|
||||
sort_by: str = "-updated_at",
|
||||
) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
base_query = db.session.query(Conversation).filter(
|
||||
Conversation.is_deleted == False,
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value)
|
||||
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value),
|
||||
)
|
||||
|
||||
if include_ids is not None:
|
||||
@ -58,28 +63,26 @@ class ConversationService:
|
||||
has_more = False
|
||||
if len(conversations) == limit:
|
||||
current_page_last_conversation = conversations[-1]
|
||||
rest_filter_condition = cls._build_filter_condition(sort_field, sort_direction,
|
||||
current_page_last_conversation, is_next_page=True)
|
||||
rest_filter_condition = cls._build_filter_condition(
|
||||
sort_field, sort_direction, current_page_last_conversation, is_next_page=True
|
||||
)
|
||||
rest_count = base_query.filter(rest_filter_condition).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
return InfiniteScrollPagination(
|
||||
data=conversations,
|
||||
limit=limit,
|
||||
has_more=has_more
|
||||
)
|
||||
return InfiniteScrollPagination(data=conversations, limit=limit, has_more=has_more)
|
||||
|
||||
@classmethod
|
||||
def _get_sort_params(cls, sort_by: str) -> tuple[str, callable]:
|
||||
if sort_by.startswith('-'):
|
||||
if sort_by.startswith("-"):
|
||||
return sort_by[1:], desc
|
||||
return sort_by, asc
|
||||
|
||||
@classmethod
|
||||
def _build_filter_condition(cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation,
|
||||
is_next_page: bool = False):
|
||||
def _build_filter_condition(
|
||||
cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation, is_next_page: bool = False
|
||||
):
|
||||
field_value = getattr(reference_conversation, sort_field)
|
||||
if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page):
|
||||
return getattr(Conversation, sort_field) < field_value
|
||||
@ -87,8 +90,14 @@ class ConversationService:
|
||||
return getattr(Conversation, sort_field) > field_value
|
||||
|
||||
@classmethod
|
||||
def rename(cls, app_model: App, conversation_id: str,
|
||||
user: Optional[Union[Account, EndUser]], name: str, auto_generate: bool):
|
||||
def rename(
|
||||
cls,
|
||||
app_model: App,
|
||||
conversation_id: str,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
name: str,
|
||||
auto_generate: bool,
|
||||
):
|
||||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||
|
||||
if auto_generate:
|
||||
@ -103,11 +112,12 @@ class ConversationService:
|
||||
@classmethod
|
||||
def auto_generate_name(cls, app_model: App, conversation: Conversation):
|
||||
# get conversation first message
|
||||
message = db.session.query(Message) \
|
||||
.filter(
|
||||
Message.app_id == app_model.id,
|
||||
Message.conversation_id == conversation.id
|
||||
).order_by(Message.created_at.asc()).first()
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.filter(Message.app_id == app_model.id, Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
@ -127,15 +137,18 @@ class ConversationService:
|
||||
|
||||
@classmethod
|
||||
def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
|
||||
conversation = db.session.query(Conversation) \
|
||||
conversation = (
|
||||
db.session.query(Conversation)
|
||||
.filter(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
Conversation.is_deleted == False
|
||||
).first()
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Conversation.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
Conversation.is_deleted == False,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -4,15 +4,12 @@ import requests
|
||||
|
||||
|
||||
class EnterpriseRequest:
|
||||
base_url = os.environ.get('ENTERPRISE_API_URL', 'ENTERPRISE_API_URL')
|
||||
secret_key = os.environ.get('ENTERPRISE_API_SECRET_KEY', 'ENTERPRISE_API_SECRET_KEY')
|
||||
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
|
||||
secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY")
|
||||
|
||||
@classmethod
|
||||
def send_request(cls, method, endpoint, json=None, params=None):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Enterprise-Api-Secret-Key": cls.secret_key
|
||||
}
|
||||
headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key}
|
||||
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = requests.request(method, url, json=json, params=params, headers=headers)
|
||||
|
@ -2,11 +2,10 @@ from services.enterprise.base import EnterpriseRequest
|
||||
|
||||
|
||||
class EnterpriseService:
|
||||
|
||||
@classmethod
|
||||
def get_info(cls):
|
||||
return EnterpriseRequest.send_request('GET', '/info')
|
||||
return EnterpriseRequest.send_request("GET", "/info")
|
||||
|
||||
@classmethod
|
||||
def get_app_web_sso_enabled(cls, app_code):
|
||||
return EnterpriseRequest.send_request('GET', f'/app-sso-setting?appCode={app_code}')
|
||||
return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}")
|
||||
|
@ -22,14 +22,16 @@ class CustomConfigurationStatus(Enum):
|
||||
"""
|
||||
Enum class for custom configuration status.
|
||||
"""
|
||||
ACTIVE = 'active'
|
||||
NO_CONFIGURE = 'no-configure'
|
||||
|
||||
ACTIVE = "active"
|
||||
NO_CONFIGURE = "no-configure"
|
||||
|
||||
|
||||
class CustomConfigurationResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider custom configuration response.
|
||||
"""
|
||||
|
||||
status: CustomConfigurationStatus
|
||||
|
||||
|
||||
@ -37,6 +39,7 @@ class SystemConfigurationResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider system configuration response.
|
||||
"""
|
||||
|
||||
enabled: bool
|
||||
current_quota_type: Optional[ProviderQuotaType] = None
|
||||
quota_configurations: list[QuotaConfiguration] = []
|
||||
@ -46,6 +49,7 @@ class ProviderResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider response.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
description: Optional[I18nObject] = None
|
||||
@ -67,18 +71,15 @@ class ProviderResponse(BaseModel):
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (dify_config.CONSOLE_API_URL
|
||||
+ f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
|
||||
if self.icon_small is not None:
|
||||
self.icon_small = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
)
|
||||
|
||||
if self.icon_large is not None:
|
||||
self.icon_large = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_large/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
)
|
||||
|
||||
|
||||
@ -86,6 +87,7 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
"""
|
||||
Model class for provider with models response.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
label: I18nObject
|
||||
icon_small: Optional[I18nObject] = None
|
||||
@ -96,18 +98,15 @@ class ProviderWithModelsResponse(BaseModel):
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (dify_config.CONSOLE_API_URL
|
||||
+ f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
|
||||
if self.icon_small is not None:
|
||||
self.icon_small = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
)
|
||||
|
||||
if self.icon_large is not None:
|
||||
self.icon_large = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_large/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
)
|
||||
|
||||
|
||||
@ -119,18 +118,15 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
url_prefix = (dify_config.CONSOLE_API_URL
|
||||
+ f"/console/api/workspaces/current/model-providers/{self.provider}")
|
||||
url_prefix = dify_config.CONSOLE_API_URL + f"/console/api/workspaces/current/model-providers/{self.provider}"
|
||||
if self.icon_small is not None:
|
||||
self.icon_small = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_small/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
en_US=f"{url_prefix}/icon_small/en_US", zh_Hans=f"{url_prefix}/icon_small/zh_Hans"
|
||||
)
|
||||
|
||||
if self.icon_large is not None:
|
||||
self.icon_large = I18nObject(
|
||||
en_US=f"{url_prefix}/icon_large/en_US",
|
||||
zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
|
||||
)
|
||||
|
||||
|
||||
@ -138,6 +134,7 @@ class DefaultModelResponse(BaseModel):
|
||||
"""
|
||||
Default model entity.
|
||||
"""
|
||||
|
||||
model: str
|
||||
model_type: ModelType
|
||||
provider: SimpleProviderEntityResponse
|
||||
@ -150,6 +147,7 @@ class ModelWithProviderEntityResponse(ModelWithProviderEntity):
|
||||
"""
|
||||
Model with provider entity.
|
||||
"""
|
||||
|
||||
provider: SimpleProviderEntityResponse
|
||||
|
||||
def __init__(self, model: ModelWithProviderEntity) -> None:
|
||||
|
@ -55,4 +55,3 @@ class RoleAlreadyAssignedError(BaseServiceError):
|
||||
|
||||
class RateLimitExceededError(BaseServiceError):
|
||||
pass
|
||||
|
||||
|
@ -1,3 +1,3 @@
|
||||
class BaseServiceError(Exception):
|
||||
def __init__(self, description: str = None):
|
||||
self.description = description
|
||||
self.description = description
|
||||
|
@ -6,8 +6,8 @@ from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
||||
|
||||
class SubscriptionModel(BaseModel):
|
||||
plan: str = 'sandbox'
|
||||
interval: str = ''
|
||||
plan: str = "sandbox"
|
||||
interval: str = ""
|
||||
|
||||
|
||||
class BillingModel(BaseModel):
|
||||
@ -27,7 +27,7 @@ class FeatureModel(BaseModel):
|
||||
vector_space: LimitationModel = LimitationModel(size=0, limit=5)
|
||||
annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10)
|
||||
documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
|
||||
docs_processing: str = 'standard'
|
||||
docs_processing: str = "standard"
|
||||
can_replace_logo: bool = False
|
||||
model_load_balancing_enabled: bool = False
|
||||
dataset_operator_enabled: bool = False
|
||||
@ -38,13 +38,13 @@ class FeatureModel(BaseModel):
|
||||
|
||||
class SystemFeatureModel(BaseModel):
|
||||
sso_enforced_for_signin: bool = False
|
||||
sso_enforced_for_signin_protocol: str = ''
|
||||
sso_enforced_for_signin_protocol: str = ""
|
||||
sso_enforced_for_web: bool = False
|
||||
sso_enforced_for_web_protocol: str = ''
|
||||
sso_enforced_for_web_protocol: str = ""
|
||||
enable_web_sso_switch_component: bool = False
|
||||
|
||||
class FeatureService:
|
||||
|
||||
class FeatureService:
|
||||
@classmethod
|
||||
def get_features(cls, tenant_id: str) -> FeatureModel:
|
||||
features = FeatureModel()
|
||||
@ -76,44 +76,44 @@ class FeatureService:
|
||||
def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str):
|
||||
billing_info = BillingService.get_info(tenant_id)
|
||||
|
||||
features.billing.enabled = billing_info['enabled']
|
||||
features.billing.subscription.plan = billing_info['subscription']['plan']
|
||||
features.billing.subscription.interval = billing_info['subscription']['interval']
|
||||
features.billing.enabled = billing_info["enabled"]
|
||||
features.billing.subscription.plan = billing_info["subscription"]["plan"]
|
||||
features.billing.subscription.interval = billing_info["subscription"]["interval"]
|
||||
|
||||
if 'members' in billing_info:
|
||||
features.members.size = billing_info['members']['size']
|
||||
features.members.limit = billing_info['members']['limit']
|
||||
if "members" in billing_info:
|
||||
features.members.size = billing_info["members"]["size"]
|
||||
features.members.limit = billing_info["members"]["limit"]
|
||||
|
||||
if 'apps' in billing_info:
|
||||
features.apps.size = billing_info['apps']['size']
|
||||
features.apps.limit = billing_info['apps']['limit']
|
||||
if "apps" in billing_info:
|
||||
features.apps.size = billing_info["apps"]["size"]
|
||||
features.apps.limit = billing_info["apps"]["limit"]
|
||||
|
||||
if 'vector_space' in billing_info:
|
||||
features.vector_space.size = billing_info['vector_space']['size']
|
||||
features.vector_space.limit = billing_info['vector_space']['limit']
|
||||
if "vector_space" in billing_info:
|
||||
features.vector_space.size = billing_info["vector_space"]["size"]
|
||||
features.vector_space.limit = billing_info["vector_space"]["limit"]
|
||||
|
||||
if 'documents_upload_quota' in billing_info:
|
||||
features.documents_upload_quota.size = billing_info['documents_upload_quota']['size']
|
||||
features.documents_upload_quota.limit = billing_info['documents_upload_quota']['limit']
|
||||
if "documents_upload_quota" in billing_info:
|
||||
features.documents_upload_quota.size = billing_info["documents_upload_quota"]["size"]
|
||||
features.documents_upload_quota.limit = billing_info["documents_upload_quota"]["limit"]
|
||||
|
||||
if 'annotation_quota_limit' in billing_info:
|
||||
features.annotation_quota_limit.size = billing_info['annotation_quota_limit']['size']
|
||||
features.annotation_quota_limit.limit = billing_info['annotation_quota_limit']['limit']
|
||||
if "annotation_quota_limit" in billing_info:
|
||||
features.annotation_quota_limit.size = billing_info["annotation_quota_limit"]["size"]
|
||||
features.annotation_quota_limit.limit = billing_info["annotation_quota_limit"]["limit"]
|
||||
|
||||
if 'docs_processing' in billing_info:
|
||||
features.docs_processing = billing_info['docs_processing']
|
||||
if "docs_processing" in billing_info:
|
||||
features.docs_processing = billing_info["docs_processing"]
|
||||
|
||||
if 'can_replace_logo' in billing_info:
|
||||
features.can_replace_logo = billing_info['can_replace_logo']
|
||||
if "can_replace_logo" in billing_info:
|
||||
features.can_replace_logo = billing_info["can_replace_logo"]
|
||||
|
||||
if 'model_load_balancing_enabled' in billing_info:
|
||||
features.model_load_balancing_enabled = billing_info['model_load_balancing_enabled']
|
||||
if "model_load_balancing_enabled" in billing_info:
|
||||
features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
|
||||
|
||||
@classmethod
|
||||
def _fulfill_params_from_enterprise(cls, features):
|
||||
enterprise_info = EnterpriseService.get_info()
|
||||
|
||||
features.sso_enforced_for_signin = enterprise_info['sso_enforced_for_signin']
|
||||
features.sso_enforced_for_signin_protocol = enterprise_info['sso_enforced_for_signin_protocol']
|
||||
features.sso_enforced_for_web = enterprise_info['sso_enforced_for_web']
|
||||
features.sso_enforced_for_web_protocol = enterprise_info['sso_enforced_for_web_protocol']
|
||||
features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"]
|
||||
features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"]
|
||||
features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"]
|
||||
features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"]
|
||||
|
@ -17,27 +17,45 @@ from models.account import Account
|
||||
from models.model import EndUser, UploadFile
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
|
||||
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg']
|
||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls', 'docx', 'csv']
|
||||
UNSTRUCTURED_ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'xls',
|
||||
'docx', 'csv', 'eml', 'msg', 'pptx', 'ppt', 'xml', 'epub']
|
||||
ALLOWED_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
|
||||
UNSTRUCTURED_ALLOWED_EXTENSIONS = [
|
||||
"txt",
|
||||
"markdown",
|
||||
"md",
|
||||
"pdf",
|
||||
"html",
|
||||
"htm",
|
||||
"xlsx",
|
||||
"xls",
|
||||
"docx",
|
||||
"csv",
|
||||
"eml",
|
||||
"msg",
|
||||
"pptx",
|
||||
"ppt",
|
||||
"xml",
|
||||
"epub",
|
||||
]
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
class FileService:
|
||||
|
||||
@staticmethod
|
||||
def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
|
||||
filename = file.filename
|
||||
extension = file.filename.split('.')[-1]
|
||||
extension = file.filename.split(".")[-1]
|
||||
if len(filename) > 200:
|
||||
filename = filename.split('.')[0][:200] + '.' + extension
|
||||
filename = filename.split(".")[0][:200] + "." + extension
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS if etl_type == 'Unstructured' \
|
||||
allowed_extensions = (
|
||||
UNSTRUCTURED_ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
|
||||
if etl_type == "Unstructured"
|
||||
else ALLOWED_EXTENSIONS + IMAGE_EXTENSIONS
|
||||
)
|
||||
if extension.lower() not in allowed_extensions:
|
||||
raise UnsupportedFileTypeError()
|
||||
elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
|
||||
@ -55,7 +73,7 @@ class FileService:
|
||||
file_size_limit = dify_config.UPLOAD_FILE_SIZE_LIMIT * 1024 * 1024
|
||||
|
||||
if file_size > file_size_limit:
|
||||
message = f'File size exceeded. {file_size} > {file_size_limit}'
|
||||
message = f"File size exceeded. {file_size} > {file_size_limit}"
|
||||
raise FileTooLargeError(message)
|
||||
|
||||
# user uuid as file name
|
||||
@ -67,7 +85,7 @@ class FileService:
|
||||
# end_user
|
||||
current_tenant_id = user.tenant_id
|
||||
|
||||
file_key = 'upload_files/' + current_tenant_id + '/' + file_uuid + '.' + extension
|
||||
file_key = "upload_files/" + current_tenant_id + "/" + file_uuid + "." + extension
|
||||
|
||||
# save file to storage
|
||||
storage.save(file_key, file_content)
|
||||
@ -81,11 +99,11 @@ class FileService:
|
||||
size=file_size,
|
||||
extension=extension,
|
||||
mime_type=file.mimetype,
|
||||
created_by_role=('account' if isinstance(user, Account) else 'end_user'),
|
||||
created_by_role=("account" if isinstance(user, Account) else "end_user"),
|
||||
created_by=user.id,
|
||||
created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(file_content).hexdigest()
|
||||
hash=hashlib.sha3_256(file_content).hexdigest(),
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
@ -99,10 +117,10 @@ class FileService:
|
||||
text_name = text_name[:200]
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt'
|
||||
file_key = "upload_files/" + current_user.current_tenant_id + "/" + file_uuid + ".txt"
|
||||
|
||||
# save file to storage
|
||||
storage.save(file_key, text.encode('utf-8'))
|
||||
storage.save(file_key, text.encode("utf-8"))
|
||||
|
||||
# save file to db
|
||||
upload_file = UploadFile(
|
||||
@ -111,13 +129,13 @@ class FileService:
|
||||
key=file_key,
|
||||
name=text_name,
|
||||
size=len(text),
|
||||
extension='txt',
|
||||
mime_type='text/plain',
|
||||
extension="txt",
|
||||
mime_type="text/plain",
|
||||
created_by=current_user.id,
|
||||
created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
used=True,
|
||||
used_by=current_user.id,
|
||||
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
@ -127,9 +145,7 @@ class FileService:
|
||||
|
||||
@staticmethod
|
||||
def get_file_preview(file_id: str) -> str:
|
||||
upload_file = db.session.query(UploadFile) \
|
||||
.filter(UploadFile.id == file_id) \
|
||||
.first()
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
@ -137,12 +153,12 @@ class FileService:
|
||||
# extract text from file
|
||||
extension = upload_file.extension
|
||||
etl_type = dify_config.ETL_TYPE
|
||||
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS
|
||||
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
|
||||
if extension.lower() not in allowed_extensions:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
text = ExtractProcessor.load_from_upload_file(upload_file, return_text=True)
|
||||
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
|
||||
text = text[0:PREVIEW_WORDS_LIMIT] if text else ""
|
||||
|
||||
return text
|
||||
|
||||
@ -152,9 +168,7 @@ class FileService:
|
||||
if not result:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
|
||||
upload_file = db.session.query(UploadFile) \
|
||||
.filter(UploadFile.id == file_id) \
|
||||
.first()
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
@ -170,9 +184,7 @@ class FileService:
|
||||
|
||||
@staticmethod
|
||||
def get_public_image_preview(file_id: str) -> tuple[Generator, str]:
|
||||
upload_file = db.session.query(UploadFile) \
|
||||
.filter(UploadFile.id == file_id) \
|
||||
.first()
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
|
@ -9,14 +9,11 @@ from models.account import Account
|
||||
from models.dataset import Dataset, DatasetQuery, DocumentSegment
|
||||
|
||||
default_retrieval_model = {
|
||||
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
@ -27,9 +24,9 @@ class HitTestingService:
|
||||
return {
|
||||
"query": {
|
||||
"content": query,
|
||||
"tsne_position": {'x': 0, 'y': 0},
|
||||
"tsne_position": {"x": 0, "y": 0},
|
||||
},
|
||||
"records": []
|
||||
"records": [],
|
||||
}
|
||||
|
||||
start = time.perf_counter()
|
||||
@ -38,28 +35,28 @@ class HitTestingService:
|
||||
if not retrieval_model:
|
||||
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
all_documents = RetrievalService.retrieve(retrival_method=retrieval_model.get('search_method', 'semantic_search'),
|
||||
dataset_id=dataset.id,
|
||||
query=cls.escape_query_for_search(query),
|
||||
top_k=retrieval_model.get('top_k', 2),
|
||||
score_threshold=retrieval_model.get('score_threshold', .0)
|
||||
if retrieval_model['score_threshold_enabled'] else None,
|
||||
reranking_model=retrieval_model.get('reranking_model', None)
|
||||
if retrieval_model['reranking_enable'] else None,
|
||||
reranking_mode=retrieval_model.get('reranking_mode')
|
||||
if retrieval_model.get('reranking_mode') else 'reranking_model',
|
||||
weights=retrieval_model.get('weights', None),
|
||||
)
|
||||
all_documents = RetrievalService.retrieve(
|
||||
retrival_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
dataset_id=dataset.id,
|
||||
query=cls.escape_query_for_search(query),
|
||||
top_k=retrieval_model.get("top_k", 2),
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else None,
|
||||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode")
|
||||
if retrieval_model.get("reranking_mode")
|
||||
else "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
)
|
||||
|
||||
end = time.perf_counter()
|
||||
logging.debug(f"Hit testing retrieve in {end - start:0.4f} seconds")
|
||||
|
||||
dataset_query = DatasetQuery(
|
||||
dataset_id=dataset.id,
|
||||
content=query,
|
||||
source='hit_testing',
|
||||
created_by_role='account',
|
||||
created_by=account.id
|
||||
dataset_id=dataset.id, content=query, source="hit_testing", created_by_role="account", created_by=account.id
|
||||
)
|
||||
|
||||
db.session.add(dataset_query)
|
||||
@ -72,14 +69,18 @@ class HitTestingService:
|
||||
i = 0
|
||||
records = []
|
||||
for document in documents:
|
||||
index_node_id = document.metadata['doc_id']
|
||||
index_node_id = document.metadata["doc_id"]
|
||||
|
||||
segment = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.index_node_id == index_node_id
|
||||
).first()
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not segment:
|
||||
i += 1
|
||||
@ -87,7 +88,7 @@ class HitTestingService:
|
||||
|
||||
record = {
|
||||
"segment": segment,
|
||||
"score": document.metadata.get('score', None),
|
||||
"score": document.metadata.get("score", None),
|
||||
}
|
||||
|
||||
records.append(record)
|
||||
@ -98,15 +99,15 @@ class HitTestingService:
|
||||
"query": {
|
||||
"content": query,
|
||||
},
|
||||
"records": records
|
||||
"records": records,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def hit_testing_args_check(cls, args):
|
||||
query = args['query']
|
||||
query = args["query"]
|
||||
|
||||
if not query or len(query) > 250:
|
||||
raise ValueError('Query is required and cannot exceed 250 characters')
|
||||
raise ValueError("Query is required and cannot exceed 250 characters")
|
||||
|
||||
@staticmethod
|
||||
def escape_query_for_search(query: str) -> str:
|
||||
|
@ -27,8 +27,14 @@ from services.workflow_service import WorkflowService
|
||||
|
||||
class MessageService:
|
||||
@classmethod
|
||||
def pagination_by_first_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
|
||||
conversation_id: str, first_id: Optional[str], limit: int) -> InfiniteScrollPagination:
|
||||
def pagination_by_first_id(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
conversation_id: str,
|
||||
first_id: Optional[str],
|
||||
limit: int,
|
||||
) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
@ -36,52 +42,69 @@ class MessageService:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
conversation_id=conversation_id
|
||||
app_model=app_model, user=user, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
if first_id:
|
||||
first_message = db.session.query(Message) \
|
||||
.filter(Message.conversation_id == conversation.id, Message.id == first_id).first()
|
||||
first_message = (
|
||||
db.session.query(Message)
|
||||
.filter(Message.conversation_id == conversation.id, Message.id == first_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not first_message:
|
||||
raise FirstMessageNotExistsError()
|
||||
|
||||
history_messages = db.session.query(Message).filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < first_message.created_at,
|
||||
Message.id != first_message.id
|
||||
) \
|
||||
.order_by(Message.created_at.desc()).limit(limit).all()
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < first_message.created_at,
|
||||
Message.id != first_message.id,
|
||||
)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \
|
||||
.order_by(Message.created_at.desc()).limit(limit).all()
|
||||
history_messages = (
|
||||
db.session.query(Message)
|
||||
.filter(Message.conversation_id == conversation.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
has_more = False
|
||||
if len(history_messages) == limit:
|
||||
current_page_first_message = history_messages[-1]
|
||||
rest_count = db.session.query(Message).filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id
|
||||
).count()
|
||||
rest_count = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.conversation_id == conversation.id,
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
|
||||
return InfiniteScrollPagination(
|
||||
data=history_messages,
|
||||
limit=limit,
|
||||
has_more=has_more
|
||||
)
|
||||
return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
|
||||
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str], limit: int, conversation_id: Optional[str] = None,
|
||||
include_ids: Optional[list] = None) -> InfiniteScrollPagination:
|
||||
def pagination_by_last_id(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str],
|
||||
limit: int,
|
||||
conversation_id: Optional[str] = None,
|
||||
include_ids: Optional[list] = None,
|
||||
) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
@ -89,9 +112,7 @@ class MessageService:
|
||||
|
||||
if conversation_id is not None:
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
conversation_id=conversation_id
|
||||
app_model=app_model, user=user, conversation_id=conversation_id
|
||||
)
|
||||
|
||||
base_query = base_query.filter(Message.conversation_id == conversation.id)
|
||||
@ -105,10 +126,12 @@ class MessageService:
|
||||
if not last_message:
|
||||
raise LastMessageNotExistsError()
|
||||
|
||||
history_messages = base_query.filter(
|
||||
Message.created_at < last_message.created_at,
|
||||
Message.id != last_message.id
|
||||
).order_by(Message.created_at.desc()).limit(limit).all()
|
||||
history_messages = (
|
||||
base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
history_messages = base_query.order_by(Message.created_at.desc()).limit(limit).all()
|
||||
|
||||
@ -116,30 +139,22 @@ class MessageService:
|
||||
if len(history_messages) == limit:
|
||||
current_page_first_message = history_messages[-1]
|
||||
rest_count = base_query.filter(
|
||||
Message.created_at < current_page_first_message.created_at,
|
||||
Message.id != current_page_first_message.id
|
||||
Message.created_at < current_page_first_message.created_at, Message.id != current_page_first_message.id
|
||||
).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
return InfiniteScrollPagination(
|
||||
data=history_messages,
|
||||
limit=limit,
|
||||
has_more=has_more
|
||||
)
|
||||
return InfiniteScrollPagination(data=history_messages, limit=limit, has_more=has_more)
|
||||
|
||||
@classmethod
|
||||
def create_feedback(cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]],
|
||||
rating: Optional[str]) -> MessageFeedback:
|
||||
def create_feedback(
|
||||
cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str]
|
||||
) -> MessageFeedback:
|
||||
if not user:
|
||||
raise ValueError('user cannot be None')
|
||||
raise ValueError("user cannot be None")
|
||||
|
||||
message = cls.get_message(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
message_id=message_id
|
||||
)
|
||||
message = cls.get_message(app_model=app_model, user=user, message_id=message_id)
|
||||
|
||||
feedback = message.user_feedback if isinstance(user, EndUser) else message.admin_feedback
|
||||
|
||||
@ -148,14 +163,14 @@ class MessageService:
|
||||
elif rating and feedback:
|
||||
feedback.rating = rating
|
||||
elif not rating and not feedback:
|
||||
raise ValueError('rating cannot be None when feedback not exists')
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
else:
|
||||
feedback = MessageFeedback(
|
||||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=rating,
|
||||
from_source=('user' if isinstance(user, EndUser) else 'admin'),
|
||||
from_source=("user" if isinstance(user, EndUser) else "admin"),
|
||||
from_end_user_id=(user.id if isinstance(user, EndUser) else None),
|
||||
from_account_id=(user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
@ -167,13 +182,17 @@ class MessageService:
|
||||
|
||||
@classmethod
|
||||
def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
).first()
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
@ -181,27 +200,22 @@ class MessageService:
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
def get_suggested_questions_after_answer(cls, app_model: App, user: Optional[Union[Account, EndUser]],
|
||||
message_id: str, invoke_from: InvokeFrom) -> list[Message]:
|
||||
def get_suggested_questions_after_answer(
|
||||
cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str, invoke_from: InvokeFrom
|
||||
) -> list[Message]:
|
||||
if not user:
|
||||
raise ValueError('user cannot be None')
|
||||
raise ValueError("user cannot be None")
|
||||
|
||||
message = cls.get_message(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
message_id=message_id
|
||||
)
|
||||
message = cls.get_message(app_model=app_model, user=user, message_id=message_id)
|
||||
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model,
|
||||
conversation_id=message.conversation_id,
|
||||
user=user
|
||||
app_model=app_model, conversation_id=message.conversation_id, user=user
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError()
|
||||
|
||||
if conversation.status != 'normal':
|
||||
if conversation.status != "normal":
|
||||
raise ConversationCompletedError()
|
||||
|
||||
model_manager = ModelManager()
|
||||
@ -216,24 +230,23 @@ class MessageService:
|
||||
if workflow is None:
|
||||
return []
|
||||
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(
|
||||
app_model=app_model,
|
||||
workflow=workflow
|
||||
)
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
if not app_config.additional_features.suggested_questions_after_answer:
|
||||
raise SuggestedQuestionsAfterAnswerDisabledError()
|
||||
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=app_model.tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
tenant_id=app_model.tenant_id, model_type=ModelType.LLM
|
||||
)
|
||||
else:
|
||||
if not conversation.override_model_configs:
|
||||
app_model_config = db.session.query(AppModelConfig).filter(
|
||||
AppModelConfig.id == conversation.app_model_config_id,
|
||||
AppModelConfig.app_id == app_model.id
|
||||
).first()
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig)
|
||||
.filter(
|
||||
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
conversation_override_model_configs = json.loads(conversation.override_model_configs)
|
||||
app_model_config = AppModelConfig(
|
||||
@ -249,16 +262,13 @@ class MessageService:
|
||||
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=app_model.tenant_id,
|
||||
provider=app_model_config.model_dict['provider'],
|
||||
provider=app_model_config.model_dict["provider"],
|
||||
model_type=ModelType.LLM,
|
||||
model=app_model_config.model_dict['name']
|
||||
model=app_model_config.model_dict["name"],
|
||||
)
|
||||
|
||||
# get memory of conversation (read-only)
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
histories = memory.get_history_prompt_text(
|
||||
max_token_limit=3000,
|
||||
@ -267,18 +277,14 @@ class MessageService:
|
||||
|
||||
with measure_time() as timer:
|
||||
questions = LLMGenerator.generate_suggested_questions_after_answer(
|
||||
tenant_id=app_model.tenant_id,
|
||||
histories=histories
|
||||
tenant_id=app_model.tenant_id, histories=histories
|
||||
)
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_id=app_model.id)
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.SUGGESTED_QUESTION_TRACE,
|
||||
message_id=message_id,
|
||||
suggested_question=questions,
|
||||
timer=timer
|
||||
TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -23,7 +23,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelLoadBalancingService:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.provider_manager = ProviderManager()
|
||||
|
||||
@ -46,10 +45,7 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Enable model load balancing
|
||||
provider_configuration.enable_model_load_balancing(
|
||||
model=model,
|
||||
model_type=ModelType.value_of(model_type)
|
||||
)
|
||||
provider_configuration.enable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def disable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
|
||||
"""
|
||||
@ -70,13 +66,11 @@ class ModelLoadBalancingService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# disable model load balancing
|
||||
provider_configuration.disable_model_load_balancing(
|
||||
model=model,
|
||||
model_type=ModelType.value_of(model_type)
|
||||
)
|
||||
provider_configuration.disable_model_load_balancing(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def get_load_balancing_configs(self, tenant_id: str, provider: str, model: str, model_type: str) \
|
||||
-> tuple[bool, list[dict]]:
|
||||
def get_load_balancing_configs(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str
|
||||
) -> tuple[bool, list[dict]]:
|
||||
"""
|
||||
Get load balancing configurations.
|
||||
:param tenant_id: workspace id
|
||||
@ -107,20 +101,24 @@ class ModelLoadBalancingService:
|
||||
is_load_balancing_enabled = True
|
||||
|
||||
# Get load balancing configurations
|
||||
load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
|
||||
load_balancing_configs = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model
|
||||
).order_by(LoadBalancingModelConfig.created_at).all()
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
.order_by(LoadBalancingModelConfig.created_at)
|
||||
.all()
|
||||
)
|
||||
|
||||
if provider_configuration.custom_configuration.provider:
|
||||
# check if the inherit configuration exists,
|
||||
# inherit is represented for the provider or model custom credentials
|
||||
inherit_config_exists = False
|
||||
for load_balancing_config in load_balancing_configs:
|
||||
if load_balancing_config.name == '__inherit__':
|
||||
if load_balancing_config.name == "__inherit__":
|
||||
inherit_config_exists = True
|
||||
break
|
||||
|
||||
@ -133,7 +131,7 @@ class ModelLoadBalancingService:
|
||||
else:
|
||||
# move the inherit configuration to the first
|
||||
for i, load_balancing_config in enumerate(load_balancing_configs[:]):
|
||||
if load_balancing_config.name == '__inherit__':
|
||||
if load_balancing_config.name == "__inherit__":
|
||||
inherit_config = load_balancing_configs.pop(i)
|
||||
load_balancing_configs.insert(0, inherit_config)
|
||||
|
||||
@ -151,7 +149,7 @@ class ModelLoadBalancingService:
|
||||
provider=provider,
|
||||
model=model,
|
||||
model_type=model_type,
|
||||
config_id=load_balancing_config.id
|
||||
config_id=load_balancing_config.id,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -172,32 +170,32 @@ class ModelLoadBalancingService:
|
||||
if variable in credentials:
|
||||
try:
|
||||
credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
credentials.get(variable),
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa
|
||||
credentials.get(variable), decoding_rsa_key, decoding_cipher_rsa
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Obfuscate credentials
|
||||
credentials = provider_configuration.obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
credential_form_schemas=credential_schemas.credential_form_schemas
|
||||
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
|
||||
)
|
||||
|
||||
datas.append({
|
||||
'id': load_balancing_config.id,
|
||||
'name': load_balancing_config.name,
|
||||
'credentials': credentials,
|
||||
'enabled': load_balancing_config.enabled,
|
||||
'in_cooldown': in_cooldown,
|
||||
'ttl': ttl
|
||||
})
|
||||
datas.append(
|
||||
{
|
||||
"id": load_balancing_config.id,
|
||||
"name": load_balancing_config.name,
|
||||
"credentials": credentials,
|
||||
"enabled": load_balancing_config.enabled,
|
||||
"in_cooldown": in_cooldown,
|
||||
"ttl": ttl,
|
||||
}
|
||||
)
|
||||
|
||||
return is_load_balancing_enabled, datas
|
||||
|
||||
def get_load_balancing_config(self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str) \
|
||||
-> Optional[dict]:
|
||||
def get_load_balancing_config(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, config_id: str
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get load balancing configuration.
|
||||
:param tenant_id: workspace id
|
||||
@ -219,14 +217,17 @@ class ModelLoadBalancingService:
|
||||
model_type = ModelType.value_of(model_type)
|
||||
|
||||
# Get load balancing configurations
|
||||
load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
|
||||
load_balancing_model_config = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id
|
||||
).first()
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not load_balancing_model_config:
|
||||
return None
|
||||
@ -244,19 +245,19 @@ class ModelLoadBalancingService:
|
||||
|
||||
# Obfuscate credentials
|
||||
credentials = provider_configuration.obfuscated_credentials(
|
||||
credentials=credentials,
|
||||
credential_form_schemas=credential_schemas.credential_form_schemas
|
||||
credentials=credentials, credential_form_schemas=credential_schemas.credential_form_schemas
|
||||
)
|
||||
|
||||
return {
|
||||
'id': load_balancing_model_config.id,
|
||||
'name': load_balancing_model_config.name,
|
||||
'credentials': credentials,
|
||||
'enabled': load_balancing_model_config.enabled
|
||||
"id": load_balancing_model_config.id,
|
||||
"name": load_balancing_model_config.name,
|
||||
"credentials": credentials,
|
||||
"enabled": load_balancing_model_config.enabled,
|
||||
}
|
||||
|
||||
def _init_inherit_config(self, tenant_id: str, provider: str, model: str, model_type: ModelType) \
|
||||
-> LoadBalancingModelConfig:
|
||||
def _init_inherit_config(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: ModelType
|
||||
) -> LoadBalancingModelConfig:
|
||||
"""
|
||||
Initialize the inherit configuration.
|
||||
:param tenant_id: workspace id
|
||||
@ -271,18 +272,16 @@ class ModelLoadBalancingService:
|
||||
provider_name=provider,
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
name='__inherit__'
|
||||
name="__inherit__",
|
||||
)
|
||||
db.session.add(inherit_config)
|
||||
db.session.commit()
|
||||
|
||||
return inherit_config
|
||||
|
||||
def update_load_balancing_configs(self, tenant_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
model_type: str,
|
||||
configs: list[dict]) -> None:
|
||||
def update_load_balancing_configs(
|
||||
self, tenant_id: str, provider: str, model: str, model_type: str, configs: list[dict]
|
||||
) -> None:
|
||||
"""
|
||||
Update load balancing configurations.
|
||||
:param tenant_id: workspace id
|
||||
@ -304,15 +303,18 @@ class ModelLoadBalancingService:
|
||||
model_type = ModelType.value_of(model_type)
|
||||
|
||||
if not isinstance(configs, list):
|
||||
raise ValueError('Invalid load balancing configs')
|
||||
raise ValueError("Invalid load balancing configs")
|
||||
|
||||
current_load_balancing_configs = db.session.query(LoadBalancingModelConfig) \
|
||||
current_load_balancing_configs = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model
|
||||
).all()
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# id as key, config as value
|
||||
current_load_balancing_configs_dict = {config.id: config for config in current_load_balancing_configs}
|
||||
@ -320,25 +322,25 @@ class ModelLoadBalancingService:
|
||||
|
||||
for config in configs:
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError('Invalid load balancing config')
|
||||
raise ValueError("Invalid load balancing config")
|
||||
|
||||
config_id = config.get('id')
|
||||
name = config.get('name')
|
||||
credentials = config.get('credentials')
|
||||
enabled = config.get('enabled')
|
||||
config_id = config.get("id")
|
||||
name = config.get("name")
|
||||
credentials = config.get("credentials")
|
||||
enabled = config.get("enabled")
|
||||
|
||||
if not name:
|
||||
raise ValueError('Invalid load balancing config name')
|
||||
raise ValueError("Invalid load balancing config name")
|
||||
|
||||
if enabled is None:
|
||||
raise ValueError('Invalid load balancing config enabled')
|
||||
raise ValueError("Invalid load balancing config enabled")
|
||||
|
||||
# is config exists
|
||||
if config_id:
|
||||
config_id = str(config_id)
|
||||
|
||||
if config_id not in current_load_balancing_configs_dict:
|
||||
raise ValueError('Invalid load balancing config id: {}'.format(config_id))
|
||||
raise ValueError("Invalid load balancing config id: {}".format(config_id))
|
||||
|
||||
updated_config_ids.add(config_id)
|
||||
|
||||
@ -347,11 +349,11 @@ class ModelLoadBalancingService:
|
||||
# check duplicate name
|
||||
for current_load_balancing_config in current_load_balancing_configs:
|
||||
if current_load_balancing_config.id != config_id and current_load_balancing_config.name == name:
|
||||
raise ValueError('Load balancing config name {} already exists'.format(name))
|
||||
raise ValueError("Load balancing config name {} already exists".format(name))
|
||||
|
||||
if credentials:
|
||||
if not isinstance(credentials, dict):
|
||||
raise ValueError('Invalid load balancing config credentials')
|
||||
raise ValueError("Invalid load balancing config credentials")
|
||||
|
||||
# validate custom provider config
|
||||
credentials = self._custom_credentials_validate(
|
||||
@ -361,7 +363,7 @@ class ModelLoadBalancingService:
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
load_balancing_model_config=load_balancing_config,
|
||||
validate=False
|
||||
validate=False,
|
||||
)
|
||||
|
||||
# update load balancing config
|
||||
@ -375,19 +377,19 @@ class ModelLoadBalancingService:
|
||||
self._clear_credentials_cache(tenant_id, config_id)
|
||||
else:
|
||||
# create load balancing config
|
||||
if name == '__inherit__':
|
||||
raise ValueError('Invalid load balancing config name')
|
||||
if name == "__inherit__":
|
||||
raise ValueError("Invalid load balancing config name")
|
||||
|
||||
# check duplicate name
|
||||
for current_load_balancing_config in current_load_balancing_configs:
|
||||
if current_load_balancing_config.name == name:
|
||||
raise ValueError('Load balancing config name {} already exists'.format(name))
|
||||
raise ValueError("Load balancing config name {} already exists".format(name))
|
||||
|
||||
if not credentials:
|
||||
raise ValueError('Invalid load balancing config credentials')
|
||||
raise ValueError("Invalid load balancing config credentials")
|
||||
|
||||
if not isinstance(credentials, dict):
|
||||
raise ValueError('Invalid load balancing config credentials')
|
||||
raise ValueError("Invalid load balancing config credentials")
|
||||
|
||||
# validate custom provider config
|
||||
credentials = self._custom_credentials_validate(
|
||||
@ -396,7 +398,7 @@ class ModelLoadBalancingService:
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
validate=False
|
||||
validate=False,
|
||||
)
|
||||
|
||||
# create load balancing config
|
||||
@ -406,7 +408,7 @@ class ModelLoadBalancingService:
|
||||
model_type=model_type.to_origin_model_type(),
|
||||
model_name=model,
|
||||
name=name,
|
||||
encrypted_config=json.dumps(credentials)
|
||||
encrypted_config=json.dumps(credentials),
|
||||
)
|
||||
|
||||
db.session.add(load_balancing_model_config)
|
||||
@ -420,12 +422,15 @@ class ModelLoadBalancingService:
|
||||
|
||||
self._clear_credentials_cache(tenant_id, config_id)
|
||||
|
||||
def validate_load_balancing_credentials(self, tenant_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
model_type: str,
|
||||
credentials: dict,
|
||||
config_id: Optional[str] = None) -> None:
|
||||
def validate_load_balancing_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
model_type: str,
|
||||
credentials: dict,
|
||||
config_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Validate load balancing credentials.
|
||||
:param tenant_id: workspace id
|
||||
@ -450,14 +455,17 @@ class ModelLoadBalancingService:
|
||||
load_balancing_model_config = None
|
||||
if config_id:
|
||||
# Get load balancing config
|
||||
load_balancing_model_config = db.session.query(LoadBalancingModelConfig) \
|
||||
load_balancing_model_config = (
|
||||
db.session.query(LoadBalancingModelConfig)
|
||||
.filter(
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id
|
||||
).first()
|
||||
LoadBalancingModelConfig.tenant_id == tenant_id,
|
||||
LoadBalancingModelConfig.provider_name == provider,
|
||||
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
|
||||
LoadBalancingModelConfig.model_name == model,
|
||||
LoadBalancingModelConfig.id == config_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not load_balancing_model_config:
|
||||
raise ValueError(f"Load balancing config {config_id} does not exist.")
|
||||
@ -469,16 +477,19 @@ class ModelLoadBalancingService:
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
load_balancing_model_config=load_balancing_model_config
|
||||
load_balancing_model_config=load_balancing_model_config,
|
||||
)
|
||||
|
||||
def _custom_credentials_validate(self, tenant_id: str,
|
||||
provider_configuration: ProviderConfiguration,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
|
||||
validate: bool = True) -> dict:
|
||||
def _custom_credentials_validate(
|
||||
self,
|
||||
tenant_id: str,
|
||||
provider_configuration: ProviderConfiguration,
|
||||
model_type: ModelType,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
load_balancing_model_config: Optional[LoadBalancingModelConfig] = None,
|
||||
validate: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate custom credentials.
|
||||
:param tenant_id: workspace id
|
||||
@ -521,12 +532,11 @@ class ModelLoadBalancingService:
|
||||
provider=provider_configuration.provider.provider,
|
||||
model_type=model_type,
|
||||
model=model,
|
||||
credentials=credentials
|
||||
credentials=credentials,
|
||||
)
|
||||
else:
|
||||
credentials = model_provider_factory.provider_credentials_validate(
|
||||
provider=provider_configuration.provider.provider,
|
||||
credentials=credentials
|
||||
provider=provider_configuration.provider.provider, credentials=credentials
|
||||
)
|
||||
|
||||
for key, value in credentials.items():
|
||||
@ -535,8 +545,9 @@ class ModelLoadBalancingService:
|
||||
|
||||
return credentials
|
||||
|
||||
def _get_credential_schema(self, provider_configuration: ProviderConfiguration) \
|
||||
-> ModelCredentialSchema | ProviderCredentialSchema:
|
||||
def _get_credential_schema(
|
||||
self, provider_configuration: ProviderConfiguration
|
||||
) -> ModelCredentialSchema | ProviderCredentialSchema:
|
||||
"""
|
||||
Get form schemas.
|
||||
:param provider_configuration: provider configuration
|
||||
@ -558,9 +569,7 @@ class ModelLoadBalancingService:
|
||||
:return:
|
||||
"""
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=tenant_id,
|
||||
identity_id=config_id,
|
||||
cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
|
||||
tenant_id=tenant_id, identity_id=config_id, cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL
|
||||
)
|
||||
|
||||
provider_model_credentials_cache.delete()
|
||||
|
@ -73,8 +73,8 @@ class ModelProviderService:
|
||||
system_configuration=SystemConfigurationResponse(
|
||||
enabled=provider_configuration.system_configuration.enabled,
|
||||
current_quota_type=provider_configuration.system_configuration.current_quota_type,
|
||||
quota_configurations=provider_configuration.system_configuration.quota_configurations
|
||||
)
|
||||
quota_configurations=provider_configuration.system_configuration.quota_configurations,
|
||||
),
|
||||
)
|
||||
|
||||
provider_responses.append(provider_response)
|
||||
@ -95,9 +95,9 @@ class ModelProviderService:
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider available models
|
||||
return [ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(
|
||||
provider=provider
|
||||
)]
|
||||
return [
|
||||
ModelWithProviderEntityResponse(model) for model in provider_configurations.get_models(provider=provider)
|
||||
]
|
||||
|
||||
def get_provider_credentials(self, tenant_id: str, provider: str) -> dict:
|
||||
"""
|
||||
@ -195,13 +195,12 @@ class ModelProviderService:
|
||||
|
||||
# Get model custom credentials from ProviderModel if exists
|
||||
return provider_configuration.get_custom_model_credentials(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model=model,
|
||||
obfuscated=True
|
||||
model_type=ModelType.value_of(model_type), model=model, obfuscated=True
|
||||
)
|
||||
|
||||
def model_credentials_validate(self, tenant_id: str, provider: str, model_type: str, model: str,
|
||||
credentials: dict) -> None:
|
||||
def model_credentials_validate(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
|
||||
) -> None:
|
||||
"""
|
||||
validate model credentials.
|
||||
|
||||
@ -222,13 +221,12 @@ class ModelProviderService:
|
||||
|
||||
# Validate model credentials
|
||||
provider_configuration.custom_model_credentials_validate(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model=model,
|
||||
credentials=credentials
|
||||
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
|
||||
)
|
||||
|
||||
def save_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str,
|
||||
credentials: dict) -> None:
|
||||
def save_model_credentials(
|
||||
self, tenant_id: str, provider: str, model_type: str, model: str, credentials: dict
|
||||
) -> None:
|
||||
"""
|
||||
save model credentials.
|
||||
|
||||
@ -249,9 +247,7 @@ class ModelProviderService:
|
||||
|
||||
# Add or update custom model credentials
|
||||
provider_configuration.add_or_update_custom_model_credentials(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model=model,
|
||||
credentials=credentials
|
||||
model_type=ModelType.value_of(model_type), model=model, credentials=credentials
|
||||
)
|
||||
|
||||
def remove_model_credentials(self, tenant_id: str, provider: str, model_type: str, model: str) -> None:
|
||||
@ -273,10 +269,7 @@ class ModelProviderService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Remove custom model credentials
|
||||
provider_configuration.delete_custom_model_credentials(
|
||||
model_type=ModelType.value_of(model_type),
|
||||
model=model
|
||||
)
|
||||
provider_configuration.delete_custom_model_credentials(model_type=ModelType.value_of(model_type), model=model)
|
||||
|
||||
def get_models_by_model_type(self, tenant_id: str, model_type: str) -> list[ProviderWithModelsResponse]:
|
||||
"""
|
||||
@ -290,9 +283,7 @@ class ModelProviderService:
|
||||
provider_configurations = self.provider_manager.get_configurations(tenant_id)
|
||||
|
||||
# Get provider available models
|
||||
models = provider_configurations.get_models(
|
||||
model_type=ModelType.value_of(model_type)
|
||||
)
|
||||
models = provider_configurations.get_models(model_type=ModelType.value_of(model_type))
|
||||
|
||||
# Group models by provider
|
||||
provider_models = {}
|
||||
@ -323,16 +314,19 @@ class ModelProviderService:
|
||||
icon_small=first_model.provider.icon_small,
|
||||
icon_large=first_model.provider.icon_large,
|
||||
status=CustomConfigurationStatus.ACTIVE,
|
||||
models=[ProviderModelWithStatusEntity(
|
||||
model=model.model,
|
||||
label=model.label,
|
||||
model_type=model.model_type,
|
||||
features=model.features,
|
||||
fetch_from=model.fetch_from,
|
||||
model_properties=model.model_properties,
|
||||
status=model.status,
|
||||
load_balancing_enabled=model.load_balancing_enabled
|
||||
) for model in models]
|
||||
models=[
|
||||
ProviderModelWithStatusEntity(
|
||||
model=model.model,
|
||||
label=model.label,
|
||||
model_type=model.model_type,
|
||||
features=model.features,
|
||||
fetch_from=model.fetch_from,
|
||||
model_properties=model.model_properties,
|
||||
status=model.status,
|
||||
load_balancing_enabled=model.load_balancing_enabled,
|
||||
)
|
||||
for model in models
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
@ -361,19 +355,13 @@ class ModelProviderService:
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# fetch credentials
|
||||
credentials = provider_configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model=model
|
||||
)
|
||||
credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model)
|
||||
|
||||
if not credentials:
|
||||
return []
|
||||
|
||||
# Call get_parameter_rules method of model instance to get model parameter rules
|
||||
return model_type_instance.get_parameter_rules(
|
||||
model=model,
|
||||
credentials=credentials
|
||||
)
|
||||
return model_type_instance.get_parameter_rules(model=model, credentials=credentials)
|
||||
|
||||
def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[DefaultModelResponse]:
|
||||
"""
|
||||
@ -384,22 +372,23 @@ class ModelProviderService:
|
||||
:return:
|
||||
"""
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
result = self.provider_manager.get_default_model(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type_enum
|
||||
)
|
||||
result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum)
|
||||
try:
|
||||
return DefaultModelResponse(
|
||||
model=result.model,
|
||||
model_type=result.model_type,
|
||||
provider=SimpleProviderEntityResponse(
|
||||
provider=result.provider.provider,
|
||||
label=result.provider.label,
|
||||
icon_small=result.provider.icon_small,
|
||||
icon_large=result.provider.icon_large,
|
||||
supported_model_types=result.provider.supported_model_types
|
||||
return (
|
||||
DefaultModelResponse(
|
||||
model=result.model,
|
||||
model_type=result.model_type,
|
||||
provider=SimpleProviderEntityResponse(
|
||||
provider=result.provider.provider,
|
||||
label=result.provider.label,
|
||||
icon_small=result.provider.icon_small,
|
||||
icon_large=result.provider.icon_large,
|
||||
supported_model_types=result.provider.supported_model_types,
|
||||
),
|
||||
)
|
||||
) if result else None
|
||||
if result
|
||||
else None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(f"get_default_model_of_model_type error: {e}")
|
||||
return None
|
||||
@ -416,13 +405,12 @@ class ModelProviderService:
|
||||
"""
|
||||
model_type_enum = ModelType.value_of(model_type)
|
||||
self.provider_manager.update_default_model_record(
|
||||
tenant_id=tenant_id,
|
||||
model_type=model_type_enum,
|
||||
provider=provider,
|
||||
model=model
|
||||
tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model
|
||||
)
|
||||
|
||||
def get_model_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[Optional[bytes], Optional[str]]:
|
||||
def get_model_provider_icon(
|
||||
self, provider: str, icon_type: str, lang: str
|
||||
) -> tuple[Optional[bytes], Optional[str]]:
|
||||
"""
|
||||
get model provider icon.
|
||||
|
||||
@ -434,11 +422,11 @@ class ModelProviderService:
|
||||
provider_instance = model_provider_factory.get_provider_instance(provider)
|
||||
provider_schema = provider_instance.get_provider_schema()
|
||||
|
||||
if icon_type.lower() == 'icon_small':
|
||||
if icon_type.lower() == "icon_small":
|
||||
if not provider_schema.icon_small:
|
||||
raise ValueError(f"Provider {provider} does not have small icon.")
|
||||
|
||||
if lang.lower() == 'zh_hans':
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_small.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_small.en_US
|
||||
@ -446,13 +434,15 @@ class ModelProviderService:
|
||||
if not provider_schema.icon_large:
|
||||
raise ValueError(f"Provider {provider} does not have large icon.")
|
||||
|
||||
if lang.lower() == 'zh_hans':
|
||||
if lang.lower() == "zh_hans":
|
||||
file_name = provider_schema.icon_large.zh_Hans
|
||||
else:
|
||||
file_name = provider_schema.icon_large.en_US
|
||||
|
||||
root_path = current_app.root_path
|
||||
provider_instance_path = os.path.dirname(os.path.join(root_path, provider_instance.__class__.__module__.replace('.', '/')))
|
||||
provider_instance_path = os.path.dirname(
|
||||
os.path.join(root_path, provider_instance.__class__.__module__.replace(".", "/"))
|
||||
)
|
||||
file_path = os.path.join(provider_instance_path, "_assets")
|
||||
file_path = os.path.join(file_path, file_name)
|
||||
|
||||
@ -460,10 +450,10 @@ class ModelProviderService:
|
||||
return None, None
|
||||
|
||||
mimetype, _ = mimetypes.guess_type(file_path)
|
||||
mimetype = mimetype or 'application/octet-stream'
|
||||
mimetype = mimetype or "application/octet-stream"
|
||||
|
||||
# read binary from file
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
byte_data = f.read()
|
||||
return byte_data, mimetype
|
||||
|
||||
@ -509,10 +499,7 @@ class ModelProviderService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Enable model
|
||||
provider_configuration.enable_model(
|
||||
model=model,
|
||||
model_type=ModelType.value_of(model_type)
|
||||
)
|
||||
provider_configuration.enable_model(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def disable_model(self, tenant_id: str, provider: str, model: str, model_type: str) -> None:
|
||||
"""
|
||||
@ -533,78 +520,49 @@ class ModelProviderService:
|
||||
raise ValueError(f"Provider {provider} does not exist.")
|
||||
|
||||
# Enable model
|
||||
provider_configuration.disable_model(
|
||||
model=model,
|
||||
model_type=ModelType.value_of(model_type)
|
||||
)
|
||||
provider_configuration.disable_model(model=model, model_type=ModelType.value_of(model_type))
|
||||
|
||||
def free_quota_submit(self, tenant_id: str, provider: str):
|
||||
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
|
||||
api_url = api_base_url + '/api/v1/providers/apply'
|
||||
api_url = api_base_url + "/api/v1/providers/apply"
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {api_key}"
|
||||
}
|
||||
response = requests.post(api_url, headers=headers, json={'workspace_id': tenant_id, 'provider_name': provider})
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
response = requests.post(api_url, headers=headers, json={"workspace_id": tenant_id, "provider_name": provider})
|
||||
if not response.ok:
|
||||
logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
|
||||
raise ValueError(f"Error: {response.status_code} ")
|
||||
|
||||
if response.json()["code"] != 'success':
|
||||
raise ValueError(
|
||||
f"error: {response.json()['message']}"
|
||||
)
|
||||
if response.json()["code"] != "success":
|
||||
raise ValueError(f"error: {response.json()['message']}")
|
||||
|
||||
rst = response.json()
|
||||
|
||||
if rst['type'] == 'redirect':
|
||||
return {
|
||||
'type': rst['type'],
|
||||
'redirect_url': rst['redirect_url']
|
||||
}
|
||||
if rst["type"] == "redirect":
|
||||
return {"type": rst["type"], "redirect_url": rst["redirect_url"]}
|
||||
else:
|
||||
return {
|
||||
'type': rst['type'],
|
||||
'result': 'success'
|
||||
}
|
||||
return {"type": rst["type"], "result": "success"}
|
||||
|
||||
def free_quota_qualification_verify(self, tenant_id: str, provider: str, token: Optional[str]):
|
||||
api_key = os.environ.get("FREE_QUOTA_APPLY_API_KEY")
|
||||
api_base_url = os.environ.get("FREE_QUOTA_APPLY_BASE_URL")
|
||||
api_url = api_base_url + '/api/v1/providers/qualification-verify'
|
||||
api_url = api_base_url + "/api/v1/providers/qualification-verify"
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f"Bearer {api_key}"
|
||||
}
|
||||
json_data = {'workspace_id': tenant_id, 'provider_name': provider}
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
json_data = {"workspace_id": tenant_id, "provider_name": provider}
|
||||
if token:
|
||||
json_data['token'] = token
|
||||
response = requests.post(api_url, headers=headers,
|
||||
json=json_data)
|
||||
json_data["token"] = token
|
||||
response = requests.post(api_url, headers=headers, json=json_data)
|
||||
if not response.ok:
|
||||
logger.error(f"Request FREE QUOTA APPLY SERVER Error: {response.status_code} ")
|
||||
raise ValueError(f"Error: {response.status_code} ")
|
||||
|
||||
rst = response.json()
|
||||
if rst["code"] != 'success':
|
||||
raise ValueError(
|
||||
f"error: {rst['message']}"
|
||||
)
|
||||
if rst["code"] != "success":
|
||||
raise ValueError(f"error: {rst['message']}")
|
||||
|
||||
data = rst['data']
|
||||
if data['qualified'] is True:
|
||||
return {
|
||||
'result': 'success',
|
||||
'provider_name': provider,
|
||||
'flag': True
|
||||
}
|
||||
data = rst["data"]
|
||||
if data["qualified"] is True:
|
||||
return {"result": "success", "provider_name": provider, "flag": True}
|
||||
else:
|
||||
return {
|
||||
'result': 'success',
|
||||
'provider_name': provider,
|
||||
'flag': False,
|
||||
'reason': data['reason']
|
||||
}
|
||||
return {"result": "success", "provider_name": provider, "flag": False, "reason": data["reason"]}
|
||||
|
@ -4,17 +4,18 @@ from models.model import App, AppModelConfig
|
||||
|
||||
|
||||
class ModerationService:
|
||||
|
||||
def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
|
||||
app_model_config: AppModelConfig = None
|
||||
|
||||
app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
)
|
||||
|
||||
if not app_model_config:
|
||||
raise ValueError("app model config not found")
|
||||
|
||||
name = app_model_config.sensitive_word_avoidance_dict['type']
|
||||
config = app_model_config.sensitive_word_avoidance_dict['config']
|
||||
name = app_model_config.sensitive_word_avoidance_dict["type"]
|
||||
config = app_model_config.sensitive_word_avoidance_dict["config"]
|
||||
|
||||
moderation = ModerationFactory(name, app_id, app_model.tenant_id, config)
|
||||
return moderation.moderation_for_outputs(text)
|
||||
|
@ -4,15 +4,12 @@ import requests
|
||||
|
||||
|
||||
class OperationService:
|
||||
base_url = os.environ.get('BILLING_API_URL', 'BILLING_API_URL')
|
||||
secret_key = os.environ.get('BILLING_API_SECRET_KEY', 'BILLING_API_SECRET_KEY')
|
||||
base_url = os.environ.get("BILLING_API_URL", "BILLING_API_URL")
|
||||
secret_key = os.environ.get("BILLING_API_SECRET_KEY", "BILLING_API_SECRET_KEY")
|
||||
|
||||
@classmethod
|
||||
def _send_request(cls, method, endpoint, json=None, params=None):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Billing-Api-Secret-Key": cls.secret_key
|
||||
}
|
||||
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
|
||||
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
response = requests.request(method, url, json=json, params=params, headers=headers)
|
||||
@ -22,11 +19,11 @@ class OperationService:
|
||||
@classmethod
|
||||
def record_utm(cls, tenant_id: str, utm_info: dict):
|
||||
params = {
|
||||
'tenant_id': tenant_id,
|
||||
'utm_source': utm_info.get('utm_source', ''),
|
||||
'utm_medium': utm_info.get('utm_medium', ''),
|
||||
'utm_campaign': utm_info.get('utm_campaign', ''),
|
||||
'utm_content': utm_info.get('utm_content', ''),
|
||||
'utm_term': utm_info.get('utm_term', '')
|
||||
"tenant_id": tenant_id,
|
||||
"utm_source": utm_info.get("utm_source", ""),
|
||||
"utm_medium": utm_info.get("utm_medium", ""),
|
||||
"utm_campaign": utm_info.get("utm_campaign", ""),
|
||||
"utm_content": utm_info.get("utm_content", ""),
|
||||
"utm_term": utm_info.get("utm_term", ""),
|
||||
}
|
||||
return cls._send_request('POST', '/tenant_utms', params=params)
|
||||
return cls._send_request("POST", "/tenant_utms", params=params)
|
||||
|
@ -12,19 +12,25 @@ class OpsService:
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter(
|
||||
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
|
||||
).first()
|
||||
trace_config_data: TraceAppConfig = (
|
||||
db.session.query(TraceAppConfig)
|
||||
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not trace_config_data:
|
||||
return None
|
||||
|
||||
# decrypt_token and obfuscated_token
|
||||
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
|
||||
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(tenant_id, tracing_provider, trace_config_data.tracing_config)
|
||||
if tracing_provider == 'langfuse' and ('project_key' not in decrypt_tracing_config or not decrypt_tracing_config.get('project_key')):
|
||||
decrypt_tracing_config = OpsTraceManager.decrypt_tracing_config(
|
||||
tenant_id, tracing_provider, trace_config_data.tracing_config
|
||||
)
|
||||
if tracing_provider == "langfuse" and (
|
||||
"project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key")
|
||||
):
|
||||
project_key = OpsTraceManager.get_trace_config_project_key(decrypt_tracing_config, tracing_provider)
|
||||
decrypt_tracing_config['project_key'] = project_key
|
||||
decrypt_tracing_config["project_key"] = project_key
|
||||
|
||||
decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config)
|
||||
|
||||
@ -44,8 +50,10 @@ class OpsService:
|
||||
if tracing_provider not in provider_config_map.keys() and tracing_provider:
|
||||
return {"error": f"Invalid tracing provider: {tracing_provider}"}
|
||||
|
||||
config_class, other_keys = provider_config_map[tracing_provider]['config_class'], \
|
||||
provider_config_map[tracing_provider]['other_keys']
|
||||
config_class, other_keys = (
|
||||
provider_config_map[tracing_provider]["config_class"],
|
||||
provider_config_map[tracing_provider]["other_keys"],
|
||||
)
|
||||
default_config_instance = config_class(**tracing_config)
|
||||
for key in other_keys:
|
||||
if key in tracing_config and tracing_config[key] == "":
|
||||
@ -59,9 +67,11 @@ class OpsService:
|
||||
project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
|
||||
|
||||
# check if trace config already exists
|
||||
trace_config_data: TraceAppConfig = db.session.query(TraceAppConfig).filter(
|
||||
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
|
||||
).first()
|
||||
trace_config_data: TraceAppConfig = (
|
||||
db.session.query(TraceAppConfig)
|
||||
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
)
|
||||
|
||||
if trace_config_data:
|
||||
return None
|
||||
@ -69,8 +79,8 @@ class OpsService:
|
||||
# get tenant id
|
||||
tenant_id = db.session.query(App).filter(App.id == app_id).first().tenant_id
|
||||
tracing_config = OpsTraceManager.encrypt_tracing_config(tenant_id, tracing_provider, tracing_config)
|
||||
if tracing_provider == 'langfuse':
|
||||
tracing_config['project_key'] = project_key
|
||||
if tracing_provider == "langfuse":
|
||||
tracing_config["project_key"] = project_key
|
||||
trace_config_data = TraceAppConfig(
|
||||
app_id=app_id,
|
||||
tracing_provider=tracing_provider,
|
||||
@ -94,9 +104,11 @@ class OpsService:
|
||||
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
|
||||
|
||||
# check if trace config already exists
|
||||
current_trace_config = db.session.query(TraceAppConfig).filter(
|
||||
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
|
||||
).first()
|
||||
current_trace_config = (
|
||||
db.session.query(TraceAppConfig)
|
||||
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not current_trace_config:
|
||||
return None
|
||||
@ -126,9 +138,11 @@ class OpsService:
|
||||
:param tracing_provider: tracing provider
|
||||
:return:
|
||||
"""
|
||||
trace_config = db.session.query(TraceAppConfig).filter(
|
||||
TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider
|
||||
).first()
|
||||
trace_config = (
|
||||
db.session.query(TraceAppConfig)
|
||||
.filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not trace_config:
|
||||
return None
|
||||
|
@ -16,7 +16,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RecommendedAppService:
|
||||
|
||||
builtin_data: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
@ -27,21 +26,21 @@ class RecommendedAppService:
|
||||
:return:
|
||||
"""
|
||||
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
|
||||
if mode == 'remote':
|
||||
if mode == "remote":
|
||||
try:
|
||||
result = cls._fetch_recommended_apps_from_dify_official(language)
|
||||
except Exception as e:
|
||||
logger.warning(f'fetch recommended apps from dify official failed: {e}, switch to built-in.')
|
||||
logger.warning(f"fetch recommended apps from dify official failed: {e}, switch to built-in.")
|
||||
result = cls._fetch_recommended_apps_from_builtin(language)
|
||||
elif mode == 'db':
|
||||
elif mode == "db":
|
||||
result = cls._fetch_recommended_apps_from_db(language)
|
||||
elif mode == 'builtin':
|
||||
elif mode == "builtin":
|
||||
result = cls._fetch_recommended_apps_from_builtin(language)
|
||||
else:
|
||||
raise ValueError(f'invalid fetch recommended apps mode: {mode}')
|
||||
raise ValueError(f"invalid fetch recommended apps mode: {mode}")
|
||||
|
||||
if not result.get('recommended_apps') and language != 'en-US':
|
||||
result = cls._fetch_recommended_apps_from_builtin('en-US')
|
||||
if not result.get("recommended_apps") and language != "en-US":
|
||||
result = cls._fetch_recommended_apps_from_builtin("en-US")
|
||||
|
||||
return result
|
||||
|
||||
@ -52,16 +51,18 @@ class RecommendedAppService:
|
||||
:param language: language
|
||||
:return:
|
||||
"""
|
||||
recommended_apps = db.session.query(RecommendedApp).filter(
|
||||
RecommendedApp.is_listed == True,
|
||||
RecommendedApp.language == language
|
||||
).all()
|
||||
recommended_apps = (
|
||||
db.session.query(RecommendedApp)
|
||||
.filter(RecommendedApp.is_listed == True, RecommendedApp.language == language)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(recommended_apps) == 0:
|
||||
recommended_apps = db.session.query(RecommendedApp).filter(
|
||||
RecommendedApp.is_listed == True,
|
||||
RecommendedApp.language == languages[0]
|
||||
).all()
|
||||
recommended_apps = (
|
||||
db.session.query(RecommendedApp)
|
||||
.filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
|
||||
.all()
|
||||
)
|
||||
|
||||
categories = set()
|
||||
recommended_apps_result = []
|
||||
@ -75,28 +76,28 @@ class RecommendedAppService:
|
||||
continue
|
||||
|
||||
recommended_app_result = {
|
||||
'id': recommended_app.id,
|
||||
'app': {
|
||||
'id': app.id,
|
||||
'name': app.name,
|
||||
'mode': app.mode,
|
||||
'icon': app.icon,
|
||||
'icon_background': app.icon_background
|
||||
"id": recommended_app.id,
|
||||
"app": {
|
||||
"id": app.id,
|
||||
"name": app.name,
|
||||
"mode": app.mode,
|
||||
"icon": app.icon,
|
||||
"icon_background": app.icon_background,
|
||||
},
|
||||
'app_id': recommended_app.app_id,
|
||||
'description': site.description,
|
||||
'copyright': site.copyright,
|
||||
'privacy_policy': site.privacy_policy,
|
||||
'custom_disclaimer': site.custom_disclaimer,
|
||||
'category': recommended_app.category,
|
||||
'position': recommended_app.position,
|
||||
'is_listed': recommended_app.is_listed
|
||||
"app_id": recommended_app.app_id,
|
||||
"description": site.description,
|
||||
"copyright": site.copyright,
|
||||
"privacy_policy": site.privacy_policy,
|
||||
"custom_disclaimer": site.custom_disclaimer,
|
||||
"category": recommended_app.category,
|
||||
"position": recommended_app.position,
|
||||
"is_listed": recommended_app.is_listed,
|
||||
}
|
||||
recommended_apps_result.append(recommended_app_result)
|
||||
|
||||
categories.add(recommended_app.category) # add category to categories
|
||||
|
||||
return {'recommended_apps': recommended_apps_result, 'categories': sorted(categories)}
|
||||
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
|
||||
|
||||
@classmethod
|
||||
def _fetch_recommended_apps_from_dify_official(cls, language: str) -> dict:
|
||||
@ -106,16 +107,16 @@ class RecommendedAppService:
|
||||
:return:
|
||||
"""
|
||||
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
|
||||
url = f'{domain}/apps?language={language}'
|
||||
url = f"{domain}/apps?language={language}"
|
||||
response = requests.get(url, timeout=(3, 10))
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f'fetch recommended apps failed, status code: {response.status_code}')
|
||||
raise ValueError(f"fetch recommended apps failed, status code: {response.status_code}")
|
||||
|
||||
result = response.json()
|
||||
|
||||
if "categories" in result:
|
||||
result["categories"] = sorted(result["categories"])
|
||||
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@ -126,7 +127,7 @@ class RecommendedAppService:
|
||||
:return:
|
||||
"""
|
||||
builtin_data = cls._get_builtin_data()
|
||||
return builtin_data.get('recommended_apps', {}).get(language)
|
||||
return builtin_data.get("recommended_apps", {}).get(language)
|
||||
|
||||
@classmethod
|
||||
def get_recommend_app_detail(cls, app_id: str) -> Optional[dict]:
|
||||
@ -136,18 +137,18 @@ class RecommendedAppService:
|
||||
:return:
|
||||
"""
|
||||
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
|
||||
if mode == 'remote':
|
||||
if mode == "remote":
|
||||
try:
|
||||
result = cls._fetch_recommended_app_detail_from_dify_official(app_id)
|
||||
except Exception as e:
|
||||
logger.warning(f'fetch recommended app detail from dify official failed: {e}, switch to built-in.')
|
||||
logger.warning(f"fetch recommended app detail from dify official failed: {e}, switch to built-in.")
|
||||
result = cls._fetch_recommended_app_detail_from_builtin(app_id)
|
||||
elif mode == 'db':
|
||||
elif mode == "db":
|
||||
result = cls._fetch_recommended_app_detail_from_db(app_id)
|
||||
elif mode == 'builtin':
|
||||
elif mode == "builtin":
|
||||
result = cls._fetch_recommended_app_detail_from_builtin(app_id)
|
||||
else:
|
||||
raise ValueError(f'invalid fetch recommended app detail mode: {mode}')
|
||||
raise ValueError(f"invalid fetch recommended app detail mode: {mode}")
|
||||
|
||||
return result
|
||||
|
||||
@ -159,7 +160,7 @@ class RecommendedAppService:
|
||||
:return:
|
||||
"""
|
||||
domain = dify_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN
|
||||
url = f'{domain}/apps/{app_id}'
|
||||
url = f"{domain}/apps/{app_id}"
|
||||
response = requests.get(url, timeout=(3, 10))
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
@ -174,10 +175,11 @@ class RecommendedAppService:
|
||||
:return:
|
||||
"""
|
||||
# is in public recommended list
|
||||
recommended_app = db.session.query(RecommendedApp).filter(
|
||||
RecommendedApp.is_listed == True,
|
||||
RecommendedApp.app_id == app_id
|
||||
).first()
|
||||
recommended_app = (
|
||||
db.session.query(RecommendedApp)
|
||||
.filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not recommended_app:
|
||||
return None
|
||||
@ -188,12 +190,12 @@ class RecommendedAppService:
|
||||
return None
|
||||
|
||||
return {
|
||||
'id': app_model.id,
|
||||
'name': app_model.name,
|
||||
'icon': app_model.icon,
|
||||
'icon_background': app_model.icon_background,
|
||||
'mode': app_model.mode,
|
||||
'export_data': AppDslService.export_dsl(app_model=app_model)
|
||||
"id": app_model.id,
|
||||
"name": app_model.name,
|
||||
"icon": app_model.icon,
|
||||
"icon_background": app_model.icon_background,
|
||||
"mode": app_model.mode,
|
||||
"export_data": AppDslService.export_dsl(app_model=app_model),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -204,7 +206,7 @@ class RecommendedAppService:
|
||||
:return:
|
||||
"""
|
||||
builtin_data = cls._get_builtin_data()
|
||||
return builtin_data.get('app_details', {}).get(app_id)
|
||||
return builtin_data.get("app_details", {}).get(app_id)
|
||||
|
||||
@classmethod
|
||||
def _get_builtin_data(cls) -> dict:
|
||||
@ -216,7 +218,7 @@ class RecommendedAppService:
|
||||
return cls.builtin_data
|
||||
|
||||
root_path = current_app.root_path
|
||||
with open(path.join(root_path, 'constants', 'recommended_apps.json'), encoding='utf-8') as f:
|
||||
with open(path.join(root_path, "constants", "recommended_apps.json"), encoding="utf-8") as f:
|
||||
json_data = f.read()
|
||||
data = json.loads(json_data)
|
||||
cls.builtin_data = data
|
||||
@ -229,27 +231,24 @@ class RecommendedAppService:
|
||||
Fetch all recommended apps and export datas
|
||||
:return:
|
||||
"""
|
||||
templates = {
|
||||
"recommended_apps": {},
|
||||
"app_details": {}
|
||||
}
|
||||
templates = {"recommended_apps": {}, "app_details": {}}
|
||||
for language in languages:
|
||||
try:
|
||||
result = cls._fetch_recommended_apps_from_dify_official(language)
|
||||
except Exception as e:
|
||||
logger.warning(f'fetch recommended apps from dify official failed: {e}, skip.')
|
||||
logger.warning(f"fetch recommended apps from dify official failed: {e}, skip.")
|
||||
continue
|
||||
|
||||
templates['recommended_apps'][language] = result
|
||||
templates["recommended_apps"][language] = result
|
||||
|
||||
for recommended_app in result.get('recommended_apps'):
|
||||
app_id = recommended_app.get('app_id')
|
||||
for recommended_app in result.get("recommended_apps"):
|
||||
app_id = recommended_app.get("app_id")
|
||||
|
||||
# get app detail
|
||||
app_detail = cls._fetch_recommended_app_detail_from_dify_official(app_id)
|
||||
if not app_detail:
|
||||
continue
|
||||
|
||||
templates['app_details'][app_id] = app_detail
|
||||
templates["app_details"][app_id] = app_detail
|
||||
|
||||
return templates
|
||||
|
@ -10,46 +10,48 @@ from services.message_service import MessageService
|
||||
|
||||
class SavedMessageService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str], limit: int) -> InfiniteScrollPagination:
|
||||
saved_messages = db.session.query(SavedMessage).filter(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
SavedMessage.created_by == user.id
|
||||
).order_by(SavedMessage.created_at.desc()).all()
|
||||
def pagination_by_last_id(
|
||||
cls, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int
|
||||
) -> InfiniteScrollPagination:
|
||||
saved_messages = (
|
||||
db.session.query(SavedMessage)
|
||||
.filter(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
SavedMessage.created_by == user.id,
|
||||
)
|
||||
.order_by(SavedMessage.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
message_ids = [sm.message_id for sm in saved_messages]
|
||||
|
||||
return MessageService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
last_id=last_id,
|
||||
limit=limit,
|
||||
include_ids=message_ids
|
||||
app_model=app_model, user=user, last_id=last_id, limit=limit, include_ids=message_ids
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
|
||||
saved_message = db.session.query(SavedMessage).filter(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.message_id == message_id,
|
||||
SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
SavedMessage.created_by == user.id
|
||||
).first()
|
||||
saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
.filter(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.message_id == message_id,
|
||||
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
SavedMessage.created_by == user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if saved_message:
|
||||
return
|
||||
|
||||
message = MessageService.get_message(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
message_id=message_id
|
||||
)
|
||||
message = MessageService.get_message(app_model=app_model, user=user, message_id=message_id)
|
||||
|
||||
saved_message = SavedMessage(
|
||||
app_id=app_model.id,
|
||||
message_id=message.id,
|
||||
created_by_role='account' if isinstance(user, Account) else 'end_user',
|
||||
created_by=user.id
|
||||
created_by_role="account" if isinstance(user, Account) else "end_user",
|
||||
created_by=user.id,
|
||||
)
|
||||
|
||||
db.session.add(saved_message)
|
||||
@ -57,12 +59,16 @@ class SavedMessageService:
|
||||
|
||||
@classmethod
|
||||
def delete(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
|
||||
saved_message = db.session.query(SavedMessage).filter(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.message_id == message_id,
|
||||
SavedMessage.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
SavedMessage.created_by == user.id
|
||||
).first()
|
||||
saved_message = (
|
||||
db.session.query(SavedMessage)
|
||||
.filter(
|
||||
SavedMessage.app_id == app_model.id,
|
||||
SavedMessage.message_id == message_id,
|
||||
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
SavedMessage.created_by == user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not saved_message:
|
||||
return
|
||||
|
@ -12,38 +12,32 @@ from models.model import App, Tag, TagBinding
|
||||
class TagService:
|
||||
@staticmethod
|
||||
def get_tags(tag_type: str, current_tenant_id: str, keyword: str = None) -> list:
|
||||
query = db.session.query(
|
||||
Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label('binding_count')
|
||||
).outerjoin(
|
||||
TagBinding, Tag.id == TagBinding.tag_id
|
||||
).filter(
|
||||
Tag.type == tag_type,
|
||||
Tag.tenant_id == current_tenant_id
|
||||
query = (
|
||||
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
|
||||
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
|
||||
)
|
||||
if keyword:
|
||||
query = query.filter(db.and_(Tag.name.ilike(f'%{keyword}%')))
|
||||
query = query.group_by(
|
||||
Tag.id
|
||||
)
|
||||
query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
|
||||
query = query.group_by(Tag.id)
|
||||
results = query.order_by(Tag.created_at.desc()).all()
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
|
||||
tags = db.session.query(Tag).filter(
|
||||
Tag.id.in_(tag_ids),
|
||||
Tag.tenant_id == current_tenant_id,
|
||||
Tag.type == tag_type
|
||||
).all()
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
.filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
|
||||
.all()
|
||||
)
|
||||
if not tags:
|
||||
return []
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
tag_bindings = db.session.query(
|
||||
TagBinding.target_id
|
||||
).filter(
|
||||
TagBinding.tag_id.in_(tag_ids),
|
||||
TagBinding.tenant_id == current_tenant_id
|
||||
).all()
|
||||
tag_bindings = (
|
||||
db.session.query(TagBinding.target_id)
|
||||
.filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
|
||||
.all()
|
||||
)
|
||||
if not tag_bindings:
|
||||
return []
|
||||
results = [tag_binding.target_id for tag_binding in tag_bindings]
|
||||
@ -51,27 +45,28 @@ class TagService:
|
||||
|
||||
@staticmethod
|
||||
def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
|
||||
tags = db.session.query(Tag).join(
|
||||
TagBinding,
|
||||
Tag.id == TagBinding.tag_id
|
||||
).filter(
|
||||
TagBinding.target_id == target_id,
|
||||
TagBinding.tenant_id == current_tenant_id,
|
||||
Tag.tenant_id == current_tenant_id,
|
||||
Tag.type == tag_type
|
||||
).all()
|
||||
tags = (
|
||||
db.session.query(Tag)
|
||||
.join(TagBinding, Tag.id == TagBinding.tag_id)
|
||||
.filter(
|
||||
TagBinding.target_id == target_id,
|
||||
TagBinding.tenant_id == current_tenant_id,
|
||||
Tag.tenant_id == current_tenant_id,
|
||||
Tag.type == tag_type,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return tags if tags else []
|
||||
|
||||
|
||||
@staticmethod
|
||||
def save_tags(args: dict) -> Tag:
|
||||
tag = Tag(
|
||||
id=str(uuid.uuid4()),
|
||||
name=args['name'],
|
||||
type=args['type'],
|
||||
name=args["name"],
|
||||
type=args["type"],
|
||||
created_by=current_user.id,
|
||||
tenant_id=current_user.current_tenant_id
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
)
|
||||
db.session.add(tag)
|
||||
db.session.commit()
|
||||
@ -82,7 +77,7 @@ class TagService:
|
||||
tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
|
||||
if not tag:
|
||||
raise NotFound("Tag not found")
|
||||
tag.name = args['name']
|
||||
tag.name = args["name"]
|
||||
db.session.commit()
|
||||
return tag
|
||||
|
||||
@ -107,20 +102,21 @@ class TagService:
|
||||
@staticmethod
|
||||
def save_tag_binding(args):
|
||||
# check if target exists
|
||||
TagService.check_target_exists(args['type'], args['target_id'])
|
||||
TagService.check_target_exists(args["type"], args["target_id"])
|
||||
# save tag binding
|
||||
for tag_id in args['tag_ids']:
|
||||
tag_binding = db.session.query(TagBinding).filter(
|
||||
TagBinding.tag_id == tag_id,
|
||||
TagBinding.target_id == args['target_id']
|
||||
).first()
|
||||
for tag_id in args["tag_ids"]:
|
||||
tag_binding = (
|
||||
db.session.query(TagBinding)
|
||||
.filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
|
||||
.first()
|
||||
)
|
||||
if tag_binding:
|
||||
continue
|
||||
new_tag_binding = TagBinding(
|
||||
tag_id=tag_id,
|
||||
target_id=args['target_id'],
|
||||
target_id=args["target_id"],
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
created_by=current_user.id
|
||||
created_by=current_user.id,
|
||||
)
|
||||
db.session.add(new_tag_binding)
|
||||
db.session.commit()
|
||||
@ -128,34 +124,34 @@ class TagService:
|
||||
@staticmethod
|
||||
def delete_tag_binding(args):
|
||||
# check if target exists
|
||||
TagService.check_target_exists(args['type'], args['target_id'])
|
||||
TagService.check_target_exists(args["type"], args["target_id"])
|
||||
# delete tag binding
|
||||
tag_bindings = db.session.query(TagBinding).filter(
|
||||
TagBinding.target_id == args['target_id'],
|
||||
TagBinding.tag_id == (args['tag_id'])
|
||||
).first()
|
||||
tag_bindings = (
|
||||
db.session.query(TagBinding)
|
||||
.filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
|
||||
.first()
|
||||
)
|
||||
if tag_bindings:
|
||||
db.session.delete(tag_bindings)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def check_target_exists(type: str, target_id: str):
|
||||
if type == 'knowledge':
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == current_user.current_tenant_id,
|
||||
Dataset.id == target_id
|
||||
).first()
|
||||
if type == "knowledge":
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
.filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
|
||||
.first()
|
||||
)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found")
|
||||
elif type == 'app':
|
||||
app = db.session.query(App).filter(
|
||||
App.tenant_id == current_user.current_tenant_id,
|
||||
App.id == target_id
|
||||
).first()
|
||||
elif type == "app":
|
||||
app = (
|
||||
db.session.query(App)
|
||||
.filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
|
||||
.first()
|
||||
)
|
||||
if not app:
|
||||
raise NotFound("App not found")
|
||||
else:
|
||||
raise NotFound("Invalid binding type")
|
||||
|
||||
|
@ -29,111 +29,107 @@ class ApiToolManageService:
|
||||
@staticmethod
|
||||
def parser_api_schema(schema: str) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse api schema to tool bundle
|
||||
parse api schema to tool bundle
|
||||
"""
|
||||
try:
|
||||
warnings = {}
|
||||
try:
|
||||
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
|
||||
except Exception as e:
|
||||
raise ValueError(f'invalid schema: {str(e)}')
|
||||
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
credentials_schema = [
|
||||
ToolProviderCredentials(
|
||||
name='auth_type',
|
||||
name="auth_type",
|
||||
type=ToolProviderCredentials.CredentialsType.SELECT,
|
||||
required=True,
|
||||
default='none',
|
||||
default="none",
|
||||
options=[
|
||||
ToolCredentialsOption(value='none', label=I18nObject(
|
||||
en_US='None',
|
||||
zh_Hans='无'
|
||||
)),
|
||||
ToolCredentialsOption(value='api_key', label=I18nObject(
|
||||
en_US='Api Key',
|
||||
zh_Hans='Api Key'
|
||||
)),
|
||||
ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
||||
ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
|
||||
],
|
||||
placeholder=I18nObject(
|
||||
en_US='Select auth type',
|
||||
zh_Hans='选择认证方式'
|
||||
)
|
||||
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
|
||||
),
|
||||
ToolProviderCredentials(
|
||||
name='api_key_header',
|
||||
name="api_key_header",
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
required=False,
|
||||
placeholder=I18nObject(
|
||||
en_US='Enter api key header',
|
||||
zh_Hans='输入 api key header,如:X-API-KEY'
|
||||
),
|
||||
default='api_key',
|
||||
help=I18nObject(
|
||||
en_US='HTTP header name for api key',
|
||||
zh_Hans='HTTP 头部字段名,用于传递 api key'
|
||||
)
|
||||
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"),
|
||||
default="api_key",
|
||||
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
|
||||
),
|
||||
ToolProviderCredentials(
|
||||
name='api_key_value',
|
||||
name="api_key_value",
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
required=False,
|
||||
placeholder=I18nObject(
|
||||
en_US='Enter api key',
|
||||
zh_Hans='输入 api key'
|
||||
),
|
||||
default=''
|
||||
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
|
||||
default="",
|
||||
),
|
||||
]
|
||||
|
||||
return jsonable_encoder({
|
||||
'schema_type': schema_type,
|
||||
'parameters_schema': tool_bundles,
|
||||
'credentials_schema': credentials_schema,
|
||||
'warning': warnings
|
||||
})
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"schema_type": schema_type,
|
||||
"parameters_schema": tool_bundles,
|
||||
"credentials_schema": credentials_schema,
|
||||
"warning": warnings,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f'invalid schema: {str(e)}')
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]:
|
||||
"""
|
||||
convert schema to tool bundles
|
||||
convert schema to tool bundles
|
||||
|
||||
:return: the list of tool bundles, description
|
||||
:return: the list of tool bundles, description
|
||||
"""
|
||||
try:
|
||||
tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
|
||||
return tool_bundles
|
||||
except Exception as e:
|
||||
raise ValueError(f'invalid schema: {str(e)}')
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def create_api_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict,
|
||||
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
icon: dict,
|
||||
credentials: dict,
|
||||
schema_type: str,
|
||||
schema: str,
|
||||
privacy_policy: str,
|
||||
custom_disclaimer: str,
|
||||
labels: list[str],
|
||||
):
|
||||
"""
|
||||
create api tool provider
|
||||
create api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f'invalid schema type {schema}')
|
||||
|
||||
raise ValueError(f"invalid schema type {schema}")
|
||||
|
||||
# check if the provider exists
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
).first()
|
||||
provider: ApiToolProvider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is not None:
|
||||
raise ValueError(f'provider {provider_name} already exists')
|
||||
raise ValueError(f"provider {provider_name} already exists")
|
||||
|
||||
# parse openapi to tool bundle
|
||||
extra_info = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
|
||||
if len(tool_bundles) > 100:
|
||||
raise ValueError('the number of apis should be less than 100')
|
||||
raise ValueError("the number of apis should be less than 100")
|
||||
|
||||
# create db provider
|
||||
db_provider = ApiToolProvider(
|
||||
@ -142,19 +138,19 @@ class ApiToolManageService:
|
||||
name=provider_name,
|
||||
icon=json.dumps(icon),
|
||||
schema=schema,
|
||||
description=extra_info.get('description', ''),
|
||||
description=extra_info.get("description", ""),
|
||||
schema_type_str=schema_type,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str={},
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
)
|
||||
|
||||
if 'auth_type' not in credentials:
|
||||
raise ValueError('auth_type is required')
|
||||
if "auth_type" not in credentials:
|
||||
raise ValueError("auth_type is required")
|
||||
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
|
||||
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
|
||||
@ -172,14 +168,12 @@ class ApiToolManageService:
|
||||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_api_tool_provider_remote_schema(
|
||||
user_id: str, tenant_id: str, url: str
|
||||
):
|
||||
def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str):
|
||||
"""
|
||||
get api tool provider remote schema
|
||||
get api tool provider remote schema
|
||||
"""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
|
||||
@ -189,84 +183,98 @@ class ApiToolManageService:
|
||||
try:
|
||||
response = get(url, headers=headers, timeout=10)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f'Got status code {response.status_code}')
|
||||
raise ValueError(f"Got status code {response.status_code}")
|
||||
schema = response.text
|
||||
|
||||
# try to parse schema, avoid SSRF attack
|
||||
ApiToolManageService.parser_api_schema(schema)
|
||||
except Exception as e:
|
||||
logger.error(f"parse api schema error: {str(e)}")
|
||||
raise ValueError('invalid schema, please check the url you provided')
|
||||
|
||||
return {
|
||||
'schema': schema
|
||||
}
|
||||
raise ValueError("invalid schema, please check the url you provided")
|
||||
|
||||
return {"schema": schema}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tool_provider_tools(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
) -> list[UserTool]:
|
||||
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
|
||||
"""
|
||||
list api tool provider tools
|
||||
list api tool provider tools
|
||||
"""
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider,
|
||||
).first()
|
||||
provider: ApiToolProvider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider}')
|
||||
|
||||
raise ValueError(f"you have not added provider {provider}")
|
||||
|
||||
controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
|
||||
labels = ToolLabelManager.get_tool_labels(controller)
|
||||
|
||||
|
||||
return [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool_bundle,
|
||||
labels=labels,
|
||||
) for tool_bundle in provider.tools
|
||||
)
|
||||
for tool_bundle in provider.tools
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def update_api_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict,
|
||||
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
original_provider: str,
|
||||
icon: dict,
|
||||
credentials: dict,
|
||||
schema_type: str,
|
||||
schema: str,
|
||||
privacy_policy: str,
|
||||
custom_disclaimer: str,
|
||||
labels: list[str],
|
||||
):
|
||||
"""
|
||||
update api tool provider
|
||||
update api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f'invalid schema type {schema}')
|
||||
|
||||
raise ValueError(f"invalid schema type {schema}")
|
||||
|
||||
# check if the provider exists
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == original_provider,
|
||||
).first()
|
||||
provider: ApiToolProvider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == original_provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'api provider {provider_name} does not exists')
|
||||
raise ValueError(f"api provider {provider_name} does not exists")
|
||||
|
||||
# parse openapi to tool bundle
|
||||
extra_info = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
|
||||
# update db provider
|
||||
provider.name = provider_name
|
||||
provider.icon = json.dumps(icon)
|
||||
provider.schema = schema
|
||||
provider.description = extra_info.get('description', '')
|
||||
provider.description = extra_info.get("description", "")
|
||||
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
|
||||
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
|
||||
provider.privacy_policy = privacy_policy
|
||||
provider.custom_disclaimer = custom_disclaimer
|
||||
|
||||
if 'auth_type' not in credentials:
|
||||
raise ValueError('auth_type is required')
|
||||
if "auth_type" not in credentials:
|
||||
raise ValueError("auth_type is required")
|
||||
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
|
||||
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(provider, auth_type)
|
||||
@ -295,84 +303,91 @@ class ApiToolManageService:
|
||||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def delete_api_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str
|
||||
):
|
||||
def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str):
|
||||
"""
|
||||
delete tool provider
|
||||
delete tool provider
|
||||
"""
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
).first()
|
||||
provider: ApiToolProvider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider_name}')
|
||||
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.commit()
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_api_tool_provider(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
):
|
||||
def get_api_tool_provider(user_id: str, tenant_id: str, provider: str):
|
||||
"""
|
||||
get api tool provider
|
||||
get api tool provider
|
||||
"""
|
||||
return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def test_api_tool_preview(
|
||||
tenant_id: str,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
tool_name: str,
|
||||
credentials: dict,
|
||||
parameters: dict,
|
||||
schema_type: str,
|
||||
schema: str
|
||||
tool_name: str,
|
||||
credentials: dict,
|
||||
parameters: dict,
|
||||
schema_type: str,
|
||||
schema: str,
|
||||
):
|
||||
"""
|
||||
test api tool before adding api tool provider
|
||||
test api tool before adding api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f'invalid schema type {schema_type}')
|
||||
|
||||
raise ValueError(f"invalid schema type {schema_type}")
|
||||
|
||||
try:
|
||||
tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)
|
||||
except Exception as e:
|
||||
raise ValueError('invalid schema')
|
||||
|
||||
raise ValueError("invalid schema")
|
||||
|
||||
# get tool bundle
|
||||
tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)
|
||||
if tool_bundle is None:
|
||||
raise ValueError(f'invalid tool name {tool_name}')
|
||||
|
||||
db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
).first()
|
||||
raise ValueError(f"invalid tool name {tool_name}")
|
||||
|
||||
db_provider: ApiToolProvider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not db_provider:
|
||||
# create a fake db provider
|
||||
db_provider = ApiToolProvider(
|
||||
tenant_id='', user_id='', name='', icon='',
|
||||
tenant_id="",
|
||||
user_id="",
|
||||
name="",
|
||||
icon="",
|
||||
schema=schema,
|
||||
description='',
|
||||
description="",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI.value,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
|
||||
if 'auth_type' not in credentials:
|
||||
raise ValueError('auth_type is required')
|
||||
if "auth_type" not in credentials:
|
||||
raise ValueError("auth_type is required")
|
||||
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
|
||||
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
|
||||
@ -381,10 +396,7 @@ class ApiToolManageService:
|
||||
|
||||
# decrypt credentials
|
||||
if db_provider.id:
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
provider_controller=provider_controller
|
||||
)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
@ -396,27 +408,27 @@ class ApiToolManageService:
|
||||
provider_controller.validate_credentials_format(credentials)
|
||||
# get tool
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
tool = tool.fork_tool_runtime(runtime={
|
||||
'credentials': credentials,
|
||||
'tenant_id': tenant_id,
|
||||
})
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
)
|
||||
result = tool.validate_credentials(credentials, parameters)
|
||||
except Exception as e:
|
||||
return { 'error': str(e) }
|
||||
|
||||
return { 'result': result or 'empty response' }
|
||||
|
||||
return {"error": str(e)}
|
||||
|
||||
return {"result": result or "empty response"}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tools(
|
||||
user_id: str, tenant_id: str
|
||||
) -> list[UserToolProvider]:
|
||||
def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
"""
|
||||
list api tools
|
||||
list api tools
|
||||
"""
|
||||
# get all api providers
|
||||
db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id
|
||||
).all() or []
|
||||
db_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or []
|
||||
)
|
||||
|
||||
result: list[UserToolProvider] = []
|
||||
|
||||
@ -425,26 +437,21 @@ class ApiToolManageService:
|
||||
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
|
||||
labels = ToolLabelManager.get_tool_labels(provider_controller)
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller,
|
||||
db_provider=provider,
|
||||
decrypt_credentials=True
|
||||
provider_controller, db_provider=provider, decrypt_credentials=True
|
||||
)
|
||||
user_provider.labels = labels
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(user_provider)
|
||||
|
||||
tools = provider_controller.get_tools(
|
||||
user_id=user_id, tenant_id=tenant_id
|
||||
)
|
||||
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
|
||||
|
||||
for tool in tools:
|
||||
user_provider.tools.append(ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
credentials=user_provider.original_credentials,
|
||||
labels=labels
|
||||
))
|
||||
user_provider.tools.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
|
||||
)
|
||||
)
|
||||
|
||||
result.append(user_provider)
|
||||
|
||||
|
@ -20,21 +20,25 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class BuiltinToolManageService:
|
||||
@staticmethod
|
||||
def list_builtin_tool_provider_tools(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
) -> list[UserTool]:
|
||||
def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
|
||||
"""
|
||||
list builtin tool provider tools
|
||||
list builtin tool provider tools
|
||||
"""
|
||||
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
|
||||
tools = provider_controller.get_tools()
|
||||
|
||||
tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_provider_configurations = ToolConfigurationManager(
|
||||
tenant_id=tenant_id, provider_controller=provider_controller
|
||||
)
|
||||
# check if user has added the provider
|
||||
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
).first()
|
||||
builtin_provider: BuiltinToolProvider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
credentials = {}
|
||||
if builtin_provider is not None:
|
||||
@ -44,47 +48,47 @@ class BuiltinToolManageService:
|
||||
|
||||
result = []
|
||||
for tool in tools:
|
||||
result.append(ToolTransformService.tool_to_user_tool(
|
||||
tool=tool,
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller)
|
||||
))
|
||||
result.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool=tool,
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(
|
||||
provider_name
|
||||
):
|
||||
def list_builtin_provider_credentials_schema(provider_name):
|
||||
"""
|
||||
list builtin provider credentials schema
|
||||
list builtin provider credentials schema
|
||||
|
||||
:return: the list of tool providers
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name)
|
||||
return jsonable_encoder([
|
||||
v for _, v in (provider.credentials_schema or {}).items()
|
||||
])
|
||||
return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()])
|
||||
|
||||
@staticmethod
|
||||
def update_builtin_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str, credentials: dict
|
||||
):
|
||||
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
|
||||
"""
|
||||
update builtin tool provider
|
||||
update builtin tool provider
|
||||
"""
|
||||
# get if the provider exists
|
||||
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
).first()
|
||||
provider: BuiltinToolProvider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
try:
|
||||
# get provider
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f'provider {provider_name} does not need credentials')
|
||||
raise ValueError(f"provider {provider_name} does not need credentials")
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
# get original credentials if exists
|
||||
if provider is not None:
|
||||
@ -121,19 +125,21 @@ class BuiltinToolManageService:
|
||||
# delete cache
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return {'result': 'success'}
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_credentials(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
):
|
||||
def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str):
|
||||
"""
|
||||
get builtin tool provider credentials
|
||||
get builtin tool provider credentials
|
||||
"""
|
||||
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
).first()
|
||||
provider: BuiltinToolProvider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
return {}
|
||||
@ -145,19 +151,21 @@ class BuiltinToolManageService:
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def delete_builtin_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str
|
||||
):
|
||||
def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
|
||||
"""
|
||||
delete tool provider
|
||||
delete tool provider
|
||||
"""
|
||||
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
).first()
|
||||
provider: BuiltinToolProvider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider_name}')
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.commit()
|
||||
@ -167,38 +175,36 @@ class BuiltinToolManageService:
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
return {'result': 'success'}
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_icon(
|
||||
provider: str
|
||||
):
|
||||
def get_builtin_tool_provider_icon(provider: str):
|
||||
"""
|
||||
get tool provider icon and it's mimetype
|
||||
get tool provider icon and it's mimetype
|
||||
"""
|
||||
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
|
||||
with open(icon_path, 'rb') as f:
|
||||
with open(icon_path, "rb") as f:
|
||||
icon_bytes = f.read()
|
||||
|
||||
return icon_bytes, mime_type
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tools(
|
||||
user_id: str, tenant_id: str
|
||||
) -> list[UserToolProvider]:
|
||||
def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
"""
|
||||
list builtin tools
|
||||
list builtin tools
|
||||
"""
|
||||
# get all builtin providers
|
||||
provider_controllers = ToolManager.list_builtin_providers()
|
||||
|
||||
# get all user added providers
|
||||
db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id
|
||||
).all() or []
|
||||
db_providers: list[BuiltinToolProvider] = (
|
||||
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
|
||||
)
|
||||
|
||||
# find provider
|
||||
find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
|
||||
find_provider = lambda provider: next(
|
||||
filter(lambda db_provider: db_provider.provider == provider, db_providers), None
|
||||
)
|
||||
|
||||
result: list[UserToolProvider] = []
|
||||
|
||||
@ -209,7 +215,7 @@ class BuiltinToolManageService:
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.identity.name
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
|
||||
@ -217,7 +223,7 @@ class BuiltinToolManageService:
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=find_provider(provider_controller.identity.name),
|
||||
decrypt_credentials=True
|
||||
decrypt_credentials=True,
|
||||
)
|
||||
|
||||
# add icon
|
||||
@ -225,12 +231,14 @@ class BuiltinToolManageService:
|
||||
|
||||
tools = provider_controller.get_tools()
|
||||
for tool in tools:
|
||||
user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
credentials=user_builtin_provider.original_credentials,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller)
|
||||
))
|
||||
user_builtin_provider.tools.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
credentials=user_builtin_provider.original_credentials,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||
)
|
||||
)
|
||||
|
||||
result.append(user_builtin_provider)
|
||||
except Exception as e:
|
||||
|
@ -5,4 +5,4 @@ from core.tools.entities.values import default_tool_labels
|
||||
class ToolLabelsService:
|
||||
@classmethod
|
||||
def list_tool_labels(cls) -> list[ToolLabel]:
|
||||
return default_tool_labels
|
||||
return default_tool_labels
|
||||
|
@ -11,13 +11,11 @@ class ToolCommonService:
|
||||
@staticmethod
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None):
|
||||
"""
|
||||
list tool providers
|
||||
list tool providers
|
||||
|
||||
:return: the list of tool providers
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
providers = ToolManager.user_list_providers(
|
||||
user_id, tenant_id, typ
|
||||
)
|
||||
providers = ToolManager.user_list_providers(user_id, tenant_id, typ)
|
||||
|
||||
# add icon
|
||||
for provider in providers:
|
||||
@ -26,4 +24,3 @@ class ToolCommonService:
|
||||
result = [provider.to_dict() for provider in providers]
|
||||
|
||||
return result
|
||||
|
@ -22,46 +22,39 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolTransformService:
|
||||
@staticmethod
|
||||
def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
|
||||
"""
|
||||
get tool provider icon url
|
||||
get tool provider icon url
|
||||
"""
|
||||
url_prefix = (dify_config.CONSOLE_API_URL
|
||||
+ "/console/api/workspaces/current/tool-provider/")
|
||||
|
||||
url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/"
|
||||
|
||||
if provider_type == ToolProviderType.BUILT_IN.value:
|
||||
return url_prefix + 'builtin/' + provider_name + '/icon'
|
||||
return url_prefix + "builtin/" + provider_name + "/icon"
|
||||
elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]:
|
||||
try:
|
||||
return json.loads(icon)
|
||||
except:
|
||||
return {
|
||||
"background": "#252525",
|
||||
"content": "\ud83d\ude01"
|
||||
}
|
||||
|
||||
return ''
|
||||
|
||||
return {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(provider: Union[dict, UserToolProvider]):
|
||||
"""
|
||||
repack provider
|
||||
repack provider
|
||||
|
||||
:param provider: the provider dict
|
||||
:param provider: the provider dict
|
||||
"""
|
||||
if isinstance(provider, dict) and 'icon' in provider:
|
||||
provider['icon'] = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider['type'],
|
||||
provider_name=provider['name'],
|
||||
icon=provider['icon']
|
||||
if isinstance(provider, dict) and "icon" in provider:
|
||||
provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
|
||||
)
|
||||
elif isinstance(provider, UserToolProvider):
|
||||
provider.icon = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider.type.value,
|
||||
provider_name=provider.name,
|
||||
icon=provider.icon
|
||||
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -92,14 +85,13 @@ class ToolTransformService:
|
||||
masked_credentials={},
|
||||
is_team_authorization=False,
|
||||
tools=[],
|
||||
labels=provider_controller.tool_labels
|
||||
labels=provider_controller.tool_labels,
|
||||
)
|
||||
|
||||
# get credentials schema
|
||||
schema = provider_controller.get_credentials_schema()
|
||||
for name, value in schema.items():
|
||||
result.masked_credentials[name] = \
|
||||
ToolProviderCredentials.CredentialsType.default(value.type)
|
||||
result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type)
|
||||
|
||||
# check if the provider need credentials
|
||||
if not provider_controller.need_credentials:
|
||||
@ -113,8 +105,7 @@ class ToolTransformService:
|
||||
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
provider_controller=provider_controller
|
||||
tenant_id=db_provider.tenant_id, provider_controller=provider_controller
|
||||
)
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
|
||||
@ -124,7 +115,7 @@ class ToolTransformService:
|
||||
result.original_credentials = decrypted_credentials
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@staticmethod
|
||||
def api_provider_to_controller(
|
||||
db_provider: ApiToolProvider,
|
||||
@ -135,25 +126,23 @@ class ToolTransformService:
|
||||
# package tool provider controller
|
||||
controller = ApiToolProviderController.from_db(
|
||||
db_provider=db_provider,
|
||||
auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else
|
||||
ApiProviderAuthType.NONE
|
||||
auth_type=ApiProviderAuthType.API_KEY
|
||||
if db_provider.credentials["auth_type"] == "api_key"
|
||||
else ApiProviderAuthType.NONE,
|
||||
)
|
||||
|
||||
return controller
|
||||
|
||||
|
||||
@staticmethod
|
||||
def workflow_provider_to_controller(
|
||||
db_provider: WorkflowToolProvider
|
||||
) -> WorkflowToolProviderController:
|
||||
def workflow_provider_to_controller(db_provider: WorkflowToolProvider) -> WorkflowToolProviderController:
|
||||
"""
|
||||
convert provider controller to provider
|
||||
"""
|
||||
return WorkflowToolProviderController.from_db(db_provider)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def workflow_provider_to_user_provider(
|
||||
provider_controller: WorkflowToolProviderController,
|
||||
labels: list[str] = None
|
||||
provider_controller: WorkflowToolProviderController, labels: list[str] = None
|
||||
):
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
@ -175,7 +164,7 @@ class ToolTransformService:
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
tools=[],
|
||||
labels=labels or []
|
||||
labels=labels or [],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@ -183,16 +172,16 @@ class ToolTransformService:
|
||||
provider_controller: ApiToolProviderController,
|
||||
db_provider: ApiToolProvider,
|
||||
decrypt_credentials: bool = True,
|
||||
labels: list[str] = None
|
||||
labels: list[str] = None,
|
||||
) -> UserToolProvider:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
username = 'Anonymous'
|
||||
username = "Anonymous"
|
||||
try:
|
||||
username = db_provider.user.name
|
||||
except Exception as e:
|
||||
logger.error(f'failed to get user name for api provider {db_provider.id}: {str(e)}')
|
||||
logger.error(f"failed to get user name for api provider {db_provider.id}: {str(e)}")
|
||||
# add provider into providers
|
||||
credentials = db_provider.credentials
|
||||
result = UserToolProvider(
|
||||
@ -212,14 +201,13 @@ class ToolTransformService:
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
tools=[],
|
||||
labels=labels or []
|
||||
labels=labels or [],
|
||||
)
|
||||
|
||||
if decrypt_credentials:
|
||||
# init tool configuration
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
provider_controller=provider_controller
|
||||
tenant_id=db_provider.tenant_id, provider_controller=provider_controller
|
||||
)
|
||||
|
||||
# decrypt the credentials and mask the credentials
|
||||
@ -229,23 +217,25 @@ class ToolTransformService:
|
||||
result.masked_credentials = masked_credentials
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@staticmethod
|
||||
def tool_to_user_tool(
|
||||
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
||||
credentials: dict = None,
|
||||
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
||||
credentials: dict = None,
|
||||
tenant_id: str = None,
|
||||
labels: list[str] = None
|
||||
labels: list[str] = None,
|
||||
) -> UserTool:
|
||||
"""
|
||||
convert tool to user tool
|
||||
"""
|
||||
if isinstance(tool, Tool):
|
||||
# fork tool runtime
|
||||
tool = tool.fork_tool_runtime(runtime={
|
||||
'credentials': credentials,
|
||||
'tenant_id': tenant_id,
|
||||
})
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
)
|
||||
|
||||
# get tool parameters
|
||||
parameters = tool.parameters or []
|
||||
@ -270,20 +260,14 @@ class ToolTransformService:
|
||||
label=tool.identity.label,
|
||||
description=tool.description.human,
|
||||
parameters=current_parameters,
|
||||
labels=labels
|
||||
labels=labels,
|
||||
)
|
||||
if isinstance(tool, ApiToolBundle):
|
||||
return UserTool(
|
||||
author=tool.author,
|
||||
name=tool.operation_id,
|
||||
label=I18nObject(
|
||||
en_US=tool.operation_id,
|
||||
zh_Hans=tool.operation_id
|
||||
),
|
||||
description=I18nObject(
|
||||
en_US=tool.summary or '',
|
||||
zh_Hans=tool.summary or ''
|
||||
),
|
||||
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
|
||||
description=I18nObject(en_US=tool.summary or "", zh_Hans=tool.summary or ""),
|
||||
parameters=tool.parameters,
|
||||
labels=labels
|
||||
)
|
||||
labels=labels,
|
||||
)
|
||||
|
@ -19,10 +19,21 @@ class WorkflowToolManageService:
|
||||
"""
|
||||
Service class for managing workflow tools.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str, name: str,
|
||||
label: str, icon: dict, description: str,
|
||||
parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict:
|
||||
def create_workflow_tool(
|
||||
cls,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
workflow_app_id: str,
|
||||
name: str,
|
||||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: list[dict],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Create a workflow tool.
|
||||
:param user_id: the user id
|
||||
@ -38,27 +49,28 @@ class WorkflowToolManageService:
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
# name or app_id
|
||||
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id)
|
||||
).first()
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
# name or app_id
|
||||
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f'Tool with name {name} or app_id {workflow_app_id} already exists')
|
||||
|
||||
app: App = db.session.query(App).filter(
|
||||
App.id == workflow_app_id,
|
||||
App.tenant_id == tenant_id
|
||||
).first()
|
||||
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
|
||||
|
||||
app: App = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f'App {workflow_app_id} not found')
|
||||
|
||||
raise ValueError(f"App {workflow_app_id} not found")
|
||||
|
||||
workflow: Workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError(f'Workflow not found for app {workflow_app_id}')
|
||||
|
||||
raise ValueError(f"Workflow not found for app {workflow_app_id}")
|
||||
|
||||
workflow_tool_provider = WorkflowToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
@ -76,19 +88,26 @@ class WorkflowToolManageService:
|
||||
WorkflowToolProviderController.from_db(workflow_tool_provider)
|
||||
except Exception as e:
|
||||
raise ValueError(str(e))
|
||||
|
||||
|
||||
db.session.add(workflow_tool_provider)
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
'result': 'success'
|
||||
}
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str,
|
||||
name: str, label: str, icon: dict, description: str,
|
||||
parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict:
|
||||
def update_workflow_tool(
|
||||
cls,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
workflow_tool_id: str,
|
||||
name: str,
|
||||
label: str,
|
||||
icon: dict,
|
||||
description: str,
|
||||
parameters: list[dict],
|
||||
privacy_policy: str = "",
|
||||
labels: list[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Update a workflow tool.
|
||||
:param user_id: the user id
|
||||
@ -106,35 +125,39 @@ class WorkflowToolManageService:
|
||||
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
|
||||
|
||||
# check if the name is unique
|
||||
existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.name == name,
|
||||
WorkflowToolProvider.id != workflow_tool_id
|
||||
).first()
|
||||
existing_workflow_tool_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.name == name,
|
||||
WorkflowToolProvider.id != workflow_tool_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_workflow_tool_provider is not None:
|
||||
raise ValueError(f'Tool with name {name} already exists')
|
||||
|
||||
workflow_tool_provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == workflow_tool_id
|
||||
).first()
|
||||
raise ValueError(f"Tool with name {name} already exists")
|
||||
|
||||
workflow_tool_provider: WorkflowToolProvider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if workflow_tool_provider is None:
|
||||
raise ValueError(f'Tool {workflow_tool_id} not found')
|
||||
|
||||
app: App = db.session.query(App).filter(
|
||||
App.id == workflow_tool_provider.app_id,
|
||||
App.tenant_id == tenant_id
|
||||
).first()
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
app: App = (
|
||||
db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if app is None:
|
||||
raise ValueError(f'App {workflow_tool_provider.app_id} not found')
|
||||
|
||||
raise ValueError(f"App {workflow_tool_provider.app_id} not found")
|
||||
|
||||
workflow: Workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise ValueError(f'Workflow not found for app {workflow_tool_provider.app_id}')
|
||||
|
||||
raise ValueError(f"Workflow not found for app {workflow_tool_provider.app_id}")
|
||||
|
||||
workflow_tool_provider.name = name
|
||||
workflow_tool_provider.label = label
|
||||
workflow_tool_provider.icon = json.dumps(icon)
|
||||
@ -154,13 +177,10 @@ class WorkflowToolManageService:
|
||||
|
||||
if labels is not None:
|
||||
ToolLabelManager.update_tool_labels(
|
||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider),
|
||||
labels
|
||||
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider), labels
|
||||
)
|
||||
|
||||
return {
|
||||
'result': 'success'
|
||||
}
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
@ -170,9 +190,7 @@ class WorkflowToolManageService:
|
||||
:param tenant_id: the tenant id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tools = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id
|
||||
).all()
|
||||
db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
tools = []
|
||||
for provider in db_tools:
|
||||
@ -188,14 +206,12 @@ class WorkflowToolManageService:
|
||||
|
||||
for tool in tools:
|
||||
user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
|
||||
provider_controller=tool,
|
||||
labels=labels.get(tool.provider_id, [])
|
||||
provider_controller=tool, labels=labels.get(tool.provider_id, [])
|
||||
)
|
||||
ToolTransformService.repack_provider(user_tool_provider)
|
||||
user_tool_provider.tools = [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0],
|
||||
labels=labels.get(tool.provider_id, [])
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=labels.get(tool.provider_id, [])
|
||||
)
|
||||
]
|
||||
result.append(user_tool_provider)
|
||||
@ -211,15 +227,12 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
"""
|
||||
db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == workflow_tool_id
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
|
||||
).delete()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return {
|
||||
'result': 'success'
|
||||
}
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
|
||||
@ -230,40 +243,37 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == workflow_tool_id
|
||||
).first()
|
||||
db_tool: WorkflowToolProvider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_tool is None:
|
||||
raise ValueError(f'Tool {workflow_tool_id} not found')
|
||||
|
||||
workflow_app: App = db.session.query(App).filter(
|
||||
App.id == db_tool.app_id,
|
||||
App.tenant_id == tenant_id
|
||||
).first()
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f'App {db_tool.app_id} not found')
|
||||
raise ValueError(f"App {db_tool.app_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
|
||||
return {
|
||||
'name': db_tool.name,
|
||||
'label': db_tool.label,
|
||||
'workflow_tool_id': db_tool.id,
|
||||
'workflow_app_id': db_tool.app_id,
|
||||
'icon': json.loads(db_tool.icon),
|
||||
'description': db_tool.description,
|
||||
'parameters': jsonable_encoder(db_tool.parameter_configurations),
|
||||
'tool': ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool)
|
||||
"name": db_tool.name,
|
||||
"label": db_tool.label,
|
||||
"workflow_tool_id": db_tool.id,
|
||||
"workflow_app_id": db_tool.app_id,
|
||||
"icon": json.loads(db_tool.icon),
|
||||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"tool": ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
),
|
||||
'synced': workflow_app.workflow.version == db_tool.version,
|
||||
'privacy_policy': db_tool.privacy_policy,
|
||||
"synced": workflow_app.workflow.version == db_tool.version,
|
||||
"privacy_policy": db_tool.privacy_policy,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
|
||||
"""
|
||||
@ -273,40 +283,37 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the tool
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.app_id == workflow_app_id
|
||||
).first()
|
||||
db_tool: WorkflowToolProvider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_tool is None:
|
||||
raise ValueError(f'Tool {workflow_app_id} not found')
|
||||
|
||||
workflow_app: App = db.session.query(App).filter(
|
||||
App.id == db_tool.app_id,
|
||||
App.tenant_id == tenant_id
|
||||
).first()
|
||||
raise ValueError(f"Tool {workflow_app_id} not found")
|
||||
|
||||
workflow_app: App = db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == tenant_id).first()
|
||||
|
||||
if workflow_app is None:
|
||||
raise ValueError(f'App {db_tool.app_id} not found')
|
||||
raise ValueError(f"App {db_tool.app_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
|
||||
return {
|
||||
'name': db_tool.name,
|
||||
'label': db_tool.label,
|
||||
'workflow_tool_id': db_tool.id,
|
||||
'workflow_app_id': db_tool.app_id,
|
||||
'icon': json.loads(db_tool.icon),
|
||||
'description': db_tool.description,
|
||||
'parameters': jsonable_encoder(db_tool.parameter_configurations),
|
||||
'tool': ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool)
|
||||
"name": db_tool.name,
|
||||
"label": db_tool.label,
|
||||
"workflow_tool_id": db_tool.id,
|
||||
"workflow_app_id": db_tool.app_id,
|
||||
"icon": json.loads(db_tool.icon),
|
||||
"description": db_tool.description,
|
||||
"parameters": jsonable_encoder(db_tool.parameter_configurations),
|
||||
"tool": ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
),
|
||||
'synced': workflow_app.workflow.version == db_tool.version,
|
||||
'privacy_policy': db_tool.privacy_policy
|
||||
"synced": workflow_app.workflow.version == db_tool.version,
|
||||
"privacy_policy": db_tool.privacy_policy,
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]:
|
||||
"""
|
||||
@ -316,19 +323,19 @@ class WorkflowToolManageService:
|
||||
:param workflow_app_id: the workflow app id
|
||||
:return: the list of tools
|
||||
"""
|
||||
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
|
||||
WorkflowToolProvider.tenant_id == tenant_id,
|
||||
WorkflowToolProvider.id == workflow_tool_id
|
||||
).first()
|
||||
db_tool: WorkflowToolProvider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if db_tool is None:
|
||||
raise ValueError(f'Tool {workflow_tool_id} not found')
|
||||
raise ValueError(f"Tool {workflow_tool_id} not found")
|
||||
|
||||
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
|
||||
|
||||
return [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool.get_tools(user_id, tenant_id)[0],
|
||||
labels=ToolLabelManager.get_tool_labels(tool)
|
||||
tool.get_tools(user_id, tenant_id)[0], labels=ToolLabelManager.get_tool_labels(tool)
|
||||
)
|
||||
]
|
||||
]
|
||||
|
@ -7,10 +7,10 @@ from models.dataset import Dataset, DocumentSegment
|
||||
|
||||
|
||||
class VectorService:
|
||||
|
||||
@classmethod
|
||||
def create_segments_vector(cls, keywords_list: Optional[list[list[str]]],
|
||||
segments: list[DocumentSegment], dataset: Dataset):
|
||||
def create_segments_vector(
|
||||
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset
|
||||
):
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
@ -20,14 +20,12 @@ class VectorService:
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
documents.append(document)
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# save vector index
|
||||
vector = Vector(
|
||||
dataset=dataset
|
||||
)
|
||||
vector = Vector(dataset=dataset)
|
||||
vector.add_texts(documents, duplicate_check=True)
|
||||
|
||||
# save keyword index
|
||||
@ -50,13 +48,11 @@ class VectorService:
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
# update vector index
|
||||
vector = Vector(
|
||||
dataset=dataset
|
||||
)
|
||||
vector = Vector(dataset=dataset)
|
||||
vector.delete_by_ids([segment.index_node_id])
|
||||
vector.add_texts([document], duplicate_check=True)
|
||||
|
||||
|
@ -11,18 +11,29 @@ from services.conversation_service import ConversationService
|
||||
|
||||
class WebConversationService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str], limit: int, invoke_from: InvokeFrom,
|
||||
pinned: Optional[bool] = None,
|
||||
sort_by='-updated_at') -> InfiniteScrollPagination:
|
||||
def pagination_by_last_id(
|
||||
cls,
|
||||
app_model: App,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str],
|
||||
limit: int,
|
||||
invoke_from: InvokeFrom,
|
||||
pinned: Optional[bool] = None,
|
||||
sort_by="-updated_at",
|
||||
) -> InfiniteScrollPagination:
|
||||
include_ids = None
|
||||
exclude_ids = None
|
||||
if pinned is not None:
|
||||
pinned_conversations = db.session.query(PinnedConversation).filter(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
PinnedConversation.created_by == user.id
|
||||
).order_by(PinnedConversation.created_at.desc()).all()
|
||||
pinned_conversations = (
|
||||
db.session.query(PinnedConversation)
|
||||
.filter(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
PinnedConversation.created_by == user.id,
|
||||
)
|
||||
.order_by(PinnedConversation.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations]
|
||||
if pinned:
|
||||
include_ids = pinned_conversation_ids
|
||||
@ -37,32 +48,34 @@ class WebConversationService:
|
||||
invoke_from=invoke_from,
|
||||
include_ids=include_ids,
|
||||
exclude_ids=exclude_ids,
|
||||
sort_by=sort_by
|
||||
sort_by=sort_by,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def pin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
|
||||
pinned_conversation = db.session.query(PinnedConversation).filter(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.conversation_id == conversation_id,
|
||||
PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
PinnedConversation.created_by == user.id
|
||||
).first()
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
.filter(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.conversation_id == conversation_id,
|
||||
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
PinnedConversation.created_by == user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if pinned_conversation:
|
||||
return
|
||||
|
||||
conversation = ConversationService.get_conversation(
|
||||
app_model=app_model,
|
||||
conversation_id=conversation_id,
|
||||
user=user
|
||||
app_model=app_model, conversation_id=conversation_id, user=user
|
||||
)
|
||||
|
||||
pinned_conversation = PinnedConversation(
|
||||
app_id=app_model.id,
|
||||
conversation_id=conversation.id,
|
||||
created_by_role='account' if isinstance(user, Account) else 'end_user',
|
||||
created_by=user.id
|
||||
created_by_role="account" if isinstance(user, Account) else "end_user",
|
||||
created_by=user.id,
|
||||
)
|
||||
|
||||
db.session.add(pinned_conversation)
|
||||
@ -70,12 +83,16 @@ class WebConversationService:
|
||||
|
||||
@classmethod
|
||||
def unpin(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
|
||||
pinned_conversation = db.session.query(PinnedConversation).filter(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.conversation_id == conversation_id,
|
||||
PinnedConversation.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
PinnedConversation.created_by == user.id
|
||||
).first()
|
||||
pinned_conversation = (
|
||||
db.session.query(PinnedConversation)
|
||||
.filter(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.conversation_id == conversation_id,
|
||||
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
PinnedConversation.created_by == user.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not pinned_conversation:
|
||||
return
|
||||
|
@ -11,161 +11,126 @@ from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
|
||||
class WebsiteService:
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
if 'url' not in args or not args['url']:
|
||||
raise ValueError('url is required')
|
||||
if 'options' not in args or not args['options']:
|
||||
raise ValueError('options is required')
|
||||
if 'limit' not in args['options'] or not args['options']['limit']:
|
||||
raise ValueError('limit is required')
|
||||
if "url" not in args or not args["url"]:
|
||||
raise ValueError("url is required")
|
||||
if "options" not in args or not args["options"]:
|
||||
raise ValueError("options is required")
|
||||
if "limit" not in args["options"] or not args["options"]["limit"]:
|
||||
raise ValueError("limit is required")
|
||||
|
||||
@classmethod
|
||||
def crawl_url(cls, args: dict) -> dict:
|
||||
provider = args.get('provider')
|
||||
url = args.get('url')
|
||||
options = args.get('options')
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
provider = args.get("provider")
|
||||
url = args.get("url")
|
||||
options = args.get("options")
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
|
||||
if provider == "firecrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
crawl_sub_pages = options.get('crawl_sub_pages', False)
|
||||
only_main_content = options.get('only_main_content', False)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
crawl_sub_pages = options.get("crawl_sub_pages", False)
|
||||
only_main_content = options.get("only_main_content", False)
|
||||
if not crawl_sub_pages:
|
||||
params = {
|
||||
'crawlerOptions': {
|
||||
"crawlerOptions": {
|
||||
"includes": [],
|
||||
"excludes": [],
|
||||
"generateImgAltText": True,
|
||||
"limit": 1,
|
||||
'returnOnlyUrls': False,
|
||||
'pageOptions': {
|
||||
'onlyMainContent': only_main_content,
|
||||
"includeHtml": False
|
||||
}
|
||||
"returnOnlyUrls": False,
|
||||
"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False},
|
||||
}
|
||||
}
|
||||
else:
|
||||
includes = options.get('includes').split(',') if options.get('includes') else []
|
||||
excludes = options.get('excludes').split(',') if options.get('excludes') else []
|
||||
includes = options.get("includes").split(",") if options.get("includes") else []
|
||||
excludes = options.get("excludes").split(",") if options.get("excludes") else []
|
||||
params = {
|
||||
'crawlerOptions': {
|
||||
"crawlerOptions": {
|
||||
"includes": includes if includes else [],
|
||||
"excludes": excludes if excludes else [],
|
||||
"generateImgAltText": True,
|
||||
"limit": options.get('limit', 1),
|
||||
'returnOnlyUrls': False,
|
||||
'pageOptions': {
|
||||
'onlyMainContent': only_main_content,
|
||||
"includeHtml": False
|
||||
}
|
||||
"limit": options.get("limit", 1),
|
||||
"returnOnlyUrls": False,
|
||||
"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False},
|
||||
}
|
||||
}
|
||||
if options.get('max_depth'):
|
||||
params['crawlerOptions']['maxDepth'] = options.get('max_depth')
|
||||
if options.get("max_depth"):
|
||||
params["crawlerOptions"]["maxDepth"] = options.get("max_depth")
|
||||
job_id = firecrawl_app.crawl_url(url, params)
|
||||
website_crawl_time_cache_key = f'website_crawl_{job_id}'
|
||||
website_crawl_time_cache_key = f"website_crawl_{job_id}"
|
||||
time = str(datetime.datetime.now().timestamp())
|
||||
redis_client.setex(website_crawl_time_cache_key, 3600, time)
|
||||
return {
|
||||
'status': 'active',
|
||||
'job_id': job_id
|
||||
}
|
||||
return {"status": "active", "job_id": job_id}
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def get_crawl_status(cls, job_id: str, provider: str) -> dict:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
|
||||
if provider == "firecrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data = {
|
||||
'status': result.get('status', 'active'),
|
||||
'job_id': job_id,
|
||||
'total': result.get('total', 0),
|
||||
'current': result.get('current', 0),
|
||||
'data': result.get('data', [])
|
||||
"status": result.get("status", "active"),
|
||||
"job_id": job_id,
|
||||
"total": result.get("total", 0),
|
||||
"current": result.get("current", 0),
|
||||
"data": result.get("data", []),
|
||||
}
|
||||
if crawl_status_data['status'] == 'completed':
|
||||
website_crawl_time_cache_key = f'website_crawl_{job_id}'
|
||||
if crawl_status_data["status"] == "completed":
|
||||
website_crawl_time_cache_key = f"website_crawl_{job_id}"
|
||||
start_time = redis_client.get(website_crawl_time_cache_key)
|
||||
if start_time:
|
||||
end_time = datetime.datetime.now().timestamp()
|
||||
time_consuming = abs(end_time - float(start_time))
|
||||
crawl_status_data['time_consuming'] = f"{time_consuming:.2f}"
|
||||
crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
|
||||
redis_client.delete(website_crawl_time_cache_key)
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
raise ValueError("Invalid provider")
|
||||
return crawl_status_data
|
||||
|
||||
@classmethod
|
||||
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
file_key = 'website_files/' + job_id + '.txt'
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
|
||||
if provider == "firecrawl":
|
||||
file_key = "website_files/" + job_id + ".txt"
|
||||
if storage.exists(file_key):
|
||||
data = storage.load_once(file_key)
|
||||
if data:
|
||||
data = json.loads(data.decode('utf-8'))
|
||||
data = json.loads(data.decode("utf-8"))
|
||||
else:
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
if result.get('status') != 'completed':
|
||||
raise ValueError('Crawl job is not completed')
|
||||
data = result.get('data')
|
||||
if result.get("status") != "completed":
|
||||
raise ValueError("Crawl job is not completed")
|
||||
data = result.get("data")
|
||||
if data:
|
||||
for item in data:
|
||||
if item.get('source_url') == url:
|
||||
if item.get("source_url") == url:
|
||||
return item
|
||||
return None
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
raise ValueError("Invalid provider")
|
||||
|
||||
@classmethod
|
||||
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
|
||||
if provider == "firecrawl":
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
params = {
|
||||
'pageOptions': {
|
||||
'onlyMainContent': only_main_content,
|
||||
"includeHtml": False
|
||||
}
|
||||
}
|
||||
api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
|
||||
params = {"pageOptions": {"onlyMainContent": only_main_content, "includeHtml": False}}
|
||||
result = firecrawl_app.scrape_url(url, params)
|
||||
return result
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
raise ValueError("Invalid provider")
|
||||
|
@ -10,7 +10,6 @@ from models.workflow import WorkflowAppLog, WorkflowRun, WorkflowRunStatus
|
||||
|
||||
|
||||
class WorkflowAppService:
|
||||
|
||||
def get_paginate_workflow_app_logs(self, app_model: App, args: dict) -> Pagination:
|
||||
"""
|
||||
Get paginate workflow app logs
|
||||
@ -18,20 +17,14 @@ class WorkflowAppService:
|
||||
:param args: request args
|
||||
:return:
|
||||
"""
|
||||
query = (
|
||||
db.select(WorkflowAppLog)
|
||||
.where(
|
||||
WorkflowAppLog.tenant_id == app_model.tenant_id,
|
||||
WorkflowAppLog.app_id == app_model.id
|
||||
)
|
||||
query = db.select(WorkflowAppLog).where(
|
||||
WorkflowAppLog.tenant_id == app_model.tenant_id, WorkflowAppLog.app_id == app_model.id
|
||||
)
|
||||
|
||||
status = WorkflowRunStatus.value_of(args.get('status')) if args.get('status') else None
|
||||
keyword = args['keyword']
|
||||
status = WorkflowRunStatus.value_of(args.get("status")) if args.get("status") else None
|
||||
keyword = args["keyword"]
|
||||
if keyword or status:
|
||||
query = query.join(
|
||||
WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id
|
||||
)
|
||||
query = query.join(WorkflowRun, WorkflowRun.id == WorkflowAppLog.workflow_run_id)
|
||||
|
||||
if keyword:
|
||||
keyword_like_val = f"%{args['keyword'][:30]}%"
|
||||
@ -39,7 +32,7 @@ class WorkflowAppService:
|
||||
WorkflowRun.inputs.ilike(keyword_like_val),
|
||||
WorkflowRun.outputs.ilike(keyword_like_val),
|
||||
# filter keyword by end user session id if created by end user role
|
||||
and_(WorkflowRun.created_by_role == 'end_user', EndUser.session_id.ilike(keyword_like_val))
|
||||
and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)),
|
||||
]
|
||||
|
||||
# filter keyword by workflow run id
|
||||
@ -49,23 +42,16 @@ class WorkflowAppService:
|
||||
|
||||
query = query.outerjoin(
|
||||
EndUser,
|
||||
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value)
|
||||
and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER.value),
|
||||
).filter(or_(*keyword_conditions))
|
||||
|
||||
if status:
|
||||
# join with workflow_run and filter by status
|
||||
query = query.filter(
|
||||
WorkflowRun.status == status.value
|
||||
)
|
||||
query = query.filter(WorkflowRun.status == status.value)
|
||||
|
||||
query = query.order_by(WorkflowAppLog.created_at.desc())
|
||||
|
||||
pagination = db.paginate(
|
||||
query,
|
||||
page=args['page'],
|
||||
per_page=args['limit'],
|
||||
error_out=False
|
||||
)
|
||||
pagination = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False)
|
||||
|
||||
return pagination
|
||||
|
||||
|
@ -18,6 +18,7 @@ class WorkflowRunService:
|
||||
:param app_model: app model
|
||||
:param args: request args
|
||||
"""
|
||||
|
||||
class WorkflowWithMessage:
|
||||
message_id: str
|
||||
conversation_id: str
|
||||
@ -33,9 +34,7 @@ class WorkflowRunService:
|
||||
with_message_workflow_runs = []
|
||||
for workflow_run in pagination.data:
|
||||
message = workflow_run.message
|
||||
with_message_workflow_run = WorkflowWithMessage(
|
||||
workflow_run=workflow_run
|
||||
)
|
||||
with_message_workflow_run = WorkflowWithMessage(workflow_run=workflow_run)
|
||||
if message:
|
||||
with_message_workflow_run.message_id = message.id
|
||||
with_message_workflow_run.conversation_id = message.conversation_id
|
||||
@ -53,26 +52,30 @@ class WorkflowRunService:
|
||||
:param app_model: app model
|
||||
:param args: request args
|
||||
"""
|
||||
limit = int(args.get('limit', 20))
|
||||
limit = int(args.get("limit", 20))
|
||||
|
||||
base_query = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.tenant_id == app_model.tenant_id,
|
||||
WorkflowRun.app_id == app_model.id,
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value
|
||||
WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value,
|
||||
)
|
||||
|
||||
if args.get('last_id'):
|
||||
if args.get("last_id"):
|
||||
last_workflow_run = base_query.filter(
|
||||
WorkflowRun.id == args.get('last_id'),
|
||||
WorkflowRun.id == args.get("last_id"),
|
||||
).first()
|
||||
|
||||
if not last_workflow_run:
|
||||
raise ValueError('Last workflow run not exists')
|
||||
raise ValueError("Last workflow run not exists")
|
||||
|
||||
workflow_runs = base_query.filter(
|
||||
WorkflowRun.created_at < last_workflow_run.created_at,
|
||||
WorkflowRun.id != last_workflow_run.id
|
||||
).order_by(WorkflowRun.created_at.desc()).limit(limit).all()
|
||||
workflow_runs = (
|
||||
base_query.filter(
|
||||
WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
|
||||
)
|
||||
.order_by(WorkflowRun.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
|
||||
|
||||
@ -81,17 +84,13 @@ class WorkflowRunService:
|
||||
current_page_first_workflow_run = workflow_runs[-1]
|
||||
rest_count = base_query.filter(
|
||||
WorkflowRun.created_at < current_page_first_workflow_run.created_at,
|
||||
WorkflowRun.id != current_page_first_workflow_run.id
|
||||
WorkflowRun.id != current_page_first_workflow_run.id,
|
||||
).count()
|
||||
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
return InfiniteScrollPagination(
|
||||
data=workflow_runs,
|
||||
limit=limit,
|
||||
has_more=has_more
|
||||
)
|
||||
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
|
||||
|
||||
def get_workflow_run(self, app_model: App, run_id: str) -> WorkflowRun:
|
||||
"""
|
||||
@ -100,11 +99,15 @@ class WorkflowRunService:
|
||||
:param app_model: app model
|
||||
:param run_id: workflow run id
|
||||
"""
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.tenant_id == app_model.tenant_id,
|
||||
WorkflowRun.app_id == app_model.id,
|
||||
WorkflowRun.id == run_id,
|
||||
).first()
|
||||
workflow_run = (
|
||||
db.session.query(WorkflowRun)
|
||||
.filter(
|
||||
WorkflowRun.tenant_id == app_model.tenant_id,
|
||||
WorkflowRun.app_id == app_model.id,
|
||||
WorkflowRun.id == run_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
return workflow_run
|
||||
|
||||
@ -117,12 +120,17 @@ class WorkflowRunService:
|
||||
if not workflow_run:
|
||||
return []
|
||||
|
||||
node_executions = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.tenant_id == app_model.tenant_id,
|
||||
WorkflowNodeExecution.app_id == app_model.id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == run_id,
|
||||
).order_by(WorkflowNodeExecution.index.desc()).all()
|
||||
node_executions = (
|
||||
db.session.query(WorkflowNodeExecution)
|
||||
.filter(
|
||||
WorkflowNodeExecution.tenant_id == app_model.tenant_id,
|
||||
WorkflowNodeExecution.app_id == app_model.id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == run_id,
|
||||
)
|
||||
.order_by(WorkflowNodeExecution.index.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
return node_executions
|
||||
|
@ -37,11 +37,13 @@ class WorkflowService:
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.version == 'draft'
|
||||
).first()
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft"
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return draft workflow
|
||||
return workflow
|
||||
@ -55,11 +57,15 @@ class WorkflowService:
|
||||
return None
|
||||
|
||||
# fetch published workflow by workflow_id
|
||||
workflow = db.session.query(Workflow).filter(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.id == app_model.workflow_id
|
||||
).first()
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.id == app_model.workflow_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
return workflow
|
||||
|
||||
@ -85,10 +91,7 @@ class WorkflowService:
|
||||
raise WorkflowHashNotEqualError()
|
||||
|
||||
# validate features structure
|
||||
self.validate_features_structure(
|
||||
app_model=app_model,
|
||||
features=features
|
||||
)
|
||||
self.validate_features_structure(app_model=app_model, features=features)
|
||||
|
||||
# create draft workflow if not found
|
||||
if not workflow:
|
||||
@ -96,7 +99,7 @@ class WorkflowService:
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type=WorkflowType.from_app_mode(app_model.mode).value,
|
||||
version='draft',
|
||||
version="draft",
|
||||
graph=json.dumps(graph),
|
||||
features=json.dumps(features),
|
||||
created_by=account.id,
|
||||
@ -122,9 +125,7 @@ class WorkflowService:
|
||||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def publish_workflow(self, app_model: App,
|
||||
account: Account,
|
||||
draft_workflow: Optional[Workflow] = None) -> Workflow:
|
||||
def publish_workflow(self, app_model: App, account: Account, draft_workflow: Optional[Workflow] = None) -> Workflow:
|
||||
"""
|
||||
Publish workflow from draft
|
||||
|
||||
@ -137,7 +138,7 @@ class WorkflowService:
|
||||
draft_workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if not draft_workflow:
|
||||
raise ValueError('No valid workflow found.')
|
||||
raise ValueError("No valid workflow found.")
|
||||
|
||||
# create new workflow
|
||||
workflow = Workflow(
|
||||
@ -187,17 +188,16 @@ class WorkflowService:
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
return workflow_engine_manager.get_default_config(node_type, filters)
|
||||
|
||||
def run_draft_workflow_node(self, app_model: App,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
account: Account) -> WorkflowNodeExecution:
|
||||
def run_draft_workflow_node(
|
||||
self, app_model: App, node_id: str, user_inputs: dict, account: Account
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Run draft workflow node
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
draft_workflow = self.get_draft_workflow(app_model=app_model)
|
||||
if not draft_workflow:
|
||||
raise ValueError('Workflow not initialized')
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
# run draft workflow node
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
@ -226,7 +226,7 @@ class WorkflowService:
|
||||
created_by_role=CreatedByRole.ACCOUNT.value,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
finished_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
)
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
@ -247,14 +247,15 @@ class WorkflowService:
|
||||
inputs=json.dumps(node_run_result.inputs) if node_run_result.inputs else None,
|
||||
process_data=json.dumps(node_run_result.process_data) if node_run_result.process_data else None,
|
||||
outputs=json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None,
|
||||
execution_metadata=(json.dumps(jsonable_encoder(node_run_result.metadata))
|
||||
if node_run_result.metadata else None),
|
||||
execution_metadata=(
|
||||
json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None
|
||||
),
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED.value,
|
||||
elapsed_time=time.perf_counter() - start_at,
|
||||
created_by_role=CreatedByRole.ACCOUNT.value,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
finished_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
)
|
||||
else:
|
||||
# create workflow node execution
|
||||
@ -273,7 +274,7 @@ class WorkflowService:
|
||||
created_by_role=CreatedByRole.ACCOUNT.value,
|
||||
created_by=account.id,
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
finished_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
@ -295,16 +296,16 @@ class WorkflowService:
|
||||
workflow_converter = WorkflowConverter()
|
||||
|
||||
if app_model.mode not in [AppMode.CHAT.value, AppMode.COMPLETION.value]:
|
||||
raise ValueError(f'Current App mode: {app_model.mode} is not supported convert to workflow.')
|
||||
raise ValueError(f"Current App mode: {app_model.mode} is not supported convert to workflow.")
|
||||
|
||||
# convert to workflow
|
||||
new_app = workflow_converter.convert_to_workflow(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
name=args.get('name'),
|
||||
icon_type=args.get('icon_type'),
|
||||
icon=args.get('icon'),
|
||||
icon_background=args.get('icon_background'),
|
||||
name=args.get("name"),
|
||||
icon_type=args.get("icon_type"),
|
||||
icon=args.get("icon"),
|
||||
icon_background=args.get("icon_background"),
|
||||
)
|
||||
|
||||
return new_app
|
||||
@ -312,15 +313,11 @@ class WorkflowService:
|
||||
def validate_features_structure(self, app_model: App, features: dict) -> dict:
|
||||
if app_model.mode == AppMode.ADVANCED_CHAT.value:
|
||||
return AdvancedChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=features,
|
||||
only_structure_validate=True
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW.value:
|
||||
return WorkflowAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
config=features,
|
||||
only_structure_validate=True
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid app mode: {app_model.mode}")
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
from configs import dify_config
|
||||
@ -14,34 +13,40 @@ class WorkspaceService:
|
||||
if not tenant:
|
||||
return None
|
||||
tenant_info = {
|
||||
'id': tenant.id,
|
||||
'name': tenant.name,
|
||||
'plan': tenant.plan,
|
||||
'status': tenant.status,
|
||||
'created_at': tenant.created_at,
|
||||
'in_trail': True,
|
||||
'trial_end_reason': None,
|
||||
'role': 'normal',
|
||||
"id": tenant.id,
|
||||
"name": tenant.name,
|
||||
"plan": tenant.plan,
|
||||
"status": tenant.status,
|
||||
"created_at": tenant.created_at,
|
||||
"in_trail": True,
|
||||
"trial_end_reason": None,
|
||||
"role": "normal",
|
||||
}
|
||||
|
||||
# Get role of user
|
||||
tenant_account_join = db.session.query(TenantAccountJoin).filter(
|
||||
TenantAccountJoin.tenant_id == tenant.id,
|
||||
TenantAccountJoin.account_id == current_user.id
|
||||
).first()
|
||||
tenant_info['role'] = tenant_account_join.role
|
||||
tenant_account_join = (
|
||||
db.session.query(TenantAccountJoin)
|
||||
.filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
tenant_info["role"] = tenant_account_join.role
|
||||
|
||||
can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo
|
||||
can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo
|
||||
|
||||
if can_replace_logo and TenantService.has_roles(tenant,
|
||||
[TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]):
|
||||
if can_replace_logo and TenantService.has_roles(
|
||||
tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]
|
||||
):
|
||||
base_url = dify_config.FILES_URL
|
||||
replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None
|
||||
remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False)
|
||||
replace_webapp_logo = (
|
||||
f"{base_url}/files/workspaces/{tenant.id}/webapp-logo"
|
||||
if tenant.custom_config_dict.get("replace_webapp_logo")
|
||||
else None
|
||||
)
|
||||
remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False)
|
||||
|
||||
tenant_info['custom_config'] = {
|
||||
'remove_webapp_brand': remove_webapp_brand,
|
||||
'replace_webapp_logo': replace_webapp_logo,
|
||||
tenant_info["custom_config"] = {
|
||||
"remove_webapp_brand": remove_webapp_brand,
|
||||
"replace_webapp_logo": replace_webapp_logo,
|
||||
}
|
||||
|
||||
return tenant_info
|
||||
|
Loading…
Reference in New Issue
Block a user