mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-11-30 10:18:13 +08:00
fix: remove openllm pypi package because of this package too large (#931)
This commit is contained in:
parent
25264e7852
commit
6c832ee328
@ -1,13 +1,13 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.llms import OpenLLM
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.openllm import OpenLLM
|
||||
|
||||
|
||||
class OpenLLMModel(BaseLLM):
|
||||
@ -19,7 +19,7 @@ class OpenLLMModel(BaseLLM):
|
||||
client = OpenLLM(
|
||||
server_url=self.credentials.get('server_url'),
|
||||
callbacks=self.callbacks,
|
||||
**self.provider_model_kwargs
|
||||
llm_kwargs=self.provider_model_kwargs
|
||||
)
|
||||
|
||||
return client
|
||||
|
@ -1,14 +1,13 @@
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from langchain.llms import OpenLLM
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.llm.openllm_model import OpenLLMModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.llms.openllm import OpenLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
@ -46,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=2, default=1),
|
||||
temperature=KwargRule[float](min=0.01, max=2, default=1),
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=128),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -71,7 +70,9 @@ class OpenLLMProvider(BaseModelProvider):
|
||||
}
|
||||
|
||||
llm = OpenLLM(
|
||||
max_tokens=10,
|
||||
llm_kwargs={
|
||||
'max_new_tokens': 10
|
||||
},
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
|
87
api/core/third_party/langchain/llms/openllm.py
vendored
Normal file
87
api/core/third_party/langchain/llms/openllm.py
vendored
Normal file
@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
import requests
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from pydantic import Field
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenLLM(LLM):
|
||||
"""OpenLLM, supporting both in-process model
|
||||
instance and remote OpenLLM servers.
|
||||
|
||||
If you have a OpenLLM server running, you can also use it remotely:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import OpenLLM
|
||||
llm = OpenLLM(server_url='http://localhost:3000')
|
||||
llm("What is the difference between a duck and a goose?")
|
||||
"""
|
||||
|
||||
server_url: Optional[str] = None
|
||||
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
|
||||
llm_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Key word arguments to be passed to openllm.LLM"""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "openllm"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"llm_config": self.llm_kwargs
|
||||
}
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = requests.post(
|
||||
f'{self.server_url}/v1/generate',
|
||||
headers=headers,
|
||||
json=params
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}")
|
||||
|
||||
json_response = response.json()
|
||||
completion = json_response["responses"][0]
|
||||
|
||||
if completion:
|
||||
completion = completion[len(prompt):]
|
||||
|
||||
if stop is not None:
|
||||
completion = enforce_stop_tokens(completion, stop)
|
||||
|
||||
return completion
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
raise NotImplementedError(
|
||||
"Async call is not supported for OpenLLM at the moment."
|
||||
)
|
@ -49,5 +49,4 @@ huggingface_hub~=0.16.4
|
||||
transformers~=4.31.0
|
||||
stripe~=5.5.0
|
||||
pandas==1.5.3
|
||||
xinference==0.2.0
|
||||
openllm~=0.2.26
|
||||
xinference==0.2.0
|
@ -23,8 +23,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
|
||||
|
||||
def test_is_credentials_valid_or_raise_valid(mocker):
|
||||
mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None)
|
||||
mocker.patch('langchain.llms.openllm.OpenLLM._call',
|
||||
mocker.patch('core.third_party.langchain.llms.openllm.OpenLLM._call',
|
||||
return_value="abc")
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
@ -35,8 +34,6 @@ def test_is_credentials_valid_or_raise_valid(mocker):
|
||||
|
||||
|
||||
def test_is_credentials_valid_or_raise_invalid(mocker):
|
||||
mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None)
|
||||
|
||||
# raise CredentialsValidateFailedError if credential is not in credentials
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
|
||||
|
Loading…
Reference in New Issue
Block a user