mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-11-30 10:18:13 +08:00
feat: claude api support (#572)
This commit is contained in:
parent
510389909c
commit
7599f79a17
@ -18,7 +18,8 @@ from models.model import Account
|
|||||||
import secrets
|
import secrets
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
from models.provider import Provider
|
from models.provider import Provider, ProviderName
|
||||||
|
from services.provider_service import ProviderService
|
||||||
|
|
||||||
|
|
||||||
@click.command('reset-password', help='Reset the account password.')
|
@click.command('reset-password', help='Reset the account password.')
|
||||||
@ -193,9 +194,40 @@ def recreate_all_dataset_indexes():
|
|||||||
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
|
click.echo(click.style('Congratulations! Recreate {} dataset indexes.'.format(recreate_count), fg='green'))
|
||||||
|
|
||||||
|
|
||||||
|
@click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.')
|
||||||
|
def sync_anthropic_hosted_providers():
|
||||||
|
click.echo(click.style('Start sync anthropic hosted providers.', fg='green'))
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
page = 1
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50)
|
||||||
|
except NotFound:
|
||||||
|
break
|
||||||
|
|
||||||
|
page += 1
|
||||||
|
for tenant in tenants:
|
||||||
|
try:
|
||||||
|
click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id))
|
||||||
|
ProviderService.create_system_provider(
|
||||||
|
tenant,
|
||||||
|
ProviderName.ANTHROPIC.value,
|
||||||
|
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
|
||||||
|
True
|
||||||
|
)
|
||||||
|
count += 1
|
||||||
|
except Exception as e:
|
||||||
|
click.echo(click.style('Sync tenant anthropic hosted provider error: {} {}'.format(e.__class__.__name__, str(e)), fg='red'))
|
||||||
|
continue
|
||||||
|
|
||||||
|
click.echo(click.style('Congratulations! Synced {} anthropic hosted providers.'.format(count), fg='green'))
|
||||||
|
|
||||||
|
|
||||||
def register_commands(app):
|
def register_commands(app):
|
||||||
app.cli.add_command(reset_password)
|
app.cli.add_command(reset_password)
|
||||||
app.cli.add_command(reset_email)
|
app.cli.add_command(reset_email)
|
||||||
app.cli.add_command(generate_invitation_codes)
|
app.cli.add_command(generate_invitation_codes)
|
||||||
app.cli.add_command(reset_encrypt_key_pair)
|
app.cli.add_command(reset_encrypt_key_pair)
|
||||||
app.cli.add_command(recreate_all_dataset_indexes)
|
app.cli.add_command(recreate_all_dataset_indexes)
|
||||||
|
app.cli.add_command(sync_anthropic_hosted_providers)
|
||||||
|
@ -51,6 +51,8 @@ DEFAULTS = {
|
|||||||
'LOG_LEVEL': 'INFO',
|
'LOG_LEVEL': 'INFO',
|
||||||
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
|
'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False',
|
||||||
'DEFAULT_LLM_PROVIDER': 'openai',
|
'DEFAULT_LLM_PROVIDER': 'openai',
|
||||||
|
'OPENAI_HOSTED_QUOTA_LIMIT': 200,
|
||||||
|
'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000,
|
||||||
'TENANT_DOCUMENT_COUNT': 100
|
'TENANT_DOCUMENT_COUNT': 100
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -192,6 +194,10 @@ class Config:
|
|||||||
|
|
||||||
# hosted provider credentials
|
# hosted provider credentials
|
||||||
self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
|
self.OPENAI_API_KEY = get_env('OPENAI_API_KEY')
|
||||||
|
self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY')
|
||||||
|
|
||||||
|
self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT')
|
||||||
|
self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT')
|
||||||
|
|
||||||
# By default it is False
|
# By default it is False
|
||||||
# You could disable it for compatibility with certain OpenAPI providers
|
# You could disable it for compatibility with certain OpenAPI providers
|
||||||
|
@ -50,8 +50,8 @@ class ChatMessageAudioApi(Resource):
|
|||||||
raise UnsupportedAudioTypeError()
|
raise UnsupportedAudioTypeError()
|
||||||
except ProviderNotSupportSpeechToTextServiceError:
|
except ProviderNotSupportSpeechToTextServiceError:
|
||||||
raise ProviderNotSupportSpeechToTextError()
|
raise ProviderNotSupportSpeechToTextError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
|
@ -63,8 +63,8 @@ class CompletionMessageApi(Resource):
|
|||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
logging.exception("App model config broken.")
|
logging.exception("App model config broken.")
|
||||||
raise AppUnavailableError()
|
raise AppUnavailableError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
@ -133,8 +133,8 @@ class ChatMessageApi(Resource):
|
|||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
logging.exception("App model config broken.")
|
logging.exception("App model config broken.")
|
||||||
raise AppUnavailableError()
|
raise AppUnavailableError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
@ -164,8 +164,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
|||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
logging.exception("App model config broken.")
|
logging.exception("App model config broken.")
|
||||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
|
@ -16,7 +16,7 @@ class ProviderNotInitializeError(BaseHTTPException):
|
|||||||
|
|
||||||
class ProviderQuotaExceededError(BaseHTTPException):
|
class ProviderQuotaExceededError(BaseHTTPException):
|
||||||
error_code = 'provider_quota_exceeded'
|
error_code = 'provider_quota_exceeded'
|
||||||
description = "Your quota for Dify Hosted OpenAI has been exhausted. " \
|
description = "Your quota for Dify Hosted Model Provider has been exhausted. " \
|
||||||
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
"Please go to Settings -> Model Provider to complete your own provider credentials."
|
||||||
code = 400
|
code = 400
|
||||||
|
|
||||||
|
@ -27,8 +27,8 @@ class IntroductionGenerateApi(Resource):
|
|||||||
account.current_tenant_id,
|
account.current_tenant_id,
|
||||||
args['prompt_template']
|
args['prompt_template']
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
@ -58,8 +58,8 @@ class RuleGenerateApi(Resource):
|
|||||||
args['audiences'],
|
args['audiences'],
|
||||||
args['hoping_to_solve']
|
args['hoping_to_solve']
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
|
@ -269,8 +269,8 @@ class MessageMoreLikeThisApi(Resource):
|
|||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
except MoreLikeThisDisabledError:
|
except MoreLikeThisDisabledError:
|
||||||
raise AppMoreLikeThisDisabledError()
|
raise AppMoreLikeThisDisabledError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
@ -297,8 +297,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
|||||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||||
except MoreLikeThisDisabledError:
|
except MoreLikeThisDisabledError:
|
||||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
@ -339,8 +339,8 @@ class MessageSuggestedQuestionApi(Resource):
|
|||||||
raise NotFound("Message not found")
|
raise NotFound("Message not found")
|
||||||
except ConversationNotExistsError:
|
except ConversationNotExistsError:
|
||||||
raise NotFound("Conversation not found")
|
raise NotFound("Conversation not found")
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
|
@ -279,8 +279,8 @@ class DatasetDocumentListApi(Resource):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user)
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
@ -324,8 +324,8 @@ class DatasetInitApi(Resource):
|
|||||||
document_data=args,
|
document_data=args,
|
||||||
account=current_user
|
account=current_user
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
|
@ -95,8 +95,8 @@ class HitTestingApi(Resource):
|
|||||||
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
|
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
|
||||||
except services.errors.index.IndexNotInitializedError:
|
except services.errors.index.IndexNotInitializedError:
|
||||||
raise DatasetNotInitializedError()
|
raise DatasetNotInitializedError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
|
@ -47,8 +47,8 @@ class ChatAudioApi(InstalledAppResource):
|
|||||||
raise UnsupportedAudioTypeError()
|
raise UnsupportedAudioTypeError()
|
||||||
except ProviderNotSupportSpeechToTextServiceError:
|
except ProviderNotSupportSpeechToTextServiceError:
|
||||||
raise ProviderNotSupportSpeechToTextError()
|
raise ProviderNotSupportSpeechToTextError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
|
@ -54,8 +54,8 @@ class CompletionApi(InstalledAppResource):
|
|||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
logging.exception("App model config broken.")
|
logging.exception("App model config broken.")
|
||||||
raise AppUnavailableError()
|
raise AppUnavailableError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
@ -113,8 +113,8 @@ class ChatApi(InstalledAppResource):
|
|||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
logging.exception("App model config broken.")
|
logging.exception("App model config broken.")
|
||||||
raise AppUnavailableError()
|
raise AppUnavailableError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
@ -155,8 +155,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
|||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
logging.exception("App model config broken.")
|
logging.exception("App model config broken.")
|
||||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
|
@ -107,8 +107,8 @@ class MessageMoreLikeThisApi(InstalledAppResource):
|
|||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
except MoreLikeThisDisabledError:
|
except MoreLikeThisDisabledError:
|
||||||
raise AppMoreLikeThisDisabledError()
|
raise AppMoreLikeThisDisabledError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
@ -135,8 +135,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
|||||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||||
except MoreLikeThisDisabledError:
|
except MoreLikeThisDisabledError:
|
||||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
@ -174,8 +174,8 @@ class MessageSuggestedQuestionApi(InstalledAppResource):
|
|||||||
raise NotFound("Conversation not found")
|
raise NotFound("Conversation not found")
|
||||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
|
@ -3,6 +3,7 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from flask import current_app
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
from flask_restful import Resource, reqparse, abort
|
from flask_restful import Resource, reqparse, abort
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
@ -34,7 +35,7 @@ class ProviderListApi(Resource):
|
|||||||
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
|
plaintext, the rest is replaced by * and the last two bits are displayed in plaintext
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ProviderService.init_supported_provider(current_user.current_tenant, "cloud")
|
ProviderService.init_supported_provider(current_user.current_tenant)
|
||||||
providers = Provider.query.filter_by(tenant_id=tenant_id).all()
|
providers = Provider.query.filter_by(tenant_id=tenant_id).all()
|
||||||
|
|
||||||
provider_list = [
|
provider_list = [
|
||||||
@ -50,7 +51,8 @@ class ProviderListApi(Resource):
|
|||||||
'quota_used': p.quota_used
|
'quota_used': p.quota_used
|
||||||
} if p.provider_type == ProviderType.SYSTEM.value else {}),
|
} if p.provider_type == ProviderType.SYSTEM.value else {}),
|
||||||
'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
|
'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant,
|
||||||
ProviderName(p.provider_name))
|
ProviderName(p.provider_name), only_custom=True)
|
||||||
|
if p.provider_type == ProviderType.CUSTOM.value else None
|
||||||
}
|
}
|
||||||
for p in providers
|
for p in providers
|
||||||
]
|
]
|
||||||
@ -121,9 +123,10 @@ class ProviderTokenApi(Resource):
|
|||||||
is_valid=token_is_valid)
|
is_valid=token_is_valid)
|
||||||
db.session.add(provider_model)
|
db.session.add(provider_model)
|
||||||
|
|
||||||
if provider_model.is_valid:
|
if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid:
|
||||||
other_providers = db.session.query(Provider).filter(
|
other_providers = db.session.query(Provider).filter(
|
||||||
Provider.tenant_id == tenant.id,
|
Provider.tenant_id == tenant.id,
|
||||||
|
Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]),
|
||||||
Provider.provider_name != provider,
|
Provider.provider_name != provider,
|
||||||
Provider.provider_type == ProviderType.CUSTOM.value
|
Provider.provider_type == ProviderType.CUSTOM.value
|
||||||
).all()
|
).all()
|
||||||
@ -133,7 +136,7 @@ class ProviderTokenApi(Resource):
|
|||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
if provider in [ProviderName.ANTHROPIC.value, ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
|
if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value,
|
||||||
ProviderName.HUGGINGFACEHUB.value]:
|
ProviderName.HUGGINGFACEHUB.value]:
|
||||||
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
|
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201
|
||||||
|
|
||||||
@ -157,7 +160,7 @@ class ProviderTokenValidateApi(Resource):
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# todo: remove this when the provider is supported
|
# todo: remove this when the provider is supported
|
||||||
if provider in [ProviderName.ANTHROPIC.value, ProviderName.COHERE.value,
|
if provider in [ProviderName.COHERE.value,
|
||||||
ProviderName.HUGGINGFACEHUB.value]:
|
ProviderName.HUGGINGFACEHUB.value]:
|
||||||
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
|
return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}
|
||||||
|
|
||||||
@ -203,7 +206,19 @@ class ProviderSystemApi(Resource):
|
|||||||
provider_model.is_valid = args['is_enabled']
|
provider_model.is_valid = args['is_enabled']
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
elif not provider_model:
|
elif not provider_model:
|
||||||
ProviderService.create_system_provider(tenant, provider, args['is_enabled'])
|
if provider == ProviderName.OPENAI.value:
|
||||||
|
quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT']
|
||||||
|
elif provider == ProviderName.ANTHROPIC.value:
|
||||||
|
quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT']
|
||||||
|
else:
|
||||||
|
quota_limit = 0
|
||||||
|
|
||||||
|
ProviderService.create_system_provider(
|
||||||
|
tenant,
|
||||||
|
provider,
|
||||||
|
quota_limit,
|
||||||
|
args['is_enabled']
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
abort(403)
|
abort(403)
|
||||||
|
|
||||||
|
@ -43,8 +43,8 @@ class AudioApi(AppApiResource):
|
|||||||
raise UnsupportedAudioTypeError()
|
raise UnsupportedAudioTypeError()
|
||||||
except ProviderNotSupportSpeechToTextServiceError:
|
except ProviderNotSupportSpeechToTextServiceError:
|
||||||
raise ProviderNotSupportSpeechToTextError()
|
raise ProviderNotSupportSpeechToTextError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
|
@ -54,8 +54,8 @@ class CompletionApi(AppApiResource):
|
|||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
logging.exception("App model config broken.")
|
logging.exception("App model config broken.")
|
||||||
raise AppUnavailableError()
|
raise AppUnavailableError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
@ -115,8 +115,8 @@ class ChatApi(AppApiResource):
|
|||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
logging.exception("App model config broken.")
|
logging.exception("App model config broken.")
|
||||||
raise AppUnavailableError()
|
raise AppUnavailableError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
@ -156,8 +156,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
|||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
logging.exception("App model config broken.")
|
logging.exception("App model config broken.")
|
||||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
|
@ -85,8 +85,8 @@ class DocumentListApi(DatasetApiResource):
|
|||||||
dataset_process_rule=dataset.latest_process_rule,
|
dataset_process_rule=dataset.latest_process_rule,
|
||||||
created_from='api'
|
created_from='api'
|
||||||
)
|
)
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
document = documents[0]
|
document = documents[0]
|
||||||
if doc_type and doc_metadata:
|
if doc_type and doc_metadata:
|
||||||
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
|
metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type]
|
||||||
|
@ -45,8 +45,8 @@ class AudioApi(WebApiResource):
|
|||||||
raise UnsupportedAudioTypeError()
|
raise UnsupportedAudioTypeError()
|
||||||
except ProviderNotSupportSpeechToTextServiceError:
|
except ProviderNotSupportSpeechToTextServiceError:
|
||||||
raise ProviderNotSupportSpeechToTextError()
|
raise ProviderNotSupportSpeechToTextError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
|
@ -52,8 +52,8 @@ class CompletionApi(WebApiResource):
|
|||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
logging.exception("App model config broken.")
|
logging.exception("App model config broken.")
|
||||||
raise AppUnavailableError()
|
raise AppUnavailableError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
@ -109,8 +109,8 @@ class ChatApi(WebApiResource):
|
|||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
logging.exception("App model config broken.")
|
logging.exception("App model config broken.")
|
||||||
raise AppUnavailableError()
|
raise AppUnavailableError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
@ -150,8 +150,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
|||||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||||
logging.exception("App model config broken.")
|
logging.exception("App model config broken.")
|
||||||
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(AppUnavailableError()).get_json()) + "\n\n"
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
|
@ -101,8 +101,8 @@ class MessageMoreLikeThisApi(WebApiResource):
|
|||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
except MoreLikeThisDisabledError:
|
except MoreLikeThisDisabledError:
|
||||||
raise AppMoreLikeThisDisabledError()
|
raise AppMoreLikeThisDisabledError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
@ -129,8 +129,8 @@ def compact_response(response: Union[dict | Generator]) -> Response:
|
|||||||
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(NotFound("Message Not Exists.")).get_json()) + "\n\n"
|
||||||
except MoreLikeThisDisabledError:
|
except MoreLikeThisDisabledError:
|
||||||
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(AppMoreLikeThisDisabledError()).get_json()) + "\n\n"
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderNotInitializeError(ex.description)).get_json()) + "\n\n"
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
yield "data: " + json.dumps(api.handle_error(ProviderQuotaExceededError()).get_json()) + "\n\n"
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
@ -167,8 +167,8 @@ class MessageSuggestedQuestionApi(WebApiResource):
|
|||||||
raise NotFound("Conversation not found")
|
raise NotFound("Conversation not found")
|
||||||
except SuggestedQuestionsAfterAnswerDisabledError:
|
except SuggestedQuestionsAfterAnswerDisabledError:
|
||||||
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
raise AppSuggestedQuestionsAfterAnswerDisabledError()
|
||||||
except ProviderTokenNotInitError:
|
except ProviderTokenNotInitError as ex:
|
||||||
raise ProviderNotInitializeError()
|
raise ProviderNotInitializeError(ex.description)
|
||||||
except QuotaExceededError:
|
except QuotaExceededError:
|
||||||
raise ProviderQuotaExceededError()
|
raise ProviderQuotaExceededError()
|
||||||
except ModelCurrentlyNotSupportError:
|
except ModelCurrentlyNotSupportError:
|
||||||
|
@ -13,8 +13,13 @@ class HostedOpenAICredential(BaseModel):
|
|||||||
api_key: str
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
|
class HostedAnthropicCredential(BaseModel):
|
||||||
|
api_key: str
|
||||||
|
|
||||||
|
|
||||||
class HostedLLMCredentials(BaseModel):
|
class HostedLLMCredentials(BaseModel):
|
||||||
openai: Optional[HostedOpenAICredential] = None
|
openai: Optional[HostedOpenAICredential] = None
|
||||||
|
anthropic: Optional[HostedAnthropicCredential] = None
|
||||||
|
|
||||||
|
|
||||||
hosted_llm_credentials = HostedLLMCredentials()
|
hosted_llm_credentials = HostedLLMCredentials()
|
||||||
@ -26,3 +31,6 @@ def init_app(app: Flask):
|
|||||||
|
|
||||||
if app.config.get("OPENAI_API_KEY"):
|
if app.config.get("OPENAI_API_KEY"):
|
||||||
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
|
hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY"))
|
||||||
|
|
||||||
|
if app.config.get("ANTHROPIC_API_KEY"):
|
||||||
|
hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY"))
|
||||||
|
@ -48,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
|||||||
})
|
})
|
||||||
|
|
||||||
self.llm_message.prompt = real_prompts
|
self.llm_message.prompt = real_prompts
|
||||||
self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages[0])
|
self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0])
|
||||||
|
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||||
|
@ -118,6 +118,7 @@ class Completion:
|
|||||||
prompt, stop_words = cls.get_main_llm_prompt(
|
prompt, stop_words = cls.get_main_llm_prompt(
|
||||||
mode=mode,
|
mode=mode,
|
||||||
llm=final_llm,
|
llm=final_llm,
|
||||||
|
model=app_model_config.model_dict,
|
||||||
pre_prompt=app_model_config.pre_prompt,
|
pre_prompt=app_model_config.pre_prompt,
|
||||||
query=query,
|
query=query,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
@ -129,6 +130,7 @@ class Completion:
|
|||||||
|
|
||||||
cls.recale_llm_max_tokens(
|
cls.recale_llm_max_tokens(
|
||||||
final_llm=final_llm,
|
final_llm=final_llm,
|
||||||
|
model=app_model_config.model_dict,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
mode=mode
|
mode=mode
|
||||||
)
|
)
|
||||||
@ -138,7 +140,8 @@ class Completion:
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
|
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict,
|
||||||
|
pre_prompt: str, query: str, inputs: dict,
|
||||||
chain_output: Optional[str],
|
chain_output: Optional[str],
|
||||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
||||||
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
|
Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
|
||||||
@ -151,10 +154,11 @@ class Completion:
|
|||||||
|
|
||||||
if mode == 'completion':
|
if mode == 'completion':
|
||||||
prompt_template = JinjaPromptTemplate.from_template(
|
prompt_template = JinjaPromptTemplate.from_template(
|
||||||
template=("""Use the following CONTEXT as your learned knowledge:
|
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||||
[CONTEXT]
|
|
||||||
|
<context>
|
||||||
{{context}}
|
{{context}}
|
||||||
[END CONTEXT]
|
</context>
|
||||||
|
|
||||||
When answer to user:
|
When answer to user:
|
||||||
- If you don't know, just say that you don't know.
|
- If you don't know, just say that you don't know.
|
||||||
@ -204,10 +208,11 @@ And answer according to the language of the user's question.
|
|||||||
|
|
||||||
if chain_output:
|
if chain_output:
|
||||||
human_inputs['context'] = chain_output
|
human_inputs['context'] = chain_output
|
||||||
human_message_prompt += """Use the following CONTEXT as your learned knowledge.
|
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
|
||||||
[CONTEXT]
|
|
||||||
|
<context>
|
||||||
{{context}}
|
{{context}}
|
||||||
[END CONTEXT]
|
</context>
|
||||||
|
|
||||||
When answer to user:
|
When answer to user:
|
||||||
- If you don't know, just say that you don't know.
|
- If you don't know, just say that you don't know.
|
||||||
@ -219,7 +224,7 @@ And answer according to the language of the user's question.
|
|||||||
if pre_prompt:
|
if pre_prompt:
|
||||||
human_message_prompt += pre_prompt
|
human_message_prompt += pre_prompt
|
||||||
|
|
||||||
query_prompt = "\nHuman: {{query}}\nAI: "
|
query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
|
||||||
|
|
||||||
if memory:
|
if memory:
|
||||||
# append chat histories
|
# append chat histories
|
||||||
@ -228,9 +233,11 @@ And answer according to the language of the user's question.
|
|||||||
inputs=human_inputs
|
inputs=human_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
|
curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message])
|
||||||
rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
|
model_name = model['name']
|
||||||
- memory.llm.max_tokens - curr_message_tokens
|
max_tokens = model.get("completion_params").get('max_tokens')
|
||||||
|
rest_tokens = llm_constant.max_context_token_length[model_name] \
|
||||||
|
- max_tokens - curr_message_tokens
|
||||||
rest_tokens = max(rest_tokens, 0)
|
rest_tokens = max(rest_tokens, 0)
|
||||||
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
|
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
|
||||||
|
|
||||||
@ -241,7 +248,10 @@ And answer according to the language of the user's question.
|
|||||||
# if histories_param not in human_inputs:
|
# if histories_param not in human_inputs:
|
||||||
# human_inputs[histories_param] = '{{' + histories_param + '}}'
|
# human_inputs[histories_param] = '{{' + histories_param + '}}'
|
||||||
|
|
||||||
human_message_prompt += "\n\n" + histories
|
human_message_prompt += "\n\n" if human_message_prompt else ""
|
||||||
|
human_message_prompt += "Here is the chat histories between human and assistant, " \
|
||||||
|
"inside <histories></histories> XML tags.\n\n<histories>"
|
||||||
|
human_message_prompt += histories + "</histories>"
|
||||||
|
|
||||||
human_message_prompt += query_prompt
|
human_message_prompt += query_prompt
|
||||||
|
|
||||||
@ -307,13 +317,15 @@ And answer according to the language of the user's question.
|
|||||||
model=app_model_config.model_dict
|
model=app_model_config.model_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
model_limited_tokens = llm_constant.max_context_token_length[llm.model_name]
|
model_name = app_model_config.model_dict.get("name")
|
||||||
max_tokens = llm.max_tokens
|
model_limited_tokens = llm_constant.max_context_token_length[model_name]
|
||||||
|
max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens')
|
||||||
|
|
||||||
# get prompt without memory and context
|
# get prompt without memory and context
|
||||||
prompt, _ = cls.get_main_llm_prompt(
|
prompt, _ = cls.get_main_llm_prompt(
|
||||||
mode=mode,
|
mode=mode,
|
||||||
llm=llm,
|
llm=llm,
|
||||||
|
model=app_model_config.model_dict,
|
||||||
pre_prompt=app_model_config.pre_prompt,
|
pre_prompt=app_model_config.pre_prompt,
|
||||||
query=query,
|
query=query,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
@ -332,16 +344,17 @@ And answer according to the language of the user's question.
|
|||||||
return rest_tokens
|
return rest_tokens
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def recale_llm_max_tokens(cls, final_llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict,
|
||||||
prompt: Union[str, List[BaseMessage]], mode: str):
|
prompt: Union[str, List[BaseMessage]], mode: str):
|
||||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||||
model_limited_tokens = llm_constant.max_context_token_length[final_llm.model_name]
|
model_name = model.get("name")
|
||||||
max_tokens = final_llm.max_tokens
|
model_limited_tokens = llm_constant.max_context_token_length[model_name]
|
||||||
|
max_tokens = model.get("completion_params").get('max_tokens')
|
||||||
|
|
||||||
if mode == 'completion' and isinstance(final_llm, BaseLLM):
|
if mode == 'completion' and isinstance(final_llm, BaseLLM):
|
||||||
prompt_tokens = final_llm.get_num_tokens(prompt)
|
prompt_tokens = final_llm.get_num_tokens(prompt)
|
||||||
else:
|
else:
|
||||||
prompt_tokens = final_llm.get_messages_tokens(prompt)
|
prompt_tokens = final_llm.get_num_tokens_from_messages(prompt)
|
||||||
|
|
||||||
if prompt_tokens + max_tokens > model_limited_tokens:
|
if prompt_tokens + max_tokens > model_limited_tokens:
|
||||||
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
|
max_tokens = max(model_limited_tokens - prompt_tokens, 16)
|
||||||
@ -350,9 +363,10 @@ And answer according to the language of the user's question.
|
|||||||
@classmethod
|
@classmethod
|
||||||
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
|
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
|
||||||
app_model_config: AppModelConfig, user: Account, streaming: bool):
|
app_model_config: AppModelConfig, user: Account, streaming: bool):
|
||||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
|
||||||
|
llm = LLMBuilder.to_llm_from_model(
|
||||||
tenant_id=app.tenant_id,
|
tenant_id=app.tenant_id,
|
||||||
model_name='gpt-3.5-turbo',
|
model=app_model_config.model_dict,
|
||||||
streaming=streaming
|
streaming=streaming
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -360,6 +374,7 @@ And answer according to the language of the user's question.
|
|||||||
original_prompt, _ = cls.get_main_llm_prompt(
|
original_prompt, _ = cls.get_main_llm_prompt(
|
||||||
mode="completion",
|
mode="completion",
|
||||||
llm=llm,
|
llm=llm,
|
||||||
|
model=app_model_config.model_dict,
|
||||||
pre_prompt=pre_prompt,
|
pre_prompt=pre_prompt,
|
||||||
query=message.query,
|
query=message.query,
|
||||||
inputs=message.inputs,
|
inputs=message.inputs,
|
||||||
@ -390,6 +405,7 @@ And answer according to the language of the user's question.
|
|||||||
|
|
||||||
cls.recale_llm_max_tokens(
|
cls.recale_llm_max_tokens(
|
||||||
final_llm=llm,
|
final_llm=llm,
|
||||||
|
model=app_model_config.model_dict,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
mode='completion'
|
mode='completion'
|
||||||
)
|
)
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
from _decimal import Decimal
|
from _decimal import Decimal
|
||||||
|
|
||||||
models = {
|
models = {
|
||||||
|
'claude-instant-1': 'anthropic', # 100,000 tokens
|
||||||
|
'claude-2': 'anthropic', # 100,000 tokens
|
||||||
'gpt-4': 'openai', # 8,192 tokens
|
'gpt-4': 'openai', # 8,192 tokens
|
||||||
'gpt-4-32k': 'openai', # 32,768 tokens
|
'gpt-4-32k': 'openai', # 32,768 tokens
|
||||||
'gpt-3.5-turbo': 'openai', # 4,096 tokens
|
'gpt-3.5-turbo': 'openai', # 4,096 tokens
|
||||||
@ -10,10 +12,13 @@ models = {
|
|||||||
'text-curie-001': 'openai', # 2,049 tokens
|
'text-curie-001': 'openai', # 2,049 tokens
|
||||||
'text-babbage-001': 'openai', # 2,049 tokens
|
'text-babbage-001': 'openai', # 2,049 tokens
|
||||||
'text-ada-001': 'openai', # 2,049 tokens
|
'text-ada-001': 'openai', # 2,049 tokens
|
||||||
'text-embedding-ada-002': 'openai' # 8191 tokens, 1536 dimensions
|
'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions
|
||||||
|
'whisper-1': 'openai'
|
||||||
}
|
}
|
||||||
|
|
||||||
max_context_token_length = {
|
max_context_token_length = {
|
||||||
|
'claude-instant-1': 100000,
|
||||||
|
'claude-2': 100000,
|
||||||
'gpt-4': 8192,
|
'gpt-4': 8192,
|
||||||
'gpt-4-32k': 32768,
|
'gpt-4-32k': 32768,
|
||||||
'gpt-3.5-turbo': 4096,
|
'gpt-3.5-turbo': 4096,
|
||||||
@ -23,17 +28,21 @@ max_context_token_length = {
|
|||||||
'text-curie-001': 2049,
|
'text-curie-001': 2049,
|
||||||
'text-babbage-001': 2049,
|
'text-babbage-001': 2049,
|
||||||
'text-ada-001': 2049,
|
'text-ada-001': 2049,
|
||||||
'text-embedding-ada-002': 8191
|
'text-embedding-ada-002': 8191,
|
||||||
}
|
}
|
||||||
|
|
||||||
models_by_mode = {
|
models_by_mode = {
|
||||||
'chat': [
|
'chat': [
|
||||||
|
'claude-instant-1', # 100,000 tokens
|
||||||
|
'claude-2', # 100,000 tokens
|
||||||
'gpt-4', # 8,192 tokens
|
'gpt-4', # 8,192 tokens
|
||||||
'gpt-4-32k', # 32,768 tokens
|
'gpt-4-32k', # 32,768 tokens
|
||||||
'gpt-3.5-turbo', # 4,096 tokens
|
'gpt-3.5-turbo', # 4,096 tokens
|
||||||
'gpt-3.5-turbo-16k', # 16,384 tokens
|
'gpt-3.5-turbo-16k', # 16,384 tokens
|
||||||
],
|
],
|
||||||
'completion': [
|
'completion': [
|
||||||
|
'claude-instant-1', # 100,000 tokens
|
||||||
|
'claude-2', # 100,000 tokens
|
||||||
'gpt-4', # 8,192 tokens
|
'gpt-4', # 8,192 tokens
|
||||||
'gpt-4-32k', # 32,768 tokens
|
'gpt-4-32k', # 32,768 tokens
|
||||||
'gpt-3.5-turbo', # 4,096 tokens
|
'gpt-3.5-turbo', # 4,096 tokens
|
||||||
@ -52,6 +61,14 @@ models_by_mode = {
|
|||||||
model_currency = 'USD'
|
model_currency = 'USD'
|
||||||
|
|
||||||
model_prices = {
|
model_prices = {
|
||||||
|
'claude-instant-1': {
|
||||||
|
'prompt': Decimal('0.00163'),
|
||||||
|
'completion': Decimal('0.00551'),
|
||||||
|
},
|
||||||
|
'claude-2': {
|
||||||
|
'prompt': Decimal('0.01102'),
|
||||||
|
'completion': Decimal('0.03268'),
|
||||||
|
},
|
||||||
'gpt-4': {
|
'gpt-4': {
|
||||||
'prompt': Decimal('0.03'),
|
'prompt': Decimal('0.03'),
|
||||||
'completion': Decimal('0.06'),
|
'completion': Decimal('0.06'),
|
||||||
|
@ -56,7 +56,7 @@ class ConversationMessageTask:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init(self):
|
def init(self):
|
||||||
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id)
|
provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name)
|
||||||
self.model_dict['provider'] = provider_name
|
self.model_dict['provider'] = provider_name
|
||||||
|
|
||||||
override_model_configs = None
|
override_model_configs = None
|
||||||
@ -89,7 +89,7 @@ class ConversationMessageTask:
|
|||||||
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
|
system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
|
||||||
system_instruction = system_message.content
|
system_instruction = system_message.content
|
||||||
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
|
llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
|
||||||
system_instruction_tokens = llm.get_messages_tokens([system_message])
|
system_instruction_tokens = llm.get_num_tokens_from_messages([system_message])
|
||||||
|
|
||||||
if not self.conversation:
|
if not self.conversation:
|
||||||
self.is_new_conversation = True
|
self.is_new_conversation = True
|
||||||
@ -185,6 +185,7 @@ class ConversationMessageTask:
|
|||||||
if provider and provider.provider_type == ProviderType.SYSTEM.value:
|
if provider and provider.provider_type == ProviderType.SYSTEM.value:
|
||||||
db.session.query(Provider).filter(
|
db.session.query(Provider).filter(
|
||||||
Provider.tenant_id == self.app.tenant_id,
|
Provider.tenant_id == self.app.tenant_id,
|
||||||
|
Provider.provider_name == provider.provider_name,
|
||||||
Provider.quota_limit > Provider.quota_used
|
Provider.quota_limit > Provider.quota_used
|
||||||
).update({'quota_used': Provider.quota_used + 1})
|
).update({'quota_used': Provider.quota_used + 1})
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from typing import List
|
|||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
|
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from models.dataset import Embedding
|
from models.dataset import Embedding
|
||||||
@ -49,6 +50,7 @@ class CacheEmbedding(Embeddings):
|
|||||||
text_embeddings.extend(embedding_results)
|
text_embeddings.extend(embedding_results)
|
||||||
return text_embeddings
|
return text_embeddings
|
||||||
|
|
||||||
|
@handle_openai_exceptions
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
"""Embed query text."""
|
"""Embed query text."""
|
||||||
# use doc embedding cache or store if not exists
|
# use doc embedding cache or store if not exists
|
||||||
|
@ -23,6 +23,10 @@ class LLMGenerator:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def generate_conversation_name(cls, tenant_id: str, query, answer):
|
def generate_conversation_name(cls, tenant_id: str, query, answer):
|
||||||
prompt = CONVERSATION_TITLE_PROMPT
|
prompt = CONVERSATION_TITLE_PROMPT
|
||||||
|
|
||||||
|
if len(query) > 2000:
|
||||||
|
query = query[:300] + "...[TRUNCATED]..." + query[-300:]
|
||||||
|
|
||||||
prompt = prompt.format(query=query)
|
prompt = prompt.format(query=query)
|
||||||
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
llm: StreamableOpenAI = LLMBuilder.to_llm(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@ -52,7 +56,17 @@ class LLMGenerator:
|
|||||||
if not message.answer:
|
if not message.answer:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n"
|
if len(message.query) > 2000:
|
||||||
|
query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
|
||||||
|
else:
|
||||||
|
query = message.query
|
||||||
|
|
||||||
|
if len(message.answer) > 2000:
|
||||||
|
answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
|
||||||
|
else:
|
||||||
|
answer = message.answer
|
||||||
|
|
||||||
|
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
|
||||||
if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
|
if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0:
|
||||||
context += message_qa_text
|
context += message_qa_text
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ class IndexBuilder:
|
|||||||
|
|
||||||
model_credentials = LLMBuilder.get_model_credentials(
|
model_credentials = LLMBuilder.get_model_credentials(
|
||||||
tenant_id=dataset.tenant_id,
|
tenant_id=dataset.tenant_id,
|
||||||
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
|
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
|
||||||
model_name='text-embedding-ada-002'
|
model_name='text-embedding-ada-002'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -40,6 +40,9 @@ class ProviderTokenNotInitError(Exception):
|
|||||||
"""
|
"""
|
||||||
description = "Provider Token Not Init"
|
description = "Provider Token Not Init"
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.description = args[0] if args else self.description
|
||||||
|
|
||||||
|
|
||||||
class QuotaExceededError(Exception):
|
class QuotaExceededError(Exception):
|
||||||
"""
|
"""
|
||||||
|
@ -8,9 +8,10 @@ from core.llm.provider.base import BaseProvider
|
|||||||
from core.llm.provider.llm_provider_service import LLMProviderService
|
from core.llm.provider.llm_provider_service import LLMProviderService
|
||||||
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
|
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
|
||||||
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
|
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
|
||||||
|
from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
|
||||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||||
from models.provider import ProviderType
|
from models.provider import ProviderType, ProviderName
|
||||||
|
|
||||||
|
|
||||||
class LLMBuilder:
|
class LLMBuilder:
|
||||||
@ -32,43 +33,43 @@ class LLMBuilder:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
|
||||||
provider = cls.get_default_provider(tenant_id)
|
provider = cls.get_default_provider(tenant_id, model_name)
|
||||||
|
|
||||||
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
|
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
|
||||||
|
|
||||||
|
llm_cls = None
|
||||||
mode = cls.get_mode_by_model(model_name)
|
mode = cls.get_mode_by_model(model_name)
|
||||||
if mode == 'chat':
|
if mode == 'chat':
|
||||||
if provider == 'openai':
|
if provider == ProviderName.OPENAI.value:
|
||||||
llm_cls = StreamableChatOpenAI
|
llm_cls = StreamableChatOpenAI
|
||||||
else:
|
elif provider == ProviderName.AZURE_OPENAI.value:
|
||||||
llm_cls = StreamableAzureChatOpenAI
|
llm_cls = StreamableAzureChatOpenAI
|
||||||
|
elif provider == ProviderName.ANTHROPIC.value:
|
||||||
|
llm_cls = StreamableChatAnthropic
|
||||||
elif mode == 'completion':
|
elif mode == 'completion':
|
||||||
if provider == 'openai':
|
if provider == ProviderName.OPENAI.value:
|
||||||
llm_cls = StreamableOpenAI
|
llm_cls = StreamableOpenAI
|
||||||
else:
|
elif provider == ProviderName.AZURE_OPENAI.value:
|
||||||
llm_cls = StreamableAzureOpenAI
|
llm_cls = StreamableAzureOpenAI
|
||||||
else:
|
|
||||||
|
if not llm_cls:
|
||||||
raise ValueError(f"model name {model_name} is not supported.")
|
raise ValueError(f"model name {model_name} is not supported.")
|
||||||
|
|
||||||
|
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
|
'model_name': model_name,
|
||||||
|
'temperature': kwargs.get('temperature', 0),
|
||||||
|
'max_tokens': kwargs.get('max_tokens', 256),
|
||||||
'top_p': kwargs.get('top_p', 1),
|
'top_p': kwargs.get('top_p', 1),
|
||||||
'frequency_penalty': kwargs.get('frequency_penalty', 0),
|
'frequency_penalty': kwargs.get('frequency_penalty', 0),
|
||||||
'presence_penalty': kwargs.get('presence_penalty', 0),
|
'presence_penalty': kwargs.get('presence_penalty', 0),
|
||||||
|
'callbacks': kwargs.get('callbacks', None),
|
||||||
|
'streaming': kwargs.get('streaming', False),
|
||||||
}
|
}
|
||||||
|
|
||||||
model_extras_kwargs = model_kwargs if mode == 'completion' else {'model_kwargs': model_kwargs}
|
model_kwargs.update(model_credentials)
|
||||||
|
model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
|
||||||
|
|
||||||
return llm_cls(
|
return llm_cls(**model_kwargs)
|
||||||
model_name=model_name,
|
|
||||||
temperature=kwargs.get('temperature', 0),
|
|
||||||
max_tokens=kwargs.get('max_tokens', 256),
|
|
||||||
**model_extras_kwargs,
|
|
||||||
callbacks=kwargs.get('callbacks', None),
|
|
||||||
streaming=kwargs.get('streaming', False),
|
|
||||||
# request_timeout=None
|
|
||||||
**model_credentials
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
|
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
|
||||||
@ -118,14 +119,29 @@ class LLMBuilder:
|
|||||||
return provider_service.get_credentials(model_name)
|
return provider_service.get_credentials(model_name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_default_provider(cls, tenant_id: str) -> str:
|
def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
|
||||||
provider = BaseProvider.get_valid_provider(tenant_id)
|
provider_name = llm_constant.models[model_name]
|
||||||
if not provider:
|
|
||||||
raise ProviderTokenNotInitError()
|
|
||||||
|
|
||||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
if provider_name == 'openai':
|
||||||
provider_name = 'openai'
|
# get the default provider (openai / azure_openai) for the tenant
|
||||||
else:
|
openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
|
||||||
provider_name = provider.provider_name
|
azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
|
||||||
|
|
||||||
|
provider = None
|
||||||
|
if openai_provider:
|
||||||
|
provider = openai_provider
|
||||||
|
elif azure_openai_provider:
|
||||||
|
provider = azure_openai_provider
|
||||||
|
|
||||||
|
if not provider:
|
||||||
|
raise ProviderTokenNotInitError(
|
||||||
|
f"No valid {provider_name} model provider credentials found. "
|
||||||
|
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||||
|
)
|
||||||
|
|
||||||
|
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||||
|
provider_name = 'openai'
|
||||||
|
else:
|
||||||
|
provider_name = provider.provider_name
|
||||||
|
|
||||||
return provider_name
|
return provider_name
|
||||||
|
@ -1,23 +1,138 @@
|
|||||||
from typing import Optional
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import anthropic
|
||||||
|
from langchain.chat_models import ChatAnthropic
|
||||||
|
from langchain.schema import HumanMessage
|
||||||
|
|
||||||
|
from core import hosted_llm_credentials
|
||||||
|
from core.llm.error import ProviderTokenNotInitError
|
||||||
from core.llm.provider.base import BaseProvider
|
from core.llm.provider.base import BaseProvider
|
||||||
from models.provider import ProviderName
|
from core.llm.provider.errors import ValidateFailedError
|
||||||
|
from models.provider import ProviderName, ProviderType
|
||||||
|
|
||||||
|
|
||||||
class AnthropicProvider(BaseProvider):
|
class AnthropicProvider(BaseProvider):
|
||||||
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
def get_models(self, model_id: Optional[str] = None) -> list[dict]:
|
||||||
credentials = self.get_credentials(model_id)
|
return [
|
||||||
# todo
|
{
|
||||||
return []
|
'id': 'claude-instant-1',
|
||||||
|
'name': 'claude-instant-1',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'id': 'claude-2',
|
||||||
|
'name': 'claude-2',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||||
"""
|
return self.get_provider_api_key(model_id=model_id)
|
||||||
Returns the API credentials for Azure OpenAI as a dictionary, for the given tenant_id.
|
|
||||||
The dictionary contains keys: azure_api_type, azure_api_version, azure_api_base, and azure_api_key.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
'anthropic_api_key': self.get_provider_api_key(model_id=model_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_provider_name(self):
|
def get_provider_name(self):
|
||||||
return ProviderName.ANTHROPIC
|
return ProviderName.ANTHROPIC
|
||||||
|
|
||||||
|
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||||
|
"""
|
||||||
|
Returns the provider configs.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
config = self.get_provider_api_key(only_custom=only_custom)
|
||||||
|
except:
|
||||||
|
config = {
|
||||||
|
'anthropic_api_key': ''
|
||||||
|
}
|
||||||
|
|
||||||
|
if obfuscated:
|
||||||
|
if not config.get('anthropic_api_key'):
|
||||||
|
config = {
|
||||||
|
'anthropic_api_key': ''
|
||||||
|
}
|
||||||
|
|
||||||
|
config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key'))
|
||||||
|
return config
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def get_encrypted_token(self, config: Union[dict | str]):
|
||||||
|
"""
|
||||||
|
Returns the encrypted token.
|
||||||
|
"""
|
||||||
|
return json.dumps({
|
||||||
|
'anthropic_api_key': self.encrypt_token(config['anthropic_api_key'])
|
||||||
|
})
|
||||||
|
|
||||||
|
def get_decrypted_token(self, token: str):
|
||||||
|
"""
|
||||||
|
Returns the decrypted token.
|
||||||
|
"""
|
||||||
|
config = json.loads(token)
|
||||||
|
config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key'])
|
||||||
|
return config
|
||||||
|
|
||||||
|
def get_token_type(self):
|
||||||
|
return dict
|
||||||
|
|
||||||
|
def config_validate(self, config: Union[dict | str]):
|
||||||
|
"""
|
||||||
|
Validates the given config.
|
||||||
|
"""
|
||||||
|
# check OpenAI / Azure OpenAI credential is valid
|
||||||
|
openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value)
|
||||||
|
azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value)
|
||||||
|
|
||||||
|
provider = None
|
||||||
|
if openai_provider:
|
||||||
|
provider = openai_provider
|
||||||
|
elif azure_openai_provider:
|
||||||
|
provider = azure_openai_provider
|
||||||
|
|
||||||
|
if not provider:
|
||||||
|
raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.")
|
||||||
|
|
||||||
|
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||||
|
quota_used = provider.quota_used if provider.quota_used is not None else 0
|
||||||
|
quota_limit = provider.quota_limit if provider.quota_limit is not None else 0
|
||||||
|
if quota_used >= quota_limit:
|
||||||
|
raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, "
|
||||||
|
f"please configure OpenAI or Azure OpenAI provider first.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not isinstance(config, dict):
|
||||||
|
raise ValueError('Config must be a object.')
|
||||||
|
|
||||||
|
if 'anthropic_api_key' not in config:
|
||||||
|
raise ValueError('anthropic_api_key must be provided.')
|
||||||
|
|
||||||
|
chat_llm = ChatAnthropic(
|
||||||
|
model='claude-instant-1',
|
||||||
|
anthropic_api_key=config['anthropic_api_key'],
|
||||||
|
max_tokens_to_sample=10,
|
||||||
|
temperature=0,
|
||||||
|
default_request_timeout=60
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
HumanMessage(
|
||||||
|
content="ping"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
chat_llm(messages)
|
||||||
|
except anthropic.APIConnectionError as ex:
|
||||||
|
raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}")
|
||||||
|
except (anthropic.APIStatusError, anthropic.RateLimitError) as ex:
|
||||||
|
raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - "
|
||||||
|
f"{ex.body['error']['type']}: {ex.body['error']['message']}")
|
||||||
|
except Exception as ex:
|
||||||
|
logging.exception('Anthropic config validation failed')
|
||||||
|
raise ex
|
||||||
|
|
||||||
|
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||||
|
if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key:
|
||||||
|
raise ProviderTokenNotInitError(
|
||||||
|
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||||
|
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||||
|
)
|
||||||
|
|
||||||
|
return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key}
|
||||||
|
@ -52,12 +52,12 @@ class AzureProvider(BaseProvider):
|
|||||||
def get_provider_name(self):
|
def get_provider_name(self):
|
||||||
return ProviderName.AZURE_OPENAI
|
return ProviderName.AZURE_OPENAI
|
||||||
|
|
||||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||||
"""
|
"""
|
||||||
Returns the provider configs.
|
Returns the provider configs.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
config = self.get_provider_api_key()
|
config = self.get_provider_api_key(only_custom=only_custom)
|
||||||
except:
|
except:
|
||||||
config = {
|
config = {
|
||||||
'openai_api_type': 'azure',
|
'openai_api_type': 'azure',
|
||||||
@ -81,7 +81,6 @@ class AzureProvider(BaseProvider):
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
def get_token_type(self):
|
def get_token_type(self):
|
||||||
# TODO: change to dict when implemented
|
|
||||||
return dict
|
return dict
|
||||||
|
|
||||||
def config_validate(self, config: Union[dict | str]):
|
def config_validate(self, config: Union[dict | str]):
|
||||||
|
@ -2,7 +2,7 @@ import base64
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from core import hosted_llm_credentials
|
from core.constant import llm_constant
|
||||||
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
|
from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs import rsa
|
from libs import rsa
|
||||||
@ -14,15 +14,18 @@ class BaseProvider(ABC):
|
|||||||
def __init__(self, tenant_id: str):
|
def __init__(self, tenant_id: str):
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
|
|
||||||
def get_provider_api_key(self, model_id: Optional[str] = None, prefer_custom: bool = True) -> Union[str | dict]:
|
def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]:
|
||||||
"""
|
"""
|
||||||
Returns the decrypted API key for the given tenant_id and provider_name.
|
Returns the decrypted API key for the given tenant_id and provider_name.
|
||||||
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
|
If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError.
|
||||||
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
|
If the provider is not found or not valid, raises a ProviderTokenNotInitError.
|
||||||
"""
|
"""
|
||||||
provider = self.get_provider(prefer_custom)
|
provider = self.get_provider(only_custom)
|
||||||
if not provider:
|
if not provider:
|
||||||
raise ProviderTokenNotInitError()
|
raise ProviderTokenNotInitError(
|
||||||
|
f"No valid {llm_constant.models[model_id]} model provider credentials found. "
|
||||||
|
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||||
|
)
|
||||||
|
|
||||||
if provider.provider_type == ProviderType.SYSTEM.value:
|
if provider.provider_type == ProviderType.SYSTEM.value:
|
||||||
quota_used = provider.quota_used if provider.quota_used is not None else 0
|
quota_used = provider.quota_used if provider.quota_used is not None else 0
|
||||||
@ -38,18 +41,19 @@ class BaseProvider(ABC):
|
|||||||
else:
|
else:
|
||||||
return self.get_decrypted_token(provider.encrypted_config)
|
return self.get_decrypted_token(provider.encrypted_config)
|
||||||
|
|
||||||
def get_provider(self, prefer_custom: bool) -> Optional[Provider]:
|
def get_provider(self, only_custom: bool = False) -> Optional[Provider]:
|
||||||
"""
|
"""
|
||||||
Returns the Provider instance for the given tenant_id and provider_name.
|
Returns the Provider instance for the given tenant_id and provider_name.
|
||||||
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
|
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
|
||||||
"""
|
"""
|
||||||
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, prefer_custom)
|
return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, prefer_custom: bool = False) -> Optional[Provider]:
|
def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[
|
||||||
|
Provider]:
|
||||||
"""
|
"""
|
||||||
Returns the Provider instance for the given tenant_id and provider_name.
|
Returns the Provider instance for the given tenant_id and provider_name.
|
||||||
If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag.
|
If both CUSTOM and System providers exist.
|
||||||
"""
|
"""
|
||||||
query = db.session.query(Provider).filter(
|
query = db.session.query(Provider).filter(
|
||||||
Provider.tenant_id == tenant_id
|
Provider.tenant_id == tenant_id
|
||||||
@ -58,39 +62,31 @@ class BaseProvider(ABC):
|
|||||||
if provider_name:
|
if provider_name:
|
||||||
query = query.filter(Provider.provider_name == provider_name)
|
query = query.filter(Provider.provider_name == provider_name)
|
||||||
|
|
||||||
providers = query.order_by(Provider.provider_type.desc() if prefer_custom else Provider.provider_type).all()
|
if only_custom:
|
||||||
|
query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value)
|
||||||
|
|
||||||
custom_provider = None
|
providers = query.order_by(Provider.provider_type.asc()).all()
|
||||||
system_provider = None
|
|
||||||
|
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
|
if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config:
|
||||||
custom_provider = provider
|
return provider
|
||||||
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
|
elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid:
|
||||||
system_provider = provider
|
return provider
|
||||||
|
|
||||||
if custom_provider:
|
return None
|
||||||
return custom_provider
|
|
||||||
elif system_provider:
|
|
||||||
return system_provider
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_hosted_credentials(self) -> str:
|
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||||
if self.get_provider_name() != ProviderName.OPENAI:
|
raise ProviderTokenNotInitError(
|
||||||
raise ProviderTokenNotInitError()
|
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||||
|
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||||
|
)
|
||||||
|
|
||||||
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
|
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||||
raise ProviderTokenNotInitError()
|
|
||||||
|
|
||||||
return hosted_llm_credentials.openai.api_key
|
|
||||||
|
|
||||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
|
||||||
"""
|
"""
|
||||||
Returns the provider configs.
|
Returns the provider configs.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
config = self.get_provider_api_key()
|
config = self.get_provider_api_key(only_custom=only_custom)
|
||||||
except:
|
except:
|
||||||
config = ''
|
config = ''
|
||||||
|
|
||||||
|
@ -31,11 +31,11 @@ class LLMProviderService:
|
|||||||
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
def get_credentials(self, model_id: Optional[str] = None) -> dict:
|
||||||
return self.provider.get_credentials(model_id)
|
return self.provider.get_credentials(model_id)
|
||||||
|
|
||||||
def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
|
def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]:
|
||||||
return self.provider.get_provider_configs(obfuscated)
|
return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom)
|
||||||
|
|
||||||
def get_provider_db_record(self, prefer_custom: bool = False) -> Optional[Provider]:
|
def get_provider_db_record(self) -> Optional[Provider]:
|
||||||
return self.provider.get_provider(prefer_custom)
|
return self.provider.get_provider()
|
||||||
|
|
||||||
def config_validate(self, config: Union[dict | str]):
|
def config_validate(self, config: Union[dict | str]):
|
||||||
"""
|
"""
|
||||||
|
@ -4,6 +4,8 @@ from typing import Optional, Union
|
|||||||
import openai
|
import openai
|
||||||
from openai.error import AuthenticationError, OpenAIError
|
from openai.error import AuthenticationError, OpenAIError
|
||||||
|
|
||||||
|
from core import hosted_llm_credentials
|
||||||
|
from core.llm.error import ProviderTokenNotInitError
|
||||||
from core.llm.moderation import Moderation
|
from core.llm.moderation import Moderation
|
||||||
from core.llm.provider.base import BaseProvider
|
from core.llm.provider.base import BaseProvider
|
||||||
from core.llm.provider.errors import ValidateFailedError
|
from core.llm.provider.errors import ValidateFailedError
|
||||||
@ -42,3 +44,12 @@ class OpenAIProvider(BaseProvider):
|
|||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
logging.exception('OpenAI config validation failed')
|
logging.exception('OpenAI config validation failed')
|
||||||
raise ex
|
raise ex
|
||||||
|
|
||||||
|
def get_hosted_credentials(self) -> Union[str | dict]:
|
||||||
|
if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key:
|
||||||
|
raise ProviderTokenNotInitError(
|
||||||
|
f"No valid {self.get_provider_name().value} model provider credentials found. "
|
||||||
|
f"Please go to Settings -> Model Provider to complete your provider credentials."
|
||||||
|
)
|
||||||
|
|
||||||
|
return hosted_llm_credentials.openai.api_key
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun, Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.schema import BaseMessage, ChatResult, LLMResult
|
from langchain.schema import BaseMessage, LLMResult
|
||||||
from langchain.chat_models import AzureChatOpenAI
|
from langchain.chat_models import AzureChatOpenAI
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||||
|
|
||||||
|
|
||||||
class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
||||||
@ -46,30 +46,7 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
|||||||
"organization": self.openai_organization if self.openai_organization else None,
|
"organization": self.openai_organization if self.openai_organization else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
|
@handle_openai_exceptions
|
||||||
"""Get the number of tokens in a list of messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: The messages to count the tokens of.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The number of tokens in the messages.
|
|
||||||
"""
|
|
||||||
tokens_per_message = 5
|
|
||||||
tokens_per_request = 3
|
|
||||||
|
|
||||||
message_tokens = tokens_per_request
|
|
||||||
message_strs = ''
|
|
||||||
for message in messages:
|
|
||||||
message_strs += message.content
|
|
||||||
message_tokens += tokens_per_message
|
|
||||||
|
|
||||||
# calc once
|
|
||||||
message_tokens += self.get_num_tokens(message_strs)
|
|
||||||
|
|
||||||
return message_tokens
|
|
||||||
|
|
||||||
@handle_llm_exceptions
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
messages: List[List[BaseMessage]],
|
messages: List[List[BaseMessage]],
|
||||||
@ -79,12 +56,18 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI):
|
|||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
return super().generate(messages, stop, callbacks, **kwargs)
|
return super().generate(messages, stop, callbacks, **kwargs)
|
||||||
|
|
||||||
@handle_llm_exceptions_async
|
@classmethod
|
||||||
async def agenerate(
|
def get_kwargs_from_model_params(cls, params: dict):
|
||||||
self,
|
model_kwargs = {
|
||||||
messages: List[List[BaseMessage]],
|
'top_p': params.get('top_p', 1),
|
||||||
stop: Optional[List[str]] = None,
|
'frequency_penalty': params.get('frequency_penalty', 0),
|
||||||
callbacks: Callbacks = None,
|
'presence_penalty': params.get('presence_penalty', 0),
|
||||||
**kwargs: Any,
|
}
|
||||||
) -> LLMResult:
|
|
||||||
return await super().agenerate(messages, stop, callbacks, **kwargs)
|
del params['top_p']
|
||||||
|
del params['frequency_penalty']
|
||||||
|
del params['presence_penalty']
|
||||||
|
|
||||||
|
params['model_kwargs'] = model_kwargs
|
||||||
|
|
||||||
|
return params
|
||||||
|
@ -5,7 +5,7 @@ from typing import Optional, List, Dict, Mapping, Any
|
|||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||||
|
|
||||||
|
|
||||||
class StreamableAzureOpenAI(AzureOpenAI):
|
class StreamableAzureOpenAI(AzureOpenAI):
|
||||||
@ -50,7 +50,7 @@ class StreamableAzureOpenAI(AzureOpenAI):
|
|||||||
"organization": self.openai_organization if self.openai_organization else None,
|
"organization": self.openai_organization if self.openai_organization else None,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
@handle_llm_exceptions
|
@handle_openai_exceptions
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
@ -60,12 +60,6 @@ class StreamableAzureOpenAI(AzureOpenAI):
|
|||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
return super().generate(prompts, stop, callbacks, **kwargs)
|
return super().generate(prompts, stop, callbacks, **kwargs)
|
||||||
|
|
||||||
@handle_llm_exceptions_async
|
@classmethod
|
||||||
async def agenerate(
|
def get_kwargs_from_model_params(cls, params: dict):
|
||||||
self,
|
return params
|
||||||
prompts: List[str],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> LLMResult:
|
|
||||||
return await super().agenerate(prompts, stop, callbacks, **kwargs)
|
|
||||||
|
39
api/core/llm/streamable_chat_anthropic.py
Normal file
39
api/core/llm/streamable_chat_anthropic.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from typing import List, Optional, Any, Dict
|
||||||
|
|
||||||
|
from langchain.callbacks.manager import Callbacks
|
||||||
|
from langchain.chat_models import ChatAnthropic
|
||||||
|
from langchain.schema import BaseMessage, LLMResult
|
||||||
|
|
||||||
|
from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions
|
||||||
|
|
||||||
|
|
||||||
|
class StreamableChatAnthropic(ChatAnthropic):
|
||||||
|
"""
|
||||||
|
Wrapper around Anthropic's large language model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@handle_anthropic_exceptions
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
messages: List[List[BaseMessage]],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
callbacks: Callbacks = None,
|
||||||
|
*,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> LLMResult:
|
||||||
|
return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_kwargs_from_model_params(cls, params: dict):
|
||||||
|
params['model'] = params.get('model_name')
|
||||||
|
del params['model_name']
|
||||||
|
|
||||||
|
params['max_tokens_to_sample'] = params.get('max_tokens')
|
||||||
|
del params['max_tokens']
|
||||||
|
|
||||||
|
del params['frequency_penalty']
|
||||||
|
del params['presence_penalty']
|
||||||
|
|
||||||
|
return params
|
@ -7,7 +7,7 @@ from typing import Optional, List, Dict, Any
|
|||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||||
|
|
||||||
|
|
||||||
class StreamableChatOpenAI(ChatOpenAI):
|
class StreamableChatOpenAI(ChatOpenAI):
|
||||||
@ -48,30 +48,7 @@ class StreamableChatOpenAI(ChatOpenAI):
|
|||||||
"organization": self.openai_organization if self.openai_organization else None,
|
"organization": self.openai_organization if self.openai_organization else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
|
@handle_openai_exceptions
|
||||||
"""Get the number of tokens in a list of messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: The messages to count the tokens of.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The number of tokens in the messages.
|
|
||||||
"""
|
|
||||||
tokens_per_message = 5
|
|
||||||
tokens_per_request = 3
|
|
||||||
|
|
||||||
message_tokens = tokens_per_request
|
|
||||||
message_strs = ''
|
|
||||||
for message in messages:
|
|
||||||
message_strs += message.content
|
|
||||||
message_tokens += tokens_per_message
|
|
||||||
|
|
||||||
# calc once
|
|
||||||
message_tokens += self.get_num_tokens(message_strs)
|
|
||||||
|
|
||||||
return message_tokens
|
|
||||||
|
|
||||||
@handle_llm_exceptions
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
messages: List[List[BaseMessage]],
|
messages: List[List[BaseMessage]],
|
||||||
@ -81,12 +58,18 @@ class StreamableChatOpenAI(ChatOpenAI):
|
|||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
return super().generate(messages, stop, callbacks, **kwargs)
|
return super().generate(messages, stop, callbacks, **kwargs)
|
||||||
|
|
||||||
@handle_llm_exceptions_async
|
@classmethod
|
||||||
async def agenerate(
|
def get_kwargs_from_model_params(cls, params: dict):
|
||||||
self,
|
model_kwargs = {
|
||||||
messages: List[List[BaseMessage]],
|
'top_p': params.get('top_p', 1),
|
||||||
stop: Optional[List[str]] = None,
|
'frequency_penalty': params.get('frequency_penalty', 0),
|
||||||
callbacks: Callbacks = None,
|
'presence_penalty': params.get('presence_penalty', 0),
|
||||||
**kwargs: Any,
|
}
|
||||||
) -> LLMResult:
|
|
||||||
return await super().agenerate(messages, stop, callbacks, **kwargs)
|
del params['top_p']
|
||||||
|
del params['frequency_penalty']
|
||||||
|
del params['presence_penalty']
|
||||||
|
|
||||||
|
params['model_kwargs'] = model_kwargs
|
||||||
|
|
||||||
|
return params
|
||||||
|
@ -6,7 +6,7 @@ from typing import Optional, List, Dict, Any, Mapping
|
|||||||
from langchain import OpenAI
|
from langchain import OpenAI
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
|
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||||
|
|
||||||
|
|
||||||
class StreamableOpenAI(OpenAI):
|
class StreamableOpenAI(OpenAI):
|
||||||
@ -49,7 +49,7 @@ class StreamableOpenAI(OpenAI):
|
|||||||
"organization": self.openai_organization if self.openai_organization else None,
|
"organization": self.openai_organization if self.openai_organization else None,
|
||||||
}}
|
}}
|
||||||
|
|
||||||
@handle_llm_exceptions
|
@handle_openai_exceptions
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
@ -59,12 +59,6 @@ class StreamableOpenAI(OpenAI):
|
|||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
return super().generate(prompts, stop, callbacks, **kwargs)
|
return super().generate(prompts, stop, callbacks, **kwargs)
|
||||||
|
|
||||||
@handle_llm_exceptions_async
|
@classmethod
|
||||||
async def agenerate(
|
def get_kwargs_from_model_params(cls, params: dict):
|
||||||
self,
|
return params
|
||||||
prompts: List[str],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
callbacks: Callbacks = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> LLMResult:
|
|
||||||
return await super().agenerate(prompts, stop, callbacks, **kwargs)
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
|
||||||
from models.provider import ProviderName
|
from models.provider import ProviderName
|
||||||
from core.llm.error_handle_wraps import handle_llm_exceptions
|
|
||||||
from core.llm.provider.base import BaseProvider
|
from core.llm.provider.base import BaseProvider
|
||||||
|
|
||||||
|
|
||||||
@ -13,7 +14,7 @@ class Whisper:
|
|||||||
self.client = openai.Audio
|
self.client = openai.Audio
|
||||||
self.credentials = provider.get_credentials()
|
self.credentials = provider.get_credentials()
|
||||||
|
|
||||||
@handle_llm_exceptions
|
@handle_openai_exceptions
|
||||||
def transcribe(self, file):
|
def transcribe(self, file):
|
||||||
return self.client.transcribe(
|
return self.client.transcribe(
|
||||||
model='whisper-1',
|
model='whisper-1',
|
||||||
|
27
api/core/llm/wrappers/anthropic_wrapper.py
Normal file
27
api/core/llm/wrappers/anthropic_wrapper.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import logging
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \
|
||||||
|
LLMBadRequestError
|
||||||
|
|
||||||
|
|
||||||
|
def handle_anthropic_exceptions(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
except anthropic.APIConnectionError as e:
|
||||||
|
logging.exception("Failed to connect to Anthropic API.")
|
||||||
|
raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}")
|
||||||
|
except anthropic.RateLimitError:
|
||||||
|
raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.")
|
||||||
|
except anthropic.AuthenticationError as e:
|
||||||
|
raise LLMAuthorizationError(f"Anthropic: {e.message}")
|
||||||
|
except anthropic.BadRequestError as e:
|
||||||
|
raise LLMBadRequestError(f"Anthropic: {e.message}")
|
||||||
|
except anthropic.APIStatusError as e:
|
||||||
|
raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}")
|
||||||
|
|
||||||
|
return wrapper
|
@ -7,7 +7,7 @@ from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRat
|
|||||||
LLMBadRequestError
|
LLMBadRequestError
|
||||||
|
|
||||||
|
|
||||||
def handle_llm_exceptions(func):
|
def handle_openai_exceptions(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
try:
|
try:
|
||||||
@ -29,27 +29,3 @@ def handle_llm_exceptions(func):
|
|||||||
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
|
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def handle_llm_exceptions_async(func):
|
|
||||||
@wraps(func)
|
|
||||||
async def wrapper(*args, **kwargs):
|
|
||||||
try:
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
except openai.error.InvalidRequestError as e:
|
|
||||||
logging.exception("Invalid request to OpenAI API.")
|
|
||||||
raise LLMBadRequestError(str(e))
|
|
||||||
except openai.error.APIConnectionError as e:
|
|
||||||
logging.exception("Failed to connect to OpenAI API.")
|
|
||||||
raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e))
|
|
||||||
except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e:
|
|
||||||
logging.exception("OpenAI service unavailable.")
|
|
||||||
raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e))
|
|
||||||
except openai.error.RateLimitError as e:
|
|
||||||
raise LLMRateLimitError(str(e))
|
|
||||||
except openai.error.AuthenticationError as e:
|
|
||||||
raise LLMAuthorizationError(str(e))
|
|
||||||
except openai.error.OpenAIError as e:
|
|
||||||
raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e))
|
|
||||||
|
|
||||||
return wrapper
|
|
@ -1,7 +1,7 @@
|
|||||||
from typing import Any, List, Dict, Union
|
from typing import Any, List, Dict, Union
|
||||||
|
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage
|
from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel
|
||||||
|
|
||||||
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
|
||||||
from core.llm.streamable_open_ai import StreamableOpenAI
|
from core.llm.streamable_open_ai import StreamableOpenAI
|
||||||
@ -12,8 +12,8 @@ from models.model import Conversation, Message
|
|||||||
class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||||
conversation: Conversation
|
conversation: Conversation
|
||||||
human_prefix: str = "Human"
|
human_prefix: str = "Human"
|
||||||
ai_prefix: str = "AI"
|
ai_prefix: str = "Assistant"
|
||||||
llm: Union[StreamableChatOpenAI | StreamableOpenAI]
|
llm: BaseLanguageModel
|
||||||
memory_key: str = "chat_history"
|
memory_key: str = "chat_history"
|
||||||
max_token_limit: int = 2000
|
max_token_limit: int = 2000
|
||||||
message_limit: int = 10
|
message_limit: int = 10
|
||||||
@ -38,12 +38,12 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
|||||||
return chat_messages
|
return chat_messages
|
||||||
|
|
||||||
# prune the chat message if it exceeds the max token limit
|
# prune the chat message if it exceeds the max token limit
|
||||||
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
|
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
|
||||||
if curr_buffer_length > self.max_token_limit:
|
if curr_buffer_length > self.max_token_limit:
|
||||||
pruned_memory = []
|
pruned_memory = []
|
||||||
while curr_buffer_length > self.max_token_limit and chat_messages:
|
while curr_buffer_length > self.max_token_limit and chat_messages:
|
||||||
pruned_memory.append(chat_messages.pop(0))
|
pruned_memory.append(chat_messages.pop(0))
|
||||||
curr_buffer_length = self.llm.get_messages_tokens(chat_messages)
|
curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages)
|
||||||
|
|
||||||
return chat_messages
|
return chat_messages
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ class DatasetTool(BaseTool):
|
|||||||
else:
|
else:
|
||||||
model_credentials = LLMBuilder.get_model_credentials(
|
model_credentials = LLMBuilder.get_model_credentials(
|
||||||
tenant_id=self.dataset.tenant_id,
|
tenant_id=self.dataset.tenant_id,
|
||||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
|
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
|
||||||
model_name='text-embedding-ada-002'
|
model_name='text-embedding-ada-002'
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ class DatasetTool(BaseTool):
|
|||||||
async def _arun(self, tool_input: str) -> str:
|
async def _arun(self, tool_input: str) -> str:
|
||||||
model_credentials = LLMBuilder.get_model_credentials(
|
model_credentials = LLMBuilder.get_model_credentials(
|
||||||
tenant_id=self.dataset.tenant_id,
|
tenant_id=self.dataset.tenant_id,
|
||||||
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id),
|
model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'),
|
||||||
model_name='text-embedding-ada-002'
|
model_name='text-embedding-ada-002'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
|
from flask import current_app
|
||||||
|
|
||||||
from events.tenant_event import tenant_was_updated
|
from events.tenant_event import tenant_was_updated
|
||||||
|
from models.provider import ProviderName
|
||||||
from services.provider_service import ProviderService
|
from services.provider_service import ProviderService
|
||||||
|
|
||||||
|
|
||||||
@ -6,4 +9,16 @@ from services.provider_service import ProviderService
|
|||||||
def handle(sender, **kwargs):
|
def handle(sender, **kwargs):
|
||||||
tenant = sender
|
tenant = sender
|
||||||
if tenant.status == 'normal':
|
if tenant.status == 'normal':
|
||||||
ProviderService.create_system_provider(tenant)
|
ProviderService.create_system_provider(
|
||||||
|
tenant,
|
||||||
|
ProviderName.OPENAI.value,
|
||||||
|
current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
|
||||||
|
True
|
||||||
|
)
|
||||||
|
|
||||||
|
ProviderService.create_system_provider(
|
||||||
|
tenant,
|
||||||
|
ProviderName.ANTHROPIC.value,
|
||||||
|
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
|
||||||
|
True
|
||||||
|
)
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
|
from flask import current_app
|
||||||
|
|
||||||
from events.tenant_event import tenant_was_created
|
from events.tenant_event import tenant_was_created
|
||||||
|
from models.provider import ProviderName
|
||||||
from services.provider_service import ProviderService
|
from services.provider_service import ProviderService
|
||||||
|
|
||||||
|
|
||||||
@ -6,4 +9,16 @@ from services.provider_service import ProviderService
|
|||||||
def handle(sender, **kwargs):
|
def handle(sender, **kwargs):
|
||||||
tenant = sender
|
tenant = sender
|
||||||
if tenant.status == 'normal':
|
if tenant.status == 'normal':
|
||||||
ProviderService.create_system_provider(tenant)
|
ProviderService.create_system_provider(
|
||||||
|
tenant,
|
||||||
|
ProviderName.OPENAI.value,
|
||||||
|
current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'],
|
||||||
|
True
|
||||||
|
)
|
||||||
|
|
||||||
|
ProviderService.create_system_provider(
|
||||||
|
tenant,
|
||||||
|
ProviderName.ANTHROPIC.value,
|
||||||
|
current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'],
|
||||||
|
True
|
||||||
|
)
|
||||||
|
@ -10,7 +10,7 @@ flask-session2==1.3.1
|
|||||||
flask-cors==3.0.10
|
flask-cors==3.0.10
|
||||||
gunicorn~=20.1.0
|
gunicorn~=20.1.0
|
||||||
gevent~=22.10.2
|
gevent~=22.10.2
|
||||||
langchain==0.0.209
|
langchain==0.0.230
|
||||||
openai~=0.27.5
|
openai~=0.27.5
|
||||||
psycopg2-binary~=2.9.6
|
psycopg2-binary~=2.9.6
|
||||||
pycryptodome==3.17
|
pycryptodome==3.17
|
||||||
@ -35,3 +35,4 @@ docx2txt==0.8
|
|||||||
pypdfium2==4.16.0
|
pypdfium2==4.16.0
|
||||||
resend~=0.5.1
|
resend~=0.5.1
|
||||||
pyjwt~=2.6.0
|
pyjwt~=2.6.0
|
||||||
|
anthropic~=0.3.4
|
||||||
|
@ -6,6 +6,30 @@ from models.account import Account
|
|||||||
from services.dataset_service import DatasetService
|
from services.dataset_service import DatasetService
|
||||||
from core.llm.llm_builder import LLMBuilder
|
from core.llm.llm_builder import LLMBuilder
|
||||||
|
|
||||||
|
MODEL_PROVIDERS = [
|
||||||
|
'openai',
|
||||||
|
'anthropic',
|
||||||
|
]
|
||||||
|
|
||||||
|
MODELS_BY_APP_MODE = {
|
||||||
|
'chat': [
|
||||||
|
'claude-instant-1',
|
||||||
|
'claude-2',
|
||||||
|
'gpt-4',
|
||||||
|
'gpt-4-32k',
|
||||||
|
'gpt-3.5-turbo',
|
||||||
|
'gpt-3.5-turbo-16k',
|
||||||
|
],
|
||||||
|
'completion': [
|
||||||
|
'claude-instant-1',
|
||||||
|
'claude-2',
|
||||||
|
'gpt-4',
|
||||||
|
'gpt-4-32k',
|
||||||
|
'gpt-3.5-turbo',
|
||||||
|
'gpt-3.5-turbo-16k',
|
||||||
|
'text-davinci-003',
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
class AppModelConfigService:
|
class AppModelConfigService:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -125,7 +149,7 @@ class AppModelConfigService:
|
|||||||
if not isinstance(config["speech_to_text"]["enabled"], bool):
|
if not isinstance(config["speech_to_text"]["enabled"], bool):
|
||||||
raise ValueError("enabled in speech_to_text must be of boolean type")
|
raise ValueError("enabled in speech_to_text must be of boolean type")
|
||||||
|
|
||||||
provider_name = LLMBuilder.get_default_provider(account.current_tenant_id)
|
provider_name = LLMBuilder.get_default_provider(account.current_tenant_id, 'whisper-1')
|
||||||
|
|
||||||
if config["speech_to_text"]["enabled"] and provider_name != 'openai':
|
if config["speech_to_text"]["enabled"] and provider_name != 'openai':
|
||||||
raise ValueError("provider not support speech to text")
|
raise ValueError("provider not support speech to text")
|
||||||
@ -153,14 +177,14 @@ class AppModelConfigService:
|
|||||||
raise ValueError("model must be of object type")
|
raise ValueError("model must be of object type")
|
||||||
|
|
||||||
# model.provider
|
# model.provider
|
||||||
if 'provider' not in config["model"] or config["model"]["provider"] != "openai":
|
if 'provider' not in config["model"] or config["model"]["provider"] not in MODEL_PROVIDERS:
|
||||||
raise ValueError("model.provider must be 'openai'")
|
raise ValueError(f"model.provider is required and must be in {str(MODEL_PROVIDERS)}")
|
||||||
|
|
||||||
# model.name
|
# model.name
|
||||||
if 'name' not in config["model"]:
|
if 'name' not in config["model"]:
|
||||||
raise ValueError("model.name is required")
|
raise ValueError("model.name is required")
|
||||||
|
|
||||||
if config["model"]["name"] not in llm_constant.models_by_mode[mode]:
|
if config["model"]["name"] not in MODELS_BY_APP_MODE[mode]:
|
||||||
raise ValueError("model.name must be in the specified model list")
|
raise ValueError("model.name must be in the specified model list")
|
||||||
|
|
||||||
# model.completion_params
|
# model.completion_params
|
||||||
|
@ -27,7 +27,7 @@ class AudioService:
|
|||||||
message = f"Audio size larger than {FILE_SIZE} mb"
|
message = f"Audio size larger than {FILE_SIZE} mb"
|
||||||
raise AudioTooLargeServiceError(message)
|
raise AudioTooLargeServiceError(message)
|
||||||
|
|
||||||
provider_name = LLMBuilder.get_default_provider(tenant_id)
|
provider_name = LLMBuilder.get_default_provider(tenant_id, 'whisper-1')
|
||||||
if provider_name != ProviderName.OPENAI.value:
|
if provider_name != ProviderName.OPENAI.value:
|
||||||
raise ProviderNotSupportSpeechToTextServiceError()
|
raise ProviderNotSupportSpeechToTextServiceError()
|
||||||
|
|
||||||
@ -37,8 +37,3 @@ class AudioService:
|
|||||||
buffer.name = 'temp.mp3'
|
buffer.name = 'temp.mp3'
|
||||||
|
|
||||||
return Whisper(provider_service.provider).transcribe(buffer)
|
return Whisper(provider_service.provider).transcribe(buffer)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -31,7 +31,7 @@ class HitTestingService:
|
|||||||
|
|
||||||
model_credentials = LLMBuilder.get_model_credentials(
|
model_credentials = LLMBuilder.get_model_credentials(
|
||||||
tenant_id=dataset.tenant_id,
|
tenant_id=dataset.tenant_id,
|
||||||
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id),
|
model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'),
|
||||||
model_name='text-embedding-ada-002'
|
model_name='text-embedding-ada-002'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -10,50 +10,40 @@ from models.provider import *
|
|||||||
class ProviderService:
|
class ProviderService:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def init_supported_provider(tenant, edition):
|
def init_supported_provider(tenant):
|
||||||
"""Initialize the model provider, check whether the supported provider has a record"""
|
"""Initialize the model provider, check whether the supported provider has a record"""
|
||||||
|
|
||||||
providers = Provider.query.filter_by(tenant_id=tenant.id).all()
|
need_init_provider_names = [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value, ProviderName.ANTHROPIC.value]
|
||||||
|
|
||||||
openai_provider_exists = False
|
providers = db.session.query(Provider).filter(
|
||||||
azure_openai_provider_exists = False
|
Provider.tenant_id == tenant.id,
|
||||||
|
Provider.provider_type == ProviderType.CUSTOM.value,
|
||||||
# TODO: The cloud version needs to construct the data of the SYSTEM type
|
Provider.provider_name.in_(need_init_provider_names)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
exists_provider_names = []
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
if provider.provider_name == ProviderName.OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
|
exists_provider_names.append(provider.provider_name)
|
||||||
openai_provider_exists = True
|
|
||||||
if provider.provider_name == ProviderName.AZURE_OPENAI.value and provider.provider_type == ProviderType.CUSTOM.value:
|
|
||||||
azure_openai_provider_exists = True
|
|
||||||
|
|
||||||
# Initialize the model provider, check whether the supported provider has a record
|
not_exists_provider_names = list(set(need_init_provider_names) - set(exists_provider_names))
|
||||||
|
|
||||||
# Create default providers if they don't exist
|
if not_exists_provider_names:
|
||||||
if not openai_provider_exists:
|
# Initialize the model provider, check whether the supported provider has a record
|
||||||
openai_provider = Provider(
|
for provider_name in not_exists_provider_names:
|
||||||
tenant_id=tenant.id,
|
provider = Provider(
|
||||||
provider_name=ProviderName.OPENAI.value,
|
tenant_id=tenant.id,
|
||||||
provider_type=ProviderType.CUSTOM.value,
|
provider_name=provider_name,
|
||||||
is_valid=False
|
provider_type=ProviderType.CUSTOM.value,
|
||||||
)
|
is_valid=False
|
||||||
db.session.add(openai_provider)
|
)
|
||||||
|
db.session.add(provider)
|
||||||
|
|
||||||
if not azure_openai_provider_exists:
|
|
||||||
azure_openai_provider = Provider(
|
|
||||||
tenant_id=tenant.id,
|
|
||||||
provider_name=ProviderName.AZURE_OPENAI.value,
|
|
||||||
provider_type=ProviderType.CUSTOM.value,
|
|
||||||
is_valid=False
|
|
||||||
)
|
|
||||||
db.session.add(azure_openai_provider)
|
|
||||||
|
|
||||||
if not openai_provider_exists or not azure_openai_provider_exists:
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_obfuscated_api_key(tenant, provider_name: ProviderName):
|
def get_obfuscated_api_key(tenant, provider_name: ProviderName, only_custom: bool = False):
|
||||||
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
|
llm_provider_service = LLMProviderService(tenant.id, provider_name.value)
|
||||||
return llm_provider_service.get_provider_configs(obfuscated=True)
|
return llm_provider_service.get_provider_configs(obfuscated=True, only_custom=only_custom)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_token_type(tenant, provider_name: ProviderName):
|
def get_token_type(tenant, provider_name: ProviderName):
|
||||||
@ -73,7 +63,7 @@ class ProviderService:
|
|||||||
return llm_provider_service.get_encrypted_token(configs)
|
return llm_provider_service.get_encrypted_token(configs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value,
|
def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, quota_limit: int = 200,
|
||||||
is_valid: bool = True):
|
is_valid: bool = True):
|
||||||
if current_app.config['EDITION'] != 'CLOUD':
|
if current_app.config['EDITION'] != 'CLOUD':
|
||||||
return
|
return
|
||||||
@ -90,7 +80,7 @@ class ProviderService:
|
|||||||
provider_name=provider_name,
|
provider_name=provider_name,
|
||||||
provider_type=ProviderType.SYSTEM.value,
|
provider_type=ProviderType.SYSTEM.value,
|
||||||
quota_type=ProviderQuotaType.TRIAL.value,
|
quota_type=ProviderQuotaType.TRIAL.value,
|
||||||
quota_limit=200,
|
quota_limit=quota_limit,
|
||||||
encrypted_config='',
|
encrypted_config='',
|
||||||
is_valid=is_valid,
|
is_valid=is_valid,
|
||||||
)
|
)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Tenant
|
from models.account import Tenant
|
||||||
from models.provider import Provider, ProviderType
|
from models.provider import Provider, ProviderType, ProviderName
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceService:
|
class WorkspaceService:
|
||||||
@ -33,7 +33,7 @@ class WorkspaceService:
|
|||||||
if provider.is_valid and provider.encrypted_config:
|
if provider.is_valid and provider.encrypted_config:
|
||||||
custom_provider = provider
|
custom_provider = provider
|
||||||
elif provider.provider_type == ProviderType.SYSTEM.value:
|
elif provider.provider_type == ProviderType.SYSTEM.value:
|
||||||
if provider.is_valid:
|
if provider.provider_name == ProviderName.OPENAI.value and provider.is_valid:
|
||||||
system_provider = provider
|
system_provider = provider
|
||||||
|
|
||||||
if system_provider and not custom_provider:
|
if system_provider and not custom_provider:
|
||||||
|
Loading…
Reference in New Issue
Block a user