mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-11-30 18:27:53 +08:00
feat: chatglm3 support (#1616)
This commit is contained in:
parent
0e627c920f
commit
ea526d0822
@ -1,27 +1,45 @@
|
||||
import decimal
|
||||
import logging
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import openai
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import ChatGLM
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema import LLMResult, get_buffer_string
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.error import LLMBadRequestError, LLMRateLimitError, LLMAuthorizationError, \
|
||||
LLMAPIUnavailableError, LLMAPIConnectionError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
|
||||
|
||||
|
||||
class ChatGLMModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.COMPLETION
|
||||
model_mode: ModelMode = ModelMode.CHAT
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return ChatGLM(
|
||||
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p')
|
||||
}
|
||||
|
||||
if provider_model_kwargs.get('max_length') is not None:
|
||||
extra_model_kwargs['max_length'] = provider_model_kwargs.get('max_length')
|
||||
|
||||
client = EnhanceChatOpenAI(
|
||||
model_name=self.name,
|
||||
temperature=provider_model_kwargs.get('temperature'),
|
||||
max_tokens=provider_model_kwargs.get('max_tokens'),
|
||||
model_kwargs=extra_model_kwargs,
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
endpoint_url=self.credentials.get('api_base'),
|
||||
**provider_model_kwargs
|
||||
request_timeout=60,
|
||||
openai_api_key="1",
|
||||
openai_api_base=self.credentials['api_base'] + '/v1'
|
||||
)
|
||||
|
||||
return client
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
@ -45,19 +63,40 @@ class ChatGLMModel(BaseLLM):
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens(prompts), 0)
|
||||
return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0)
|
||||
|
||||
def get_currency(self):
|
||||
return 'RMB'
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
extra_model_kwargs = {
|
||||
'top_p': provider_model_kwargs.get('top_p')
|
||||
}
|
||||
|
||||
self.client.temperature = provider_model_kwargs.get('temperature')
|
||||
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
|
||||
self.client.model_kwargs = extra_model_kwargs
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, ValueError):
|
||||
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
|
||||
if isinstance(ex, openai.error.InvalidRequestError):
|
||||
logging.warning("Invalid request to ChatGLM API.")
|
||||
return LLMBadRequestError(str(ex))
|
||||
elif isinstance(ex, openai.error.APIConnectionError):
|
||||
logging.warning("Failed to connect to ChatGLM API.")
|
||||
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||
logging.warning("ChatGLM service unavailable.")
|
||||
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||
elif isinstance(ex, openai.error.RateLimitError):
|
||||
return LLMRateLimitError(str(ex))
|
||||
elif isinstance(ex, openai.error.AuthenticationError):
|
||||
return LLMAuthorizationError(str(ex))
|
||||
elif isinstance(ex, openai.error.OpenAIError):
|
||||
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||
else:
|
||||
return ex
|
||||
|
||||
@classmethod
|
||||
def support_streaming(cls):
|
||||
return True
|
@ -2,6 +2,7 @@ import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
import requests
|
||||
from langchain.llms import ChatGLM
|
||||
|
||||
from core.helper import encrypter
|
||||
@ -25,21 +26,26 @@ class ChatGLMProvider(BaseModelProvider):
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'chatglm2-6b',
|
||||
'name': 'ChatGLM2-6B',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
'id': 'chatglm3-6b',
|
||||
'name': 'ChatGLM3-6B',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
},
|
||||
{
|
||||
'id': 'chatglm-6b',
|
||||
'name': 'ChatGLM-6B',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
'id': 'chatglm3-6b-32k',
|
||||
'name': 'ChatGLM3-6B-32K',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
},
|
||||
{
|
||||
'id': 'chatglm2-6b',
|
||||
'name': 'ChatGLM2-6B',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||
return ModelMode.COMPLETION.value
|
||||
return ModelMode.CHAT.value
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
@ -64,16 +70,19 @@ class ChatGLMProvider(BaseModelProvider):
|
||||
:return:
|
||||
"""
|
||||
model_max_tokens = {
|
||||
'chatglm-6b': 2000,
|
||||
'chatglm2-6b': 32000,
|
||||
'chatglm3-6b-32k': 32000,
|
||||
'chatglm3-6b': 8000,
|
||||
'chatglm2-6b': 8000,
|
||||
}
|
||||
|
||||
max_tokens_alias = 'max_length' if model_name == 'chatglm2-6b' else 'max_tokens'
|
||||
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
|
||||
max_tokens=KwargRule[int](alias=max_tokens_alias, min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -85,16 +94,10 @@ class ChatGLMProvider(BaseModelProvider):
|
||||
raise CredentialsValidateFailedError('ChatGLM Endpoint URL must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'endpoint_url': credentials['api_base']
|
||||
}
|
||||
response = requests.get(f"{credentials['api_base']}/v1/models", timeout=5)
|
||||
|
||||
llm = ChatGLM(
|
||||
max_token=10,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
if response.status_code != 200:
|
||||
raise Exception('ChatGLM Endpoint URL is invalid.')
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user