mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-11-30 18:27:53 +08:00
Fix/multi thread parameter (#1604)
This commit is contained in:
parent
f704094a5f
commit
a5b80c9d1f
@ -192,7 +192,7 @@ class DatasetMultiRetrieverTool(BaseTool):
|
||||
'search_method'] == 'hybrid_search':
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'dataset_id': str(dataset.id),
|
||||
'query': query,
|
||||
'top_k': self.top_k,
|
||||
'score_threshold': self.score_threshold,
|
||||
@ -210,7 +210,7 @@ class DatasetMultiRetrieverTool(BaseTool):
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
|
||||
kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'dataset_id': str(dataset.id),
|
||||
'query': query,
|
||||
'search_method': 'hybrid_search',
|
||||
'embeddings': embeddings,
|
||||
|
@ -106,7 +106,7 @@ class DatasetRetrieverTool(BaseTool):
|
||||
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'dataset_id': str(dataset.id),
|
||||
'query': query,
|
||||
'top_k': self.top_k,
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
|
||||
@ -124,7 +124,7 @@ class DatasetRetrieverTool(BaseTool):
|
||||
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'dataset_id': str(dataset.id),
|
||||
'query': query,
|
||||
'search_method': retrieval_model['search_method'],
|
||||
'embeddings': embeddings,
|
||||
|
@ -61,7 +61,7 @@ class HitTestingService:
|
||||
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'dataset_id': str(dataset.id),
|
||||
'query': query,
|
||||
'top_k': retrieval_model['top_k'],
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
|
||||
@ -77,7 +77,7 @@ class HitTestingService:
|
||||
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'dataset': dataset,
|
||||
'dataset_id': str(dataset.id),
|
||||
'query': query,
|
||||
'search_method': retrieval_model['search_method'],
|
||||
'embeddings': embeddings,
|
||||
|
@ -4,6 +4,7 @@ from flask import current_app, Flask
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from core.index.vector_index.vector_index import VectorIndex
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
|
||||
default_retrieval_model = {
|
||||
@ -21,10 +22,13 @@ default_retrieval_model = {
|
||||
class RetrievalService:
|
||||
|
||||
@classmethod
|
||||
def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str,
|
||||
def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
|
||||
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||
all_documents: list, search_method: str, embeddings: Embeddings):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
@ -56,10 +60,13 @@ class RetrievalService:
|
||||
all_documents.extend(documents)
|
||||
|
||||
@classmethod
|
||||
def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str,
|
||||
def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
|
||||
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||
all_documents: list, search_method: str, embeddings: Embeddings):
|
||||
with flask_app.app_context():
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
|
Loading…
Reference in New Issue
Block a user