dify/api/services/completion_service.py

622 lines
25 KiB
Python
Raw Normal View History

2023-05-15 08:51:32 +08:00
import json
import logging
import threading
import time
import uuid
from typing import Generator, Union, Any, Optional, List
2023-05-15 08:51:32 +08:00
from flask import current_app, Flask
from redis.client import PubSub
from sqlalchemy import and_
from core.completion import Completion
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.file.message_file_parser import MessageFileParser
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, \
2023-05-15 08:51:32 +08:00
LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_providers.models.entity.message import PromptMessageFile
2023-05-15 08:51:32 +08:00
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
from services.app_model_config_service import AppModelConfigService
from services.errors.app import MoreLikeThisDisabledError
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.completion import CompletionStoppedError
from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError
from services.errors.message import MessageNotExistsError
class CompletionService:
@classmethod
def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any,
2023-05-15 08:51:32 +08:00
from_source: str, streaming: bool = True,
is_model_config_override: bool = False) -> Union[dict, Generator]:
2023-05-15 08:51:32 +08:00
# is streaming mode
inputs = args['inputs']
query = args['query']
files = args['files'] if 'files' in args and args['files'] else []
auto_generate_name = args['auto_generate_name'] \
if 'auto_generate_name' in args else True
2023-05-30 12:24:51 +08:00
if app_model.mode != 'completion' and not query:
2023-05-30 12:24:51 +08:00
raise ValueError('query is required')
query = query.replace('\x00', '')
2023-05-15 08:51:32 +08:00
conversation_id = args['conversation_id'] if 'conversation_id' in args else None
conversation = None
if conversation_id:
conversation_filter = [
Conversation.id == args['conversation_id'],
Conversation.app_id == app_model.id,
Conversation.status == 'normal'
]
if from_source == 'console':
conversation_filter.append(Conversation.from_account_id == user.id)
else:
conversation_filter.append(Conversation.from_end_user_id == user.id if user else None)
conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first()
if not conversation:
raise ConversationNotExistsError()
if conversation.status != 'normal':
raise ConversationCompletedError()
if not conversation.override_model_configs:
app_model_config = db.session.query(AppModelConfig).filter(
AppModelConfig.id == conversation.app_model_config_id,
AppModelConfig.app_id == app_model.id
).first()
2023-05-15 08:51:32 +08:00
if not app_model_config:
raise AppModelConfigBrokenError()
else:
conversation_override_model_configs = json.loads(conversation.override_model_configs)
2023-05-15 08:51:32 +08:00
app_model_config = AppModelConfig(
id=conversation.app_model_config_id,
app_id=app_model.id,
)
app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
2023-05-15 08:51:32 +08:00
if is_model_config_override:
# build new app model config
if 'model' not in args['model_config']:
raise ValueError('model_config.model is required')
if 'completion_params' not in args['model_config']['model']:
raise ValueError('model_config.model.completion_params is required')
completion_params = AppModelConfigService.validate_model_completion_params(
cp=args['model_config']['model']['completion_params'],
model_name=app_model_config.model_dict["name"]
)
app_model_config_model = app_model_config.model_dict
app_model_config_model['completion_params'] = completion_params
app_model_config.retriever_resource = json.dumps({'enabled': True})
2023-05-15 08:51:32 +08:00
app_model_config = app_model_config.copy()
app_model_config.model = json.dumps(app_model_config_model)
2023-05-15 08:51:32 +08:00
else:
if app_model.app_model_config_id is None:
raise AppModelConfigBrokenError()
app_model_config = app_model.app_model_config
if not app_model_config:
raise AppModelConfigBrokenError()
if is_model_config_override:
if not isinstance(user, Account):
raise Exception("Only account can override model config")
# validate config
model_config = AppModelConfigService.validate_configuration(
tenant_id=app_model.tenant_id,
2023-05-15 08:51:32 +08:00
account=user,
config=args['model_config'],
mode=app_model.mode
2023-05-15 08:51:32 +08:00
)
app_model_config = AppModelConfig(
id=app_model_config.id,
app_id=app_model.id,
)
app_model_config = app_model_config.from_model_config_dict(model_config)
2023-05-15 08:51:32 +08:00
# clean input by app_model_config form rules
inputs = cls.get_cleaned_inputs(inputs, app_model_config)
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
app_model_config,
user
)
2023-05-15 08:51:32 +08:00
generate_task_id = str(uuid.uuid4())
pubsub = redis_client.pubsub()
pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
user = cls.get_real_user_instead_of_proxy_obj(user)
generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'generate_task_id': generate_task_id,
'detached_app_model': app_model,
'app_model_config': app_model_config.copy(),
2023-05-15 08:51:32 +08:00
'query': query,
'inputs': inputs,
'files': file_objs,
'detached_user': user,
'detached_conversation': conversation,
2023-05-15 08:51:32 +08:00
'streaming': streaming,
'is_model_config_override': is_model_config_override,
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
'auto_generate_name': auto_generate_name,
'from_source': from_source
2023-05-15 08:51:32 +08:00
})
generate_worker_thread.start()
# wait for 10 minutes to close the thread
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
generate_task_id)
2023-05-15 08:51:32 +08:00
return cls.compact_response(pubsub, streaming)
@classmethod
def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]):
if isinstance(user, Account):
user = db.session.query(Account).filter(Account.id == user.id).first()
2023-05-15 08:51:32 +08:00
elif isinstance(user, EndUser):
user = db.session.query(EndUser).filter(EndUser.id == user.id).first()
2023-05-15 08:51:32 +08:00
else:
raise Exception("Unknown user type")
return user
@classmethod
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
app_model_config: AppModelConfig,
query: str, inputs: dict, files: List[PromptMessageFile],
detached_user: Union[Account, EndUser],
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
retriever_from: str = 'dev', auto_generate_name: bool = True, from_source: str = 'console'):
2023-05-15 08:51:32 +08:00
with flask_app.app_context():
# fixed the state of the model object when it detached from the original session
user = db.session.merge(detached_user)
app_model = db.session.merge(detached_app_model)
2023-05-15 08:51:32 +08:00
if detached_conversation:
conversation = db.session.merge(detached_conversation)
else:
conversation = None
try:
# run
2023-05-15 08:51:32 +08:00
Completion.generate(
task_id=generate_task_id,
app=app_model,
app_model_config=app_model_config,
query=query,
inputs=inputs,
user=user,
files=files,
2023-05-15 08:51:32 +08:00
conversation=conversation,
streaming=streaming,
is_override=is_model_config_override,
retriever_from=retriever_from,
auto_generate_name=auto_generate_name,
from_source=from_source
2023-05-15 08:51:32 +08:00
)
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
2023-05-15 08:51:32 +08:00
pass
except (ValueError, LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
2023-05-15 08:51:32 +08:00
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
ModelCurrentlyNotSupportError) as e:
PubHandler.pub_error(user, generate_task_id, e)
except LLMAuthorizationError:
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
except Exception as e:
logging.exception("Unknown Error in completion")
PubHandler.pub_error(user, generate_task_id, e)
finally:
db.session.remove()
2023-05-15 08:51:32 +08:00
@classmethod
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user,
generate_task_id) -> threading.Thread:
# wait for 10 minutes to close the thread
timeout = 600
2023-05-15 08:51:32 +08:00
def close_pubsub():
with flask_app.app_context():
try:
user = db.session.merge(detached_user)
sleep_iterations = 0
while sleep_iterations < timeout and worker_thread.is_alive():
if sleep_iterations > 0 and sleep_iterations % 10 == 0:
PubHandler.ping(user, generate_task_id)
time.sleep(1)
sleep_iterations += 1
if worker_thread.is_alive():
PubHandler.stop(user, generate_task_id)
try:
pubsub.close()
except Exception:
pass
finally:
db.session.remove()
2023-05-15 08:51:32 +08:00
countdown_thread = threading.Thread(target=close_pubsub)
countdown_thread.start()
return countdown_thread
@classmethod
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
message_id: str, streaming: bool = True,
retriever_from: str = 'dev') -> Union[dict, Generator]:
2023-05-15 08:51:32 +08:00
if not user:
raise ValueError('user cannot be None')
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
).first()
if not message:
raise MessageNotExistsError()
current_app_model_config = app_model.app_model_config
more_like_this = current_app_model_config.more_like_this_dict
if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
raise MoreLikeThisDisabledError()
app_model_config = message.app_model_config
model_dict = app_model_config.model_dict
completion_params = model_dict.get('completion_params')
completion_params['temperature'] = 0.9
model_dict['completion_params'] = completion_params
app_model_config.model = json.dumps(model_dict)
2023-05-15 08:51:32 +08:00
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_objs = message_file_parser.transform_message_files(
message.files, app_model_config
)
2023-05-15 08:51:32 +08:00
generate_task_id = str(uuid.uuid4())
pubsub = redis_client.pubsub()
pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
user = cls.get_real_user_instead_of_proxy_obj(user)
generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
2023-05-15 08:51:32 +08:00
'flask_app': current_app._get_current_object(),
'generate_task_id': generate_task_id,
'detached_app_model': app_model,
'app_model_config': app_model_config.copy(),
'query': message.query,
'inputs': message.inputs,
'files': file_objs,
'detached_user': user,
'detached_conversation': None,
'streaming': streaming,
'is_model_config_override': True,
'retriever_from': retriever_from,
'auto_generate_name': False
2023-05-15 08:51:32 +08:00
})
generate_worker_thread.start()
# wait for 10 minutes to close the thread
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
generate_task_id)
2023-05-15 08:51:32 +08:00
return cls.compact_response(pubsub, streaming)
@classmethod
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
if user_inputs is None:
user_inputs = {}
filtered_inputs = {}
# Filter input variables from form configuration, handle required fields, default values, and option values
input_form_config = app_model_config.user_input_form_list
for config in input_form_config:
input_config = list(config.values())[0]
variable = input_config["variable"]
input_type = list(config.keys())[0]
if variable not in user_inputs or not user_inputs[variable]:
if "required" in input_config and input_config["required"]:
raise ValueError(f"{variable} is required in input form")
else:
filtered_inputs[variable] = input_config["default"] if "default" in input_config else ""
continue
value = user_inputs[variable]
if input_type == "select":
options = input_config["options"] if "options" in input_config else []
if value not in options:
raise ValueError(f"{variable} in input form must be one of the following: {options}")
else:
if 'max_length' in input_config:
max_length = input_config['max_length']
2023-05-15 08:51:32 +08:00
if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
filtered_inputs[variable] = value.replace('\x00', '') if value else None
2023-05-15 08:51:32 +08:00
return filtered_inputs
@classmethod
def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict, Generator]:
2023-05-15 08:51:32 +08:00
generate_channel = list(pubsub.channels.keys())[0].decode('utf-8')
if not streaming:
try:
message_result = {}
2023-05-15 08:51:32 +08:00
for message in pubsub.listen():
if message["type"] == "message":
result = message["data"].decode('utf-8')
result = json.loads(result)
if result.get('error'):
cls.handle_error(result)
if result['event'] == 'annotation' and 'data' in result:
message_result['annotation'] = result.get('data')
return cls.get_blocking_annotation_message_response_data(message_result)
if result['event'] == 'message' and 'data' in result:
message_result['message'] = result.get('data')
if result['event'] == 'message_end' and 'data' in result:
message_result['message_end'] = result.get('data')
return cls.get_blocking_message_response_data(message_result)
2023-05-15 08:51:32 +08:00
except ValueError as e:
if e.args[0] != "I/O operation on closed file.": # ignore this error
raise CompletionStoppedError()
else:
logging.exception(e)
raise
finally:
db.session.remove()
2023-05-15 08:51:32 +08:00
try:
pubsub.unsubscribe(generate_channel)
except ConnectionError:
pass
else:
def generate() -> Generator:
try:
for message in pubsub.listen():
if message["type"] == "message":
result = message["data"].decode('utf-8')
result = json.loads(result)
if result.get('error'):
cls.handle_error(result)
event = result.get('event')
if event == "end":
logging.debug("{} finished".format(generate_channel))
break
if event == 'message':
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
elif event == 'message_replace':
yield "data: " + json.dumps(
cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
2023-05-15 08:51:32 +08:00
elif event == 'chain':
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
elif event == 'agent_thought':
yield "data: " + json.dumps(
cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
elif event == 'annotation':
yield "data: " + json.dumps(
cls.get_annotation_response_data(result.get('data'))) + "\n\n"
elif event == 'message_end':
yield "data: " + json.dumps(
cls.get_message_end_data(result.get('data'))) + "\n\n"
elif event == 'ping':
yield "event: ping\n\n"
else:
yield "data: " + json.dumps(result) + "\n\n"
2023-05-15 08:51:32 +08:00
except ValueError as e:
if e.args[0] != "I/O operation on closed file.": # ignore this error
logging.exception(e)
raise
finally:
db.session.remove()
2023-05-15 08:51:32 +08:00
try:
pubsub.unsubscribe(generate_channel)
except ConnectionError:
pass
return generate()
@classmethod
def get_message_response_data(cls, data: dict):
response_data = {
'event': 'message',
'task_id': data.get('task_id'),
'id': data.get('message_id'),
'answer': data.get('text'),
'created_at': int(time.time())
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def get_message_replace_response_data(cls, data: dict):
response_data = {
'event': 'message_replace',
'task_id': data.get('task_id'),
'id': data.get('message_id'),
'answer': data.get('text'),
'created_at': int(time.time())
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def get_blocking_message_response_data(cls, data: dict):
message = data.get('message')
response_data = {
'event': 'message',
'task_id': message.get('task_id'),
'id': message.get('message_id'),
'answer': message.get('text'),
'metadata': {},
'created_at': int(time.time())
}
if message.get('mode') == 'chat':
response_data['conversation_id'] = message.get('conversation_id')
if 'message_end' in data:
message_end = data.get('message_end')
if 'retriever_resources' in message_end:
response_data['metadata']['retriever_resources'] = message_end.get('retriever_resources')
return response_data
@classmethod
def get_blocking_annotation_message_response_data(cls, data: dict):
message = data.get('annotation')
response_data = {
'event': 'annotation',
'task_id': message.get('task_id'),
'id': message.get('message_id'),
'answer': message.get('text'),
'metadata': {},
'created_at': int(time.time()),
'annotation_id': message.get('annotation_id'),
'annotation_author_name': message.get('annotation_author_name')
}
if message.get('mode') == 'chat':
response_data['conversation_id'] = message.get('conversation_id')
return response_data
@classmethod
def get_message_end_data(cls, data: dict):
response_data = {
'event': 'message_end',
'task_id': data.get('task_id'),
'id': data.get('message_id')
}
if 'retriever_resources' in data:
response_data['retriever_resources'] = data.get('retriever_resources')
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
2023-05-15 08:51:32 +08:00
@classmethod
def get_chain_response_data(cls, data: dict):
response_data = {
'event': 'chain',
'id': data.get('chain_id'),
'task_id': data.get('task_id'),
'message_id': data.get('message_id'),
'type': data.get('type'),
'input': data.get('input'),
'output': data.get('output'),
'created_at': int(time.time())
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def get_agent_thought_response_data(cls, data: dict):
response_data = {
'event': 'agent_thought',
'id': data.get('id'),
2023-05-15 08:51:32 +08:00
'chain_id': data.get('chain_id'),
'task_id': data.get('task_id'),
'message_id': data.get('message_id'),
'position': data.get('position'),
'thought': data.get('thought'),
'tool': data.get('tool'),
2023-05-15 08:51:32 +08:00
'tool_input': data.get('tool_input'),
'created_at': int(time.time())
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def get_annotation_response_data(cls, data: dict):
response_data = {
'event': 'annotation',
'task_id': data.get('task_id'),
'id': data.get('message_id'),
'answer': data.get('text'),
'created_at': int(time.time()),
'annotation_id': data.get('annotation_id'),
'annotation_author_name': data.get('annotation_author_name'),
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
2023-05-15 08:51:32 +08:00
return response_data
@classmethod
def handle_error(cls, result: dict):
logging.debug("error: %s", result)
error = result.get('error')
description = result.get('description')
# handle errors
llm_errors = {
'ValueError': LLMBadRequestError,
2023-05-15 08:51:32 +08:00
'LLMBadRequestError': LLMBadRequestError,
'LLMAPIConnectionError': LLMAPIConnectionError,
'LLMAPIUnavailableError': LLMAPIUnavailableError,
'LLMRateLimitError': LLMRateLimitError,
'ProviderTokenNotInitError': ProviderTokenNotInitError,
'QuotaExceededError': QuotaExceededError,
'ModelCurrentlyNotSupportError': ModelCurrentlyNotSupportError
}
if error in llm_errors:
raise llm_errors[error](description)
elif error == 'LLMAuthorizationError':
raise LLMAuthorizationError('Incorrect API key provided')
else:
raise Exception(description)