diff --git a/api/libs/bearer_data_source.py b/api/libs/bearer_data_source.py index c1aee7b81..ab6e398eb 100644 --- a/api/libs/bearer_data_source.py +++ b/api/libs/bearer_data_source.py @@ -25,7 +25,7 @@ class FireCrawlDataSource(BearerDataSource): TEST_CRAWL_SITE_URL = "https://www.google.com" FIRECRAWL_API_VERSION = "v0" - test_api_endpoint = self.api_base_url.rstrip('/') + f"/{FIRECRAWL_API_VERSION}/scrape" + test_api_endpoint = self.api_base_url.rstrip("/") + f"/{FIRECRAWL_API_VERSION}/scrape" headers = { "Authorization": f"Bearer {self.api_key}", @@ -45,9 +45,9 @@ class FireCrawlDataSource(BearerDataSource): data_source_binding = DataSourceBearerBinding.query.filter( db.and_( DataSourceBearerBinding.tenant_id == current_user.current_tenant_id, - DataSourceBearerBinding.provider == 'firecrawl', + DataSourceBearerBinding.provider == "firecrawl", DataSourceBearerBinding.endpoint_url == self.api_base_url, - DataSourceBearerBinding.bearer_key == self.api_key + DataSourceBearerBinding.bearer_key == self.api_key, ) ).first() if data_source_binding: @@ -56,9 +56,9 @@ class FireCrawlDataSource(BearerDataSource): else: new_data_source_binding = DataSourceBearerBinding( tenant_id=current_user.current_tenant_id, - provider='firecrawl', + provider="firecrawl", endpoint_url=self.api_base_url, - bearer_key=self.api_key + bearer_key=self.api_key, ) db.session.add(new_data_source_binding) db.session.commit() diff --git a/api/libs/exception.py b/api/libs/exception.py index 567062f06..5970269ec 100644 --- a/api/libs/exception.py +++ b/api/libs/exception.py @@ -4,7 +4,7 @@ from werkzeug.exceptions import HTTPException class BaseHTTPException(HTTPException): - error_code: str = 'unknown' + error_code: str = "unknown" data: Optional[dict] = None def __init__(self, description=None, response=None): @@ -14,4 +14,4 @@ class BaseHTTPException(HTTPException): "code": self.error_code, "message": self.description, "status": self.code, - } \ No newline at end of file + } diff --git a/api/libs/external_api.py b/api/libs/external_api.py index 677ff0fc5..179617ac0 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -10,7 +10,6 @@ from core.errors.error import AppInvokeQuotaExceededError class ExternalApi(Api): - def handle_error(self, e): """Error handler for the API transforms a raised exception into a Flask response, with the appropriate HTTP status code and body. @@ -29,54 +28,57 @@ class ExternalApi(Api): status_code = e.code default_data = { - 'code': re.sub(r'(? self.max_length: - error = ('Invalid {arg}: {val}. {arg} cannot exceed length {length}' - .format(arg=self.argument, val=value, length=self.max_length)) + error = "Invalid {arg}: {val}. {arg} cannot exceed length {length}".format( + arg=self.argument, val=value, length=self.max_length + ) raise ValueError(error) return value class float_range: - """ Restrict input to an float in a range (inclusive) """ - def __init__(self, low, high, argument='argument'): + """Restrict input to an float in a range (inclusive)""" + + def __init__(self, low, high, argument="argument"): self.low = low self.high = high self.argument = argument @@ -99,15 +100,16 @@ class float_range: def __call__(self, value): value = _get_float(value) if value < self.low or value > self.high: - error = ('Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}' - .format(arg=self.argument, val=value, lo=self.low, hi=self.high)) + error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format( + arg=self.argument, val=value, lo=self.low, hi=self.high + ) raise ValueError(error) return value class datetime_string: - def __init__(self, format, argument='argument'): + def __init__(self, format, argument="argument"): self.format = format self.argument = argument @@ -115,8 +117,9 @@ class datetime_string: try: datetime.strptime(value, self.format) except ValueError: - error = ('Invalid {arg}: {val}. {arg} must be conform to the format {format}' - .format(arg=self.argument, val=value, format=self.format)) + error = "Invalid {arg}: {val}. {arg} must be conform to the format {format}".format( + arg=self.argument, val=value, format=self.format + ) raise ValueError(error) return value @@ -126,14 +129,14 @@ def _get_float(value): try: return float(value) except (TypeError, ValueError): - raise ValueError('{} is not a valid float'.format(value)) + raise ValueError("{} is not a valid float".format(value)) + def timezone(timezone_string): if timezone_string and timezone_string in available_timezones(): return timezone_string - error = ('{timezone_string} is not a valid timezone.' - .format(timezone_string=timezone_string)) + error = "{timezone_string} is not a valid timezone.".format(timezone_string=timezone_string) raise ValueError(error) @@ -147,8 +150,8 @@ def generate_string(n): def get_remote_ip(request) -> str: - if request.headers.get('CF-Connecting-IP'): - return request.headers.get('Cf-Connecting-Ip') + if request.headers.get("CF-Connecting-IP"): + return request.headers.get("Cf-Connecting-Ip") elif request.headers.getlist("X-Forwarded-For"): return request.headers.getlist("X-Forwarded-For")[0] else: @@ -156,54 +159,45 @@ def get_remote_ip(request) -> str: def generate_text_hash(text: str) -> str: - hash_text = str(text) + 'None' + hash_text = str(text) + "None" return sha256(hash_text.encode()).hexdigest() def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response: if isinstance(response, dict): - return Response(response=json.dumps(response), status=200, mimetype='application/json') + return Response(response=json.dumps(response), status=200, mimetype="application/json") else: + def generate() -> Generator: yield from response - return Response(stream_with_context(generate()), status=200, - mimetype='text/event-stream') + return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") class TokenManager: - @classmethod def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str: old_token = cls._get_current_token_for_account(account.id, token_type) if old_token: if isinstance(old_token, bytes): - old_token = old_token.decode('utf-8') + old_token = old_token.decode("utf-8") cls.revoke_token(old_token, token_type) token = str(uuid.uuid4()) - token_data = { - 'account_id': account.id, - 'email': account.email, - 'token_type': token_type - } + token_data = {"account_id": account.id, "email": account.email, "token_type": token_type} if additional_data: token_data.update(additional_data) - expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS'] + expiry_hours = current_app.config[f"{token_type.upper()}_TOKEN_EXPIRY_HOURS"] token_key = cls._get_token_key(token, token_type) - redis_client.setex( - token_key, - expiry_hours * 60 * 60, - json.dumps(token_data) - ) + redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data)) cls._set_current_token_for_account(account.id, token, token_type, expiry_hours) return token @classmethod def _get_token_key(cls, token: str, token_type: str) -> str: - return f'{token_type}:token:{token}' + return f"{token_type}:token:{token}" @classmethod def revoke_token(cls, token: str, token_type: str): @@ -233,7 +227,7 @@ class TokenManager: @classmethod def _get_account_token_key(cls, account_id: str, token_type: str) -> str: - return f'{token_type}:account:{account_id}' + return f"{token_type}:account:{account_id}" class RateLimiter: @@ -250,7 +244,7 @@ class RateLimiter: current_time = int(time.time()) window_start_time = current_time - self.time_window - redis_client.zremrangebyscore(key, '-inf', window_start_time) + redis_client.zremrangebyscore(key, "-inf", window_start_time) attempts = redis_client.zcard(key) if attempts and int(attempts) >= self.max_attempts: diff --git a/api/libs/infinite_scroll_pagination.py b/api/libs/infinite_scroll_pagination.py index a1cb7b78f..133ccb188 100644 --- a/api/libs/infinite_scroll_pagination.py +++ b/api/libs/infinite_scroll_pagination.py @@ -1,4 +1,3 @@ - class InfiniteScrollPagination: def __init__(self, data, limit, has_more): self.data = data diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 2cf023a39..41d690589 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -10,13 +10,13 @@ def parse_json_markdown(json_string: str) -> dict: end_index = json_string.find("```", start_index + len("```json")) if start_index != -1 and end_index != -1: - extracted_content = json_string[start_index + len("```json"):end_index].strip() + extracted_content = json_string[start_index + len("```json") : end_index].strip() # Parse the JSON string into a Python dictionary parsed = json.loads(extracted_content) elif start_index != -1 and end_index == -1 and json_string.endswith("``"): end_index = json_string.find("``", start_index + len("```json")) - extracted_content = json_string[start_index + len("```json"):end_index].strip() + extracted_content = json_string[start_index + len("```json") : end_index].strip() # Parse the JSON string into a Python dictionary parsed = json.loads(extracted_content) @@ -37,7 +37,6 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: for key in expected_keys: if key not in json_obj: raise OutputParserException( - f"Got invalid return object. Expected key `{key}` " - f"to be present, but got {json_obj}" + f"Got invalid return object. Expected key `{key}` " f"to be present, but got {json_obj}" ) return json_obj diff --git a/api/libs/login.py b/api/libs/login.py index 14085fe60..7f05eb840 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -51,27 +51,29 @@ def login_required(func): @wraps(func) def decorated_view(*args, **kwargs): - auth_header = request.headers.get('Authorization') - admin_api_key_enable = os.getenv('ADMIN_API_KEY_ENABLE', default='False') - if admin_api_key_enable.lower() == 'true': + auth_header = request.headers.get("Authorization") + admin_api_key_enable = os.getenv("ADMIN_API_KEY_ENABLE", default="False") + if admin_api_key_enable.lower() == "true": if auth_header: - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') - admin_api_key = os.getenv('ADMIN_API_KEY') + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + admin_api_key = os.getenv("ADMIN_API_KEY") if admin_api_key: - if os.getenv('ADMIN_API_KEY') == auth_token: - workspace_id = request.headers.get('X-WORKSPACE-ID') + if os.getenv("ADMIN_API_KEY") == auth_token: + workspace_id = request.headers.get("X-WORKSPACE-ID") if workspace_id: - tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ - .filter(Tenant.id == workspace_id) \ - .filter(TenantAccountJoin.tenant_id == Tenant.id) \ - .filter(TenantAccountJoin.role == 'owner') \ + tenant_account_join = ( + db.session.query(Tenant, TenantAccountJoin) + .filter(Tenant.id == workspace_id) + .filter(TenantAccountJoin.tenant_id == Tenant.id) + .filter(TenantAccountJoin.role == "owner") .one_or_none() + ) if tenant_account_join: tenant, ta = tenant_account_join account = Account.query.filter_by(id=ta.account_id).first() diff --git a/api/libs/oauth.py b/api/libs/oauth.py index dacdee0bc..d8ce1a1e6 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -35,31 +35,31 @@ class OAuth: class GitHubOAuth(OAuth): - _AUTH_URL = 'https://github.com/login/oauth/authorize' - _TOKEN_URL = 'https://github.com/login/oauth/access_token' - _USER_INFO_URL = 'https://api.github.com/user' - _EMAIL_INFO_URL = 'https://api.github.com/user/emails' + _AUTH_URL = "https://github.com/login/oauth/authorize" + _TOKEN_URL = "https://github.com/login/oauth/access_token" + _USER_INFO_URL = "https://api.github.com/user" + _EMAIL_INFO_URL = "https://api.github.com/user/emails" def get_authorization_url(self): params = { - 'client_id': self.client_id, - 'redirect_uri': self.redirect_uri, - 'scope': 'user:email' # Request only basic user information + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "scope": "user:email", # Request only basic user information } return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str): data = { - 'client_id': self.client_id, - 'client_secret': self.client_secret, - 'code': code, - 'redirect_uri': self.redirect_uri + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "redirect_uri": self.redirect_uri, } - headers = {'Accept': 'application/json'} + headers = {"Accept": "application/json"} response = requests.post(self._TOKEN_URL, data=data, headers=headers) response_json = response.json() - access_token = response_json.get('access_token') + access_token = response_json.get("access_token") if not access_token: raise ValueError(f"Error in GitHub OAuth: {response_json}") @@ -67,55 +67,51 @@ class GitHubOAuth(OAuth): return access_token def get_raw_user_info(self, token: str): - headers = {'Authorization': f"token {token}"} + headers = {"Authorization": f"token {token}"} response = requests.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() user_info = response.json() email_response = requests.get(self._EMAIL_INFO_URL, headers=headers) email_info = email_response.json() - primary_email = next((email for email in email_info if email['primary'] == True), None) + primary_email = next((email for email in email_info if email["primary"] == True), None) - return {**user_info, 'email': primary_email['email']} + return {**user_info, "email": primary_email["email"]} def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: - email = raw_info.get('email') + email = raw_info.get("email") if not email: email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com" - return OAuthUserInfo( - id=str(raw_info['id']), - name=raw_info['name'], - email=email - ) + return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email) class GoogleOAuth(OAuth): - _AUTH_URL = 'https://accounts.google.com/o/oauth2/v2/auth' - _TOKEN_URL = 'https://oauth2.googleapis.com/token' - _USER_INFO_URL = 'https://www.googleapis.com/oauth2/v3/userinfo' + _AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth" + _TOKEN_URL = "https://oauth2.googleapis.com/token" + _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" def get_authorization_url(self): params = { - 'client_id': self.client_id, - 'response_type': 'code', - 'redirect_uri': self.redirect_uri, - 'scope': 'openid email' + "client_id": self.client_id, + "response_type": "code", + "redirect_uri": self.redirect_uri, + "scope": "openid email", } return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str): data = { - 'client_id': self.client_id, - 'client_secret': self.client_secret, - 'code': code, - 'grant_type': 'authorization_code', - 'redirect_uri': self.redirect_uri + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": self.redirect_uri, } - headers = {'Accept': 'application/json'} + headers = {"Accept": "application/json"} response = requests.post(self._TOKEN_URL, data=data, headers=headers) response_json = response.json() - access_token = response_json.get('access_token') + access_token = response_json.get("access_token") if not access_token: raise ValueError(f"Error in Google OAuth: {response_json}") @@ -123,16 +119,10 @@ class GoogleOAuth(OAuth): return access_token def get_raw_user_info(self, token: str): - headers = {'Authorization': f"Bearer {token}"} + headers = {"Authorization": f"Bearer {token}"} response = requests.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() return response.json() def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: - return OAuthUserInfo( - id=str(raw_info['sub']), - name=None, - email=raw_info['email'] - ) - - + return OAuthUserInfo(id=str(raw_info["sub"]), name=None, email=raw_info["email"]) diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 358858ceb..6da1a6d39 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -21,53 +21,49 @@ class OAuthDataSource: class NotionOAuth(OAuthDataSource): - _AUTH_URL = 'https://api.notion.com/v1/oauth/authorize' - _TOKEN_URL = 'https://api.notion.com/v1/oauth/token' + _AUTH_URL = "https://api.notion.com/v1/oauth/authorize" + _TOKEN_URL = "https://api.notion.com/v1/oauth/token" _NOTION_PAGE_SEARCH = "https://api.notion.com/v1/search" _NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks" _NOTION_BOT_USER = "https://api.notion.com/v1/users/me" def get_authorization_url(self): params = { - 'client_id': self.client_id, - 'response_type': 'code', - 'redirect_uri': self.redirect_uri, - 'owner': 'user' + "client_id": self.client_id, + "response_type": "code", + "redirect_uri": self.redirect_uri, + "owner": "user", } return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" def get_access_token(self, code: str): - data = { - 'code': code, - 'grant_type': 'authorization_code', - 'redirect_uri': self.redirect_uri - } - headers = {'Accept': 'application/json'} + data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri} + headers = {"Accept": "application/json"} auth = (self.client_id, self.client_secret) response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) response_json = response.json() - access_token = response_json.get('access_token') + access_token = response_json.get("access_token") if not access_token: raise ValueError(f"Error in Notion OAuth: {response_json}") - workspace_name = response_json.get('workspace_name') - workspace_icon = response_json.get('workspace_icon') - workspace_id = response_json.get('workspace_id') + workspace_name = response_json.get("workspace_name") + workspace_icon = response_json.get("workspace_icon") + workspace_id = response_json.get("workspace_id") # get all authorized pages pages = self.get_authorized_pages(access_token) source_info = { - 'workspace_name': workspace_name, - 'workspace_icon': workspace_icon, - 'workspace_id': workspace_id, - 'pages': pages, - 'total': len(pages) + "workspace_name": workspace_name, + "workspace_icon": workspace_icon, + "workspace_id": workspace_id, + "pages": pages, + "total": len(pages), } # save data source binding data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', - DataSourceOauthBinding.access_token == access_token + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, ) ).first() if data_source_binding: @@ -79,7 +75,7 @@ class NotionOAuth(OAuthDataSource): tenant_id=current_user.current_tenant_id, access_token=access_token, source_info=source_info, - provider='notion' + provider="notion", ) db.session.add(new_data_source_binding) db.session.commit() @@ -91,18 +87,18 @@ class NotionOAuth(OAuthDataSource): # get all authorized pages pages = self.get_authorized_pages(access_token) source_info = { - 'workspace_name': workspace_name, - 'workspace_icon': workspace_icon, - 'workspace_id': workspace_id, - 'pages': pages, - 'total': len(pages) + "workspace_name": workspace_name, + "workspace_icon": workspace_icon, + "workspace_id": workspace_id, + "pages": pages, + "total": len(pages), } # save data source binding data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', - DataSourceOauthBinding.access_token == access_token + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, ) ).first() if data_source_binding: @@ -114,7 +110,7 @@ class NotionOAuth(OAuthDataSource): tenant_id=current_user.current_tenant_id, access_token=access_token, source_info=source_info, - provider='notion' + provider="notion", ) db.session.add(new_data_source_binding) db.session.commit() @@ -124,9 +120,9 @@ class NotionOAuth(OAuthDataSource): data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.id == binding_id, - DataSourceOauthBinding.disabled == False + DataSourceOauthBinding.disabled == False, ) ).first() if data_source_binding: @@ -134,17 +130,17 @@ class NotionOAuth(OAuthDataSource): pages = self.get_authorized_pages(data_source_binding.access_token) source_info = data_source_binding.source_info new_source_info = { - 'workspace_name': source_info['workspace_name'], - 'workspace_icon': source_info['workspace_icon'], - 'workspace_id': source_info['workspace_id'], - 'pages': pages, - 'total': len(pages) + "workspace_name": source_info["workspace_name"], + "workspace_icon": source_info["workspace_icon"], + "workspace_id": source_info["workspace_id"], + "pages": pages, + "total": len(pages), } data_source_binding.source_info = new_source_info data_source_binding.disabled = False db.session.commit() else: - raise ValueError('Data source binding not found') + raise ValueError("Data source binding not found") def get_authorized_pages(self, access_token: str): pages = [] @@ -152,143 +148,121 @@ class NotionOAuth(OAuthDataSource): database_results = self.notion_database_search(access_token) # get page detail for page_result in page_results: - page_id = page_result['id'] - page_name = 'Untitled' - for key in page_result['properties']: - if 'title' in page_result['properties'][key] and page_result['properties'][key]['title']: - title_list = page_result['properties'][key]['title'] - if len(title_list) > 0 and 'plain_text' in title_list[0]: - page_name = title_list[0]['plain_text'] - page_icon = page_result['icon'] + page_id = page_result["id"] + page_name = "Untitled" + for key in page_result["properties"]: + if "title" in page_result["properties"][key] and page_result["properties"][key]["title"]: + title_list = page_result["properties"][key]["title"] + if len(title_list) > 0 and "plain_text" in title_list[0]: + page_name = title_list[0]["plain_text"] + page_icon = page_result["icon"] if page_icon: - icon_type = page_icon['type'] - if icon_type == 'external' or icon_type == 'file': - url = page_icon[icon_type]['url'] - icon = { - 'type': 'url', - 'url': url if url.startswith('http') else f'https://www.notion.so{url}' - } + icon_type = page_icon["type"] + if icon_type == "external" or icon_type == "file": + url = page_icon[icon_type]["url"] + icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} else: - icon = { - 'type': 'emoji', - 'emoji': page_icon[icon_type] - } + icon = {"type": "emoji", "emoji": page_icon[icon_type]} else: icon = None - parent = page_result['parent'] - parent_type = parent['type'] - if parent_type == 'block_id': + parent = page_result["parent"] + parent_type = parent["type"] + if parent_type == "block_id": parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) - elif parent_type == 'workspace': - parent_id = 'root' + elif parent_type == "workspace": + parent_id = "root" else: parent_id = parent[parent_type] page = { - 'page_id': page_id, - 'page_name': page_name, - 'page_icon': icon, - 'parent_id': parent_id, - 'type': 'page' + "page_id": page_id, + "page_name": page_name, + "page_icon": icon, + "parent_id": parent_id, + "type": "page", } pages.append(page) # get database detail for database_result in database_results: - page_id = database_result['id'] - if len(database_result['title']) > 0: - page_name = database_result['title'][0]['plain_text'] + page_id = database_result["id"] + if len(database_result["title"]) > 0: + page_name = database_result["title"][0]["plain_text"] else: - page_name = 'Untitled' - page_icon = database_result['icon'] + page_name = "Untitled" + page_icon = database_result["icon"] if page_icon: - icon_type = page_icon['type'] - if icon_type == 'external' or icon_type == 'file': - url = page_icon[icon_type]['url'] - icon = { - 'type': 'url', - 'url': url if url.startswith('http') else f'https://www.notion.so{url}' - } + icon_type = page_icon["type"] + if icon_type == "external" or icon_type == "file": + url = page_icon[icon_type]["url"] + icon = {"type": "url", "url": url if url.startswith("http") else f"https://www.notion.so{url}"} else: - icon = { - 'type': icon_type, - icon_type: page_icon[icon_type] - } + icon = {"type": icon_type, icon_type: page_icon[icon_type]} else: icon = None - parent = database_result['parent'] - parent_type = parent['type'] - if parent_type == 'block_id': + parent = database_result["parent"] + parent_type = parent["type"] + if parent_type == "block_id": parent_id = self.notion_block_parent_page_id(access_token, parent[parent_type]) - elif parent_type == 'workspace': - parent_id = 'root' + elif parent_type == "workspace": + parent_id = "root" else: parent_id = parent[parent_type] page = { - 'page_id': page_id, - 'page_name': page_name, - 'page_icon': icon, - 'parent_id': parent_id, - 'type': 'database' + "page_id": page_id, + "page_name": page_name, + "page_icon": icon, + "parent_id": parent_id, + "type": "database", } pages.append(page) return pages def notion_page_search(self, access_token: str): - data = { - 'filter': { - "value": "page", - "property": "object" - } - } + data = {"filter": {"value": "page", "property": "object"}} headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", - 'Notion-Version': '2022-06-28', + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", } response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response_json = response.json() - results = response_json.get('results', []) + results = response_json.get("results", []) return results def notion_block_parent_page_id(self, access_token: str, block_id: str): headers = { - 'Authorization': f"Bearer {access_token}", - 'Notion-Version': '2022-06-28', + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", } - response = requests.get(url=f'{self._NOTION_BLOCK_SEARCH}/{block_id}', headers=headers) + response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) response_json = response.json() - parent = response_json['parent'] - parent_type = parent['type'] - if parent_type == 'block_id': + parent = response_json["parent"] + parent_type = parent["type"] + if parent_type == "block_id": return self.notion_block_parent_page_id(access_token, parent[parent_type]) return parent[parent_type] def notion_workspace_name(self, access_token: str): headers = { - 'Authorization': f"Bearer {access_token}", - 'Notion-Version': '2022-06-28', + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", } response = requests.get(url=self._NOTION_BOT_USER, headers=headers) response_json = response.json() - if 'object' in response_json and response_json['object'] == 'user': - user_type = response_json['type'] + if "object" in response_json and response_json["object"] == "user": + user_type = response_json["type"] user_info = response_json[user_type] - if 'workspace_name' in user_info: - return user_info['workspace_name'] - return 'workspace' + if "workspace_name" in user_info: + return user_info["workspace_name"] + return "workspace" def notion_database_search(self, access_token: str): - data = { - 'filter': { - "value": "database", - "property": "object" - } - } + data = {"filter": {"value": "database", "property": "object"}} headers = { - 'Content-Type': 'application/json', - 'Authorization': f"Bearer {access_token}", - 'Notion-Version': '2022-06-28', + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", } response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response_json = response.json() - results = response_json.get('results', []) + results = response_json.get("results", []) return results diff --git a/api/libs/passport.py b/api/libs/passport.py index 34bdc5599..8df4f529b 100644 --- a/api/libs/passport.py +++ b/api/libs/passport.py @@ -9,14 +9,14 @@ class PassportService: self.sk = dify_config.SECRET_KEY def issue(self, payload): - return jwt.encode(payload, self.sk, algorithm='HS256') + return jwt.encode(payload, self.sk, algorithm="HS256") def verify(self, token): try: - return jwt.decode(token, self.sk, algorithms=['HS256']) + return jwt.decode(token, self.sk, algorithms=["HS256"]) except jwt.exceptions.InvalidSignatureError: - raise Unauthorized('Invalid token signature.') + raise Unauthorized("Invalid token signature.") except jwt.exceptions.DecodeError: - raise Unauthorized('Invalid token.') + raise Unauthorized("Invalid token.") except jwt.exceptions.ExpiredSignatureError: - raise Unauthorized('Token has expired.') + raise Unauthorized("Token has expired.") diff --git a/api/libs/password.py b/api/libs/password.py index cdd1d69db..cfcc0db22 100644 --- a/api/libs/password.py +++ b/api/libs/password.py @@ -5,6 +5,7 @@ import re password_pattern = r"^(?=.*[a-zA-Z])(?=.*\d).{8,}$" + def valid_password(password): # Define a regex pattern for password rules pattern = password_pattern @@ -12,11 +13,11 @@ def valid_password(password): if re.match(pattern, password) is not None: return password - raise ValueError('Not a valid password.') + raise ValueError("Not a valid password.") def hash_password(password_str, salt_byte): - dk = hashlib.pbkdf2_hmac('sha256', password_str.encode('utf-8'), salt_byte, 10000) + dk = hashlib.pbkdf2_hmac("sha256", password_str.encode("utf-8"), salt_byte, 10000) return binascii.hexlify(dk) diff --git a/api/libs/rsa.py b/api/libs/rsa.py index 9f29c5881..a578bf3e5 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -48,7 +48,7 @@ def encrypt(text, public_key): def get_decrypt_decoding(tenant_id): filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" - cache_key = 'tenant_privkey:{hash}'.format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) + cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) private_key = redis_client.get(cache_key) if not private_key: try: @@ -66,12 +66,12 @@ def get_decrypt_decoding(tenant_id): def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa): if encrypted_text.startswith(prefix_hybrid): - encrypted_text = encrypted_text[len(prefix_hybrid):] + encrypted_text = encrypted_text[len(prefix_hybrid) :] - enc_aes_key = encrypted_text[:rsa_key.size_in_bytes()] - nonce = encrypted_text[rsa_key.size_in_bytes():rsa_key.size_in_bytes() + 16] - tag = encrypted_text[rsa_key.size_in_bytes() + 16:rsa_key.size_in_bytes() + 32] - ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32:] + enc_aes_key = encrypted_text[: rsa_key.size_in_bytes()] + nonce = encrypted_text[rsa_key.size_in_bytes() : rsa_key.size_in_bytes() + 16] + tag = encrypted_text[rsa_key.size_in_bytes() + 16 : rsa_key.size_in_bytes() + 32] + ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32 :] aes_key = cipher_rsa.decrypt(enc_aes_key) diff --git a/api/libs/smtp.py b/api/libs/smtp.py index bf3a1a92e..bd7de7dd6 100644 --- a/api/libs/smtp.py +++ b/api/libs/smtp.py @@ -5,7 +5,9 @@ from email.mime.text import MIMEText class SMTPClient: - def __init__(self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False): + def __init__( + self, server: str, port: int, username: str, password: str, _from: str, use_tls=False, opportunistic_tls=False + ): self.server = server self.port = port self._from = _from @@ -25,17 +27,17 @@ class SMTPClient: smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10) else: smtp = smtplib.SMTP(self.server, self.port, timeout=10) - + if self.username and self.password: smtp.login(self.username, self.password) msg = MIMEMultipart() - msg['Subject'] = mail['subject'] - msg['From'] = self._from - msg['To'] = mail['to'] - msg.attach(MIMEText(mail['html'], 'html')) + msg["Subject"] = mail["subject"] + msg["From"] = self._from + msg["To"] = mail["to"] + msg.attach(MIMEText(mail["html"], "html")) - smtp.sendmail(self._from, mail['to'], msg.as_string()) + smtp.sendmail(self._from, mail["to"], msg.as_string()) except smtplib.SMTPException as e: logging.error(f"SMTP error occurred: {str(e)}") raise diff --git a/api/pyproject.toml b/api/pyproject.toml index 3e107f5e9..60c1c86d0 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -73,12 +73,10 @@ exclude = [ "core/**/*.py", "controllers/**/*.py", "models/**/*.py", - "utils/**/*.py", "migrations/**/*", "services/**/*.py", "tasks/**/*.py", "tests/**/*.py", - "libs/**/*.py", "configs/**/*.py", ]