mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-11-30 02:08:37 +08:00
Feat/firecrawl data source (#5232)
Co-authored-by: Nicolas <nicolascamara29@gmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
parent
918ebe1620
commit
ba5f8afaa8
@ -215,4 +215,5 @@ WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||
WORKFLOW_CALL_MAX_DEPTH=5
|
||||
|
||||
# App configuration
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
APP_MAX_EXECUTION_TIME=1200
|
||||
|
||||
|
@ -29,13 +29,13 @@ from .app import (
|
||||
)
|
||||
|
||||
# Import auth controllers
|
||||
from .auth import activate, data_source_oauth, login, oauth
|
||||
from .auth import activate, data_source_bearer_auth, data_source_oauth, login, oauth
|
||||
|
||||
# Import billing controllers
|
||||
from .billing import billing
|
||||
|
||||
# Import datasets controllers
|
||||
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing
|
||||
from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing, website
|
||||
|
||||
# Import explore controllers
|
||||
from .explore import (
|
||||
|
67
api/controllers/console/auth/data_source_bearer_auth.py
Normal file
67
api/controllers/console/auth/data_source_bearer_auth.py
Normal file
@ -0,0 +1,67 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.auth.error import ApiKeyAuthFailedError
|
||||
from libs.login import login_required
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
from ..setup import setup_required
|
||||
from ..wraps import account_initialization_required
|
||||
|
||||
|
||||
class ApiKeyAuthDataSource(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id)
|
||||
if data_source_api_key_bindings:
|
||||
return {
|
||||
'settings': [data_source_api_key_binding.to_dict() for data_source_api_key_binding in
|
||||
data_source_api_key_bindings]}
|
||||
return {'settings': []}
|
||||
|
||||
|
||||
class ApiKeyAuthDataSourceBinding(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('category', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||
try:
|
||||
ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args)
|
||||
except Exception as e:
|
||||
raise ApiKeyAuthFailedError(str(e))
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class ApiKeyAuthDataSourceBindingDelete(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, binding_id):
|
||||
# The role of the current user in the table must be admin or owner
|
||||
if not current_user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source')
|
||||
api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding')
|
||||
api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/<uuid:binding_id>')
|
7
api/controllers/console/auth/error.py
Normal file
7
api/controllers/console/auth/error.py
Normal file
@ -0,0 +1,7 @@
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class ApiKeyAuthFailedError(BaseHTTPException):
|
||||
error_code = 'auth_failed'
|
||||
description = "{message}"
|
||||
code = 500
|
@ -16,7 +16,7 @@ from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import Document
|
||||
from models.source import DataSourceBinding
|
||||
from models.source import DataSourceOauthBinding
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
@ -29,9 +29,9 @@ class DataSourceApi(Resource):
|
||||
@marshal_with(integrate_list_fields)
|
||||
def get(self):
|
||||
# get workspace data source integrates
|
||||
data_source_integrates = db.session.query(DataSourceBinding).filter(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.disabled == False
|
||||
data_source_integrates = db.session.query(DataSourceOauthBinding).filter(
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.disabled == False
|
||||
).all()
|
||||
|
||||
base_url = request.url_root.rstrip('/')
|
||||
@ -71,7 +71,7 @@ class DataSourceApi(Resource):
|
||||
def patch(self, binding_id, action):
|
||||
binding_id = str(binding_id)
|
||||
action = str(action)
|
||||
data_source_binding = DataSourceBinding.query.filter_by(
|
||||
data_source_binding = DataSourceOauthBinding.query.filter_by(
|
||||
id=binding_id
|
||||
).first()
|
||||
if data_source_binding is None:
|
||||
@ -124,7 +124,7 @@ class DataSourceNotionListApi(Resource):
|
||||
data_source_info = json.loads(document.data_source_info)
|
||||
exist_page_ids.append(data_source_info['notion_page_id'])
|
||||
# get all authorized pages
|
||||
data_source_bindings = DataSourceBinding.query.filter_by(
|
||||
data_source_bindings = DataSourceOauthBinding.query.filter_by(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider='notion',
|
||||
disabled=False
|
||||
@ -163,12 +163,12 @@ class DataSourceNotionApi(Resource):
|
||||
def get(self, workspace_id, page_id, page_type):
|
||||
workspace_id = str(workspace_id)
|
||||
page_id = str(page_id)
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == 'notion',
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
|
@ -315,6 +315,22 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif args['info_list']['data_source_type'] == 'website_crawl':
|
||||
website_info_list = args['info_list']['website_info_list']
|
||||
for url in website_info_list['urls']:
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="website_crawl",
|
||||
website_info={
|
||||
"provider": website_info_list['provider'],
|
||||
"job_id": website_info_list['job_id'],
|
||||
"url": url,
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": 'crawl',
|
||||
"only_main_content": website_info_list['only_main_content']
|
||||
},
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
indexing_runner = IndexingRunner()
|
||||
@ -519,6 +535,7 @@ class DatasetRetrievalSettingMockApi(Resource):
|
||||
raise ValueError(f"Unsupported vector db type {vector_type}.")
|
||||
|
||||
|
||||
|
||||
class DatasetErrorDocs(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
|
@ -465,6 +465,20 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
||||
document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
elif document.data_source_type == 'website_crawl':
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="website_crawl",
|
||||
website_info={
|
||||
"provider": data_source_info['provider'],
|
||||
"job_id": data_source_info['job_id'],
|
||||
"url": data_source_info['url'],
|
||||
"tenant_id": current_user.current_tenant_id,
|
||||
"mode": data_source_info['mode'],
|
||||
"only_main_content": data_source_info['only_main_content']
|
||||
},
|
||||
document_model=document.doc_form
|
||||
)
|
||||
extract_settings.append(extract_setting)
|
||||
|
||||
else:
|
||||
raise ValueError('Data source type not support')
|
||||
@ -952,6 +966,33 @@ class DocumentRenameApi(DocumentResource):
|
||||
return document
|
||||
|
||||
|
||||
class WebsiteDocumentSyncApi(DocumentResource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
"""sync website document."""
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
if document.tenant_id != current_user.current_tenant_id:
|
||||
raise Forbidden('No permission.')
|
||||
if document.data_source_type != 'website_crawl':
|
||||
raise ValueError('Document is not a website document.')
|
||||
# 403 if document is archived
|
||||
if DocumentService.check_archived(document):
|
||||
raise ArchivedDocumentImmutableError()
|
||||
# sync document
|
||||
DocumentService.sync_website_document(dataset_id, document)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
api.add_resource(GetProcessRuleApi, '/datasets/process-rule')
|
||||
api.add_resource(DatasetDocumentListApi,
|
||||
'/datasets/<uuid:dataset_id>/documents')
|
||||
@ -980,3 +1021,5 @@ api.add_resource(DocumentRecoverApi, '/datasets/<uuid:dataset_id>/documents/<uui
|
||||
api.add_resource(DocumentRetryApi, '/datasets/<uuid:dataset_id>/retry')
|
||||
api.add_resource(DocumentRenameApi,
|
||||
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/rename')
|
||||
|
||||
api.add_resource(WebsiteDocumentSyncApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/website-sync')
|
||||
|
@ -73,6 +73,12 @@ class InvalidMetadataError(BaseHTTPException):
|
||||
code = 400
|
||||
|
||||
|
||||
class WebsiteCrawlError(BaseHTTPException):
|
||||
error_code = 'crawl_failed'
|
||||
description = "{message}"
|
||||
code = 500
|
||||
|
||||
|
||||
class DatasetInUseError(BaseHTTPException):
|
||||
error_code = 'dataset_in_use'
|
||||
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
|
||||
|
49
api/controllers/console/datasets/website.py
Normal file
49
api/controllers/console/datasets/website.py
Normal file
@ -0,0 +1,49 @@
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import WebsiteCrawlError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from libs.login import login_required
|
||||
from services.website_service import WebsiteService
|
||||
|
||||
|
||||
class WebsiteCrawlApi(Resource):
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('provider', type=str, choices=['firecrawl'],
|
||||
required=True, nullable=True, location='json')
|
||||
parser.add_argument('url', type=str, required=True, nullable=True, location='json')
|
||||
parser.add_argument('options', type=dict, required=True, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
WebsiteService.document_create_args_validate(args)
|
||||
# crawl url
|
||||
try:
|
||||
result = WebsiteService.crawl_url(args)
|
||||
except Exception as e:
|
||||
raise WebsiteCrawlError(str(e))
|
||||
return result, 200
|
||||
|
||||
|
||||
class WebsiteCrawlStatusApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, job_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args')
|
||||
args = parser.parse_args()
|
||||
# get crawl status
|
||||
try:
|
||||
result = WebsiteService.get_crawl_status(job_id, args['provider'])
|
||||
except Exception as e:
|
||||
raise WebsiteCrawlError(str(e))
|
||||
return result, 200
|
||||
|
||||
|
||||
api.add_resource(WebsiteCrawlApi, '/website/crawl')
|
||||
api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/<string:job_id>')
|
@ -339,7 +339,7 @@ class IndexingRunner:
|
||||
def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
|
||||
-> list[Document]:
|
||||
# load file
|
||||
if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
|
||||
if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]:
|
||||
return []
|
||||
|
||||
data_source_info = dataset_document.data_source_info_dict
|
||||
@ -375,6 +375,23 @@ class IndexingRunner:
|
||||
document_model=dataset_document.doc_form
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
|
||||
elif dataset_document.data_source_type == 'website_crawl':
|
||||
if (not data_source_info or 'provider' not in data_source_info
|
||||
or 'url' not in data_source_info or 'job_id' not in data_source_info):
|
||||
raise ValueError("no website import info found")
|
||||
extract_setting = ExtractSetting(
|
||||
datasource_type="website_crawl",
|
||||
website_info={
|
||||
"provider": data_source_info['provider'],
|
||||
"job_id": data_source_info['job_id'],
|
||||
"tenant_id": dataset_document.tenant_id,
|
||||
"url": data_source_info['url'],
|
||||
"mode": data_source_info['mode'],
|
||||
"only_main_content": data_source_info['only_main_content']
|
||||
},
|
||||
document_model=dataset_document.doc_form
|
||||
)
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
|
||||
# update document status to splitting
|
||||
self._update_document_index_status(
|
||||
document_id=dataset_document.id,
|
||||
|
@ -124,7 +124,7 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
default=float(credentials.get('presence_penalty', 0)),
|
||||
min=-2,
|
||||
max=2
|
||||
)
|
||||
),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=Decimal(cred_with_endpoint.get('input_price', 0)),
|
||||
|
@ -4,3 +4,4 @@ from enum import Enum
|
||||
class DatasourceType(Enum):
|
||||
FILE = "upload_file"
|
||||
NOTION = "notion_import"
|
||||
WEBSITE = "website_crawl"
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from models.dataset import Document
|
||||
@ -19,14 +21,33 @@ class NotionInfo(BaseModel):
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class WebsiteInfo(BaseModel):
|
||||
"""
|
||||
website import info.
|
||||
"""
|
||||
provider: str
|
||||
job_id: str
|
||||
url: str
|
||||
mode: str
|
||||
tenant_id: str
|
||||
only_main_content: bool = False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
super().__init__(**data)
|
||||
|
||||
|
||||
class ExtractSetting(BaseModel):
|
||||
"""
|
||||
Model class for provider response.
|
||||
"""
|
||||
datasource_type: str
|
||||
upload_file: UploadFile = None
|
||||
notion_info: NotionInfo = None
|
||||
document_model: str = None
|
||||
upload_file: Optional[UploadFile]
|
||||
notion_info: Optional[NotionInfo]
|
||||
website_info: Optional[WebsiteInfo]
|
||||
document_model: Optional[str]
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data) -> None:
|
||||
|
@ -11,6 +11,7 @@ from core.rag.extractor.csv_extractor import CSVExtractor
|
||||
from core.rag.extractor.entity.datasource_type import DatasourceType
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.extractor.excel_extractor import ExcelExtractor
|
||||
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
|
||||
from core.rag.extractor.html_extractor import HtmlExtractor
|
||||
from core.rag.extractor.markdown_extractor import MarkdownExtractor
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
@ -154,5 +155,17 @@ class ExtractProcessor:
|
||||
tenant_id=extract_setting.notion_info.tenant_id,
|
||||
)
|
||||
return extractor.extract()
|
||||
elif extract_setting.datasource_type == DatasourceType.WEBSITE.value:
|
||||
if extract_setting.website_info.provider == 'firecrawl':
|
||||
extractor = FirecrawlWebExtractor(
|
||||
url=extract_setting.website_info.url,
|
||||
job_id=extract_setting.website_info.job_id,
|
||||
tenant_id=extract_setting.website_info.tenant_id,
|
||||
mode=extract_setting.website_info.mode,
|
||||
only_main_content=extract_setting.website_info.only_main_content
|
||||
)
|
||||
return extractor.extract()
|
||||
else:
|
||||
raise ValueError(f"Unsupported website provider: {extract_setting.website_info.provider}")
|
||||
else:
|
||||
raise ValueError(f"Unsupported datasource type: {extract_setting.datasource_type}")
|
||||
|
132
api/core/rag/extractor/firecrawl/firecrawl_app.py
Normal file
132
api/core/rag/extractor/firecrawl/firecrawl_app.py
Normal file
@ -0,0 +1,132 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
|
||||
class FirecrawlApp:
|
||||
def __init__(self, api_key=None, base_url=None):
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url or 'https://api.firecrawl.dev'
|
||||
if self.api_key is None and self.base_url == 'https://api.firecrawl.dev':
|
||||
raise ValueError('No API key provided')
|
||||
|
||||
def scrape_url(self, url, params=None) -> dict:
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
json_data = {'url': url}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = requests.post(
|
||||
f'{self.base_url}/v0/scrape',
|
||||
headers=headers,
|
||||
json=json_data
|
||||
)
|
||||
if response.status_code == 200:
|
||||
response = response.json()
|
||||
if response['success'] == True:
|
||||
data = response['data']
|
||||
return {
|
||||
'title': data.get('metadata').get('title'),
|
||||
'description': data.get('metadata').get('description'),
|
||||
'source_url': data.get('metadata').get('sourceURL'),
|
||||
'markdown': data.get('markdown')
|
||||
}
|
||||
else:
|
||||
raise Exception(f'Failed to scrape URL. Error: {response["error"]}')
|
||||
|
||||
elif response.status_code in [402, 409, 500]:
|
||||
error_message = response.json().get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}. Error: {error_message}')
|
||||
else:
|
||||
raise Exception(f'Failed to scrape URL. Status code: {response.status_code}')
|
||||
|
||||
def crawl_url(self, url, params=None) -> str:
|
||||
start_time = time.time()
|
||||
headers = self._prepare_headers()
|
||||
json_data = {'url': url}
|
||||
if params:
|
||||
json_data.update(params)
|
||||
response = self._post_request(f'{self.base_url}/v0/crawl', json_data, headers)
|
||||
if response.status_code == 200:
|
||||
job_id = response.json().get('jobId')
|
||||
return job_id
|
||||
else:
|
||||
self._handle_error(response, 'start crawl job')
|
||||
|
||||
def check_crawl_status(self, job_id) -> dict:
|
||||
headers = self._prepare_headers()
|
||||
response = self._get_request(f'{self.base_url}/v0/crawl/status/{job_id}', headers)
|
||||
if response.status_code == 200:
|
||||
crawl_status_response = response.json()
|
||||
if crawl_status_response.get('status') == 'completed':
|
||||
total = crawl_status_response.get('total', 0)
|
||||
if total == 0:
|
||||
raise Exception('Failed to check crawl status. Error: No page found')
|
||||
data = crawl_status_response.get('data', [])
|
||||
url_data_list = []
|
||||
for item in data:
|
||||
if isinstance(item, dict) and 'metadata' in item and 'markdown' in item:
|
||||
url_data = {
|
||||
'title': item.get('metadata').get('title'),
|
||||
'description': item.get('metadata').get('description'),
|
||||
'source_url': item.get('metadata').get('sourceURL'),
|
||||
'markdown': item.get('markdown')
|
||||
}
|
||||
url_data_list.append(url_data)
|
||||
if url_data_list:
|
||||
file_key = 'website_files/' + job_id + '.txt'
|
||||
if storage.exists(file_key):
|
||||
storage.delete(file_key)
|
||||
storage.save(file_key, json.dumps(url_data_list).encode('utf-8'))
|
||||
return {
|
||||
'status': 'completed',
|
||||
'total': crawl_status_response.get('total'),
|
||||
'current': crawl_status_response.get('current'),
|
||||
'data': url_data_list
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
'status': crawl_status_response.get('status'),
|
||||
'total': crawl_status_response.get('total'),
|
||||
'current': crawl_status_response.get('current'),
|
||||
'data': []
|
||||
}
|
||||
|
||||
else:
|
||||
self._handle_error(response, 'check crawl status')
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
|
||||
def _post_request(self, url, data, headers, retries=3, backoff_factor=0.5):
|
||||
for attempt in range(retries):
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
if response.status_code == 502:
|
||||
time.sleep(backoff_factor * (2 ** attempt))
|
||||
else:
|
||||
return response
|
||||
return response
|
||||
|
||||
def _get_request(self, url, headers, retries=3, backoff_factor=0.5):
|
||||
for attempt in range(retries):
|
||||
response = requests.get(url, headers=headers)
|
||||
if response.status_code == 502:
|
||||
time.sleep(backoff_factor * (2 ** attempt))
|
||||
else:
|
||||
return response
|
||||
return response
|
||||
|
||||
def _handle_error(self, response, action):
|
||||
error_message = response.json().get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to {action}. Status code: {response.status_code}. Error: {error_message}')
|
||||
|
||||
|
60
api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py
Normal file
60
api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py
Normal file
@ -0,0 +1,60 @@
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from services.website_service import WebsiteService
|
||||
|
||||
|
||||
class FirecrawlWebExtractor(BaseExtractor):
|
||||
"""
|
||||
Crawl and scrape websites and return content in clean llm-ready markdown.
|
||||
|
||||
|
||||
Args:
|
||||
url: The URL to scrape.
|
||||
api_key: The API key for Firecrawl.
|
||||
base_url: The base URL for the Firecrawl API. Defaults to 'https://api.firecrawl.dev'.
|
||||
mode: The mode of operation. Defaults to 'scrape'. Options are 'crawl', 'scrape' and 'crawl_return_urls'.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
job_id: str,
|
||||
tenant_id: str,
|
||||
mode: str = 'crawl',
|
||||
only_main_content: bool = False
|
||||
):
|
||||
"""Initialize with url, api_key, base_url and mode."""
|
||||
self._url = url
|
||||
self.job_id = job_id
|
||||
self.tenant_id = tenant_id
|
||||
self.mode = mode
|
||||
self.only_main_content = only_main_content
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Extract content from the URL."""
|
||||
documents = []
|
||||
if self.mode == 'crawl':
|
||||
crawl_data = WebsiteService.get_crawl_url_data(self.job_id, 'firecrawl', self._url, self.tenant_id)
|
||||
if crawl_data is None:
|
||||
return []
|
||||
document = Document(page_content=crawl_data.get('markdown', ''),
|
||||
metadata={
|
||||
'source_url': crawl_data.get('source_url'),
|
||||
'description': crawl_data.get('description'),
|
||||
'title': crawl_data.get('title')
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
elif self.mode == 'scrape':
|
||||
scrape_data = WebsiteService.get_scrape_url_data('firecrawl', self._url, self.tenant_id,
|
||||
self.only_main_content)
|
||||
|
||||
document = Document(page_content=scrape_data.get('markdown', ''),
|
||||
metadata={
|
||||
'source_url': scrape_data.get('source_url'),
|
||||
'description': scrape_data.get('description'),
|
||||
'title': scrape_data.get('title')
|
||||
}
|
||||
)
|
||||
documents.append(document)
|
||||
return documents
|
@ -9,7 +9,7 @@ from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Document as DocumentModel
|
||||
from models.source import DataSourceBinding
|
||||
from models.source import DataSourceOauthBinding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -345,12 +345,12 @@ class NotionExtractor(BaseExtractor):
|
||||
|
||||
@classmethod
|
||||
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
|
||||
DataSourceOauthBinding.tenant_id == tenant_id,
|
||||
DataSourceOauthBinding.provider == 'notion',
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
|
||||
|
64
api/libs/bearer_data_source.py
Normal file
64
api/libs/bearer_data_source.py
Normal file
@ -0,0 +1,64 @@
|
||||
# [REVIEW] Implement if Needed? Do we need a new type of data source
|
||||
from abc import abstractmethod
|
||||
|
||||
import requests
|
||||
from api.models.source import DataSourceBearerBinding
|
||||
from flask_login import current_user
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
|
||||
class BearerDataSource:
|
||||
def __init__(self, api_key: str, api_base_url: str):
|
||||
self.api_key = api_key
|
||||
self.api_base_url = api_base_url
|
||||
|
||||
@abstractmethod
|
||||
def validate_bearer_data_source(self):
|
||||
"""
|
||||
Validate the data source
|
||||
"""
|
||||
|
||||
|
||||
class FireCrawlDataSource(BearerDataSource):
|
||||
def validate_bearer_data_source(self):
|
||||
TEST_CRAWL_SITE_URL = "https://www.google.com"
|
||||
FIRECRAWL_API_VERSION = "v0"
|
||||
|
||||
test_api_endpoint = self.api_base_url.rstrip('/') + f"/{FIRECRAWL_API_VERSION}/scrape"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {
|
||||
"url": TEST_CRAWL_SITE_URL,
|
||||
}
|
||||
|
||||
response = requests.get(test_api_endpoint, headers=headers, json=data)
|
||||
|
||||
return response.json().get("status") == "success"
|
||||
|
||||
def save_credentials(self):
|
||||
# save data source binding
|
||||
data_source_binding = DataSourceBearerBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBearerBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBearerBinding.provider == 'firecrawl',
|
||||
DataSourceBearerBinding.endpoint_url == self.api_base_url,
|
||||
DataSourceBearerBinding.bearer_key == self.api_key
|
||||
)
|
||||
).first()
|
||||
if data_source_binding:
|
||||
data_source_binding.disabled = False
|
||||
db.session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceBearerBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider='firecrawl',
|
||||
endpoint_url=self.api_base_url,
|
||||
bearer_key=self.api_key
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
@ -4,7 +4,7 @@ import requests
|
||||
from flask_login import current_user
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.source import DataSourceBinding
|
||||
from models.source import DataSourceOauthBinding
|
||||
|
||||
|
||||
class OAuthDataSource:
|
||||
@ -63,11 +63,11 @@ class NotionOAuth(OAuthDataSource):
|
||||
'total': len(pages)
|
||||
}
|
||||
# save data source binding
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.access_token == access_token
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == 'notion',
|
||||
DataSourceOauthBinding.access_token == access_token
|
||||
)
|
||||
).first()
|
||||
if data_source_binding:
|
||||
@ -75,7 +75,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
data_source_binding.disabled = False
|
||||
db.session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceBinding(
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=source_info,
|
||||
@ -98,11 +98,11 @@ class NotionOAuth(OAuthDataSource):
|
||||
'total': len(pages)
|
||||
}
|
||||
# save data source binding
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.access_token == access_token
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == 'notion',
|
||||
DataSourceOauthBinding.access_token == access_token
|
||||
)
|
||||
).first()
|
||||
if data_source_binding:
|
||||
@ -110,7 +110,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
data_source_binding.disabled = False
|
||||
db.session.commit()
|
||||
else:
|
||||
new_data_source_binding = DataSourceBinding(
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=source_info,
|
||||
@ -121,12 +121,12 @@ class NotionOAuth(OAuthDataSource):
|
||||
|
||||
def sync_data_source(self, binding_id: str):
|
||||
# save data source binding
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.id == binding_id,
|
||||
DataSourceBinding.disabled == False
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == 'notion',
|
||||
DataSourceOauthBinding.id == binding_id,
|
||||
DataSourceOauthBinding.disabled == False
|
||||
)
|
||||
).first()
|
||||
if data_source_binding:
|
||||
|
@ -0,0 +1,67 @@
|
||||
"""add-api-key-auth-binding
|
||||
|
||||
Revision ID: 7b45942e39bb
|
||||
Revises: 47cc7df8c4f3
|
||||
Create Date: 2024-05-14 07:31:29.702766
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '7b45942e39bb'
|
||||
down_revision = '4e99a8df00ff'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('data_source_api_key_auth_bindings',
|
||||
sa.Column('id', models.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.StringUUID(), nullable=False),
|
||||
sa.Column('category', sa.String(length=255), nullable=False),
|
||||
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||
sa.Column('credentials', sa.Text(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.Column('disabled', sa.Boolean(), server_default=sa.text('false'), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey')
|
||||
)
|
||||
with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op:
|
||||
batch_op.create_index('data_source_api_key_auth_binding_provider_idx', ['provider'], unique=False)
|
||||
batch_op.create_index('data_source_api_key_auth_binding_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('data_source_bindings', schema=None) as batch_op:
|
||||
batch_op.drop_index('source_binding_tenant_id_idx')
|
||||
batch_op.drop_index('source_info_idx')
|
||||
|
||||
op.rename_table('data_source_bindings', 'data_source_oauth_bindings')
|
||||
|
||||
with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op:
|
||||
batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
batch_op.create_index('source_info_idx', ['source_info'], unique=False, postgresql_using='gin')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
with op.batch_alter_table('data_source_oauth_bindings', schema=None) as batch_op:
|
||||
batch_op.drop_index('source_info_idx', postgresql_using='gin')
|
||||
batch_op.drop_index('source_binding_tenant_id_idx')
|
||||
|
||||
op.rename_table('data_source_oauth_bindings', 'data_source_bindings')
|
||||
|
||||
with op.batch_alter_table('data_source_bindings', schema=None) as batch_op:
|
||||
batch_op.create_index('source_info_idx', ['source_info'], unique=False)
|
||||
batch_op.create_index('source_binding_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('data_source_api_key_auth_bindings', schema=None) as batch_op:
|
||||
batch_op.drop_index('data_source_api_key_auth_binding_tenant_id_idx')
|
||||
batch_op.drop_index('data_source_api_key_auth_binding_provider_idx')
|
||||
|
||||
op.drop_table('data_source_api_key_auth_bindings')
|
||||
# ### end Alembic commands ###
|
@ -270,7 +270,7 @@ class Document(db.Model):
|
||||
255), nullable=False, server_default=db.text("'text_model'::character varying"))
|
||||
doc_language = db.Column(db.String(255), nullable=True)
|
||||
|
||||
DATA_SOURCES = ['upload_file', 'notion_import']
|
||||
DATA_SOURCES = ['upload_file', 'notion_import', 'website_crawl']
|
||||
|
||||
@property
|
||||
def display_status(self):
|
||||
@ -322,7 +322,7 @@ class Document(db.Model):
|
||||
'created_at': file_detail.created_at.timestamp()
|
||||
}
|
||||
}
|
||||
elif self.data_source_type == 'notion_import':
|
||||
elif self.data_source_type == 'notion_import' or self.data_source_type == 'website_crawl':
|
||||
return json.loads(self.data_source_info)
|
||||
return {}
|
||||
|
||||
|
@ -1,11 +1,13 @@
|
||||
import json
|
||||
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import StringUUID
|
||||
|
||||
|
||||
class DataSourceBinding(db.Model):
|
||||
__tablename__ = 'data_source_bindings'
|
||||
class DataSourceOauthBinding(db.Model):
|
||||
__tablename__ = 'data_source_oauth_bindings'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='source_binding_pkey'),
|
||||
db.Index('source_binding_tenant_id_idx', 'tenant_id'),
|
||||
@ -20,3 +22,33 @@ class DataSourceBinding(db.Model):
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
|
||||
|
||||
|
||||
class DataSourceApiKeyAuthBinding(db.Model):
|
||||
__tablename__ = 'data_source_api_key_auth_bindings'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='data_source_api_key_auth_binding_pkey'),
|
||||
db.Index('data_source_api_key_auth_binding_tenant_id_idx', 'tenant_id'),
|
||||
db.Index('data_source_api_key_auth_binding_provider_idx', 'provider'),
|
||||
)
|
||||
|
||||
id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
category = db.Column(db.String(255), nullable=False)
|
||||
provider = db.Column(db.String(255), nullable=False)
|
||||
credentials = db.Column(db.Text, nullable=True) # JSON
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
'id': self.id,
|
||||
'tenant_id': self.tenant_id,
|
||||
'category': self.category,
|
||||
'provider': self.provider,
|
||||
'credentials': json.loads(self.credentials),
|
||||
'created_at': self.created_at.timestamp(),
|
||||
'updated_at': self.updated_at.timestamp(),
|
||||
'disabled': self.disabled
|
||||
}
|
||||
|
@ -78,6 +78,9 @@ CODE_MAX_STRING_LENGTH = "80000"
|
||||
CODE_EXECUTION_ENDPOINT="http://127.0.0.1:8194"
|
||||
CODE_EXECUTION_API_KEY="dify-sandbox"
|
||||
|
||||
FIRECRAWL_API_KEY = "fc-"
|
||||
|
||||
|
||||
|
||||
[tool.poetry]
|
||||
name = "dify-api"
|
||||
|
0
api/services/auth/__init__.py
Normal file
0
api/services/auth/__init__.py
Normal file
10
api/services/auth/api_key_auth_base.py
Normal file
10
api/services/auth/api_key_auth_base.py
Normal file
@ -0,0 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ApiKeyAuthBase(ABC):
|
||||
def __init__(self, credentials: dict):
|
||||
self.credentials = credentials
|
||||
|
||||
@abstractmethod
|
||||
def validate_credentials(self):
|
||||
raise NotImplementedError
|
14
api/services/auth/api_key_auth_factory.py
Normal file
14
api/services/auth/api_key_auth_factory.py
Normal file
@ -0,0 +1,14 @@
|
||||
|
||||
from services.auth.firecrawl import FirecrawlAuth
|
||||
|
||||
|
||||
class ApiKeyAuthFactory:
|
||||
|
||||
def __init__(self, provider: str, credentials: dict):
|
||||
if provider == 'firecrawl':
|
||||
self.auth = FirecrawlAuth(credentials)
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
|
||||
def validate_credentials(self):
|
||||
return self.auth.validate_credentials()
|
70
api/services/auth/api_key_auth_service.py
Normal file
70
api/services/auth/api_key_auth_service.py
Normal file
@ -0,0 +1,70 @@
|
||||
import json
|
||||
|
||||
from core.helper import encrypter
|
||||
from extensions.ext_database import db
|
||||
from models.source import DataSourceApiKeyAuthBinding
|
||||
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
|
||||
|
||||
|
||||
class ApiKeyAuthService:
|
||||
|
||||
@staticmethod
|
||||
def get_provider_auth_list(tenant_id: str) -> list:
|
||||
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
).all()
|
||||
return data_source_api_key_bindings
|
||||
|
||||
@staticmethod
|
||||
def create_provider_auth(tenant_id: str, args: dict):
|
||||
auth_result = ApiKeyAuthFactory(args['provider'], args['credentials']).validate_credentials()
|
||||
if auth_result:
|
||||
# Encrypt the api key
|
||||
api_key = encrypter.encrypt_token(tenant_id, args['credentials']['config']['api_key'])
|
||||
args['credentials']['config']['api_key'] = api_key
|
||||
|
||||
data_source_api_key_binding = DataSourceApiKeyAuthBinding()
|
||||
data_source_api_key_binding.tenant_id = tenant_id
|
||||
data_source_api_key_binding.category = args['category']
|
||||
data_source_api_key_binding.provider = args['provider']
|
||||
data_source_api_key_binding.credentials = json.dumps(args['credentials'], ensure_ascii=False)
|
||||
db.session.add(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_auth_credentials(tenant_id: str, category: str, provider: str):
|
||||
data_source_api_key_bindings = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.category == category,
|
||||
DataSourceApiKeyAuthBinding.provider == provider,
|
||||
DataSourceApiKeyAuthBinding.disabled.is_(False)
|
||||
).first()
|
||||
if not data_source_api_key_bindings:
|
||||
return None
|
||||
credentials = json.loads(data_source_api_key_bindings.credentials)
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def delete_provider_auth(tenant_id: str, binding_id: str):
|
||||
data_source_api_key_binding = db.session.query(DataSourceApiKeyAuthBinding).filter(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.id == binding_id
|
||||
).first()
|
||||
if data_source_api_key_binding:
|
||||
db.session.delete(data_source_api_key_binding)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def validate_api_key_auth_args(cls, args):
|
||||
if 'category' not in args or not args['category']:
|
||||
raise ValueError('category is required')
|
||||
if 'provider' not in args or not args['provider']:
|
||||
raise ValueError('provider is required')
|
||||
if 'credentials' not in args or not args['credentials']:
|
||||
raise ValueError('credentials is required')
|
||||
if not isinstance(args['credentials'], dict):
|
||||
raise ValueError('credentials must be a dictionary')
|
||||
if 'auth_type' not in args['credentials'] or not args['credentials']['auth_type']:
|
||||
raise ValueError('auth_type is required')
|
||||
|
56
api/services/auth/firecrawl.py
Normal file
56
api/services/auth/firecrawl.py
Normal file
@ -0,0 +1,56 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||
|
||||
|
||||
class FirecrawlAuth(ApiKeyAuthBase):
|
||||
def __init__(self, credentials: dict):
|
||||
super().__init__(credentials)
|
||||
auth_type = credentials.get('auth_type')
|
||||
if auth_type != 'bearer':
|
||||
raise ValueError('Invalid auth type, Firecrawl auth type must be Bearer')
|
||||
self.api_key = credentials.get('config').get('api_key', None)
|
||||
self.base_url = credentials.get('config').get('base_url', 'https://api.firecrawl.dev')
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError('No API key provided')
|
||||
|
||||
def validate_credentials(self):
|
||||
headers = self._prepare_headers()
|
||||
options = {
|
||||
'url': 'https://example.com',
|
||||
'crawlerOptions': {
|
||||
'excludes': [],
|
||||
'includes': [],
|
||||
'limit': 1
|
||||
},
|
||||
'pageOptions': {
|
||||
'onlyMainContent': True
|
||||
}
|
||||
}
|
||||
response = self._post_request(f'{self.base_url}/v0/crawl', options, headers)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
self._handle_error(response)
|
||||
|
||||
def _prepare_headers(self):
|
||||
return {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {self.api_key}'
|
||||
}
|
||||
|
||||
def _post_request(self, url, data, headers):
|
||||
return requests.post(url, headers=headers, json=data)
|
||||
|
||||
def _handle_error(self, response):
|
||||
if response.status_code in [402, 409, 500]:
|
||||
error_message = response.json().get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
|
||||
else:
|
||||
if response.text:
|
||||
error_message = json.loads(response.text).get('error', 'Unknown error occurred')
|
||||
raise Exception(f'Failed to authorize. Status code: {response.status_code}. Error: {error_message}')
|
||||
raise Exception(f'Unexpected error occurred while trying to authorize. Status code: {response.status_code}')
|
@ -31,7 +31,7 @@ from models.dataset import (
|
||||
DocumentSegment,
|
||||
)
|
||||
from models.model import UploadFile
|
||||
from models.source import DataSourceBinding
|
||||
from models.source import DataSourceOauthBinding
|
||||
from services.errors.account import NoPermissionError
|
||||
from services.errors.dataset import DatasetInUseError, DatasetNameDuplicateError
|
||||
from services.errors.document import DocumentIndexingError
|
||||
@ -48,6 +48,7 @@ from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
|
||||
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
||||
from tasks.retry_document_indexing_task import retry_document_indexing_task
|
||||
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
|
||||
|
||||
|
||||
class DatasetService:
|
||||
@ -508,18 +509,40 @@ class DocumentService:
|
||||
@staticmethod
|
||||
def retry_document(dataset_id: str, documents: list[Document]):
|
||||
for document in documents:
|
||||
# add retry flag
|
||||
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
|
||||
cache_result = redis_client.get(retry_indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise ValueError("Document is being retried, please try again later")
|
||||
# retry document indexing
|
||||
document.indexing_status = 'waiting'
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
# add retry flag
|
||||
retry_indexing_cache_key = 'document_{}_is_retried'.format(document.id)
|
||||
|
||||
redis_client.setex(retry_indexing_cache_key, 600, 1)
|
||||
# trigger async task
|
||||
document_ids = [document.id for document in documents]
|
||||
retry_document_indexing_task.delay(dataset_id, document_ids)
|
||||
|
||||
@staticmethod
|
||||
def sync_website_document(dataset_id: str, document: Document):
|
||||
# add sync flag
|
||||
sync_indexing_cache_key = 'document_{}_is_sync'.format(document.id)
|
||||
cache_result = redis_client.get(sync_indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise ValueError("Document is being synced, please try again later")
|
||||
# sync document indexing
|
||||
document.indexing_status = 'waiting'
|
||||
data_source_info = document.data_source_info_dict
|
||||
data_source_info['mode'] = 'scrape'
|
||||
document.data_source_info = json.dumps(data_source_info, ensure_ascii=False)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
redis_client.setex(sync_indexing_cache_key, 600, 1)
|
||||
|
||||
sync_website_document_indexing_task.delay(dataset_id, document.id)
|
||||
@staticmethod
|
||||
def get_documents_position(dataset_id):
|
||||
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
|
||||
if document:
|
||||
@ -545,6 +568,9 @@ class DocumentService:
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info['pages'])
|
||||
elif document_data["data_source"]["type"] == "website_crawl":
|
||||
website_info = document_data["data_source"]['info_list']['website_info_list']
|
||||
count = len(website_info['urls'])
|
||||
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
@ -683,12 +709,12 @@ class DocumentService:
|
||||
exist_document[data_source_info['notion_page_id']] = document.id
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == 'notion',
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
@ -717,6 +743,28 @@ class DocumentService:
|
||||
# delete not selected documents
|
||||
if len(exist_document) > 0:
|
||||
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||
elif document_data["data_source"]["type"] == "website_crawl":
|
||||
website_info = document_data["data_source"]['info_list']['website_info_list']
|
||||
urls = website_info['urls']
|
||||
for url in urls:
|
||||
data_source_info = {
|
||||
'url': url,
|
||||
'provider': website_info['provider'],
|
||||
'job_id': website_info['job_id'],
|
||||
'only_main_content': website_info.get('only_main_content', False),
|
||||
'mode': 'crawl',
|
||||
}
|
||||
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, url, batch)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
documents.append(document)
|
||||
position += 1
|
||||
db.session.commit()
|
||||
|
||||
# trigger async task
|
||||
@ -818,12 +866,12 @@ class DocumentService:
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info['workspace_id']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
||||
DataSourceOauthBinding.provider == 'notion',
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
@ -835,6 +883,17 @@ class DocumentService:
|
||||
"notion_page_icon": page['page_icon'],
|
||||
"type": page['type']
|
||||
}
|
||||
elif document_data["data_source"]["type"] == "website_crawl":
|
||||
website_info = document_data["data_source"]['info_list']['website_info_list']
|
||||
urls = website_info['urls']
|
||||
for url in urls:
|
||||
data_source_info = {
|
||||
'url': url,
|
||||
'provider': website_info['provider'],
|
||||
'job_id': website_info['job_id'],
|
||||
'only_main_content': website_info.get('only_main_content', False),
|
||||
'mode': 'crawl',
|
||||
}
|
||||
document.data_source_type = document_data["data_source"]["type"]
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.name = file_name
|
||||
@ -873,6 +932,9 @@ class DocumentService:
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info['pages'])
|
||||
elif document_data["data_source"]["type"] == "website_crawl":
|
||||
website_info = document_data["data_source"]['info_list']['website_info_list']
|
||||
count = len(website_info['urls'])
|
||||
batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
@ -973,6 +1035,10 @@ class DocumentService:
|
||||
if 'notion_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
|
||||
'notion_info_list']:
|
||||
raise ValueError("Notion source info is required")
|
||||
if args['data_source']['type'] == 'website_crawl':
|
||||
if 'website_info_list' not in args['data_source']['info_list'] or not args['data_source']['info_list'][
|
||||
'website_info_list']:
|
||||
raise ValueError("Website source info is required")
|
||||
|
||||
@classmethod
|
||||
def process_rule_args_validate(cls, args: dict):
|
||||
|
171
api/services/website_service.py
Normal file
171
api/services/website_service.py
Normal file
@ -0,0 +1,171 @@
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from services.auth.api_key_auth_service import ApiKeyAuthService
|
||||
|
||||
|
||||
class WebsiteService:
|
||||
|
||||
@classmethod
|
||||
def document_create_args_validate(cls, args: dict):
|
||||
if 'url' not in args or not args['url']:
|
||||
raise ValueError('url is required')
|
||||
if 'options' not in args or not args['options']:
|
||||
raise ValueError('options is required')
|
||||
if 'limit' not in args['options'] or not args['options']['limit']:
|
||||
raise ValueError('limit is required')
|
||||
|
||||
@classmethod
|
||||
def crawl_url(cls, args: dict) -> dict:
|
||||
provider = args.get('provider')
|
||||
url = args.get('url')
|
||||
options = args.get('options')
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
crawl_sub_pages = options.get('crawl_sub_pages', False)
|
||||
only_main_content = options.get('only_main_content', False)
|
||||
if not crawl_sub_pages:
|
||||
params = {
|
||||
'crawlerOptions': {
|
||||
"includes": [],
|
||||
"excludes": [],
|
||||
"generateImgAltText": True,
|
||||
"limit": 1,
|
||||
'returnOnlyUrls': False,
|
||||
'pageOptions': {
|
||||
'onlyMainContent': only_main_content,
|
||||
"includeHtml": False
|
||||
}
|
||||
}
|
||||
}
|
||||
else:
|
||||
includes = options.get('includes').split(',') if options.get('includes') else []
|
||||
excludes = options.get('excludes').split(',') if options.get('excludes') else []
|
||||
params = {
|
||||
'crawlerOptions': {
|
||||
"includes": includes if includes else [],
|
||||
"excludes": excludes if excludes else [],
|
||||
"generateImgAltText": True,
|
||||
"limit": options.get('limit', 1),
|
||||
'returnOnlyUrls': False,
|
||||
'pageOptions': {
|
||||
'onlyMainContent': only_main_content,
|
||||
"includeHtml": False
|
||||
}
|
||||
}
|
||||
}
|
||||
if options.get('max_depth'):
|
||||
params['crawlerOptions']['maxDepth'] = options.get('max_depth')
|
||||
job_id = firecrawl_app.crawl_url(url, params)
|
||||
website_crawl_time_cache_key = f'website_crawl_{job_id}'
|
||||
time = str(datetime.datetime.now().timestamp())
|
||||
redis_client.setex(website_crawl_time_cache_key, 3600, time)
|
||||
return {
|
||||
'status': 'active',
|
||||
'job_id': job_id
|
||||
}
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
|
||||
@classmethod
|
||||
def get_crawl_status(cls, job_id: str, provider: str) -> dict:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
crawl_status_data = {
|
||||
'status': result.get('status', 'active'),
|
||||
'job_id': job_id,
|
||||
'total': result.get('total', 0),
|
||||
'current': result.get('current', 0),
|
||||
'data': result.get('data', [])
|
||||
}
|
||||
if crawl_status_data['status'] == 'completed':
|
||||
website_crawl_time_cache_key = f'website_crawl_{job_id}'
|
||||
start_time = redis_client.get(website_crawl_time_cache_key)
|
||||
if start_time:
|
||||
end_time = datetime.datetime.now().timestamp()
|
||||
time_consuming = abs(end_time - float(start_time))
|
||||
crawl_status_data['time_consuming'] = f"{time_consuming:.2f}"
|
||||
redis_client.delete(website_crawl_time_cache_key)
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
return crawl_status_data
|
||||
|
||||
@classmethod
|
||||
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict | None:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
file_key = 'website_files/' + job_id + '.txt'
|
||||
if storage.exists(file_key):
|
||||
data = storage.load_once(file_key)
|
||||
if data:
|
||||
data = json.loads(data.decode('utf-8'))
|
||||
else:
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
result = firecrawl_app.check_crawl_status(job_id)
|
||||
if result.get('status') != 'completed':
|
||||
raise ValueError('Crawl job is not completed')
|
||||
data = result.get('data')
|
||||
if data:
|
||||
for item in data:
|
||||
if item.get('source_url') == url:
|
||||
return item
|
||||
return None
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
||||
|
||||
@classmethod
|
||||
def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict | None:
|
||||
credentials = ApiKeyAuthService.get_auth_credentials(tenant_id,
|
||||
'website',
|
||||
provider)
|
||||
if provider == 'firecrawl':
|
||||
# decrypt api_key
|
||||
api_key = encrypter.decrypt_token(
|
||||
tenant_id=tenant_id,
|
||||
token=credentials.get('config').get('api_key')
|
||||
)
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=credentials.get('config').get('base_url', None))
|
||||
params = {
|
||||
'pageOptions': {
|
||||
'onlyMainContent': only_main_content,
|
||||
"includeHtml": False
|
||||
}
|
||||
}
|
||||
result = firecrawl_app.scrape_url(url, params)
|
||||
return result
|
||||
else:
|
||||
raise ValueError('Invalid provider')
|
@ -11,7 +11,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.source import DataSourceBinding
|
||||
from models.source import DataSourceOauthBinding
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
@ -43,12 +43,12 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
page_id = data_source_info['notion_page_id']
|
||||
page_type = data_source_info['type']
|
||||
page_edited_time = data_source_info['last_edited_time']
|
||||
data_source_binding = DataSourceBinding.query.filter(
|
||||
data_source_binding = DataSourceOauthBinding.query.filter(
|
||||
db.and_(
|
||||
DataSourceBinding.tenant_id == document.tenant_id,
|
||||
DataSourceBinding.provider == 'notion',
|
||||
DataSourceBinding.disabled == False,
|
||||
DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
DataSourceOauthBinding.tenant_id == document.tenant_id,
|
||||
DataSourceOauthBinding.provider == 'notion',
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"'
|
||||
)
|
||||
).first()
|
||||
if not data_source_binding:
|
||||
|
90
api/tasks/sync_website_document_indexing_task.py
Normal file
90
api/tasks/sync_website_document_indexing_task.py
Normal file
@ -0,0 +1,90 @@
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def sync_website_document_indexing_task(dataset_id: str, document_id: str):
|
||||
"""
|
||||
Async process document
|
||||
:param dataset_id:
|
||||
:param document_id:
|
||||
|
||||
Usage: sunc_website_document_indexing_task.delay(dataset_id, document_id)
|
||||
"""
|
||||
start_at = time.perf_counter()
|
||||
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
|
||||
sync_indexing_cache_key = 'document_{}_is_sync'.format(document_id)
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError("Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription.")
|
||||
except Exception as e:
|
||||
document = db.session.query(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id
|
||||
).first()
|
||||
if document:
|
||||
document.indexing_status = 'error'
|
||||
document.error = str(e)
|
||||
document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
return
|
||||
|
||||
logging.info(click.style('Start sync website document: {}'.format(document_id), fg='green'))
|
||||
document = db.session.query(Document).filter(
|
||||
Document.id == document_id,
|
||||
Document.dataset_id == dataset_id
|
||||
).first()
|
||||
try:
|
||||
if document:
|
||||
# clean old data
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = 'parsing'
|
||||
document.processing_started_at = datetime.datetime.utcnow()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
except Exception as ex:
|
||||
document.indexing_status = 'error'
|
||||
document.error = str(ex)
|
||||
document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
logging.info(click.style(str(ex), fg='yellow'))
|
||||
redis_client.delete(sync_indexing_cache_key)
|
||||
pass
|
||||
end_at = time.perf_counter()
|
||||
logging.info(click.style('Sync document: {} latency: {}'.format(document_id, end_at - start_at), fg='green'))
|
@ -0,0 +1,33 @@
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
|
||||
from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor
|
||||
from core.rag.models.document import Document
|
||||
from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response
|
||||
|
||||
|
||||
def test_firecrawl_web_extractor_crawl_mode(mocker):
|
||||
url = "https://firecrawl.dev"
|
||||
api_key = os.getenv('FIRECRAWL_API_KEY') or 'fc-'
|
||||
base_url = 'https://api.firecrawl.dev'
|
||||
firecrawl_app = FirecrawlApp(api_key=api_key,
|
||||
base_url=base_url)
|
||||
params = {
|
||||
'crawlerOptions': {
|
||||
"includes": [],
|
||||
"excludes": [],
|
||||
"generateImgAltText": True,
|
||||
"maxDepth": 1,
|
||||
"limit": 1,
|
||||
'returnOnlyUrls': False,
|
||||
|
||||
}
|
||||
}
|
||||
mocked_firecrawl = {
|
||||
"jobId": "test",
|
||||
}
|
||||
mocker.patch("requests.post", return_value=_mock_response(mocked_firecrawl))
|
||||
job_id = firecrawl_app.crawl_url(url, params)
|
||||
print(job_id)
|
||||
assert isinstance(job_id, str)
|
0
api/tests/unit_tests/oss/__init__.py
Normal file
0
api/tests/unit_tests/oss/__init__.py
Normal file
0
api/tests/unit_tests/oss/local/__init__.py
Normal file
0
api/tests/unit_tests/oss/local/__init__.py
Normal file
Loading…
Reference in New Issue
Block a user