fix: remove openllm pypi package because of this package too large (#931)

This commit is contained in:
takatost 2023-08-21 02:12:28 +08:00 committed by GitHub
parent 25264e7852
commit 6c832ee328
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 97 additions and 13 deletions

View File

@ -1,13 +1,13 @@
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.llms import OpenLLM
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.openllm import OpenLLM
class OpenLLMModel(BaseLLM):
@ -19,7 +19,7 @@ class OpenLLMModel(BaseLLM):
client = OpenLLM(
server_url=self.credentials.get('server_url'),
callbacks=self.callbacks,
**self.provider_model_kwargs
llm_kwargs=self.provider_model_kwargs
)
return client

View File

@ -1,14 +1,13 @@
import json
from typing import Type
from langchain.llms import OpenLLM
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.llms.openllm import OpenLLM
from models.provider import ProviderType
@ -46,11 +45,11 @@ class OpenLLMProvider(BaseModelProvider):
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=2, default=1),
temperature=KwargRule[float](min=0.01, max=2, default=1),
top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](min=-2, max=2, default=0),
frequency_penalty=KwargRule[float](min=-2, max=2, default=0),
max_tokens=KwargRule[int](min=10, max=4000, default=128),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=128),
)
@classmethod
@ -71,7 +70,9 @@ class OpenLLMProvider(BaseModelProvider):
}
llm = OpenLLM(
max_tokens=10,
llm_kwargs={
'max_new_tokens': 10
},
**credential_kwargs
)

View File

@ -0,0 +1,87 @@
from __future__ import annotations
import logging
from typing import (
Any,
Dict,
List,
Optional,
)
import requests
from langchain.llms.utils import enforce_stop_tokens
from pydantic import Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
logger = logging.getLogger(__name__)
class OpenLLM(LLM):
"""OpenLLM, supporting both in-process model
instance and remote OpenLLM servers.
If you have a OpenLLM server running, you can also use it remotely:
.. code-block:: python
from langchain.llms import OpenLLM
llm = OpenLLM(server_url='http://localhost:3000')
llm("What is the difference between a duck and a goose?")
"""
server_url: Optional[str] = None
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
llm_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Key word arguments to be passed to openllm.LLM"""
@property
def _llm_type(self) -> str:
return "openllm"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: CallbackManagerForLLMRun | None = None,
**kwargs: Any,
) -> str:
params = {
"prompt": prompt,
"llm_config": self.llm_kwargs
}
headers = {"Content-Type": "application/json"}
response = requests.post(
f'{self.server_url}/v1/generate',
headers=headers,
json=params
)
if not response.ok:
raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}")
json_response = response.json()
completion = json_response["responses"][0]
if completion:
completion = completion[len(prompt):]
if stop is not None:
completion = enforce_stop_tokens(completion, stop)
return completion
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
raise NotImplementedError(
"Async call is not supported for OpenLLM at the moment."
)

View File

@ -49,5 +49,4 @@ huggingface_hub~=0.16.4
transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference==0.2.0
openllm~=0.2.26
xinference==0.2.0

View File

@ -23,8 +23,7 @@ def decrypt_side_effect(tenant_id, encrypted_key):
def test_is_credentials_valid_or_raise_valid(mocker):
mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None)
mocker.patch('langchain.llms.openllm.OpenLLM._call',
mocker.patch('core.third_party.langchain.llms.openllm.OpenLLM._call',
return_value="abc")
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
@ -35,8 +34,6 @@ def test_is_credentials_valid_or_raise_valid(mocker):
def test_is_credentials_valid_or_raise_invalid(mocker):
mocker.patch('langchain.llms.openllm.OpenLLM._identifying_params', return_value=None)
# raise CredentialsValidateFailedError if credential is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(