dify/api/services/completion_service.py

248 lines
9.9 KiB
Python
Raw Normal View History

2023-05-15 08:51:32 +08:00
import json
from typing import Generator, Union, Any
2023-05-15 08:51:32 +08:00
from sqlalchemy import and_
from core.application_manager import ApplicationManager
from core.entities.application_entities import InvokeFrom
from core.file.message_file_parser import MessageFileParser
2023-05-15 08:51:32 +08:00
from extensions.ext_database import db
from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
from services.app_model_config_service import AppModelConfigService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError
from services.errors.message import MessageNotExistsError
class CompletionService:
@classmethod
def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any,
invoke_from: InvokeFrom, streaming: bool = True,
is_model_config_override: bool = False) -> Union[dict, Generator]:
2023-05-15 08:51:32 +08:00
# is streaming mode
inputs = args['inputs']
query = args['query']
files = args['files'] if 'files' in args and args['files'] else []
auto_generate_name = args['auto_generate_name'] \
if 'auto_generate_name' in args else True
2023-05-30 12:24:51 +08:00
if app_model.mode != 'completion' and not query:
2023-05-30 12:24:51 +08:00
raise ValueError('query is required')
query = query.replace('\x00', '')
2023-05-15 08:51:32 +08:00
conversation_id = args['conversation_id'] if 'conversation_id' in args else None
conversation = None
if conversation_id:
conversation_filter = [
Conversation.id == args['conversation_id'],
Conversation.app_id == app_model.id,
Conversation.status == 'normal'
]
if isinstance(user, Account):
2023-05-15 08:51:32 +08:00
conversation_filter.append(Conversation.from_account_id == user.id)
else:
conversation_filter.append(Conversation.from_end_user_id == user.id if user else None)
conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first()
if not conversation:
raise ConversationNotExistsError()
if conversation.status != 'normal':
raise ConversationCompletedError()
if not conversation.override_model_configs:
app_model_config = db.session.query(AppModelConfig).filter(
AppModelConfig.id == conversation.app_model_config_id,
AppModelConfig.app_id == app_model.id
).first()
2023-05-15 08:51:32 +08:00
if not app_model_config:
raise AppModelConfigBrokenError()
else:
conversation_override_model_configs = json.loads(conversation.override_model_configs)
2023-05-15 08:51:32 +08:00
app_model_config = AppModelConfig(
id=conversation.app_model_config_id,
app_id=app_model.id,
)
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
2023-05-15 08:51:32 +08:00
if is_model_config_override:
# build new app model config
if 'model' not in args['model_config']:
raise ValueError('model_config.model is required')
if 'completion_params' not in args['model_config']['model']:
raise ValueError('model_config.model.completion_params is required')
completion_params = AppModelConfigService.validate_model_completion_params(
cp=args['model_config']['model']['completion_params'],
model_name=app_model_config.model_dict["name"]
)
app_model_config_model = app_model_config.model_dict
app_model_config_model['completion_params'] = completion_params
app_model_config.retriever_resource = json.dumps({'enabled': True})
2023-05-15 08:51:32 +08:00
app_model_config = app_model_config.copy()
app_model_config.model = json.dumps(app_model_config_model)
2023-05-15 08:51:32 +08:00
else:
if app_model.app_model_config_id is None:
raise AppModelConfigBrokenError()
app_model_config = app_model.app_model_config
if not app_model_config:
raise AppModelConfigBrokenError()
if is_model_config_override:
if not isinstance(user, Account):
raise Exception("Only account can override model config")
# validate config
model_config = AppModelConfigService.validate_configuration(
tenant_id=app_model.tenant_id,
2023-05-15 08:51:32 +08:00
account=user,
config=args['model_config'],
app_mode=app_model.mode
2023-05-15 08:51:32 +08:00
)
app_model_config = AppModelConfig(
id=app_model_config.id,
app_id=app_model.id,
)
app_model_config = app_model_config.from_model_config_dict(model_config)
2023-05-15 08:51:32 +08:00
# clean input by app_model_config form rules
inputs = cls.get_cleaned_inputs(inputs, app_model_config)
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
app_model_config,
user
)
application_manager = ApplicationManager()
return application_manager.generate(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_model_config_id=app_model_config.id,
app_model_config_dict=app_model_config.to_dict(),
app_model_config_override=is_model_config_override,
user=user,
invoke_from=invoke_from,
inputs=inputs,
query=query,
files=file_objs,
conversation=conversation,
stream=streaming,
extras={
"auto_generate_conversation_name": auto_generate_name
}
)
2023-05-15 08:51:32 +08:00
@classmethod
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
-> Union[dict, Generator]:
2023-05-15 08:51:32 +08:00
if not user:
raise ValueError('user cannot be None')
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
).first()
if not message:
raise MessageNotExistsError()
current_app_model_config = app_model.app_model_config
more_like_this = current_app_model_config.more_like_this_dict
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
raise MoreLikeThisDisabledError()
app_model_config = message.app_model_config
model_dict = app_model_config.model_dict
completion_params = model_dict.get('completion_params')
completion_params['temperature'] = 0.9
model_dict['completion_params'] = completion_params
app_model_config.model = json.dumps(model_dict)
2023-05-15 08:51:32 +08:00
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_objs = message_file_parser.transform_message_files(
message.files, app_model_config
)
application_manager = ApplicationManager()
return application_manager.generate(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
app_model_config_id=app_model_config.id,
app_model_config_dict=app_model_config.to_dict(),
app_model_config_override=True,
user=user,
invoke_from=invoke_from,
inputs=message.inputs,
query=message.query,
files=file_objs,
conversation=None,
stream=streaming,
extras={
"auto_generate_conversation_name": False
}
)
2023-05-15 08:51:32 +08:00
@classmethod
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
if user_inputs is None:
user_inputs = {}
filtered_inputs = {}
# Filter input variables from form configuration, handle required fields, default values, and option values
input_form_config = app_model_config.user_input_form_list
for config in input_form_config:
input_config = list(config.values())[0]
variable = input_config["variable"]
input_type = list(config.keys())[0]
if variable not in user_inputs or not user_inputs[variable]:
if "required" in input_config and input_config["required"]:
raise ValueError(f"{variable} is required in input form")
else:
filtered_inputs[variable] = input_config["default"] if "default" in input_config else ""
continue
value = user_inputs[variable]
if input_type == "select":
options = input_config["options"] if "options" in input_config else []
if value not in options:
raise ValueError(f"{variable} in input form must be one of the following: {options}")
else:
if 'max_length' in input_config:
max_length = input_config['max_length']
2023-05-15 08:51:32 +08:00
if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
filtered_inputs[variable] = value.replace('\x00', '') if value else None
2023-05-15 08:51:32 +08:00
return filtered_inputs