mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-12 12:25:09 +08:00
7753ba2d37
Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: jyong <718720800@qq.com>
84 lines
3.8 KiB
Python
84 lines
3.8 KiB
Python
from typing import Optional, cast
|
|
|
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
|
from core.model_runtime.entities.message_entities import PromptMessage
|
|
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
|
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
|
|
|
|
|
class PromptTransform:
|
|
def _append_chat_histories(self, memory: TokenBufferMemory,
|
|
memory_config: MemoryConfig,
|
|
prompt_messages: list[PromptMessage],
|
|
model_config: ModelConfigWithCredentialsEntity) -> list[PromptMessage]:
|
|
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
|
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
|
|
prompt_messages.extend(histories)
|
|
|
|
return prompt_messages
|
|
|
|
def _calculate_rest_token(self, prompt_messages: list[PromptMessage],
|
|
model_config: ModelConfigWithCredentialsEntity) -> int:
|
|
rest_tokens = 2000
|
|
|
|
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
|
if model_context_tokens:
|
|
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
|
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
|
|
|
curr_message_tokens = model_type_instance.get_num_tokens(
|
|
model_config.model,
|
|
model_config.credentials,
|
|
prompt_messages
|
|
)
|
|
|
|
max_tokens = 0
|
|
for parameter_rule in model_config.model_schema.parameter_rules:
|
|
if (parameter_rule.name == 'max_tokens'
|
|
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
|
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
|
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
|
|
|
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
|
rest_tokens = max(rest_tokens, 0)
|
|
|
|
return rest_tokens
|
|
|
|
def _get_history_messages_from_memory(self, memory: TokenBufferMemory,
|
|
memory_config: MemoryConfig,
|
|
max_token_limit: int,
|
|
human_prefix: Optional[str] = None,
|
|
ai_prefix: Optional[str] = None) -> str:
|
|
"""Get memory messages."""
|
|
kwargs = {
|
|
"max_token_limit": max_token_limit
|
|
}
|
|
|
|
if human_prefix:
|
|
kwargs['human_prefix'] = human_prefix
|
|
|
|
if ai_prefix:
|
|
kwargs['ai_prefix'] = ai_prefix
|
|
|
|
if memory_config.window.enabled and memory_config.window.size is not None and memory_config.window.size > 0:
|
|
kwargs['message_limit'] = memory_config.window.size
|
|
|
|
return memory.get_history_prompt_text(
|
|
**kwargs
|
|
)
|
|
|
|
def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory,
|
|
memory_config: MemoryConfig,
|
|
max_token_limit: int) -> list[PromptMessage]:
|
|
"""Get memory messages."""
|
|
return memory.get_history_prompt_messages(
|
|
max_token_limit=max_token_limit,
|
|
message_limit=memory_config.window.size
|
|
if (memory_config.window.enabled
|
|
and memory_config.window.size is not None
|
|
and memory_config.window.size > 0)
|
|
else 10
|
|
)
|