feat: claude api support (#572)

This commit is contained in:
John Wang 2023-07-17 00:14:19 +08:00 committed by GitHub
parent 510389909c
commit 7599f79a17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
52 changed files with 637 additions and 349 deletions

View File

@ -18,7 +18,8 @@ from models.model import Account
import secrets import secrets
import base64 import base64
from models.provider import Provider from models.provider import Provider, ProviderName
from services.provider_service import ProviderService
@click.command('reset-password', help='Reset the account password.') @click.command('reset-password', help='Reset the account password.')
@ -193,9 +194,40 @@ def recreate_all_dataset_indexes():
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green')) click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
def sync_anthropic_hosted_providers():
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
count = 0
page = 1
while True:
try:
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
except NotFound:
break
page += 1
for tenant in tenants:
try:
click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)
count += 1
except Exception as e:
click.echo(click.style('Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
continue
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
def register_commands(app): def register_commands(app):
app.cli.add_command(reset_password) app.cli.add_command(reset_password)
app.cli.add_command(reset_email) app.cli.add_command(reset_email)
app.cli.add_command(generate_invitation_codes) app.cli.add_command(generate_invitation_codes)
app.cli.add_command(reset_encrypt_key_pair) app.cli.add_command(reset_encrypt_key_pair)
app.cli.add_command(recreate_all_dataset_indexes) app.cli.add_command(recreate_all_dataset_indexes)
app.cli.add_command(sync_anthropic_hosted_providers)

View File

@ -51,6 +51,8 @@ DEFAULTS = {
'LOG_LEVEL': 'INFO', 'LOG_LEVEL': 'INFO',
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False', 'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
'DEFAULT_LLM_PROVIDER': 'openai', 'DEFAULT_LLM_PROVIDER': 'openai',
'OPENAI_HOSTED_QUOTA_LIMIT': 200,
'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
'TENANT_DOCUMENT_COUNT': 100 'TENANT_DOCUMENT_COUNT': 100
} }
@ -192,6 +194,10 @@ class Config:
# hosted provider credentials # hosted provider credentials
self.OPENAI_API_KEY = get_env('OPENAI_API_KEY') self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')
self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
# By default it is False # By default it is False
# You could disable it for compatibility with certain OpenAPI providers # You could disable it for compatibility with certain OpenAPI providers

View File

@ -50,8 +50,8 @@ class ChatMessageAudioApi(Resource):
raise UnsupportedAudioTypeError() raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError: except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError() raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:

View File

@ -63,8 +63,8 @@ class CompletionMessageApi(Resource):
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
@ -133,8 +133,8 @@ class ChatMessageApi(Resource):
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
@ -164,8 +164,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError: except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:

View File

@ -16,7 +16,7 @@ class ProviderNotInitializeError(BaseHTTPException):
class ProviderQuotaExceededError(BaseHTTPException): class ProviderQuotaExceededError(BaseHTTPException):
error_code = 'provider_quota_exceeded' error_code = 'provider_quota_exceeded'
description = "Your quota for Dify Hosted OpenAI has been exhausted. " \ description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
"Please go to Settings -> Model Provider to complete your own provider credentials." "Please go to Settings -> Model Provider to complete your own provider credentials."
code = 400 code = 400

View File

@ -27,8 +27,8 @@ class IntroductionGenerateApi(Resource):
account.current_tenant_id, account.current_tenant_id,
args['prompt_template'] args['prompt_template']
) )
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
@ -58,8 +58,8 @@ class RuleGenerateApi(Resource):
args['audiences'], args['audiences'],
args['hoping_to_solve'] args['hoping_to_solve']
) )
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:

View File

@ -269,8 +269,8 @@ class MessageMoreLikeThisApi(Resource):
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError: except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError() raise AppMoreLikeThisDisabledError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
@ -297,8 +297,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
except MoreLikeThisDisabledError: except MoreLikeThisDisabledError:
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError: except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
@ -339,8 +339,8 @@ class MessageSuggestedQuestionApi(Resource):
raise NotFound("Message not found") raise NotFound("Message not found")
except ConversationNotExistsError: except ConversationNotExistsError:
raise NotFound("Conversation not found") raise NotFound("Conversation not found")
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:

View File

@ -279,8 +279,8 @@ class DatasetDocumentListApi(Resource):
try: try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
@ -324,8 +324,8 @@ class DatasetInitApi(Resource):
document_data=args, document_data=args,
account=current_user account=current_user
) )
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:

View File

@ -95,8 +95,8 @@ class HitTestingApi(Resource):
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)} return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError: except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError() raise DatasetNotInitializedError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:

View File

@ -47,8 +47,8 @@ class ChatAudioApi(InstalledAppResource):
raise UnsupportedAudioTypeError() raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError: except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError() raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:

View File

@ -54,8 +54,8 @@ class CompletionApi(InstalledAppResource):
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
@ -113,8 +113,8 @@ class ChatApi(InstalledAppResource):
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
@ -155,8 +155,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError: except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:

View File

@ -107,8 +107,8 @@ class MessageMoreLikeThisApi(InstalledAppResource):
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError: except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError() raise AppMoreLikeThisDisabledError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
@ -135,8 +135,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
except MoreLikeThisDisabledError: except MoreLikeThisDisabledError:
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError: except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
@ -174,8 +174,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
raise NotFound("Conversation not found") raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError: except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError() raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:

View File

@ -3,6 +3,7 @@ import base64
import json import json
import logging import logging
from flask import current_app
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_restful import Resource, reqparse, abort from flask_restful import Resource, reqparse, abort
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -34,7 +35,7 @@ class ProviderListApi(Resource):
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
""" """
ProviderService.init_supported_provider(current_user.current_tenant, "cloud") ProviderService.init_supported_provider(current_user.current_tenant)
providers = Provider.query.filter_by(tenant_id=tenant_id).all() providers = Provider.query.filter_by(tenant_id=tenant_id).all()
provider_list = [ provider_list = [
@ -50,7 +51,8 @@ class ProviderListApi(Resource):
'quota_used': p.quota_used 'quota_used': p.quota_used
} if p.provider_type == ProviderType.SYSTEM.value else {}), } if p.provider_type == ProviderType.SYSTEM.value else {}),
'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant, 'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
ProviderName(p.provider_name)) ProviderName(p.provider_name), only_custom=True)
if p.provider_type == ProviderType.CUSTOM.value else None
} }
for p in providers for p in providers
] ]
@ -121,9 +123,10 @@ class ProviderTokenApi(Resource):
is_valid=token_is_valid) is_valid=token_is_valid)
db.session.add(provider_model) db.session.add(provider_model)
if provider_model.is_valid: if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid:
other_providers = db.session.query(Provider).filter( other_providers = db.session.query(Provider).filter(
Provider.tenant_id == tenant.id, Provider.tenant_id == tenant.id,
Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]),
Provider.provider_name != provider, Provider.provider_name != provider,
Provider.provider_type == ProviderType.CUSTOM.value Provider.provider_type == ProviderType.CUSTOM.value
).all() ).all()
@ -133,7 +136,7 @@ class ProviderTokenApi(Resource):
db.session.commit() db.session.commit()
if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]: ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201 return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
@ -157,7 +160,7 @@ class ProviderTokenValidateApi(Resource):
args = parser.parse_args() args = parser.parse_args()
# todo: remove this when the provider is supported # todo: remove this when the provider is supported
if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value, if provider in [ProviderName.COHERE.value,
ProviderName.HUGGINGFACEHUB.value]: ProviderName.HUGGINGFACEHUB.value]:
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'} return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
@ -203,7 +206,19 @@ class ProviderSystemApi(Resource):
provider_model.is_valid = args['is_enabled'] provider_model.is_valid = args['is_enabled']
db.session.commit() db.session.commit()
elif not provider_model: elif not provider_model:
ProviderService.create_system_provider(tenant, provider, args['is_enabled']) if provider == ProviderName.OPENAI.value:
quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
elif provider == ProviderName.ANTHROPIC.value:
quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
else:
quota_limit = 0
ProviderService.create_system_provider(
tenant,
provider,
quota_limit,
args['is_enabled']
)
else: else:
abort(403) abort(403)

