mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-11-30 02:08:37 +08:00
parent
e409895c02
commit
724e053732
151
api/commands.py
151
api/commands.py
@ -3,12 +3,13 @@ import json
|
||||
import math
|
||||
import random
|
||||
import string
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import click
|
||||
from tqdm import tqdm
|
||||
from flask import current_app
|
||||
from flask import current_app, Flask
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@ -456,92 +457,92 @@ def update_qdrant_indexes():
|
||||
@click.command('normalization-collections', help='restore all collections in one')
|
||||
def normalization_collections():
|
||||
click.echo(click.style('Start normalization collections.', fg='green'))
|
||||
normalization_count = 0
|
||||
|
||||
normalization_count = []
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
datasets = db.session.query(Dataset).filter(Dataset.indexing_technique == 'high_quality') \
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=50)
|
||||
.order_by(Dataset.created_at.desc()).paginate(page=page, per_page=100)
|
||||
except NotFound:
|
||||
break
|
||||
|
||||
datasets_result = datasets.items
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
if not dataset.collection_binding_id:
|
||||
try:
|
||||
click.echo('restore dataset index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except Exception:
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider_name='openai',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
||||
model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
|
||||
DatasetCollectionBinding.model_name == embedding_model.name). \
|
||||
order_by(DatasetCollectionBinding.created_at). \
|
||||
first()
|
||||
for i in range(0, len(datasets_result), 5):
|
||||
threads = []
|
||||
sub_datasets = datasets_result[i:i + 5]
|
||||
for dataset in sub_datasets:
|
||||
document_format_thread = threading.Thread(target=deal_dataset_vector, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'normalization_count': normalization_count
|
||||
})
|
||||
threads.append(document_format_thread)
|
||||
document_format_thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
if not dataset_collection_binding:
|
||||
dataset_collection_binding = DatasetCollectionBinding(
|
||||
provider_name=embedding_model.model_provider.provider_name,
|
||||
model_name=embedding_model.name,
|
||||
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
db.session.commit()
|
||||
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
index.restore_dataset_in_one(dataset, dataset_collection_binding)
|
||||
else:
|
||||
click.echo('passed.')
|
||||
def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
click.echo('restore dataset index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = ModelFactory.get_embedding_model(
|
||||
tenant_id=dataset.tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except Exception:
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider_name='openai',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({'openai_api_key': 'TEST'}),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = OpenAIProvider(provider=provider)
|
||||
embedding_model = OpenAIEmbedding(name="text-embedding-ada-002",
|
||||
model_provider=model_provider)
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.provider_name == embedding_model.model_provider.provider_name,
|
||||
DatasetCollectionBinding.model_name == embedding_model.name). \
|
||||
order_by(DatasetCollectionBinding.created_at). \
|
||||
first()
|
||||
|
||||
original_index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if original_index:
|
||||
original_index.delete_original_collection(dataset, dataset_collection_binding)
|
||||
normalization_count += 1
|
||||
else:
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
if not dataset_collection_binding:
|
||||
dataset_collection_binding = DatasetCollectionBinding(
|
||||
provider_name=embedding_model.model_provider.provider_name,
|
||||
model_name=embedding_model.name,
|
||||
collection_name="Vector_index_" + str(uuid.uuid4()).replace("-", "_") + '_Node'
|
||||
)
|
||||
db.session.add(dataset_collection_binding)
|
||||
db.session.commit()
|
||||
|
||||
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(normalization_count), fg='green'))
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantVectorIndex, QdrantConfig
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
# index.delete_by_group_id(dataset.id)
|
||||
index.restore_dataset_in_one(dataset, dataset_collection_binding)
|
||||
else:
|
||||
click.echo('passed.')
|
||||
normalization_count.append(1)
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
|
||||
|
||||
@click.command('update_app_model_configs', help='Migrate data to support paragraph variable.')
|
||||
|
@ -113,8 +113,10 @@ class BaseVectorIndex(BaseIndex):
|
||||
def delete_by_group_id(self, group_id: str) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
vector_store.delete()
|
||||
if self.dataset.collection_binding_id:
|
||||
vector_store.delete_by_group_id(group_id)
|
||||
else:
|
||||
vector_store.delete()
|
||||
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
@ -283,7 +285,7 @@ class BaseVectorIndex(BaseIndex):
|
||||
|
||||
if documents:
|
||||
try:
|
||||
self.create_with_collection_name(documents, dataset_collection_binding.collection_name)
|
||||
self.add_texts(documents)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -1390,70 +1390,12 @@ class Qdrant(VectorStore):
|
||||
path=path,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
# Skip any validation in case of forced collection recreate.
|
||||
if force_recreate:
|
||||
raise ValueError
|
||||
|
||||
# Get the vector configuration of the existing collection and vector, if it
|
||||
# was specified. If the old configuration does not match the current one,
|
||||
# an exception is being thrown.
|
||||
collection_info = client.get_collection(collection_name=collection_name)
|
||||
current_vector_config = collection_info.config.params.vectors
|
||||
if isinstance(current_vector_config, dict) and vector_name is not None:
|
||||
if vector_name not in current_vector_config:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} does not "
|
||||
f"contain vector named {vector_name}. Did you mean one of the "
|
||||
f"existing vectors: {', '.join(current_vector_config.keys())}? "
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
current_vector_config = current_vector_config.get(
|
||||
vector_name
|
||||
) # type: ignore[assignment]
|
||||
elif isinstance(current_vector_config, dict) and vector_name is None:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} uses named vectors. "
|
||||
f"If you want to reuse it, please set `vector_name` to any of the "
|
||||
f"existing named vectors: "
|
||||
f"{', '.join(current_vector_config.keys())}." # noqa
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
elif (
|
||||
not isinstance(current_vector_config, dict) and vector_name is not None
|
||||
):
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} doesn't use named "
|
||||
f"vectors. If you want to reuse it, please set `vector_name` to "
|
||||
f"`None`. If you want to recreate the collection, set "
|
||||
f"`force_recreate` parameter to `True`."
|
||||
)
|
||||
|
||||
# Check if the vector configuration has the same dimensionality.
|
||||
if current_vector_config.size != vector_size: # type: ignore[union-attr]
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection is configured for vectors with "
|
||||
f"{current_vector_config.size} " # type: ignore[union-attr]
|
||||
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
|
||||
current_distance_func = (
|
||||
current_vector_config.distance.name.upper() # type: ignore[union-attr]
|
||||
)
|
||||
if current_distance_func != distance_func:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection is configured for "
|
||||
f"{current_vector_config.distance} " # type: ignore[union-attr]
|
||||
f"similarity. Please set `distance_func` parameter to "
|
||||
f"`{distance_func}` if you want to reuse it. If you want to "
|
||||
f"recreate the collection, set `force_recreate` parameter to "
|
||||
f"`True`."
|
||||
)
|
||||
except (UnexpectedResponse, RpcError, ValueError):
|
||||
all_collection_name = []
|
||||
collections_response = client.get_collections()
|
||||
collection_list = collections_response.collections
|
||||
for collection in collection_list:
|
||||
all_collection_name.append(collection.name)
|
||||
if collection_name not in all_collection_name:
|
||||
vectors_config = rest.VectorParams(
|
||||
size=vector_size,
|
||||
distance=rest.Distance[distance_func],
|
||||
@ -1481,6 +1423,67 @@ class Qdrant(VectorStore):
|
||||
timeout=timeout, # type: ignore[arg-type]
|
||||
)
|
||||
is_new_collection = True
|
||||
if force_recreate:
|
||||
raise ValueError
|
||||
|
||||
# Get the vector configuration of the existing collection and vector, if it
|
||||
# was specified. If the old configuration does not match the current one,
|
||||
# an exception is being thrown.
|
||||
collection_info = client.get_collection(collection_name=collection_name)
|
||||
current_vector_config = collection_info.config.params.vectors
|
||||
if isinstance(current_vector_config, dict) and vector_name is not None:
|
||||
if vector_name not in current_vector_config:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} does not "
|
||||
f"contain vector named {vector_name}. Did you mean one of the "
|
||||
f"existing vectors: {', '.join(current_vector_config.keys())}? "
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
current_vector_config = current_vector_config.get(
|
||||
vector_name
|
||||
) # type: ignore[assignment]
|
||||
elif isinstance(current_vector_config, dict) and vector_name is None:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} uses named vectors. "
|
||||
f"If you want to reuse it, please set `vector_name` to any of the "
|
||||
f"existing named vectors: "
|
||||
f"{', '.join(current_vector_config.keys())}." # noqa
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
elif (
|
||||
not isinstance(current_vector_config, dict) and vector_name is not None
|
||||
):
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection {collection_name} doesn't use named "
|
||||
f"vectors. If you want to reuse it, please set `vector_name` to "
|
||||
f"`None`. If you want to recreate the collection, set "
|
||||
f"`force_recreate` parameter to `True`."
|
||||
)
|
||||
|
||||
# Check if the vector configuration has the same dimensionality.
|
||||
if current_vector_config.size != vector_size: # type: ignore[union-attr]
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection is configured for vectors with "
|
||||
f"{current_vector_config.size} " # type: ignore[union-attr]
|
||||
f"dimensions. Selected embeddings are {vector_size}-dimensional. "
|
||||
f"If you want to recreate the collection, set `force_recreate` "
|
||||
f"parameter to `True`."
|
||||
)
|
||||
|
||||
current_distance_func = (
|
||||
current_vector_config.distance.name.upper() # type: ignore[union-attr]
|
||||
)
|
||||
if current_distance_func != distance_func:
|
||||
raise QdrantException(
|
||||
f"Existing Qdrant collection is configured for "
|
||||
f"{current_vector_config.distance} " # type: ignore[union-attr]
|
||||
f"similarity. Please set `distance_func` parameter to "
|
||||
f"`{distance_func}` if you want to reuse it. If you want to "
|
||||
f"recreate the collection, set `force_recreate` parameter to "
|
||||
f"`True`."
|
||||
)
|
||||
qdrant = cls(
|
||||
client=client,
|
||||
collection_name=collection_name,
|
||||
|
@ -169,6 +169,19 @@ class QdrantVectorIndex(BaseVectorIndex):
|
||||
],
|
||||
))
|
||||
|
||||
def delete(self) -> None:
|
||||
vector_store = self._get_vector_store()
|
||||
vector_store = cast(self._get_vector_store_class(), vector_store)
|
||||
|
||||
from qdrant_client.http import models
|
||||
vector_store.del_texts(models.Filter(
|
||||
must=[
|
||||
models.FieldCondition(
|
||||
key="group_id",
|
||||
match=models.MatchValue(value=self.dataset.id),
|
||||
),
|
||||
],
|
||||
))
|
||||
|
||||
def _is_origin(self):
|
||||
if self.dataset.index_struct_dict:
|
||||
|
@ -5,4 +5,5 @@ from tasks.clean_dataset_task import clean_dataset_task
|
||||
@dataset_was_deleted.connect
|
||||
def handle(sender, **kwargs):
|
||||
dataset = sender
|
||||
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique, dataset.index_struct)
|
||||
clean_dataset_task.delay(dataset.id, dataset.tenant_id, dataset.indexing_technique,
|
||||
dataset.index_struct, dataset.collection_binding_id)
|
||||
|
@ -13,13 +13,15 @@ from models.dataset import DocumentSegment, Dataset, DatasetKeywordTable, Datase
|
||||
|
||||
|
||||
@shared_task(queue='dataset')
|
||||
def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, index_struct: str):
|
||||
def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
|
||||
index_struct: str, collection_binding_id: str):
|
||||
"""
|
||||
Clean dataset when dataset deleted.
|
||||
:param dataset_id: dataset id
|
||||
:param tenant_id: tenant id
|
||||
:param indexing_technique: indexing technique
|
||||
:param index_struct: index struct dict
|
||||
:param collection_binding_id: collection binding id
|
||||
|
||||
Usage: clean_dataset_task.delay(dataset_id, tenant_id, indexing_technique, index_struct)
|
||||
"""
|
||||
@ -31,9 +33,9 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
|
||||
id=dataset_id,
|
||||
tenant_id=tenant_id,
|
||||
indexing_technique=indexing_technique,
|
||||
index_struct=index_struct
|
||||
index_struct=index_struct,
|
||||
collection_binding_id=collection_binding_id
|
||||
)
|
||||
|
||||
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
|
||||
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
|
||||
|
||||
@ -43,7 +45,7 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
vector_index = IndexBuilder.get_default_high_quality_index(dataset)
|
||||
try:
|
||||
vector_index.delete()
|
||||
vector_index.delete_by_group_id(dataset.id)
|
||||
except Exception:
|
||||
logging.exception("Delete doc index failed when dataset deleted.")
|
||||
|
||||
|
@ -31,8 +31,8 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
|
||||
raise Exception('Dataset not found')
|
||||
|
||||
if action == "remove":
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=False)
|
||||
index.delete()
|
||||
index = IndexBuilder.get_index(dataset, 'high_quality', ignore_high_quality_check=True)
|
||||
index.delete_by_group_id(dataset.id)
|
||||
elif action == "add":
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
|
Loading…
Reference in New Issue
Block a user