mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-02 19:27:48 +08:00
feat: chatglm3 support (#1616)
This commit is contained in:
parent
0e627c920f
commit
ea526d0822
@ -1,27 +1,45 @@
|
|||||||
import decimal
|
import logging
|
||||||
from typing import List, Optional, Any
|
from typing import List, Optional, Any
|
||||||
|
|
||||||
|
import openai
|
||||||
from langchain.callbacks.manager import Callbacks
|
from langchain.callbacks.manager import Callbacks
|
||||||
from langchain.llms import ChatGLM
|
from langchain.schema import LLMResult, get_buffer_string
|
||||||
from langchain.schema import LLMResult
|
|
||||||
|
|
||||||
from core.model_providers.error import LLMBadRequestError
|
from core.model_providers.error import LLMBadRequestError, LLMRateLimitError, LLMAuthorizationError, \
|
||||||
|
LLMAPIUnavailableError, LLMAPIConnectionError
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
from core.model_providers.models.llm.base import BaseLLM
|
||||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||||
|
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
class ChatGLMModel(BaseLLM):
|
class ChatGLMModel(BaseLLM):
|
||||||
model_mode: ModelMode = ModelMode.COMPLETION
|
model_mode: ModelMode = ModelMode.CHAT
|
||||||
|
|
||||||
def _init_client(self) -> Any:
|
def _init_client(self) -> Any:
|
||||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||||
return ChatGLM(
|
|
||||||
|
extra_model_kwargs = {
|
||||||
|
'top_p': provider_model_kwargs.get('top_p')
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider_model_kwargs.get('max_length') is not None:
|
||||||
|
extra_model_kwargs['max_length'] = provider_model_kwargs.get('max_length')
|
||||||
|
|
||||||
|
client = EnhanceChatOpenAI(
|
||||||
|
model_name=self.name,
|
||||||
|
temperature=provider_model_kwargs.get('temperature'),
|
||||||
|
max_tokens=provider_model_kwargs.get('max_tokens'),
|
||||||
|
model_kwargs=extra_model_kwargs,
|
||||||
|
streaming=self.streaming,
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
endpoint_url=self.credentials.get('api_base'),
|
request_timeout=60,
|
||||||
**provider_model_kwargs
|
openai_api_key="1",
|
||||||
|
openai_api_base=self.credentials['api_base'] + '/v1'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
def _run(self, messages: List[PromptMessage],
|
def _run(self, messages: List[PromptMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
@ -45,19 +63,40 @@ class ChatGLMModel(BaseLLM):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
prompts = self._get_prompt_from_messages(messages)
|
prompts = self._get_prompt_from_messages(messages)
|
||||||
return max(self._client.get_num_tokens(prompts), 0)
|
return max(sum([self._client.get_num_tokens(get_buffer_string([m])) for m in prompts]) - len(prompts), 0)
|
||||||
|
|
||||||
def get_currency(self):
|
def get_currency(self):
|
||||||
return 'RMB'
|
return 'RMB'
|
||||||
|
|
||||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||||
for k, v in provider_model_kwargs.items():
|
extra_model_kwargs = {
|
||||||
if hasattr(self.client, k):
|
'top_p': provider_model_kwargs.get('top_p')
|
||||||
setattr(self.client, k, v)
|
}
|
||||||
|
|
||||||
|
self.client.temperature = provider_model_kwargs.get('temperature')
|
||||||
|
self.client.max_tokens = provider_model_kwargs.get('max_tokens')
|
||||||
|
self.client.model_kwargs = extra_model_kwargs
|
||||||
|
|
||||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||||
if isinstance(ex, ValueError):
|
if isinstance(ex, openai.error.InvalidRequestError):
|
||||||
return LLMBadRequestError(f"ChatGLM: {str(ex)}")
|
logging.warning("Invalid request to ChatGLM API.")
|
||||||
|
return LLMBadRequestError(str(ex))
|
||||||
|
elif isinstance(ex, openai.error.APIConnectionError):
|
||||||
|
logging.warning("Failed to connect to ChatGLM API.")
|
||||||
|
return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex))
|
||||||
|
elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)):
|
||||||
|
logging.warning("ChatGLM service unavailable.")
|
||||||
|
return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex))
|
||||||
|
elif isinstance(ex, openai.error.RateLimitError):
|
||||||
|
return LLMRateLimitError(str(ex))
|
||||||
|
elif isinstance(ex, openai.error.AuthenticationError):
|
||||||
|
return LLMAuthorizationError(str(ex))
|
||||||
|
elif isinstance(ex, openai.error.OpenAIError):
|
||||||
|
return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex))
|
||||||
else:
|
else:
|
||||||
return ex
|
return ex
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def support_streaming(cls):
|
||||||
|
return True
|
@ -2,6 +2,7 @@ import json
|
|||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
|
import requests
|
||||||
from langchain.llms import ChatGLM
|
from langchain.llms import ChatGLM
|
||||||
|
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
@ -25,21 +26,26 @@ class ChatGLMProvider(BaseModelProvider):
|
|||||||
if model_type == ModelType.TEXT_GENERATION:
|
if model_type == ModelType.TEXT_GENERATION:
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
'id': 'chatglm2-6b',
|
'id': 'chatglm3-6b',
|
||||||
'name': 'ChatGLM2-6B',
|
'name': 'ChatGLM3-6B',
|
||||||
'mode': ModelMode.COMPLETION.value,
|
'mode': ModelMode.CHAT.value,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'id': 'chatglm-6b',
|
'id': 'chatglm3-6b-32k',
|
||||||
'name': 'ChatGLM-6B',
|
'name': 'ChatGLM3-6B-32K',
|
||||||
'mode': ModelMode.COMPLETION.value,
|
'mode': ModelMode.CHAT.value,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'id': 'chatglm2-6b',
|
||||||
|
'name': 'ChatGLM2-6B',
|
||||||
|
'mode': ModelMode.CHAT.value,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _get_text_generation_model_mode(self, model_name) -> str:
|
def _get_text_generation_model_mode(self, model_name) -> str:
|
||||||
return ModelMode.COMPLETION.value
|
return ModelMode.CHAT.value
|
||||||
|
|
||||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||||
"""
|
"""
|
||||||
@ -64,16 +70,19 @@ class ChatGLMProvider(BaseModelProvider):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
model_max_tokens = {
|
model_max_tokens = {
|
||||||
'chatglm-6b': 2000,
|
'chatglm3-6b-32k': 32000,
|
||||||
'chatglm2-6b': 32000,
|
'chatglm3-6b': 8000,
|
||||||
|
'chatglm2-6b': 8000,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
max_tokens_alias = 'max_length' if model_name == 'chatglm2-6b' else 'max_tokens'
|
||||||
|
|
||||||
return ModelKwargsRules(
|
return ModelKwargsRules(
|
||||||
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
|
temperature=KwargRule[float](min=0, max=2, default=1, precision=2),
|
||||||
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
|
top_p=KwargRule[float](min=0, max=1, default=0.7, precision=2),
|
||||||
presence_penalty=KwargRule[float](enabled=False),
|
presence_penalty=KwargRule[float](enabled=False),
|
||||||
frequency_penalty=KwargRule[float](enabled=False),
|
frequency_penalty=KwargRule[float](enabled=False),
|
||||||
max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
|
max_tokens=KwargRule[int](alias=max_tokens_alias, min=10, max=model_max_tokens.get(model_name), default=2048, precision=0),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -85,16 +94,10 @@ class ChatGLMProvider(BaseModelProvider):
|
|||||||
raise CredentialsValidateFailedError('ChatGLM Endpoint URL must be provided.')
|
raise CredentialsValidateFailedError('ChatGLM Endpoint URL must be provided.')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
credential_kwargs = {
|
response = requests.get(f"{credentials['api_base']}/v1/models", timeout=5)
|
||||||
'endpoint_url': credentials['api_base']
|
|
||||||
}
|
|
||||||
|
|
||||||
llm = ChatGLM(
|
if response.status_code != 200:
|
||||||
max_token=10,
|
raise Exception('ChatGLM Endpoint URL is invalid.')
|
||||||
**credential_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
llm("ping")
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise CredentialsValidateFailedError(str(ex))
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user