feat: chatglm3 support (#1616)

This commit is contained in:
takatost 2023-11-25 15:37:07 +08:00 committed by GitHub
parent 0e627c920f
commit ea526d0822
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 75 additions and 33 deletions

View File

@ -1,27 +1,45 @@
import decimal import logging
from typing import List, Optional, Any from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.llms import ChatGLM from langchain.schema import LLMResult, get_buffer_string
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError from core.model_providers.error import LLMBadRequestError, LLMRateLimitError, LLMAuthorizationError, \
LLMAPIUnavailableError, LLMAPIConnectionError
from core.model_providers.models.llm.base import BaseLLM from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
class ChatGLMModel(BaseLLM): class ChatGLMModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any: def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ChatGLM(
extra_model_kwargs = {
'top_p': provider_model_kwargs.get('top_p')
}
if provider_model_kwargs.get('max_length') is not None:
extra_model_kwargs['max_length'] = provider_model_kwargs.get('max_length')
client = EnhanceChatOpenAI(
model_name=self.name,
temperature=provider_model_kwargs.get('temperature'),
max_tokens=provider_model_kwargs.get('max_tokens'),
model_kwargs=extra_model_kwargs,
streaming=self.streaming,
callbacks=self.callbacks, callbacks=self.callbacks,
endpoint_url=self.credentials.get('api_base'), request_timeout=60,
**provider_model_kwargs openai_api_key="1",
openai_api_base=self.credentials['api_base'] + '/v1'
) )
return client
def _run(self, messages: List[PromptMessage], def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
@ -45,19 +63,40 @@ class ChatGLMModel(BaseLLM):
:return: :return:
""" """
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0)
def get_currency(self): def get_currency(self):
return 'RMB' return 'RMB'
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items(): extra_model_kwargs = {
if hasattr(self.client, k): 'top_p': provider_model_kwargs.get('top_p')
setattr(self.client, k, v) }
self.client.temperature = provider_model_kwargs.get('temperature')
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
self.client.model_kwargs = extra_model_kwargs
def handle_exceptions(self, ex: Exception) -> Exception: def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, ValueError): if isinstance(ex, openai.error.InvalidRequestError):
return LLMBadRequestError(f"ChatGLM: {str(ex)}") logging.warning("Invalid request to ChatGLM API.")
return LLMBadRequestError(str(ex))
elif isinstance(ex, openai.error.APIConnectionError):
logging.warning("Failed to connect to ChatGLM API.")
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
logging.warning("ChatGLM service unavailable.")
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
elif isinstance(ex, openai.error.RateLimitError):
return LLMRateLimitError(str(ex))
elif isinstance(ex, openai.error.AuthenticationError):
return LLMAuthorizationError(str(ex))
elif isinstance(ex, openai.error.OpenAIError):
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
else: else:
return ex return ex
@classmethod
def support_streaming(cls):
return True

View File

@ -2,6 +2,7 @@ import json
from json import JSONDecodeError from json import JSONDecodeError
from typing import Type from typing import Type
import requests
from langchain.llms import ChatGLM from langchain.llms import ChatGLM
from core.helper import encrypter from core.helper import encrypter
@ -25,21 +26,26 @@ class ChatGLMProvider(BaseModelProvider):
if model_type == ModelType.TEXT_GENERATION: if model_type == ModelType.TEXT_GENERATION:
return [ return [
{ {
'id': 'chatglm2-6b', 'id': 'chatglm3-6b',
'name': 'ChatGLM2-6B', 'name': 'ChatGLM3-6B',
'mode': ModelMode.COMPLETION.value, 'mode': ModelMode.CHAT.value,
}, },
{ {
'id': 'chatglm-6b', 'id': 'chatglm3-6b-32k',
'name': 'ChatGLM-6B', 'name': 'ChatGLM3-6B-32K',
'mode': ModelMode.COMPLETION.value, 'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm2-6b',
'name': 'ChatGLM2-6B',
'mode': ModelMode.CHAT.value,
} }
] ]
else: else:
return [] return []
def _get_text_generation_model_mode(self, model_name) -> str: def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
""" """
@ -64,16 +70,19 @@ class ChatGLMProvider(BaseModelProvider):
:return: :return:
""" """
model_max_tokens = { model_max_tokens = {
'chatglm-6b': 2000, 'chatglm3-6b-32k': 32000,
'chatglm2-6b': 32000, 'chatglm3-6b': 8000,
'chatglm2-6b': 8000,
} }
max_tokens_alias = 'max_length' if model_name == 'chatglm2-6b' else 'max_tokens'
return ModelKwargsRules( return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1, precision=2), temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2), top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
presence_penalty=KwargRule[float](enabled=False), presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False), frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048, precision=0), max_tokens=KwargRule[int](alias=max_tokens_alias, min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
) )
@classmethod @classmethod
@ -85,16 +94,10 @@ class ChatGLMProvider(BaseModelProvider):
raise CredentialsValidateFailedError('ChatGLM Endpoint URL must be provided.') raise CredentialsValidateFailedError('ChatGLM Endpoint URL must be provided.')
try: try:
credential_kwargs = { response = requests.get(f"{credentials['api_base']}/v1/models", timeout=5)
'endpoint_url': credentials['api_base']
}
llm = ChatGLM( if response.status_code != 200:
max_token=10, raise Exception('ChatGLM Endpoint URL is invalid.')
**credential_kwargs
)
llm("ping")
except Exception as ex: except Exception as ex:
raise CredentialsValidateFailedError(str(ex)) raise CredentialsValidateFailedError(str(ex))