dify/api/core/prompt/prompt_transform.py

591 lines
24 KiB
Python
Raw Normal View History

import enum
2023-10-18 20:02:52 +08:00
import json
import os
import re
from typing import Optional, cast
from core.entities.application_entities import (
AdvancedCompletionPromptTemplateEntity,
ModelConfigEntity,
PromptTemplateEntity,
)
from core.file.file_obj import FileObj
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageRole,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
2023-10-18 20:02:52 +08:00
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import PromptTemplateParser
2023-10-18 20:02:52 +08:00
class AppMode(enum.Enum):
COMPLETION = 'completion'
CHAT = 'chat'
@classmethod
def value_of(cls, value: str) -> 'AppMode':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
class ModelMode(enum.Enum):
COMPLETION = 'completion'
CHAT = 'chat'
@classmethod
def value_of(cls, value: str) -> 'ModelMode':
"""
Get value of given mode.
:param value: mode value
:return: mode
"""
for mode in cls:
if mode.value == value:
return mode
raise ValueError(f'invalid mode value {value}')
2023-10-18 20:02:52 +08:00
class PromptTransform:
def get_prompt(self,
app_mode: str,
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
2023-10-18 20:02:52 +08:00
query: str,
files: list[FileObj],
2023-10-18 20:02:52 +08:00
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) -> \
tuple[list[PromptMessage], Optional[list[str]]]:
app_mode = AppMode.value_of(app_mode)
model_mode = ModelMode.value_of(model_config.mode)
prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(
app_mode=app_mode,
provider=model_config.provider,
model=model_config.model
))
if app_mode == AppMode.CHAT and model_mode == ModelMode.CHAT:
stops = None
prompt_messages = self._get_simple_chat_app_chat_model_prompt_messages(
prompt_rules=prompt_rules,
pre_prompt=prompt_template_entity.simple_prompt_template,
inputs=inputs,
query=query,
files=files,
context=context,
memory=memory,
model_config=model_config
)
else:
stops = prompt_rules.get('stops')
if stops is not None and len(stops) == 0:
stops = None
prompt_messages = self._get_simple_others_prompt_messages(
prompt_rules=prompt_rules,
pre_prompt=prompt_template_entity.simple_prompt_template,
inputs=inputs,
query=query,
files=files,
context=context,
memory=memory,
model_config=model_config
)
return prompt_messages, stops
def get_advanced_prompt(self, app_mode: str,
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
query: str,
files: list[FileObj],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) -> list[PromptMessage]:
app_mode = AppMode.value_of(app_mode)
model_mode = ModelMode.value_of(model_config.mode)
2023-10-18 20:02:52 +08:00
prompt_messages = []
if app_mode == AppMode.CHAT:
if model_mode == ModelMode.COMPLETION:
prompt_messages = self._get_chat_app_completion_model_prompt_messages(
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query,
2024-01-12 12:29:13 +08:00
files=files,
context=context,
memory=memory,
model_config=model_config
)
elif model_mode == ModelMode.CHAT:
prompt_messages = self._get_chat_app_chat_model_prompt_messages(
prompt_template_entity=prompt_template_entity,
inputs=inputs,
query=query,
files=files,
context=context,
memory=memory,
model_config=model_config
)
elif app_mode == AppMode.COMPLETION:
if model_mode == ModelMode.CHAT:
prompt_messages = self._get_completion_app_chat_model_prompt_messages(
prompt_template_entity=prompt_template_entity,
inputs=inputs,
files=files,
context=context,
)
elif model_mode == ModelMode.COMPLETION:
prompt_messages = self._get_completion_app_completion_model_prompt_messages(
prompt_template_entity=prompt_template_entity,
inputs=inputs,
context=context,
)
2023-10-18 20:02:52 +08:00
return prompt_messages
def _get_history_messages_from_memory(self, memory: TokenBufferMemory,
max_token_limit: int,
human_prefix: Optional[str] = None,
ai_prefix: Optional[str] = None) -> str:
2023-10-18 20:02:52 +08:00
"""Get memory messages."""
kwargs = {
"max_token_limit": max_token_limit
}
if human_prefix:
kwargs['human_prefix'] = human_prefix
2023-10-18 20:02:52 +08:00
if ai_prefix:
kwargs['ai_prefix'] = ai_prefix
return memory.get_history_prompt_text(
**kwargs
)
def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory,
max_token_limit: int) -> list[PromptMessage]:
2023-10-18 20:02:52 +08:00
"""Get memory messages."""
return memory.get_history_prompt_messages(
max_token_limit=max_token_limit
)
def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:
2023-10-18 20:02:52 +08:00
# baichuan
if provider == 'baichuan':
return self._prompt_file_name_for_baichuan(app_mode)
2023-10-18 20:02:52 +08:00
baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"]
if provider in baichuan_supported_providers and 'baichuan' in model.lower():
return self._prompt_file_name_for_baichuan(app_mode)
2023-10-18 20:02:52 +08:00
# common
if app_mode == AppMode.COMPLETION:
2023-10-18 20:02:52 +08:00
return 'common_completion'
else:
return 'common_chat'
def _prompt_file_name_for_baichuan(self, app_mode: AppMode) -> str:
if app_mode == AppMode.COMPLETION:
2023-10-18 20:02:52 +08:00
return 'baichuan_completion'
else:
return 'baichuan_chat'
2023-10-18 20:02:52 +08:00
def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
# Get the absolute path of the subdirectory
prompt_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
'generate_prompts')
json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
# Open the JSON file and read its content
with open(json_file_path, encoding='utf-8') as json_file:
2023-10-18 20:02:52 +08:00
return json.load(json_file)
def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict,
pre_prompt: str,
inputs: dict,
query: str,
context: Optional[str],
files: list[FileObj],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) -> list[PromptMessage]:
prompt_messages = []
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format(
{'context': context}
)
pre_prompt_content = ''
if pre_prompt:
prompt_template = PromptTemplateParser(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
pre_prompt_content = prompt_template.format(
prompt_inputs
)
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += pre_prompt_content
prompt = re.sub(r'<\|.*?\|>', '', prompt)
if prompt:
prompt_messages.append(SystemPromptMessage(content=prompt))
self._append_chat_histories(
memory=memory,
prompt_messages=prompt_messages,
model_config=model_config
)
if files:
prompt_message_contents = [TextPromptMessageContent(data=query)]
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _get_simple_others_prompt_messages(self, prompt_rules: dict,
pre_prompt: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[TokenBufferMemory],
files: list[FileObj],
model_config: ModelConfigEntity) -> list[PromptMessage]:
2023-10-18 20:02:52 +08:00
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format(
{'context': context}
)
pre_prompt_content = ''
if pre_prompt:
prompt_template = PromptTemplateParser(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
pre_prompt_content = prompt_template.format(
prompt_inputs
)
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += pre_prompt_content
query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
if memory and 'histories_prompt' in prompt_rules:
# append chat histories
tmp_human_message = UserPromptMessage(
content=PromptBuilder.parse_prompt(
prompt=prompt + query_prompt,
inputs={
'query': query
}
)
2023-10-18 20:02:52 +08:00
)
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
2023-10-18 20:02:52 +08:00
histories = self._get_history_messages_from_memory(
memory=memory,
max_token_limit=rest_tokens,
ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
)
2023-10-18 20:02:52 +08:00
prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format({'histories': histories})
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
elif order == 'histories_prompt':
prompt += histories_prompt_content
prompt_template = PromptTemplateParser(template=query_prompt)
query_prompt_content = prompt_template.format({'query': query})
prompt += query_prompt_content
prompt = re.sub(r'<\|.*?\|>', '', prompt)
model_mode = ModelMode.value_of(model_config.mode)
if model_mode == ModelMode.CHAT and files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_message = UserPromptMessage(content=prompt_message_contents)
else:
2024-01-12 12:29:13 +08:00
if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_message = UserPromptMessage(content=prompt_message_contents)
else:
prompt_message = UserPromptMessage(content=prompt)
return [prompt_message]
2023-10-18 20:02:52 +08:00
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
if '#context#' in prompt_template.variable_keys:
if context:
prompt_inputs['#context#'] = context
2023-10-18 20:02:52 +08:00
else:
prompt_inputs['#context#'] = ''
def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
if '#query#' in prompt_template.variable_keys:
if query:
prompt_inputs['#query#'] = query
else:
prompt_inputs['#query#'] = ''
def _set_histories_variable(self, memory: TokenBufferMemory,
raw_prompt: str,
role_prefix: AdvancedCompletionPromptTemplateEntity.RolePrefixEntity,
prompt_template: PromptTemplateParser,
prompt_inputs: dict,
model_config: ModelConfigEntity) -> None:
2023-10-18 20:02:52 +08:00
if '#histories#' in prompt_template.variable_keys:
if memory:
tmp_human_message = UserPromptMessage(
content=PromptBuilder.parse_prompt(
prompt=raw_prompt,
inputs={'#histories#': '', **prompt_inputs}
)
2023-10-18 20:02:52 +08:00
)
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
histories = self._get_history_messages_from_memory(
memory=memory,
max_token_limit=rest_tokens,
human_prefix=role_prefix.user,
ai_prefix=role_prefix.assistant
)
2023-10-18 20:02:52 +08:00
prompt_inputs['#histories#'] = histories
else:
prompt_inputs['#histories#'] = ''
def _append_chat_histories(self, memory: TokenBufferMemory,
prompt_messages: list[PromptMessage],
model_config: ModelConfigEntity) -> None:
2023-10-18 20:02:52 +08:00
if memory:
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
2023-10-18 20:02:52 +08:00
histories = self._get_history_messages_list_from_memory(memory, rest_tokens)
prompt_messages.extend(histories)
def _calculate_rest_token(self, prompt_messages: list[PromptMessage], model_config: ModelConfigEntity) -> int:
2023-10-18 20:02:52 +08:00
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
curr_message_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
prompt_messages
)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
2023-10-18 20:02:52 +08:00
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _format_prompt(self, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> str:
prompt = prompt_template.format(
prompt_inputs
)
prompt = re.sub(r'<\|.*?\|>', '', prompt)
return prompt
def _get_chat_app_completion_model_prompt_messages(self,
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
query: str,
files: list[FileObj],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) -> list[PromptMessage]:
raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt
role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix
2023-10-18 20:02:52 +08:00
prompt_messages = []
2023-10-18 20:02:52 +08:00
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
self._set_context_variable(context, prompt_template, prompt_inputs)
self._set_query_variable(query, prompt_template, prompt_inputs)
self._set_histories_variable(
memory=memory,
raw_prompt=raw_prompt,
role_prefix=role_prefix,
prompt_template=prompt_template,
prompt_inputs=prompt_inputs,
model_config=model_config
)
2023-10-18 20:02:52 +08:00
prompt = self._format_prompt(prompt_template, prompt_inputs)
2024-01-12 12:29:13 +08:00
if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=prompt))
2023-10-18 20:02:52 +08:00
return prompt_messages
def _get_chat_app_chat_model_prompt_messages(self,
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
query: str,
files: list[FileObj],
context: Optional[str],
memory: Optional[TokenBufferMemory],
model_config: ModelConfigEntity) -> list[PromptMessage]:
raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages
2023-10-18 20:02:52 +08:00
prompt_messages = []
for prompt_item in raw_prompt_list:
raw_prompt = prompt_item.text
2023-10-18 20:02:52 +08:00
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
self._set_context_variable(context, prompt_template, prompt_inputs)
prompt = self._format_prompt(prompt_template, prompt_inputs)
if prompt_item.role == PromptMessageRole.USER:
prompt_messages.append(UserPromptMessage(content=prompt))
elif prompt_item.role == PromptMessageRole.SYSTEM and prompt:
prompt_messages.append(SystemPromptMessage(content=prompt))
elif prompt_item.role == PromptMessageRole.ASSISTANT:
prompt_messages.append(AssistantPromptMessage(content=prompt))
self._append_chat_histories(memory, prompt_messages, model_config)
if files:
prompt_message_contents = [TextPromptMessageContent(data=query)]
for file in files:
prompt_message_contents.append(file.prompt_message_content)
2023-10-18 20:02:52 +08:00
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
2023-10-18 20:02:52 +08:00
return prompt_messages
def _get_completion_app_completion_model_prompt_messages(self,
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
context: Optional[str]) -> list[PromptMessage]:
raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt
2023-10-18 20:02:52 +08:00
prompt_messages = []
2023-10-18 20:02:52 +08:00
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
self._set_context_variable(context, prompt_template, prompt_inputs)
prompt = self._format_prompt(prompt_template, prompt_inputs)
prompt_messages.append(UserPromptMessage(content=prompt))
2023-10-18 20:02:52 +08:00
return prompt_messages
def _get_completion_app_chat_model_prompt_messages(self,
prompt_template_entity: PromptTemplateEntity,
inputs: dict,
files: list[FileObj],
context: Optional[str]) -> list[PromptMessage]:
raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages
2023-10-18 20:02:52 +08:00
prompt_messages = []
for prompt_item in raw_prompt_list:
raw_prompt = prompt_item.text
2023-10-18 20:02:52 +08:00
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
self._set_context_variable(context, prompt_template, prompt_inputs)
prompt = self._format_prompt(prompt_template, prompt_inputs)
if prompt_item.role == PromptMessageRole.USER:
prompt_messages.append(UserPromptMessage(content=prompt))
elif prompt_item.role == PromptMessageRole.SYSTEM and prompt:
prompt_messages.append(SystemPromptMessage(content=prompt))
elif prompt_item.role == PromptMessageRole.ASSISTANT:
prompt_messages.append(AssistantPromptMessage(content=prompt))
for prompt_message in prompt_messages[::-1]:
if prompt_message.role == PromptMessageRole.USER:
if files:
prompt_message_contents = [TextPromptMessageContent(data=prompt_message.content)]
for file in files:
prompt_message_contents.append(file.prompt_message_content)
prompt_message.content = prompt_message_contents
break
return prompt_messages