mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-02 11:18:19 +08:00
380 lines
15 KiB
Python
380 lines
15 KiB
Python
import json
|
|
import logging
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
import requests
|
|
from flask import current_app
|
|
from langchain.document_loaders.base import BaseLoader
|
|
from langchain.schema import Document
|
|
|
|
from extensions.ext_database import db
|
|
from models.dataset import Document as DocumentModel
|
|
from models.source import DataSourceBinding
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
BLOCK_CHILD_URL_TMPL = "https://api.notion.com/v1/blocks/{block_id}/children"
|
|
DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}/query"
|
|
SEARCH_URL = "https://api.notion.com/v1/search"
|
|
RETRIEVE_PAGE_URL_TMPL = "https://api.notion.com/v1/pages/{page_id}"
|
|
RETRIEVE_DATABASE_URL_TMPL = "https://api.notion.com/v1/databases/{database_id}"
|
|
HEADING_TYPE = ['heading_1', 'heading_2', 'heading_3']
|
|
|
|
|
|
class NotionLoader(BaseLoader):
|
|
def __init__(
|
|
self,
|
|
notion_access_token: str,
|
|
notion_workspace_id: str,
|
|
notion_obj_id: str,
|
|
notion_page_type: str,
|
|
document_model: Optional[DocumentModel] = None
|
|
):
|
|
self._document_model = document_model
|
|
self._notion_workspace_id = notion_workspace_id
|
|
self._notion_obj_id = notion_obj_id
|
|
self._notion_page_type = notion_page_type
|
|
self._notion_access_token = notion_access_token
|
|
|
|
if not self._notion_access_token:
|
|
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
|
|
if integration_token is None:
|
|
raise ValueError(
|
|
"Must specify `integration_token` or set environment "
|
|
"variable `NOTION_INTEGRATION_TOKEN`."
|
|
)
|
|
|
|
self._notion_access_token = integration_token
|
|
|
|
@classmethod
|
|
def from_document(cls, document_model: DocumentModel):
|
|
data_source_info = document_model.data_source_info_dict
|
|
if not data_source_info or 'notion_page_id' not in data_source_info \
|
|
or 'notion_workspace_id' not in data_source_info:
|
|
raise ValueError("no notion page found")
|
|
|
|
notion_workspace_id = data_source_info['notion_workspace_id']
|
|
notion_obj_id = data_source_info['notion_page_id']
|
|
notion_page_type = data_source_info['type']
|
|
notion_access_token = cls._get_access_token(document_model.tenant_id, notion_workspace_id)
|
|
|
|
return cls(
|
|
notion_access_token=notion_access_token,
|
|
notion_workspace_id=notion_workspace_id,
|
|
notion_obj_id=notion_obj_id,
|
|
notion_page_type=notion_page_type,
|
|
document_model=document_model
|
|
)
|
|
|
|
def load(self) -> List[Document]:
|
|
self.update_last_edited_time(
|
|
self._document_model
|
|
)
|
|
|
|
text_docs = self._load_data_as_documents(self._notion_obj_id, self._notion_page_type)
|
|
|
|
return text_docs
|
|
|
|
def _load_data_as_documents(
|
|
self, notion_obj_id: str, notion_page_type: str
|
|
) -> List[Document]:
|
|
docs = []
|
|
if notion_page_type == 'database':
|
|
# get all the pages in the database
|
|
page_text_documents = self._get_notion_database_data(notion_obj_id)
|
|
docs.extend(page_text_documents)
|
|
elif notion_page_type == 'page':
|
|
page_text_list = self._get_notion_block_data(notion_obj_id)
|
|
for page_text in page_text_list:
|
|
docs.append(Document(page_content=page_text))
|
|
else:
|
|
raise ValueError("notion page type not supported")
|
|
|
|
return docs
|
|
|
|
def _get_notion_database_data(
|
|
self, database_id: str, query_dict: Dict[str, Any] = {}
|
|
) -> List[Document]:
|
|
"""Get all the pages from a Notion database."""
|
|
res = requests.post(
|
|
DATABASE_URL_TMPL.format(database_id=database_id),
|
|
headers={
|
|
"Authorization": "Bearer " + self._notion_access_token,
|
|
"Content-Type": "application/json",
|
|
"Notion-Version": "2022-06-28",
|
|
},
|
|
json=query_dict,
|
|
)
|
|
|
|
data = res.json()
|
|
|
|
database_content_list = []
|
|
if 'results' not in data or data["results"] is None:
|
|
return []
|
|
for result in data["results"]:
|
|
properties = result['properties']
|
|
data = {}
|
|
for property_name, property_value in properties.items():
|
|
type = property_value['type']
|
|
if type == 'multi_select':
|
|
value = []
|
|
multi_select_list = property_value[type]
|
|
for multi_select in multi_select_list:
|
|
value.append(multi_select['name'])
|
|
elif type == 'rich_text' or type == 'title':
|
|
if len(property_value[type]) > 0:
|
|
value = property_value[type][0]['plain_text']
|
|
else:
|
|
value = ''
|
|
elif type == 'select' or type == 'status':
|
|
if property_value[type]:
|
|
value = property_value[type]['name']
|
|
else:
|
|
value = ''
|
|
else:
|
|
value = property_value[type]
|
|
data[property_name] = value
|
|
row_dict = {k: v for k, v in data.items() if v}
|
|
row_content = ''
|
|
for key, value in row_dict.items():
|
|
if isinstance(value, dict):
|
|
value_dict = {k: v for k, v in value.items() if v}
|
|
value_content = ''.join(f'{k}:{v} ' for k, v in value_dict.items())
|
|
row_content = row_content + f'{key}:{value_content}\n'
|
|
else:
|
|
row_content = row_content + f'{key}:{value}\n'
|
|
document = Document(page_content=row_content)
|
|
database_content_list.append(document)
|
|
|
|
return database_content_list
|
|
|
|
def _get_notion_block_data(self, page_id: str) -> List[str]:
|
|
result_lines_arr = []
|
|
cur_block_id = page_id
|
|
while True:
|
|
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
|
query_dict: Dict[str, Any] = {}
|
|
|
|
res = requests.request(
|
|
"GET",
|
|
block_url,
|
|
headers={
|
|
"Authorization": "Bearer " + self._notion_access_token,
|
|
"Content-Type": "application/json",
|
|
"Notion-Version": "2022-06-28",
|
|
},
|
|
json=query_dict
|
|
)
|
|
data = res.json()
|
|
# current block's heading
|
|
heading = ''
|
|
for result in data["results"]:
|
|
result_type = result["type"]
|
|
result_obj = result[result_type]
|
|
cur_result_text_arr = []
|
|
if result_type == 'table':
|
|
result_block_id = result["id"]
|
|
text = self._read_table_rows(result_block_id)
|
|
text += "\n\n"
|
|
result_lines_arr.append(text)
|
|
else:
|
|
if "rich_text" in result_obj:
|
|
for rich_text in result_obj["rich_text"]:
|
|
# skip if doesn't have text object
|
|
if "text" in rich_text:
|
|
text = rich_text["text"]["content"]
|
|
cur_result_text_arr.append(text)
|
|
if result_type in HEADING_TYPE:
|
|
heading = text
|
|
|
|
result_block_id = result["id"]
|
|
has_children = result["has_children"]
|
|
block_type = result["type"]
|
|
if has_children and block_type != 'child_page':
|
|
children_text = self._read_block(
|
|
result_block_id, num_tabs=1
|
|
)
|
|
cur_result_text_arr.append(children_text)
|
|
|
|
cur_result_text = "\n".join(cur_result_text_arr)
|
|
cur_result_text += "\n\n"
|
|
if result_type in HEADING_TYPE:
|
|
result_lines_arr.append(cur_result_text)
|
|
else:
|
|
result_lines_arr.append(f'{heading}\n{cur_result_text}')
|
|
|
|
if data["next_cursor"] is None:
|
|
break
|
|
else:
|
|
cur_block_id = data["next_cursor"]
|
|
return result_lines_arr
|
|
|
|
def _read_block(self, block_id: str, num_tabs: int = 0) -> str:
|
|
"""Read a block."""
|
|
result_lines_arr = []
|
|
cur_block_id = block_id
|
|
while True:
|
|
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
|
query_dict: Dict[str, Any] = {}
|
|
|
|
res = requests.request(
|
|
"GET",
|
|
block_url,
|
|
headers={
|
|
"Authorization": "Bearer " + self._notion_access_token,
|
|
"Content-Type": "application/json",
|
|
"Notion-Version": "2022-06-28",
|
|
},
|
|
json=query_dict
|
|
)
|
|
data = res.json()
|
|
if 'results' not in data or data["results"] is None:
|
|
break
|
|
heading = ''
|
|
for result in data["results"]:
|
|
result_type = result["type"]
|
|
result_obj = result[result_type]
|
|
cur_result_text_arr = []
|
|
if result_type == 'table':
|
|
result_block_id = result["id"]
|
|
text = self._read_table_rows(result_block_id)
|
|
result_lines_arr.append(text)
|
|
else:
|
|
if "rich_text" in result_obj:
|
|
for rich_text in result_obj["rich_text"]:
|
|
# skip if doesn't have text object
|
|
if "text" in rich_text:
|
|
text = rich_text["text"]["content"]
|
|
prefix = "\t" * num_tabs
|
|
cur_result_text_arr.append(prefix + text)
|
|
if result_type in HEADING_TYPE:
|
|
heading = text
|
|
result_block_id = result["id"]
|
|
has_children = result["has_children"]
|
|
block_type = result["type"]
|
|
if has_children and block_type != 'child_page':
|
|
children_text = self._read_block(
|
|
result_block_id, num_tabs=num_tabs + 1
|
|
)
|
|
cur_result_text_arr.append(children_text)
|
|
|
|
cur_result_text = "\n".join(cur_result_text_arr)
|
|
if result_type in HEADING_TYPE:
|
|
result_lines_arr.append(cur_result_text)
|
|
else:
|
|
result_lines_arr.append(f'{heading}\n{cur_result_text}')
|
|
|
|
if data["next_cursor"] is None:
|
|
break
|
|
else:
|
|
cur_block_id = data["next_cursor"]
|
|
|
|
result_lines = "\n".join(result_lines_arr)
|
|
return result_lines
|
|
|
|
def _read_table_rows(self, block_id: str) -> str:
|
|
"""Read table rows."""
|
|
done = False
|
|
result_lines_arr = []
|
|
cur_block_id = block_id
|
|
while not done:
|
|
block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id)
|
|
query_dict: Dict[str, Any] = {}
|
|
|
|
res = requests.request(
|
|
"GET",
|
|
block_url,
|
|
headers={
|
|
"Authorization": "Bearer " + self._notion_access_token,
|
|
"Content-Type": "application/json",
|
|
"Notion-Version": "2022-06-28",
|
|
},
|
|
json=query_dict
|
|
)
|
|
data = res.json()
|
|
# get table headers text
|
|
table_header_cell_texts = []
|
|
tabel_header_cells = data["results"][0]['table_row']['cells']
|
|
for tabel_header_cell in tabel_header_cells:
|
|
if tabel_header_cell:
|
|
for table_header_cell_text in tabel_header_cell:
|
|
text = table_header_cell_text["text"]["content"]
|
|
table_header_cell_texts.append(text)
|
|
# get table columns text and format
|
|
results = data["results"]
|
|
for i in range(len(results) - 1):
|
|
column_texts = []
|
|
tabel_column_cells = data["results"][i + 1]['table_row']['cells']
|
|
for j in range(len(tabel_column_cells)):
|
|
if tabel_column_cells[j]:
|
|
for table_column_cell_text in tabel_column_cells[j]:
|
|
column_text = table_column_cell_text["text"]["content"]
|
|
column_texts.append(f'{table_header_cell_texts[j]}:{column_text}')
|
|
|
|
cur_result_text = "\n".join(column_texts)
|
|
result_lines_arr.append(cur_result_text)
|
|
|
|
if data["next_cursor"] is None:
|
|
done = True
|
|
break
|
|
else:
|
|
cur_block_id = data["next_cursor"]
|
|
|
|
result_lines = "\n".join(result_lines_arr)
|
|
return result_lines
|
|
|
|
def update_last_edited_time(self, document_model: DocumentModel):
|
|
if not document_model:
|
|
return
|
|
|
|
last_edited_time = self.get_notion_last_edited_time()
|
|
data_source_info = document_model.data_source_info_dict
|
|
data_source_info['last_edited_time'] = last_edited_time
|
|
update_params = {
|
|
DocumentModel.data_source_info: json.dumps(data_source_info)
|
|
}
|
|
|
|
DocumentModel.query.filter_by(id=document_model.id).update(update_params)
|
|
db.session.commit()
|
|
|
|
def get_notion_last_edited_time(self) -> str:
|
|
obj_id = self._notion_obj_id
|
|
page_type = self._notion_page_type
|
|
if page_type == 'database':
|
|
retrieve_page_url = RETRIEVE_DATABASE_URL_TMPL.format(database_id=obj_id)
|
|
else:
|
|
retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id)
|
|
|
|
query_dict: Dict[str, Any] = {}
|
|
|
|
res = requests.request(
|
|
"GET",
|
|
retrieve_page_url,
|
|
headers={
|
|
"Authorization": "Bearer " + self._notion_access_token,
|
|
"Content-Type": "application/json",
|
|
"Notion-Version": "2022-06-28",
|
|
},
|
|
json=query_dict
|
|
)
|
|
|
|
data = res.json()
|
|
return data["last_edited_time"]
|
|
|
|
@classmethod
|
|
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
|
|
data_source_binding = DataSourceBinding.query.filter(
|
|
db.and_(
|
|
DataSourceBinding.tenant_id == tenant_id,
|
|
DataSourceBinding.provider == 'notion',
|
|
DataSourceBinding.disabled == False,
|
|
DataSourceBinding.source_info['workspace_id'] == f'"{notion_workspace_id}"'
|
|
)
|
|
).first()
|
|
|
|
if not data_source_binding:
|
|
raise Exception(f'No notion data source binding found for tenant {tenant_id} '
|
|
f'and notion workspace {notion_workspace_id}')
|
|
|
|
return data_source_binding.access_token
|