mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-02 11:18:19 +08:00
fix score_threshold_enabled name (#1626)
Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
parent
603e55f252
commit
74b2260ba6
@ -40,7 +40,7 @@ default_retrieval_model = {
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
class OrchestratorRuleParser:
|
||||
@ -220,8 +220,8 @@ class OrchestratorRuleParser:
|
||||
# top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
|
||||
|
||||
score_threshold = None
|
||||
score_threshold_enable = retrieval_model_config.get("score_threshold_enable")
|
||||
if score_threshold_enable:
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
tool = DatasetRetrieverTool.from_dataset(
|
||||
@ -239,7 +239,7 @@ class OrchestratorRuleParser:
|
||||
dataset_ids=dataset_ids,
|
||||
tenant_id=kwargs['tenant_id'],
|
||||
top_k=dataset_configs.get('top_k', 2),
|
||||
score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enable', False) else None,
|
||||
score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enabled', False) else None,
|
||||
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
|
||||
conversation_message_task=conversation_message_task,
|
||||
return_resource=return_resource,
|
||||
|
@ -24,7 +24,7 @@ default_retrieval_model = {
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
|
||||
@ -216,7 +216,7 @@ class DatasetMultiRetrieverTool(BaseTool):
|
||||
'embeddings': embeddings,
|
||||
'score_threshold': retrieval_model[
|
||||
'score_threshold'] if retrieval_model[
|
||||
'score_threshold_enable'] else None,
|
||||
'score_threshold_enabled'] else None,
|
||||
'top_k': self.top_k,
|
||||
'reranking_model': retrieval_model[
|
||||
'reranking_model'] if retrieval_model[
|
||||
|
@ -25,7 +25,7 @@ default_retrieval_model = {
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
|
||||
@ -110,7 +110,7 @@ class DatasetRetrieverTool(BaseTool):
|
||||
'query': query,
|
||||
'top_k': self.top_k,
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
|
||||
'score_threshold_enable'] else None,
|
||||
'score_threshold_enabled'] else None,
|
||||
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
|
||||
'reranking_enable'] else None,
|
||||
'all_documents': documents,
|
||||
@ -129,7 +129,7 @@ class DatasetRetrieverTool(BaseTool):
|
||||
'search_method': retrieval_model['search_method'],
|
||||
'embeddings': embeddings,
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
|
||||
'score_threshold_enable'] else None,
|
||||
'score_threshold_enabled'] else None,
|
||||
'top_k': self.top_k,
|
||||
'reranking_model': retrieval_model['reranking_model'] if retrieval_model[
|
||||
'reranking_enable'] else None,
|
||||
@ -148,7 +148,7 @@ class DatasetRetrieverTool(BaseTool):
|
||||
model_name=retrieval_model['reranking_model']['reranking_model_name']
|
||||
)
|
||||
documents = hybrid_rerank.rerank(query, documents,
|
||||
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
|
||||
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
|
||||
self.top_k)
|
||||
else:
|
||||
documents = []
|
||||
|
@ -22,7 +22,7 @@ dataset_retrieval_model_fields = {
|
||||
'reranking_enable': fields.Boolean,
|
||||
'reranking_model': fields.Nested(reranking_model_fields),
|
||||
'top_k': fields.Integer,
|
||||
'score_threshold_enable': fields.Boolean,
|
||||
'score_threshold_enabled': fields.Boolean,
|
||||
'score_threshold': fields.Float
|
||||
}
|
||||
|
||||
|
@ -104,7 +104,7 @@ class Dataset(db.Model):
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
return self.retrieval_model if self.retrieval_model else default_retrieval_model
|
||||
|
||||
|
@ -485,7 +485,7 @@ class DocumentService:
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
dataset.retrieval_model = document_data.get('retrieval_model') if document_data.get('retrieval_model') else default_retrieval_model
|
||||
@ -769,7 +769,7 @@ class DocumentService:
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
retrieval_model = default_retrieval_model
|
||||
# save dataset
|
||||
|
@ -25,7 +25,7 @@ default_retrieval_model = {
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
class HitTestingService:
|
||||
@ -64,7 +64,7 @@ class HitTestingService:
|
||||
'dataset_id': str(dataset.id),
|
||||
'query': query,
|
||||
'top_k': retrieval_model['top_k'],
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
|
||||
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
|
||||
'all_documents': all_documents,
|
||||
'search_method': retrieval_model['search_method'],
|
||||
@ -81,7 +81,7 @@ class HitTestingService:
|
||||
'query': query,
|
||||
'search_method': retrieval_model['search_method'],
|
||||
'embeddings': embeddings,
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
|
||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
|
||||
'top_k': retrieval_model['top_k'],
|
||||
'reranking_model': retrieval_model['reranking_model'] if retrieval_model['reranking_enable'] else None,
|
||||
'all_documents': all_documents
|
||||
@ -99,7 +99,7 @@ class HitTestingService:
|
||||
model_name=retrieval_model['reranking_model']['reranking_model_name']
|
||||
)
|
||||
all_documents = hybrid_rerank.rerank(query, all_documents,
|
||||
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
|
||||
retrieval_model['score_threshold'] if retrieval_model['score_threshold_enabled'] else None,
|
||||
retrieval_model['top_k'])
|
||||
|
||||
end = time.perf_counter()
|
||||
|
@ -15,7 +15,7 @@ default_retrieval_model = {
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enable': False
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user