View File

@ -43,8 +43,8 @@ class AudioApi(AppApiResource):
raise UnsupportedAudioTypeError() raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError: except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError() raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:

View File

@ -54,8 +54,8 @@ class CompletionApi(AppApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
@ -115,8 +115,8 @@ class ChatApi(AppApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
@ -156,8 +156,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError: except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:

View File

@ -85,8 +85,8 @@ class DocumentListApi(DatasetApiResource):
dataset_process_rule=dataset.latest_process_rule, dataset_process_rule=dataset.latest_process_rule,
created_from='api' created_from='api'
) )
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
document = documents[0] document = documents[0]
if doc_type and doc_metadata: if doc_type and doc_metadata:
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]

View File

@ -45,8 +45,8 @@ class AudioApi(WebApiResource):
raise UnsupportedAudioTypeError() raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError: except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError() raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:

View File

@ -52,8 +52,8 @@ class CompletionApi(WebApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
@ -109,8 +109,8 @@ class ChatApi(WebApiResource):
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
raise AppUnavailableError() raise AppUnavailableError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
@ -150,8 +150,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
except services.errors.app_model_config.AppModelConfigBrokenError: except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.") logging.exception("App model config broken.")
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError: except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:

View File

@ -101,8 +101,8 @@ class MessageMoreLikeThisApi(WebApiResource):
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")
except MoreLikeThisDisabledError: except MoreLikeThisDisabledError:
raise AppMoreLikeThisDisabledError() raise AppMoreLikeThisDisabledError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
@ -129,8 +129,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
except MoreLikeThisDisabledError: except MoreLikeThisDisabledError:
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
except QuotaExceededError: except QuotaExceededError:
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n" yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
@ -167,8 +167,8 @@ class MessageSuggestedQuestionApi(WebApiResource):
raise NotFound("Conversation not found") raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError: except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError() raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError() raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
raise ProviderQuotaExceededError() raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:

View File

@ -13,8 +13,13 @@ class HostedOpenAICredential(BaseModel):
api_key: str api_key: str
class HostedAnthropicCredential(BaseModel):
api_key: str
class HostedLLMCredentials(BaseModel): class HostedLLMCredentials(BaseModel):
openai: Optional[HostedOpenAICredential] = None openai: Optional[HostedOpenAICredential] = None
anthropic: Optional[HostedAnthropicCredential] = None
hosted_llm_credentials = HostedLLMCredentials() hosted_llm_credentials = HostedLLMCredentials()
@ -26,3 +31,6 @@ def init_app(app: Flask):
if app.config.get("OPENAI_API_KEY"): if app.config.get("OPENAI_API_KEY"):
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY")) hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
if app.config.get("ANTHROPIC_API_KEY"):
hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))

View File

