From e4d175780e2eab98fafcc82aa6a42f73b265cba2 Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Tue, 12 Nov 2024 14:38:24 +0800 Subject: [PATCH] fix: retrieval setting validate (#10454) --- .../configuration/dataset-config/index.tsx | 6 +- .../params-config/config-content.tsx | 2 +- .../dataset-config/params-config/index.tsx | 6 +- .../components/app/configuration/index.tsx | 11 +- .../nodes/knowledge-retrieval/default.ts | 13 +- .../nodes/knowledge-retrieval/types.ts | 2 + .../nodes/knowledge-retrieval/use-config.ts | 24 +++- .../nodes/knowledge-retrieval/utils.ts | 115 ++++++++++++------ 8 files changed, 130 insertions(+), 49 deletions(-) diff --git a/web/app/components/app/configuration/dataset-config/index.tsx b/web/app/components/app/configuration/dataset-config/index.tsx index 0d9d575c1..78b49f81d 100644 --- a/web/app/components/app/configuration/dataset-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/index.tsx @@ -47,12 +47,16 @@ const DatasetConfig: FC = () => { const { currentModel: currentRerankModel, + currentProvider: currentRerankProvider, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const onRemove = (id: string) => { const filteredDataSets = dataSet.filter(item => item.id !== id) setDataSet(filteredDataSets) - const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, !!currentRerankModel) + const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, { + provider: currentRerankProvider?.provider, + model: currentRerankModel?.model, + }) setDatasetConfigs({ ...(datasetConfigs as any), ...retrievalConfig, diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 5bd748382..dcb2b1a3f 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -172,7 +172,7 @@ const ConfigContent: FC = ({ return false return datasetConfigs.reranking_enable - }, [canManuallyToggleRerank, datasetConfigs.reranking_enable]) + }, [canManuallyToggleRerank, datasetConfigs.reranking_enable, isRerankDefaultModelValid]) const handleDisabledSwitchClick = useCallback(() => { if (!currentRerankModel && !showRerankModel) diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.tsx index 94920fbd3..7f7a4799d 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.tsx @@ -43,6 +43,7 @@ const ParamsConfig = ({ const { defaultModel: rerankDefaultModel, currentModel: isRerankDefaultModelValid, + currentProvider: rerankDefaultProvider, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const isValid = () => { @@ -91,7 +92,10 @@ const ParamsConfig = ({ reranking_mode: restConfigs.reranking_mode, weights: restConfigs.weights, reranking_enable: restConfigs.reranking_enable, - }, selectedDatasets, selectedDatasets, !!isRerankDefaultModelValid) + }, selectedDatasets, selectedDatasets, { + provider: rerankDefaultProvider?.provider, + model: isRerankDefaultModelValid?.model, + }) setTempDataSetConfigs({ ...retrievalConfig, diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index 2bb11a870..b5b7e98d4 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -226,6 +226,7 @@ const Configuration: FC = () => { const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false) const { currentModel: currentRerankModel, + currentProvider: currentRerankProvider, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const handleSelect = (data: DataSet[]) => { if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) { @@ -279,7 +280,10 @@ const Configuration: FC = () => { reranking_mode: restConfigs.reranking_mode, weights: restConfigs.weights, reranking_enable: restConfigs.reranking_enable, - }, newDatasets, dataSets, !!currentRerankModel) + }, newDatasets, dataSets, { + provider: currentRerankProvider?.provider, + model: currentRerankModel?.model, + }) setDatasetConfigs({ ...retrievalConfig, @@ -620,7 +624,10 @@ const Configuration: FC = () => { syncToPublishedConfig(config) setPublishedConfig(config) - const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, !!currentRerankModel) + const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, { + provider: currentRerankProvider?.provider, + model: currentRerankModel?.model, + }) setDatasetConfigs({ retrieval_model: RETRIEVE_TYPE.multiWay, ...modelConfig.dataset_configs, diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/default.ts b/web/app/components/workflow/nodes/knowledge-retrieval/default.ts index 03591dd52..e902d29b9 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/default.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/default.ts @@ -1,7 +1,7 @@ import { BlockEnum } from '../../types' import type { NodeDefault } from '../../types' import type { KnowledgeRetrievalNodeType } from './types' -import { RerankingModeEnum } from '@/models/datasets' +import { checkoutRerankModelConfigedInRetrievalSettings } from './utils' import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants' import { DATASET_DEFAULT } from '@/config' import { RETRIEVE_TYPE } from '@/types/app' @@ -36,12 +36,17 @@ const nodeDefault: NodeDefault = { if (!errorMessages && (!payload.dataset_ids || payload.dataset_ids.length === 0)) errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.knowledge`) }) - if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.multiWay && payload.multiple_retrieval_config?.reranking_mode === RerankingModeEnum.RerankingModel && !payload.multiple_retrieval_config?.reranking_model?.provider && payload.multiple_retrieval_config?.reranking_enable) - errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) }) - if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.oneWay && !payload.single_retrieval_config?.model?.provider) errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t('common.modelProvider.systemReasoningModel.key') }) + const { _datasets, multiple_retrieval_config, retrieval_mode } = payload + if (retrieval_mode === RETRIEVE_TYPE.multiWay) { + const checked = checkoutRerankModelConfigedInRetrievalSettings(_datasets || [], multiple_retrieval_config) + + if (!errorMessages && !checked) + errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) }) + } + return { isValid: !errorMessages, errorMessage: errorMessages, diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/types.ts b/web/app/components/workflow/nodes/knowledge-retrieval/types.ts index da9373962..1b85bfc0b 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/types.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/types.ts @@ -1,6 +1,7 @@ import type { CommonNodeType, ModelConfig, ValueSelector } from '@/app/components/workflow/types' import type { RETRIEVE_TYPE } from '@/types/app' import type { + DataSet, RerankingModeEnum, } from '@/models/datasets' @@ -35,4 +36,5 @@ export type KnowledgeRetrievalNodeType = CommonNodeType & { retrieval_mode: RETRIEVE_TYPE multiple_retrieval_config?: MultipleRetrievalConfig single_retrieval_config?: SingleRetrievalConfig + _datasets?: DataSet[] } diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts index 288a718aa..e90fe2c2f 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts @@ -67,6 +67,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { const { currentModel: currentRerankModel, + currentProvider: currentRerankProvider, } = useCurrentProviderAndModel( rerankModelList, rerankDefaultModel @@ -163,7 +164,10 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { draft.retrieval_mode = newMode if (newMode === RETRIEVE_TYPE.multiWay) { const multipleRetrievalConfig = draft.multiple_retrieval_config - draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) + draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, { + provider: currentRerankProvider?.provider, + model: currentRerankModel?.model, + }) } else { const hasSetModel = draft.single_retrieval_config?.model?.provider @@ -180,14 +184,17 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { } }) setInputs(newInputs) - }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel]) + }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider]) const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { const newInputs = produce(inputs, (draft) => { - draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) + draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, { + provider: currentRerankProvider?.provider, + model: currentRerankModel?.model, + }) }) setInputs(newInputs) - }, [inputs, setInputs, selectedDatasets, currentRerankModel]) + }, [inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider]) // datasets useEffect(() => { @@ -200,6 +207,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { } const newInputs = produce(inputs, (draft) => { draft.dataset_ids = datasetIds + draft._datasets = selectedDatasets }) setInputs(newInputs) })() @@ -228,10 +236,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { } = getSelectedDatasetsMode(newDatasets) const newInputs = produce(inputs, (draft) => { draft.dataset_ids = newDatasets.map(d => d.id) + draft._datasets = newDatasets if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { const multipleRetrievalConfig = draft.multiple_retrieval_config - draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, !!currentRerankModel) + draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, { + provider: currentRerankProvider?.provider, + model: currentRerankModel?.model, + }) } }) setInputs(newInputs) @@ -243,7 +255,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { || allExternal ) setRerankModelOpen(true) - }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel]) + }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider]) const filterVar = useCallback((varPayload: Var) => { return varPayload.type === VarType.string diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts b/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts index fd3d3ebab..e9da9accc 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts @@ -94,9 +94,10 @@ export const getMultipleRetrievalConfig = ( multipleRetrievalConfig: MultipleRetrievalConfig, selectedDatasets: DataSet[], originalDatasets: DataSet[], - isValidRerankModel?: boolean, + validRerankModel?: { provider?: string; model?: string }, ) => { const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0 + const rerankModelIsValid = validRerankModel?.provider && validRerankModel?.model const { allHighQuality, @@ -128,18 +129,10 @@ export const getMultipleRetrievalConfig = ( reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true, } - if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) - result.reranking_mode = RerankingModeEnum.RerankingModel - - if (allHighQuality && !inconsistentEmbeddingModel && reranking_mode === undefined && allInternal) - result.reranking_mode = RerankingModeEnum.WeightedScore - - if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && !weights) { - if (!isValidRerankModel) - result.reranking_mode = RerankingModeEnum.WeightedScore - else - result.reranking_mode = RerankingModeEnum.RerankingModel + if (!rerankModelIsValid) + result.reranking_model = undefined + const setDefaultWeights = () => { result.weights = { vector_setting: { vector_weight: allHighQualityVectorSearch @@ -160,31 +153,85 @@ export const getMultipleRetrievalConfig = ( } } - if (shouldSetWeightDefaultValue && allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined || !isValidRerankModel) && allInternal && weights) { - if (!isValidRerankModel) - result.reranking_mode = RerankingModeEnum.WeightedScore - else - result.reranking_mode = RerankingModeEnum.RerankingModel + if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) { + result.reranking_mode = RerankingModeEnum.RerankingModel - result.weights = { - vector_setting: { - vector_weight: allHighQualityVectorSearch - ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic - : allHighQualityFullTextSearch - ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic - : DEFAULT_WEIGHTED_SCORE.other.semantic, - embedding_provider_name: selectedDatasets[0].embedding_model_provider, - embedding_model_name: selectedDatasets[0].embedding_model, - }, - keyword_setting: { - keyword_weight: allHighQualityVectorSearch - ? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword - : allHighQualityFullTextSearch - ? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword - : DEFAULT_WEIGHTED_SCORE.other.keyword, - }, + if (rerankModelIsValid) { + result.reranking_mode = RerankingModeEnum.RerankingModel + result.reranking_model = { + provider: validRerankModel?.provider || '', + model: validRerankModel?.model || '', + } + } + else { + result.reranking_model = undefined + } + } + + if (allHighQuality && !inconsistentEmbeddingModel && allInternal) { + if (!reranking_mode) { + if (validRerankModel?.provider && validRerankModel?.model) { + result.reranking_mode = RerankingModeEnum.RerankingModel + result.reranking_model = { + provider: validRerankModel.provider, + model: validRerankModel.model, + } + } + else { + result.reranking_mode = RerankingModeEnum.WeightedScore + setDefaultWeights() + } + } + + if (reranking_mode === RerankingModeEnum.WeightedScore && !weights) + setDefaultWeights() + + if (reranking_mode === RerankingModeEnum.WeightedScore && weights && shouldSetWeightDefaultValue) { + if (rerankModelIsValid) { + result.reranking_mode = RerankingModeEnum.RerankingModel + result.reranking_model = { + provider: validRerankModel.provider || '', + model: validRerankModel.model || '', + } + } + else { + setDefaultWeights() + } + } + + if (reranking_mode === RerankingModeEnum.RerankingModel && !rerankModelIsValid && shouldSetWeightDefaultValue) { + result.reranking_mode = RerankingModeEnum.WeightedScore + setDefaultWeights() } } return result } + +export const checkoutRerankModelConfigedInRetrievalSettings = ( + datasets: DataSet[], + multipleRetrievalConfig?: MultipleRetrievalConfig, +) => { + if (!multipleRetrievalConfig) + return true + + const { + allEconomic, + allExternal, + } = getSelectedDatasetsMode(datasets) + + const { + reranking_enable, + reranking_mode, + reranking_model, + } = multipleRetrievalConfig + + if (reranking_mode === RerankingModeEnum.RerankingModel && (!reranking_model?.provider || !reranking_model?.model)) { + if ((allEconomic || allExternal) && !reranking_enable) + return true + + return false + } + + return true +}