mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-02 03:07:59 +08:00
Fix/ignore economy dataset (#1043)
Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
parent
f9bec1edf8
commit
a55ba6e614
@ -92,11 +92,14 @@ class DatasetListApi(Resource):
|
||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
item['embedding_available'] = True
|
||||
if item['indexing_technique'] == 'high_quality':
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
item['embedding_available'] = True
|
||||
else:
|
||||
item['embedding_available'] = False
|
||||
else:
|
||||
item['embedding_available'] = False
|
||||
item['embedding_available'] = True
|
||||
response = {
|
||||
'data': data,
|
||||
'has_more': len(datasets) == limit,
|
||||
@ -122,14 +125,6 @@ class DatasetListApi(Resource):
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
|
||||
try:
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
@ -167,6 +162,11 @@ class DatasetApi(Resource):
|
||||
@account_initialization_required
|
||||
def patch(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False,
|
||||
@ -254,6 +254,7 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json')
|
||||
parser.add_argument('indexing_technique', type=str, required=True, nullable=True, location='json')
|
||||
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
|
||||
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json')
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json')
|
||||
@ -275,7 +276,8 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
try:
|
||||
response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details,
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'])
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
@ -290,7 +292,8 @@ class DatasetIndexingEstimateApi(Resource):
|
||||
response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id,
|
||||
args['info_list']['notion_info_list'],
|
||||
args['process_rule'], args['doc_form'],
|
||||
args['doc_language'], args['dataset_id'])
|
||||
args['doc_language'], args['dataset_id'],
|
||||
args['indexing_technique'])
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
|
@ -285,20 +285,6 @@ class DatasetDocumentListApi(Resource):
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
try:
|
||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
@ -339,15 +325,17 @@ class DatasetInitApi(Resource):
|
||||
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
|
||||
location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
if args['indexing_technique'] == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
# validate args
|
||||
DocumentService.document_create_args_validate(args)
|
||||
@ -729,6 +717,12 @@ class DocumentDeleteApi(DocumentResource):
|
||||
def delete(self, dataset_id, document_id):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
try:
|
||||
@ -791,6 +785,12 @@ class DocumentStatusApi(DocumentResource):
|
||||
def patch(self, dataset_id, document_id, action):
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
|
@ -149,7 +149,8 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# The role of the current user in the ta table must be admin or owner
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
@ -158,20 +159,20 @@ class DatasetDocumentSegmentApi(Resource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
@ -244,18 +245,19 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
if current_user.current_tenant.current_role not in ['admin', 'owner']:
|
||||
raise Forbidden()
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
@ -284,25 +286,28 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# check segment
|
||||
segment_id = str(segment_id)
|
||||
segment = DocumentSegment.query.filter(
|
||||
DocumentSegment.id == str(segment_id),
|
||||
@ -339,6 +344,8 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound('Dataset not found.')
|
||||
# check user's model setting
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
@ -378,18 +385,6 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
|
||||
document = DocumentService.get_document(dataset_id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
# check file
|
||||
|
@ -67,12 +67,13 @@ class DatesetDocumentStore:
|
||||
|
||||
if max_position is None:
|
||||
max_position = 0
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
model_provider_name=self._dataset.embedding_model_provider,
|
||||
model_name=self._dataset.embedding_model
|
||||
)
|
||||
embedding_model = None
|
||||
if self._dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=self._dataset.tenant_id,
|
||||
model_provider_name=self._dataset.embedding_model_provider,
|
||||
model_name=self._dataset.embedding_model
|
||||
)
|
||||
|
||||
for doc in docs:
|
||||
if not isinstance(doc, Document):
|
||||
@ -88,7 +89,7 @@ class DatesetDocumentStore:
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(doc.page_content)
|
||||
tokens = embedding_model.get_num_tokens(doc.page_content) if embedding_model else 0
|
||||
|
||||
if not segment_document:
|
||||
max_position += 1
|
||||
|
@ -1,10 +1,18 @@
|
||||
import json
|
||||
|
||||
from flask import current_app
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs
|
||||
from core.model_providers.models.llm.openai_model import OpenAIModel
|
||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
||||
from models.dataset import Dataset
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
class IndexBuilder:
|
||||
@ -35,4 +43,13 @@ class IndexBuilder:
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError('Unknown indexing technique')
|
||||
raise ValueError('Unknown indexing technique')
|
||||
|
||||
@classmethod
|
||||
def get_default_high_quality_index(cls, dataset: Dataset):
|
||||
embeddings = OpenAIEmbeddings(openai_api_key=' ')
|
||||
return VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
@ -217,25 +217,29 @@ class IndexingRunner:
|
||||
db.session.commit()
|
||||
|
||||
def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
embedding_model = None
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise ValueError('Dataset not found.')
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
if indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
total_segments = 0
|
||||
@ -263,8 +267,8 @@ class IndexingRunner:
|
||||
for document in documents:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
|
||||
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
|
||||
if indexing_technique == 'high_quality' or embedding_model:
|
||||
tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content))
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
@ -286,32 +290,35 @@ class IndexingRunner:
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
|
||||
"currency": embedding_model.get_currency() if embedding_model else 'USD',
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict,
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None) -> dict:
|
||||
doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
|
||||
indexing_technique: str = 'economy') -> dict:
|
||||
"""
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
embedding_model = None
|
||||
if dataset_id:
|
||||
dataset = Dataset.query.filter_by(
|
||||
id=dataset_id
|
||||
).first()
|
||||
if not dataset:
|
||||
raise ValueError('Dataset not found.')
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
else:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
if indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
# load data from notion
|
||||
tokens = 0
|
||||
preview_texts = []
|
||||
@ -356,8 +363,8 @@ class IndexingRunner:
|
||||
for document in documents:
|
||||
if len(preview_texts) < 5:
|
||||
preview_texts.append(document.page_content)
|
||||
|
||||
tokens += embedding_model.get_num_tokens(document.page_content)
|
||||
if indexing_technique == 'high_quality' or embedding_model:
|
||||
tokens += embedding_model.get_num_tokens(document.page_content)
|
||||
|
||||
if doc_form and doc_form == 'qa_model':
|
||||
text_generation_model = ModelFactory.get_text_generation_model(
|
||||
@ -379,8 +386,8 @@ class IndexingRunner:
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"tokens": tokens,
|
||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)),
|
||||
"currency": embedding_model.get_currency(),
|
||||
"total_price": '{:f}'.format(embedding_model.calc_tokens_price(tokens)) if embedding_model else 0,
|
||||
"currency": embedding_model.get_currency() if embedding_model else 'USD',
|
||||
"preview": preview_texts
|
||||
}
|
||||
|
||||
@ -657,12 +664,13 @@ class IndexingRunner:
|
||||
"""
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
keyword_table_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
# chunk nodes by chunk size
|
||||
indexing_start_at = time.perf_counter()
|
||||
@ -672,11 +680,11 @@ class IndexingRunner:
|
||||
# check document is paused
|
||||
self._check_document_paused_status(dataset_document.id)
|
||||
chunk_documents = documents[i:i + chunk_size]
|
||||
|
||||
tokens += sum(
|
||||
embedding_model.get_num_tokens(document.page_content)
|
||||
for document in chunk_documents
|
||||
)
|
||||
if dataset.indexing_technique == 'high_quality' or embedding_model:
|
||||
tokens += sum(
|
||||
embedding_model.get_num_tokens(document.page_content)
|
||||
for document in chunk_documents
|
||||
)
|
||||
|
||||
# save vector index
|
||||
if vector_index:
|
||||
|
@ -1,6 +1,5 @@
|
||||
from events.dataset_event import dataset_was_deleted
|
||||
from events.event_handlers.document_index_event import document_index_created
|
||||
from tasks.clean_dataset_task import clean_dataset_task
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
@ -0,0 +1,46 @@
|
||||
"""update_dataset_model_field_null_available
|
||||
|
||||
Revision ID: 4bcffcd64aa4
|
||||
Revises: 853f9b9cd3b6
|
||||
Create Date: 2023-08-28 20:58:50.077056
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '4bcffcd64aa4'
|
||||
down_revision = '853f9b9cd3b6'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.alter_column('embedding_model',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=True,
|
||||
existing_server_default=sa.text("'text-embedding-ada-002'::character varying"))
|
||||
batch_op.alter_column('embedding_model_provider',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=True,
|
||||
existing_server_default=sa.text("'openai'::character varying"))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.alter_column('embedding_model_provider',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=False,
|
||||
existing_server_default=sa.text("'openai'::character varying"))
|
||||
batch_op.alter_column('embedding_model',
|
||||
existing_type=sa.VARCHAR(length=255),
|
||||
nullable=False,
|
||||
existing_server_default=sa.text("'text-embedding-ada-002'::character varying"))
|
||||
|
||||
# ### end Alembic commands ###
|
@ -36,10 +36,8 @@ class Dataset(db.Model):
|
||||
updated_by = db.Column(UUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False,
|
||||
server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
embedding_model = db.Column(db.String(
|
||||
255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying"))
|
||||
embedding_model_provider = db.Column(db.String(
|
||||
255), nullable=False, server_default=db.text("'openai'::character varying"))
|
||||
embedding_model = db.Column(db.String(255), nullable=True)
|
||||
embedding_model_provider = db.Column(db.String(255), nullable=True)
|
||||
|
||||
@property
|
||||
def dataset_keyword_table(self):
|
||||
|
@ -10,6 +10,7 @@ from flask import current_app
|
||||
from sqlalchemy import func
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_redis import redis_client
|
||||
from flask_login import current_user
|
||||
@ -91,16 +92,18 @@ class DatasetService:
|
||||
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first():
|
||||
raise DatasetNameDuplicateError(
|
||||
f'Dataset with name {name} already exists.')
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
embedding_model = None
|
||||
if indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
dataset = Dataset(name=name, indexing_technique=indexing_technique)
|
||||
# dataset = Dataset(name=name, provider=provider, config=config)
|
||||
dataset.created_by = account.id
|
||||
dataset.updated_by = account.id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
||||
dataset.embedding_model = embedding_model.name
|
||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name if embedding_model else None
|
||||
dataset.embedding_model = embedding_model.name if embedding_model else None
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
return dataset
|
||||
@ -115,6 +118,23 @@ class DatasetService:
|
||||
else:
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def check_dataset_model_setting(dataset):
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(f"The dataset in unavailable, due to: "
|
||||
f"{ex.description}")
|
||||
|
||||
@staticmethod
|
||||
def update_dataset(dataset_id, data, user):
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
@ -124,6 +144,19 @@ class DatasetService:
|
||||
if data['indexing_technique'] == 'economy':
|
||||
deal_dataset_vector_index_task.delay(dataset_id, 'remove')
|
||||
elif data['indexing_technique'] == 'high_quality':
|
||||
# check embedding model setting
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ValueError(ex.description)
|
||||
deal_dataset_vector_index_task.delay(dataset_id, 'add')
|
||||
filtered_data = {k: v for k, v in data.items() if v is not None or k == 'description'}
|
||||
|
||||
@ -397,23 +430,23 @@ class DocumentService:
|
||||
|
||||
# check document limit
|
||||
if current_app.config['EDITION'] == 'CLOUD':
|
||||
count = 0
|
||||
if document_data["data_source"]["type"] == "upload_file":
|
||||
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
||||
count = len(upload_file_list)
|
||||
elif document_data["data_source"]["type"] == "notion_import":
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info['pages'])
|
||||
documents_count = DocumentService.get_tenant_documents_count()
|
||||
total_count = documents_count + count
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if total_count > tenant_document_count:
|
||||
raise ValueError(f"over document limit {tenant_document_count}.")
|
||||
if 'original_document_id' not in document_data or not document_data['original_document_id']:
|
||||
count = 0
|
||||
if document_data["data_source"]["type"] == "upload_file":
|
||||
upload_file_list = document_data["data_source"]["info_list"]['file_info_list']['file_ids']
|
||||
count = len(upload_file_list)
|
||||
elif document_data["data_source"]["type"] == "notion_import":
|
||||
notion_info_list = document_data["data_source"]['info_list']['notion_info_list']
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info['pages'])
|
||||
documents_count = DocumentService.get_tenant_documents_count()
|
||||
total_count = documents_count + count
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if total_count > tenant_document_count:
|
||||
raise ValueError(f"over document limit {tenant_document_count}.")
|
||||
# if dataset is empty, update dataset data_source_type
|
||||
if not dataset.data_source_type:
|
||||
dataset.data_source_type = document_data["data_source"]["type"]
|
||||
db.session.commit()
|
||||
|
||||
if not dataset.indexing_technique:
|
||||
if 'indexing_technique' not in document_data \
|
||||
@ -421,6 +454,13 @@ class DocumentService:
|
||||
raise ValueError("Indexing technique is required")
|
||||
|
||||
dataset.indexing_technique = document_data["indexing_technique"]
|
||||
if document_data["indexing_technique"] == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id
|
||||
)
|
||||
dataset.embedding_model = embedding_model.name
|
||||
dataset.embedding_model_provider = embedding_model.model_provider.provider_name
|
||||
|
||||
|
||||
documents = []
|
||||
batch = time.strftime('%Y%m%d%H%M%S') + str(random.randint(100000, 999999))
|
||||
@ -466,11 +506,11 @@ class DocumentService:
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, file_name, batch)
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, file_name, batch)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
@ -512,11 +552,11 @@ class DocumentService:
|
||||
"type": page['type']
|
||||
}
|
||||
document = DocumentService.build_document(dataset, dataset_process_rule.id,
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, page['page_name'], batch)
|
||||
document_data["data_source"]["type"],
|
||||
document_data["doc_form"],
|
||||
document_data["doc_language"],
|
||||
data_source_info, created_from, position,
|
||||
account, page['page_name'], batch)
|
||||
db.session.add(document)
|
||||
db.session.flush()
|
||||
document_ids.append(document.id)
|
||||
@ -536,9 +576,9 @@ class DocumentService:
|
||||
|
||||
@staticmethod
|
||||
def build_document(dataset: Dataset, process_rule_id: str, data_source_type: str, document_form: str,
|
||||
document_language: str, data_source_info: dict, created_from: str, position: int,
|
||||
account: Account,
|
||||
name: str, batch: str):
|
||||
document_language: str, data_source_info: dict, created_from: str, position: int,
|
||||
account: Account,
|
||||
name: str, batch: str):
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
@ -567,6 +607,7 @@ class DocumentService:
|
||||
def update_document_with_dataset_id(dataset: Dataset, document_data: dict,
|
||||
account: Account, dataset_process_rule: Optional[DatasetProcessRule] = None,
|
||||
created_from: str = 'web'):
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
document = DocumentService.get_document(dataset.id, document_data["original_document_id"])
|
||||
if document.display_status != 'available':
|
||||
raise ValueError("Document is not available")
|
||||
@ -674,9 +715,11 @@ class DocumentService:
|
||||
tenant_document_count = int(current_app.config['TENANT_DOCUMENT_COUNT'])
|
||||
if total_count > tenant_document_count:
|
||||
raise ValueError(f"All your documents have overed limit {tenant_document_count}.")
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
embedding_model = None
|
||||
if document_data['indexing_technique'] == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
# save dataset
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
@ -684,8 +727,8 @@ class DocumentService:
|
||||
data_source_type=document_data["data_source"]["type"],
|
||||
indexing_technique=document_data["indexing_technique"],
|
||||
created_by=account.id,
|
||||
embedding_model=embedding_model.name,
|
||||
embedding_model_provider=embedding_model.model_provider.provider_name
|
||||
embedding_model=embedding_model.name if embedding_model else None,
|
||||
embedding_model_provider=embedding_model.model_provider.provider_name if embedding_model else None
|
||||
)
|
||||
|
||||
db.session.add(dataset)
|
||||
@ -903,15 +946,15 @@ class SegmentService:
|
||||
content = args['content']
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document_id == document.id
|
||||
).scalar()
|
||||
@ -973,15 +1016,16 @@ class SegmentService:
|
||||
kw_index.update_segment_keywords_index(segment.index_node_id, segment.keywords)
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
segment.content = content
|
||||
segment.index_node_hash = segment_hash
|
||||
segment.word_count = len(content)
|
||||
@ -1013,7 +1057,7 @@ class SegmentService:
|
||||
cache_result = redis_client.get(indexing_cache_key)
|
||||
if cache_result is not None:
|
||||
raise ValueError("Segment is deleting.")
|
||||
|
||||
|
||||
# enabled segment need to delete index
|
||||
if segment.enabled:
|
||||
# send delete segment index task
|
||||
|
@ -49,18 +49,20 @@ def batch_create_segment_to_index_task(job_id: str, content: List, dataset_id: s
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != 'completed':
|
||||
raise ValueError('Document is not available.')
|
||||
document_segments = []
|
||||
for segment in content:
|
||||
content = segment['content']
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
embedding_model = None
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
|
||||
for segment in content:
|
||||
content = segment['content']
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
# calc embedding use tokens
|
||||
tokens = embedding_model.get_num_tokens(content)
|
||||
tokens = embedding_model.get_num_tokens(content) if embedding_model else 0
|
||||
max_position = db.session.query(func.max(DocumentSegment.position)).filter(
|
||||
DocumentSegment.document_id == dataset_document.id
|
||||
).scalar()
|
||||
|
@ -3,8 +3,10 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from flask import current_app
|
||||
|
||||
from core.index.index import IndexBuilder
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, DatasetQuery, DatasetProcessRule, \
|
||||
AppDatasetJoin, Document
|
||||
@ -35,11 +37,11 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
|
||||
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
|
||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
|
||||
|
||||
vector_index = IndexBuilder.get_index(dataset, 'high_quality')
|
||||
kw_index = IndexBuilder.get_index(dataset, 'economy')
|
||||
|
||||
# delete from vector index
|
||||
if vector_index:
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
|
||||
try:
|
||||
vector_index.delete()
|
||||
except Exception:
|
||||
|
@ -31,7 +31,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
raise Exception('Dataset not found')
|
||||
|
||||
if action == "remove":
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
|
||||
index.delete()
|
||||
elif action == "add":
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
@ -43,7 +43,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
|
||||
if dataset_documents:
|
||||
# save vector index
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
# delete from vector index
|
||||
|
Loading…
Reference in New Issue
Block a user