@ -48,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
}) })
self.llm_message.prompt = real_prompts self.llm_message.prompt = real_prompts
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0]) self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0])
def on_llm_start( def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any

View File

@ -118,6 +118,7 @@ class Completion:
prompt, stop_words = cls.get_main_llm_prompt( prompt, stop_words = cls.get_main_llm_prompt(
mode=mode, mode=mode,
llm=final_llm, llm=final_llm,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt, pre_prompt=app_model_config.pre_prompt,
query=query, query=query,
inputs=inputs, inputs=inputs,
@ -129,6 +130,7 @@ class Completion:
cls.recale_llm_max_tokens( cls.recale_llm_max_tokens(
final_llm=final_llm, final_llm=final_llm,
model=app_model_config.model_dict,
prompt=prompt, prompt=prompt,
mode=mode mode=mode
) )
@ -138,7 +140,8 @@ class Completion:
return response return response
@classmethod @classmethod
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
pre_prompt: str, query: str, inputs: dict,
chain_output: Optional[str], chain_output: Optional[str],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]: Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
@ -151,10 +154,11 @@ class Completion:
if mode == 'completion': if mode == 'completion':
prompt_template = JinjaPromptTemplate.from_template( prompt_template = JinjaPromptTemplate.from_template(
template=("""Use the following CONTEXT as your learned knowledge: template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
[CONTEXT]
<context>
{{context}} {{context}}
[END CONTEXT] </context>
When answer to user: When answer to user:
- If you don't know, just say that you don't know. - If you don't know, just say that you don't know.
@ -204,10 +208,11 @@ And answer according to the language of the user's question.
if chain_output: if chain_output:
human_inputs['context'] = chain_output human_inputs['context'] = chain_output
human_message_prompt += """Use the following CONTEXT as your learned knowledge. human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
[CONTEXT]
<context>
{{context}} {{context}}
[END CONTEXT] </context>
When answer to user: When answer to user:
- If you don't know, just say that you don't know. - If you don't know, just say that you don't know.
@ -219,7 +224,7 @@ And answer according to the language of the user's question.
if pre_prompt: if pre_prompt:
human_message_prompt += pre_prompt human_message_prompt += pre_prompt
query_prompt = "\nHuman: {{query}}\nAI: " query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
if memory: if memory:
# append chat histories # append chat histories
@ -228,9 +233,11 @@ And answer according to the language of the user's question.
inputs=human_inputs inputs=human_inputs
) )
curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message]) curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \ model_name = model['name']
- memory.llm.max_tokens - curr_message_tokens max_tokens = model.get("completion_params").get('max_tokens')
rest_tokens = llm_constant.max_context_token_length[model_name] \
- max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0) rest_tokens = max(rest_tokens, 0)
histories = cls.get_history_messages_from_memory(memory, rest_tokens) histories = cls.get_history_messages_from_memory(memory, rest_tokens)
@ -241,7 +248,10 @@ And answer according to the language of the user's question.
# if histories_param not in human_inputs: # if histories_param not in human_inputs:
# human_inputs[histories_param] = '{{' + histories_param + '}}' # human_inputs[histories_param] = '{{' + histories_param + '}}'
human_message_prompt += "\n\n" + histories human_message_prompt += "\n\n" if human_message_prompt else ""
human_message_prompt += "Here is the chat histories between human and assistant, " \
"inside <histories></histories> XML tags.\n\n<histories>"
human_message_prompt += histories + "</histories>"
human_message_prompt += query_prompt human_message_prompt += query_prompt
@ -307,13 +317,15 @@ And answer according to the language of the user's question.
model=app_model_config.model_dict model=app_model_config.model_dict
) )
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name] model_name = app_model_config.model_dict.get("name")
max_tokens = llm.max_tokens model_limited_tokens = llm_constant.max_context_token_length[model_name]
max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
# get prompt without memory and context # get prompt without memory and context
prompt, _ = cls.get_main_llm_prompt( prompt, _ = cls.get_main_llm_prompt(
mode=mode, mode=mode,
llm=llm, llm=llm,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt, pre_prompt=app_model_config.pre_prompt,
query=query, query=query,
inputs=inputs, inputs=inputs,
@ -332,16 +344,17 @@ And answer according to the language of the user's question.
return rest_tokens return rest_tokens
@classmethod @classmethod
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI], def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
prompt: Union[str, List[BaseMessage]], mode: str): prompt: Union[str, List[BaseMessage]], mode: str):
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name] model_name = model.get("name")
max_tokens = final_llm.max_tokens model_limited_tokens = llm_constant.max_context_token_length[model_name]
max_tokens = model.get("completion_params").get('max_tokens')
if mode == 'completion' and isinstance(final_llm, BaseLLM): if mode == 'completion' and isinstance(final_llm, BaseLLM):
prompt_tokens = final_llm.get_num_tokens(prompt) prompt_tokens = final_llm.get_num_tokens(prompt)
else: else:
prompt_tokens = final_llm.get_messages_tokens(prompt) prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
if prompt_tokens + max_tokens > model_limited_tokens: if prompt_tokens + max_tokens > model_limited_tokens:
max_tokens = max(model_limited_tokens - prompt_tokens, 16) max_tokens = max(model_limited_tokens - prompt_tokens, 16)
@ -350,9 +363,10 @@ And answer according to the language of the user's question.
@classmethod @classmethod
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str, def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
app_model_config: AppModelConfig, user: Account, streaming: bool): app_model_config: AppModelConfig, user: Account, streaming: bool):
llm: StreamableOpenAI = LLMBuilder.to_llm(
llm = LLMBuilder.to_llm_from_model(
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
model_name='gpt-3.5-turbo', model=app_model_config.model_dict,
streaming=streaming streaming=streaming
) )
@ -360,6 +374,7 @@ And answer according to the language of the user's question.
original_prompt, _ = cls.get_main_llm_prompt( original_prompt, _ = cls.get_main_llm_prompt(
mode="completion", mode="completion",
llm=llm, llm=llm,
model=app_model_config.model_dict,
pre_prompt=pre_prompt, pre_prompt=pre_prompt,
query=message.query, query=message.query,
inputs=message.inputs, inputs=message.inputs,
@ -390,6 +405,7 @@ And answer according to the language of the user's question.
cls.recale_llm_max_tokens( cls.recale_llm_max_tokens(
final_llm=llm, final_llm=llm,
model=app_model_config.model_dict,
prompt=prompt, prompt=prompt,
mode='completion' mode='completion'
) )

View File

@ -1,6 +1,8 @@
from _decimal import Decimal from _decimal import Decimal
models = { models = {
'claude-instant-1': 'anthropic', # 100,000 tokens
'claude-2': 'anthropic', # 100,000 tokens
'gpt-4': 'openai', # 8,192 tokens 'gpt-4': 'openai', # 8,192 tokens
'gpt-4-32k': 'openai', # 32,768 tokens 'gpt-4-32k': 'openai', # 32,768 tokens
'gpt-3.5-turbo': 'openai', # 4,096 tokens 'gpt-3.5-turbo': 'openai', # 4,096 tokens
@ -10,10 +12,13 @@ models = {
'text-curie-001': 'openai', # 2,049 tokens 'text-curie-001': 'openai', # 2,049 tokens
'text-babbage-001': 'openai', # 2,049 tokens 'text-babbage-001': 'openai', # 2,049 tokens
'text-ada-001': 'openai', # 2,049 tokens 'text-ada-001': 'openai', # 2,049 tokens
'text-embedding-ada-002': 'openai' # 8191 tokens, 1536 dimensions 'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions
'whisper-1': 'openai'
} }
max_context_token_length = { max_context_token_length = {
'claude-instant-1': 100000,
'claude-2': 100000,
'gpt-4': 8192, 'gpt-4': 8192,
'gpt-4-32k': 32768, 'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096, 'gpt-3.5-turbo': 4096,
@ -23,17 +28,21 @@ max_context_token_length = {
'text-curie-001': 2049, 'text-curie-001': 2049,
'text-babbage-001': 2049, 'text-babbage-001': 2049,
'text-ada-001': 2049, 'text-ada-001': 2049,
'text-embedding-ada-002': 8191 'text-embedding-ada-002': 8191,
} }
models_by_mode = { models_by_mode = {
'chat': [ 'chat': [
'claude-instant-1', # 100,000 tokens
'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens 'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens 'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens 'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens 'gpt-3.5-turbo-16k', # 16,384 tokens
], ],
'completion': [ 'completion': [
'claude-instant-1', # 100,000 tokens
'claude-2', # 100,000 tokens
'gpt-4', # 8,192 tokens 'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens 'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo', # 4,096 tokens 'gpt-3.5-turbo', # 4,096 tokens
@ -52,6 +61,14 @@ models_by_mode = {
model_currency = 'USD' model_currency = 'USD'
model_prices = { model_prices = {
'claude-instant-1': {
'prompt': Decimal('0.00163'),
'completion': Decimal('0.00551'),
},
'claude-2': {
'prompt': Decimal('0.01102'),
'completion': Decimal('0.03268'),
},
'gpt-4': { 'gpt-4': {
'prompt': Decimal('0.03'), 'prompt': Decimal('0.03'),
'completion': Decimal('0.06'), 'completion': Decimal('0.06'),

View File

@ -56,7 +56,7 @@ class ConversationMessageTask:
) )
def init(self): def init(self):
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id) provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
self.model_dict['provider'] = provider_name self.model_dict['provider'] = provider_name
override_model_configs = None override_model_configs = None
@ -89,7 +89,7 @@ class ConversationMessageTask:
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs) system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
system_instruction = system_message.content system_instruction = system_message.content
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name) llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
system_instruction_tokens = llm.get_messages_tokens([system_message]) system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
if not self.conversation: if not self.conversation:
self.is_new_conversation = True self.is_new_conversation = True
@ -185,6 +185,7 @@ class ConversationMessageTask:
if provider and provider.provider_type == ProviderType.SYSTEM.value: if provider and provider.provider_type == ProviderType.SYSTEM.value:
db.session.query(Provider).filter( db.session.query(Provider).filter(
Provider.tenant_id == self.app.tenant_id, Provider.tenant_id == self.app.tenant_id,
Provider.provider_name == provider.provider_name,
Provider.quota_limit > Provider.quota_used Provider.quota_limit > Provider.quota_used
).update({'quota_used': Provider.quota_used + 1}) ).update({'quota_used': Provider.quota_used + 1})

View File

@ -4,6 +4,7 @@ from typing import List
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
from extensions.ext_database import db from extensions.ext_database import db
from libs import helper from libs import helper
from models.dataset import Embedding from models.dataset import Embedding
@ -49,6 +50,7 @@ class CacheEmbedding(Embeddings):
text_embeddings.extend(embedding_results) text_embeddings.extend(embedding_results)
return text_embeddings return text_embeddings
@handle_openai_exceptions
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""Embed query text.""" """Embed query text."""
# use doc embedding cache or store if not exists # use doc embedding cache or store if not exists

View File

@ -23,6 +23,10 @@ class LLMGenerator:
@classmethod @classmethod
def generate_conversation_name(cls, tenant_id: str, query, answer): def generate_conversation_name(cls, tenant_id: str, query, answer):
prompt = CONVERSATION_TITLE_PROMPT prompt = CONVERSATION_TITLE_PROMPT
if len(query) > 2000:
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
prompt = prompt.format(query=query) prompt = prompt.format(query=query)
llm: StreamableOpenAI = LLMBuilder.to_llm( llm: StreamableOpenAI = LLMBuilder.to_llm(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -52,7 +56,17 @@ class LLMGenerator:
if not message.answer: if not message.answer:
continue continue
message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n" if len(message.query) > 2000:
query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
else:
query = message.query
if len(message.answer) > 2000:
answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
else:
answer = message.answer
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0: if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
context += message_qa_text context += message_qa_text

View File

@ -17,7 +17,7 @@ class IndexBuilder:
model_credentials = LLMBuilder.get_model_credentials( model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id), model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002' model_name='text-embedding-ada-002'
) )

View File

@ -40,6 +40,9 @@ class ProviderTokenNotInitError(Exception):
""" """
description = "Provider Token Not Init" description = "Provider Token Not Init"
def __init__(self, *args, **kwargs):
self.description = args[0] if args else self.description
class QuotaExceededError(Exception): class QuotaExceededError(Exception):
""" """

View File

@ -8,9 +8,10 @@ from core.llm.provider.base import BaseProvider
from core.llm.provider.llm_provider_service import LLMProviderService from core.llm.provider.llm_provider_service import LLMProviderService
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI from core.llm.streamable_open_ai import StreamableOpenAI
from models.provider import ProviderType from models.provider import ProviderType, ProviderName
class LLMBuilder: class LLMBuilder:
@ -32,43 +33,43 @@ class LLMBuilder:
@classmethod @classmethod
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]: def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
provider = cls.get_default_provider(tenant_id) provider = cls.get_default_provider(tenant_id, model_name)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name) model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
llm_cls = None
mode = cls.get_mode_by_model(model_name) mode = cls.get_mode_by_model(model_name)
if mode == 'chat': if mode == 'chat':
if provider == 'openai': if provider == ProviderName.OPENAI.value:
llm_cls = StreamableChatOpenAI llm_cls = StreamableChatOpenAI
else: elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureChatOpenAI llm_cls = StreamableAzureChatOpenAI
elif provider == ProviderName.ANTHROPIC.value:
llm_cls = StreamableChatAnthropic
elif mode == 'completion': elif mode == 'completion':
if provider == 'openai': if provider == ProviderName.OPENAI.value:
llm_cls = StreamableOpenAI llm_cls = StreamableOpenAI
else: elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureOpenAI llm_cls = StreamableAzureOpenAI
else:
if not llm_cls:
raise ValueError(f"model name {model_name} is not supported.") raise ValueError(f"model name {model_name} is not supported.")
model_kwargs = { model_kwargs = {
'model_name': model_name,
'temperature': kwargs.get('temperature', 0),
'max_tokens': kwargs.get('max_tokens', 256),
'top_p': kwargs.get('top_p', 1), 'top_p': kwargs.get('top_p', 1),
'frequency_penalty': kwargs.get('frequency_penalty', 0), 'frequency_penalty': kwargs.get('frequency_penalty', 0),
'presence_penalty': kwargs.get('presence_penalty', 0), 'presence_penalty': kwargs.get('presence_penalty', 0),
'callbacks': kwargs.get('callbacks', None),
'streaming': kwargs.get('streaming', False),
} }
model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs} model_kwargs.update(model_credentials)
model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
return llm_cls( return llm_cls(**model_kwargs)
model_name=model_name,
temperature=kwargs.get('temperature', 0),
max_tokens=kwargs.get('max_tokens', 256),
**model_extras_kwargs,
callbacks=kwargs.get('callbacks', None),
streaming=kwargs.get('streaming', False),
# request_timeout=None
**model_credentials
)
@classmethod @classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False, def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
@ -118,14 +119,29 @@ class LLMBuilder:
return provider_service.get_credentials(model_name) return provider_service.get_credentials(model_name)
@classmethod @classmethod
def get_default_provider(cls, tenant_id: str) -> str: def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
provider = BaseProvider.get_valid_provider(tenant_id) provider_name = llm_constant.models[model_name]
if not provider:
raise ProviderTokenNotInitError()
if provider.provider_type == ProviderType.SYSTEM.value: if provider_name == 'openai':
provider_name = 'openai' # get the default provider (openai / azure_openai) for the tenant
else: openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
provider_name = provider.provider_name azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
provider = None
if openai_provider:
provider = openai_provider
elif azure_openai_provider:
provider = azure_openai_provider
if not provider:
raise ProviderTokenNotInitError(
f"No valid {provider_name} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
if provider.provider_type == ProviderType.SYSTEM.value:
provider_name = 'openai'
else:
provider_name = provider.provider_name
return provider_name return provider_name

View File

@ -1,23 +1,138 @@
from typing import Optional import json
import logging
from typing import Optional, Union
import anthropic
from langchain.chat_models import ChatAnthropic
from langchain.schema import HumanMessage
from core import hosted_llm_credentials
from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider from core.llm.provider.base import BaseProvider
from models.provider import ProviderName from core.llm.provider.errors import ValidateFailedError
from models.provider import ProviderName, ProviderType
class AnthropicProvider(BaseProvider): class AnthropicProvider(BaseProvider):
def get_models(self, model_id: Optional[str] = None) -> list[dict]: def get_models(self, model_id: Optional[str] = None) -> list[dict]:
credentials = self.get_credentials(model_id) return [
# todo {
return [] 'id': 'claude-instant-1',
'name': 'claude-instant-1',
},
{
'id': 'claude-2',
'name': 'claude-2',
},
]
def get_credentials(self, model_id: Optional[str] = None) -> dict: def get_credentials(self, model_id: Optional[str] = None) -> dict:
""" return self.get_provider_api_key(model_id=model_id)
Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id.
The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key.
"""
return {
'anthropic_api_key': self.get_provider_api_key(model_id=model_id)
}
def get_provider_name(self): def get_provider_name(self):
return ProviderName.ANTHROPIC return ProviderName.ANTHROPIC
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
"""
Returns the provider configs.
"""
try:
config = self.get_provider_api_key(only_custom=only_custom)
except:
config = {
'anthropic_api_key': ''
}
if obfuscated:
if not config.get('anthropic_api_key'):
config = {
'anthropic_api_key': ''
}
config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
return config
return config
def get_encrypted_token(self, config: Union[dict | str]):
"""
Returns the encrypted token.
"""
return json.dumps({
'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
})
def get_decrypted_token(self, token: str):
"""
Returns the decrypted token.
"""
config = json.loads(token)
config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
return config
def get_token_type(self):
return dict
def config_validate(self, config: Union[dict | str]):
"""
Validates the given config.
"""
# check OpenAI / Azure OpenAI credential is valid
openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
provider = None
if openai_provider:
provider = openai_provider
elif azure_openai_provider:
provider = azure_openai_provider
if not provider:
raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
if quota_used >= quota_limit:
raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
f"please configure OpenAI or Azure OpenAI provider first.")
try:
if not isinstance(config, dict):
raise ValueError('Config must be a object.')
if 'anthropic_api_key' not in config:
raise ValueError('anthropic_api_key must be provided.')
chat_llm = ChatAnthropic(
model='claude-instant-1',
anthropic_api_key=config['anthropic_api_key'],
max_tokens_to_sample=10,
temperature=0,
default_request_timeout=60
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except anthropic.APIConnectionError as ex:
raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
f"{ex.body['error']['type']}: {ex.body['error']['message']}")
except Exception as ex:
logging.exception('Anthropic config validation failed')
raise ex
def get_hosted_credentials(self) -> Union[str | dict]:
if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}

View File

@ -52,12 +52,12 @@ class AzureProvider(BaseProvider):
def get_provider_name(self): def get_provider_name(self):
return ProviderName.AZURE_OPENAI return ProviderName.AZURE_OPENAI
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
""" """
Returns the provider configs. Returns the provider configs.
""" """
try: try:
config = self.get_provider_api_key() config = self.get_provider_api_key(only_custom=only_custom)
except: except:
config = { config = {
'openai_api_type': 'azure', 'openai_api_type': 'azure',
@ -81,7 +81,6 @@ class AzureProvider(BaseProvider):
return config return config
def get_token_type(self): def get_token_type(self):
# TODO: change to dict when implemented
return dict return dict
def config_validate(self, config: Union[dict | str]): def config_validate(self, config: Union[dict | str]):

View File

@ -2,7 +2,7 @@ import base64
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional, Union from typing import Optional, Union
from core import hosted_llm_credentials from core.constant import llm_constant
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
from extensions.ext_database import db from extensions.ext_database import db
from libs import rsa from libs import rsa
@ -14,15 +14,18 @@ class BaseProvider(ABC):
def __init__(self, tenant_id: str): def __init__(self, tenant_id: str):
self.tenant_id = tenant_id self.tenant_id = tenant_id
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]: def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
""" """
Returns the decrypted API key for the given tenant_id and provider_name. Returns the decrypted API key for the given tenant_id and provider_name.
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError. If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
If the provider is not found or not valid, raises a ProviderTokenNotInitError. If the provider is not found or not valid, raises a ProviderTokenNotInitError.
""" """
provider = self.get_provider(prefer_custom) provider = self.get_provider(only_custom)
if not provider: if not provider:
raise ProviderTokenNotInitError() raise ProviderTokenNotInitError(
f"No valid {llm_constant.models[model_id]} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
if provider.provider_type == ProviderType.SYSTEM.value: if provider.provider_type == ProviderType.SYSTEM.value:
quota_used = provider.quota_used if provider.quota_used is not None else 0 quota_used = provider.quota_used if provider.quota_used is not None else 0
@ -38,18 +41,19 @@ class BaseProvider(ABC):
else: else:
return self.get_decrypted_token(provider.encrypted_config) return self.get_decrypted_token(provider.encrypted_config)
def get_provider(self, prefer_custom: bool) -> Optional[Provider]: def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
""" """
Returns the Provider instance for the given tenant_id and provider_name. Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
""" """
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom) return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
@classmethod @classmethod
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]: def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
Provider]:
""" """
Returns the Provider instance for the given tenant_id and provider_name. Returns the Provider instance for the given tenant_id and provider_name.
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. If both CUSTOM and System providers exist.
""" """
query = db.session.query(Provider).filter( query = db.session.query(Provider).filter(
Provider.tenant_id == tenant_id Provider.tenant_id == tenant_id
@ -58,39 +62,31 @@ class BaseProvider(ABC):
if provider_name: if provider_name:
query = query.filter(Provider.provider_name == provider_name) query = query.filter(Provider.provider_name == provider_name)
providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all() if only_custom:
query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
custom_provider = None providers = query.order_by(Provider.provider_type.asc()).all()
system_provider = None
for provider in providers: for provider in providers:
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config: if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
custom_provider = provider return provider
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid: elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
system_provider = provider return provider
if custom_provider: return None
return custom_provider
elif system_provider:
return system_provider
else:
return None
def get_hosted_credentials(self) -> str: def get_hosted_credentials(self) -> Union[str | dict]:
if self.get_provider_name() != ProviderName.OPENAI: raise ProviderTokenNotInitError(
raise ProviderTokenNotInitError() f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key: def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
raise ProviderTokenNotInitError()
return hosted_llm_credentials.openai.api_key
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
""" """
Returns the provider configs. Returns the provider configs.
""" """
try: try:
config = self.get_provider_api_key() config = self.get_provider_api_key(only_custom=only_custom)
except: except:
config = '' config = ''

View File

@ -31,11 +31,11 @@ class LLMProviderService:
def get_credentials(self, model_id: Optional[str] = None) -> dict: def get_credentials(self, model_id: Optional[str] = None) -> dict:
return self.provider.get_credentials(model_id) return self.provider.get_credentials(model_id)
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]: def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
return self.provider.get_provider_configs(obfuscated) return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]: def get_provider_db_record(self) -> Optional[Provider]:
return self.provider.get_provider(prefer_custom) return self.provider.get_provider()
def config_validate(self, config: Union[dict | str]): def config_validate(self, config: Union[dict | str]):
""" """

View File

@ -4,6 +4,8 @@ from typing import Optional, Union
import openai import openai
from openai.error import AuthenticationError, OpenAIError from openai.error import AuthenticationError, OpenAIError
from core import hosted_llm_credentials
from core.llm.error import ProviderTokenNotInitError
from core.llm.moderation import Moderation from core.llm.moderation import Moderation
from core.llm.provider.base import BaseProvider from core.llm.provider.base import BaseProvider
from core.llm.provider.errors import ValidateFailedError from core.llm.provider.errors import ValidateFailedError
@ -42,3 +44,12 @@ class OpenAIProvider(BaseProvider):
except Exception as ex: except Exception as ex:
logging.exception('OpenAI config validation failed') logging.exception('OpenAI config validation failed')
raise ex raise ex
def get_hosted_credentials(self) -> Union[str | dict]:
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
raise ProviderTokenNotInitError(
f"No valid {self.get_provider_name().value} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
return hosted_llm_credentials.openai.api_key

View File

@ -1,11 +1,11 @@
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMessage, ChatResult, LLMResult from langchain.schema import BaseMessage, LLMResult
from langchain.chat_models import AzureChatOpenAI from langchain.chat_models import AzureChatOpenAI
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
from pydantic import root_validator from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableAzureChatOpenAI(AzureChatOpenAI): class StreamableAzureChatOpenAI(AzureChatOpenAI):
@ -46,30 +46,7 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
"organization": self.openai_organization if self.openai_organization else None, "organization": self.openai_organization if self.openai_organization else None,
} }
def get_messages_tokens(self, messages: List[BaseMessage]) -> int: @handle_openai_exceptions
"""Get the number of tokens in a list of messages.
Args:
messages: The messages to count the tokens of.
Returns:
The number of tokens in the messages.
"""
tokens_per_message = 5
tokens_per_request = 3
message_tokens = tokens_per_request
message_strs = ''
for message in messages:
message_strs += message.content
message_tokens += tokens_per_message
# calc once
message_tokens += self.get_num_tokens(message_strs)
return message_tokens
@handle_llm_exceptions
def generate( def generate(
self, self,
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
@ -79,12 +56,18 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
) -> LLMResult: ) -> LLMResult:
return super().generate(messages, stop, callbacks, **kwargs) return super().generate(messages, stop, callbacks, **kwargs)
@handle_llm_exceptions_async @classmethod
async def agenerate( def get_kwargs_from_model_params(cls, params: dict):
self, model_kwargs = {
messages: List[List[BaseMessage]], 'top_p': params.get('top_p', 1),
stop: Optional[List[str]] = None, 'frequency_penalty': params.get('frequency_penalty', 0),
callbacks: Callbacks = None, 'presence_penalty': params.get('presence_penalty', 0),
**kwargs: Any, }
) -> LLMResult:
return await super().agenerate(messages, stop, callbacks, **kwargs) del params['top_p']
del params['frequency_penalty']
del params['presence_penalty']
params['model_kwargs'] = model_kwargs
return params

View File

@ -5,7 +5,7 @@ from typing import Optional, List, Dict, Mapping, Any
from pydantic import root_validator from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableAzureOpenAI(AzureOpenAI): class StreamableAzureOpenAI(AzureOpenAI):
@ -50,7 +50,7 @@ class StreamableAzureOpenAI(AzureOpenAI):
"organization": self.openai_organization if self.openai_organization else None, "organization": self.openai_organization if self.openai_organization else None,
}} }}
@handle_llm_exceptions @handle_openai_exceptions
def generate( def generate(
self, self,
prompts: List[str], prompts: List[str],
@ -60,12 +60,6 @@ class StreamableAzureOpenAI(AzureOpenAI):
) -> LLMResult: ) -> LLMResult:
return super().generate(prompts, stop, callbacks, **kwargs) return super().generate(prompts, stop, callbacks, **kwargs)
@handle_llm_exceptions_async @classmethod
async def agenerate( def get_kwargs_from_model_params(cls, params: dict):
self, return params
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(prompts, stop, callbacks, **kwargs)

View File

@ -0,0 +1,39 @@
from typing import List, Optional, Any, Dict
from langchain.callbacks.manager import Callbacks
from langchain.chat_models import ChatAnthropic
from langchain.schema import BaseMessage, LLMResult
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
class StreamableChatAnthropic(ChatAnthropic):
"""
Wrapper around Anthropic's large language model.
"""
@handle_anthropic_exceptions
def generate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> LLMResult:
return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
@classmethod
def get_kwargs_from_model_params(cls, params: dict):
params['model'] = params.get('model_name')
del params['model_name']
params['max_tokens_to_sample'] = params.get('max_tokens')
del params['max_tokens']
del params['frequency_penalty']
del params['presence_penalty']
return params

View File

@ -7,7 +7,7 @@ from typing import Optional, List, Dict, Any
from pydantic import root_validator from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableChatOpenAI(ChatOpenAI): class StreamableChatOpenAI(ChatOpenAI):
@ -48,30 +48,7 @@ class StreamableChatOpenAI(ChatOpenAI):
"organization": self.openai_organization if self.openai_organization else None, "organization": self.openai_organization if self.openai_organization else None,
} }
def get_messages_tokens(self, messages: List[BaseMessage]) -> int: @handle_openai_exceptions
"""Get the number of tokens in a list of messages.
Args:
messages: The messages to count the tokens of.
Returns:
The number of tokens in the messages.
"""
tokens_per_message = 5
tokens_per_request = 3
message_tokens = tokens_per_request
message_strs = ''
for message in messages:
message_strs += message.content
message_tokens += tokens_per_message
# calc once
message_tokens += self.get_num_tokens(message_strs)
return message_tokens
@handle_llm_exceptions
def generate( def generate(
self, self,
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
@ -81,12 +58,18 @@ class StreamableChatOpenAI(ChatOpenAI):
) -> LLMResult: ) -> LLMResult:
return super().generate(messages, stop, callbacks, **kwargs) return super().generate(messages, stop, callbacks, **kwargs)
@handle_llm_exceptions_async @classmethod
async def agenerate( def get_kwargs_from_model_params(cls, params: dict):
self, model_kwargs = {
messages: List[List[BaseMessage]], 'top_p': params.get('top_p', 1),
stop: Optional[List[str]] = None, 'frequency_penalty': params.get('frequency_penalty', 0),
callbacks: Callbacks = None, 'presence_penalty': params.get('presence_penalty', 0),
**kwargs: Any, }
) -> LLMResult:
return await super().agenerate(messages, stop, callbacks, **kwargs) del params['top_p']
del params['frequency_penalty']
del params['presence_penalty']
params['model_kwargs'] = model_kwargs
return params

View File

@ -6,7 +6,7 @@ from typing import Optional, List, Dict, Any, Mapping
from langchain import OpenAI from langchain import OpenAI
from pydantic import root_validator from pydantic import root_validator
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
class StreamableOpenAI(OpenAI): class StreamableOpenAI(OpenAI):
@ -49,7 +49,7 @@ class StreamableOpenAI(OpenAI):
"organization": self.openai_organization if self.openai_organization else None, "organization": self.openai_organization if self.openai_organization else None,
}} }}
@handle_llm_exceptions @handle_openai_exceptions
def generate( def generate(
self, self,
prompts: List[str], prompts: List[str],
@ -59,12 +59,6 @@ class StreamableOpenAI(OpenAI):
) -> LLMResult: ) -> LLMResult:
return super().generate(prompts, stop, callbacks, **kwargs) return super().generate(prompts, stop, callbacks, **kwargs)
@handle_llm_exceptions_async @classmethod
async def agenerate( def get_kwargs_from_model_params(cls, params: dict):
self, return params
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
return await super().agenerate(prompts, stop, callbacks, **kwargs)

View File

@ -1,6 +1,7 @@
import openai import openai
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
from models.provider import ProviderName from models.provider import ProviderName
from core.llm.error_handle_wraps import handle_llm_exceptions
from core.llm.provider.base import BaseProvider from core.llm.provider.base import BaseProvider
@ -13,7 +14,7 @@ class Whisper:
self.client = openai.Audio self.client = openai.Audio
self.credentials = provider.get_credentials() self.credentials = provider.get_credentials()
@handle_llm_exceptions @handle_openai_exceptions
def transcribe(self, file): def transcribe(self, file):
return self.client.transcribe( return self.client.transcribe(
model='whisper-1', model='whisper-1',

View File

@ -0,0 +1,27 @@
import logging
from functools import wraps
import anthropic
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
LLMBadRequestError
def handle_anthropic_exceptions(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except anthropic.APIConnectionError as e:
logging.exception("Failed to connect to Anthropic API.")
raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
except anthropic.RateLimitError:
raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
except anthropic.AuthenticationError as e:
raise LLMAuthorizationError(f"Anthropic: {e.message}")
except anthropic.BadRequestError as e:
raise LLMBadRequestError(f"Anthropic: {e.message}")
except anthropic.APIStatusError as e:
raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
return wrapper

View File

@ -7,7 +7,7 @@ from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRat
LLMBadRequestError LLMBadRequestError
def handle_llm_exceptions(func): def handle_openai_exceptions(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
try: try:
@ -29,27 +29,3 @@ def handle_llm_exceptions(func):
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e)) raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
return wrapper return wrapper
def handle_llm_exceptions_async(func):
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except openai.error.InvalidRequestError as e:
logging.exception("Invalid request to OpenAI API.")
raise LLMBadRequestError(str(e))
except openai.error.APIConnectionError as e:
logging.exception("Failed to connect to OpenAI API.")
raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
logging.exception("OpenAI service unavailable.")
raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
except openai.error.RateLimitError as e:
raise LLMRateLimitError(str(e))
except openai.error.AuthenticationError as e:
raise LLMAuthorizationError(str(e))
except openai.error.OpenAIError as e:
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
return wrapper

View File

@ -1,7 +1,7 @@
from typing import Any, List, Dict, Union from typing import Any, List, Dict, Union
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI from core.llm.streamable_open_ai import StreamableOpenAI
@ -12,8 +12,8 @@ from models.model import Conversation, Message
class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory): class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
conversation: Conversation conversation: Conversation
human_prefix: str = "Human" human_prefix: str = "Human"
ai_prefix: str = "AI" ai_prefix: str = "Assistant"
llm: Union[StreamableChatOpenAI | StreamableOpenAI] llm: BaseLanguageModel
memory_key: str = "chat_history" memory_key: str = "chat_history"
max_token_limit: int = 2000 max_token_limit: int = 2000
message_limit: int = 10 message_limit: int = 10
@ -38,12 +38,12 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
return chat_messages return chat_messages
# prune the chat message if it exceeds the max token limit # prune the chat message if it exceeds the max token limit
curr_buffer_length = self.llm.get_messages_tokens(chat_messages) curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
if curr_buffer_length > self.max_token_limit: if curr_buffer_length > self.max_token_limit:
pruned_memory = [] pruned_memory = []
while curr_buffer_length > self.max_token_limit and chat_messages: while curr_buffer_length > self.max_token_limit and chat_messages:
pruned_memory.append(chat_messages.pop(0)) pruned_memory.append(chat_messages.pop(0))
curr_buffer_length = self.llm.get_messages_tokens(chat_messages) curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
return chat_messages return chat_messages

View File

@ -30,7 +30,7 @@ class DatasetTool(BaseTool):
else: else:
model_credentials = LLMBuilder.get_model_credentials( model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id, tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id), model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002' model_name='text-embedding-ada-002'
) )
@ -60,7 +60,7 @@ class DatasetTool(BaseTool):
async def _arun(self, tool_input: str) -> str: async def _arun(self, tool_input: str) -> str:
model_credentials = LLMBuilder.get_model_credentials( model_credentials = LLMBuilder.get_model_credentials(
tenant_id=self.dataset.tenant_id, tenant_id=self.dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id), model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002' model_name='text-embedding-ada-002'
) )

View File

@ -1,4 +1,7 @@
from flask import current_app
from events.tenant_event import tenant_was_updated from events.tenant_event import tenant_was_updated
from models.provider import ProviderName
from services.provider_service import ProviderService from services.provider_service import ProviderService
@ -6,4 +9,16 @@ from services.provider_service import ProviderService
def handle(sender, **kwargs): def handle(sender, **kwargs):
tenant = sender tenant = sender
if tenant.status == 'normal': if tenant.status == 'normal':
ProviderService.create_system_provider(tenant) ProviderService.create_system_provider(
tenant,
ProviderName.OPENAI.value,
current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
True
)
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)

View File

@ -1,4 +1,7 @@
from flask import current_app
from events.tenant_event import tenant_was_created from events.tenant_event import tenant_was_created
from models.provider import ProviderName
from services.provider_service import ProviderService from services.provider_service import ProviderService
@ -6,4 +9,16 @@ from services.provider_service import ProviderService
def handle(sender, **kwargs): def handle(sender, **kwargs):
tenant = sender tenant = sender
if tenant.status == 'normal': if tenant.status == 'normal':
ProviderService.create_system_provider(tenant) ProviderService.create_system_provider(
tenant,
ProviderName.OPENAI.value,
current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
True
)
ProviderService.create_system_provider(
tenant,
ProviderName.ANTHROPIC.value,
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
True
)

View File

@ -10,7 +10,7 @@ flask-session2==1.3.1
flask-cors==3.0.10 flask-cors==3.0.10
gunicorn~=20.1.0 gunicorn~=20.1.0
gevent~=22.10.2 gevent~=22.10.2
langchain==0.0.209 langchain==0.0.230
openai~=0.27.5 openai~=0.27.5
psycopg2-binary~=2.9.6 psycopg2-binary~=2.9.6
pycryptodome==3.17 pycryptodome==3.17
@ -35,3 +35,4 @@ docx2txt==0.8
pypdfium2==4.16.0 pypdfium2==4.16.0
resend~=0.5.1 resend~=0.5.1
pyjwt~=2.6.0 pyjwt~=2.6.0
anthropic~=0.3.4

View File

@ -6,6 +6,30 @@ from models.account import Account
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from core.llm.llm_builder import LLMBuilder from core.llm.llm_builder import LLMBuilder
MODEL_PROVIDERS = [
'openai',
'anthropic',
]
MODELS_BY_APP_MODE = {
'chat': [
'claude-instant-1',
'claude-2',
'gpt-4',
'gpt-4-32k',
'gpt-3.5-turbo',
'gpt-3.5-turbo-16k',
],
'completion': [
'claude-instant-1',
'claude-2',
'gpt-4',
'gpt-4-32k',
'gpt-3.5-turbo',
'gpt-3.5-turbo-16k',
'text-davinci-003',
]
}
class AppModelConfigService: class AppModelConfigService:
@staticmethod @staticmethod
@ -125,7 +149,7 @@ class AppModelConfigService:
if not isinstance(config["speech_to_text"]["enabled"], bool): if not isinstance(config["speech_to_text"]["enabled"], bool):
raise ValueError("enabled in speech_to_text must be of boolean type") raise ValueError("enabled in speech_to_text must be of boolean type")
provider_name = LLMBuilder.get_default_provider(account.current_tenant_id) provider_name = LLMBuilder.get_default_provider(account.current_tenant_id, 'whisper-1')
if config["speech_to_text"]["enabled"] and provider_name != 'openai': if config["speech_to_text"]["enabled"] and provider_name != 'openai':
raise ValueError("provider not support speech to text") raise ValueError("provider not support speech to text")
@ -153,14 +177,14 @@ class AppModelConfigService:
raise ValueError("model must be of object type") raise ValueError("model must be of object type")
# model.provider # model.provider
if 'provider' not in config["model"] or config["model"]["provider"] != "openai": if 'provider' not in config["model"] or config["model"]["provider"] not in MODEL_PROVIDERS:
raise ValueError("model.provider must be 'openai'") raise ValueError(f"model.provider is required and must be in {str(MODEL_PROVIDERS)}")
# model.name # model.name
if 'name' not in config["model"]: if 'name' not in config["model"]:
raise ValueError("model.name is required") raise ValueError("model.name is required")
if config["model"]["name"] not in llm_constant.models_by_mode[mode]: if config["model"]["name"] not in MODELS_BY_APP_MODE[mode]:
raise ValueError("model.name must be in the specified model list") raise ValueError("model.name must be in the specified model list")
# model.completion_params # model.completion_params

View File

@ -27,7 +27,7 @@ class AudioService:
message = f"Audio size larger than {FILE_SIZE} mb" message = f"Audio size larger than {FILE_SIZE} mb"
raise AudioTooLargeServiceError(message) raise AudioTooLargeServiceError(message)
provider_name = LLMBuilder.get_default_provider(tenant_id) provider_name = LLMBuilder.get_default_provider(tenant_id, 'whisper-1')
if provider_name != ProviderName.OPENAI.value: if provider_name != ProviderName.OPENAI.value:
raise ProviderNotSupportSpeechToTextServiceError() raise ProviderNotSupportSpeechToTextServiceError()
@ -37,8 +37,3 @@ class AudioService:
buffer.name = 'temp.mp3' buffer.name = 'temp.mp3'
return Whisper(provider_service.provider).transcribe(buffer) return Whisper(provider_service.provider).transcribe(buffer)

View File

@ -31,7 +31,7 @@ class HitTestingService:
model_credentials = LLMBuilder.get_model_credentials( model_credentials = LLMBuilder.get_model_credentials(
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id), model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
model_name='text-embedding-ada-002' model_name='text-embedding-ada-002'
) )

View File

@ -10,50 +10,40 @@ from models.provider import *
class ProviderService: class ProviderService:
@staticmethod @staticmethod
def init_supported_provider(tenant, edition): def init_supported_provider(tenant):
"""Initialize the model provider, check whether the supported provider has a record""" """Initialize the model provider, check whether the supported provider has a record"""
providers = Provider.query.filter_by(tenant_id=tenant.id).all() need_init_provider_names = [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value, ProviderName.ANTHROPIC.value]
openai_provider_exists = False providers = db.session.query(Provider).filter(
azure_openai_provider_exists = False Provider.tenant_id == tenant.id,
Provider.provider_type == ProviderType.CUSTOM.value,
# TODO: The cloud version needs to construct the data of the SYSTEM type Provider.provider_name.in_(need_init_provider_names)
).all()
exists_provider_names = []
for provider in providers: for provider in providers:
if provider.provider_name == ProviderName.OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value: exists_provider_names.append(provider.provider_name)
openai_provider_exists = True
if provider.provider_name == ProviderName.AZURE_OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
azure_openai_provider_exists = True
# Initialize the model provider, check whether the supported provider has a record not_exists_provider_names = list(set(need_init_provider_names) - set(exists_provider_names))
# Create default providers if they don't exist if not_exists_provider_names:
if not openai_provider_exists: # Initialize the model provider, check whether the supported provider has a record
openai_provider = Provider( for provider_name in not_exists_provider_names:
tenant_id=tenant.id, provider = Provider(
provider_name=ProviderName.OPENAI.value, tenant_id=tenant.id,
provider_type=ProviderType.CUSTOM.value, provider_name=provider_name,
is_valid=False provider_type=ProviderType.CUSTOM.value,
) is_valid=False
db.session.add(openai_provider) )
db.session.add(provider)
if not azure_openai_provider_exists:
azure_openai_provider = Provider(
tenant_id=tenant.id,
provider_name=ProviderName.AZURE_OPENAI.value,
provider_type=ProviderType.CUSTOM.value,
is_valid=False
)
db.session.add(azure_openai_provider)
if not openai_provider_exists or not azure_openai_provider_exists:
db.session.commit() db.session.commit()
@staticmethod @staticmethod
def get_obfuscated_api_key(tenant, provider_name: ProviderName): def get_obfuscated_api_key(tenant, provider_name: ProviderName, only_custom: bool = False):
llm_provider_service = LLMProviderService(tenant.id, provider_name.value) llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
return llm_provider_service.get_provider_configs(obfuscated=True) return llm_provider_service.get_provider_configs(obfuscated=True, only_custom=only_custom)
@staticmethod @staticmethod
def get_token_type(tenant, provider_name: ProviderName): def get_token_type(tenant, provider_name: ProviderName):
@ -73,7 +63,7 @@ class ProviderService:
return llm_provider_service.get_encrypted_token(configs) return llm_provider_service.get_encrypted_token(configs)
@staticmethod @staticmethod
def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, quota_limit: int = 200,
is_valid: bool = True): is_valid: bool = True):
if current_app.config['EDITION'] != 'CLOUD': if current_app.config['EDITION'] != 'CLOUD':
return return
@ -90,7 +80,7 @@ class ProviderService:
provider_name=provider_name, provider_name=provider_name,
provider_type=ProviderType.SYSTEM.value, provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value, quota_type=ProviderQuotaType.TRIAL.value,
quota_limit=200, quota_limit=quota_limit,
encrypted_config='', encrypted_config='',
is_valid=is_valid, is_valid=is_valid,
) )

View File

@ -1,6 +1,6 @@
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Tenant from models.account import Tenant
from models.provider import Provider, ProviderType from models.provider import Provider, ProviderType, ProviderName
class WorkspaceService: class WorkspaceService:
@ -33,7 +33,7 @@ class WorkspaceService:
if provider.is_valid and provider.encrypted_config: if provider.is_valid and provider.encrypted_config:
custom_provider = provider custom_provider = provider
elif provider.provider_type == ProviderType.SYSTEM.value: elif provider.provider_type == ProviderType.SYSTEM.value:
if provider.is_valid: if provider.provider_name == ProviderName.OPENAI.value and provider.is_valid:
system_provider = provider system_provider = provider
if system_provider and not custom_provider: if system_provider and not custom_provider: