feat: optimize xinference request max token key and stop reason (#998)

This commit is contained in:
takatost 2023-08-24 18:11:15 +08:00 committed by GitHub
parent 276d3d10a0
commit 9ae91a2ec3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 21 deletions

View File

@ -2,7 +2,6 @@ import json
from typing import Type from typing import Type
import requests import requests
from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
from core.helper import encrypter from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
@ -73,7 +72,7 @@ class XinferenceProvider(BaseModelProvider):
top_p=KwargRule[float](min=0, max=1, default=0.7), top_p=KwargRule[float](min=0, max=1, default=0.7),
presence_penalty=KwargRule[float](enabled=False), presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False), frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256), max_tokens=KwargRule[int](min=10, max=4000, default=256),
) )

View File

@ -9,11 +9,11 @@ from xinference.client import RESTfulChatglmCppChatModelHandle, \
class XinferenceLLM(Xinference): class XinferenceLLM(Xinference):
def _call( def _call(
self, self,
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
"""Call the xinference model and return the output. """Call the xinference model and return the output.
@ -56,10 +56,10 @@ class XinferenceLLM(Xinference):
if generate_config and generate_config.get("stream"): if generate_config and generate_config.get("stream"):
combined_text_output = "" combined_text_output = ""
for token in self._stream_generate( for token in self._stream_generate(
model=model, model=model,
prompt=prompt, prompt=prompt,
run_manager=run_manager, run_manager=run_manager,
generate_config=generate_config, generate_config=generate_config,
): ):
combined_text_output += token combined_text_output += token
return combined_text_output return combined_text_output
@ -73,10 +73,10 @@ class XinferenceLLM(Xinference):
if generate_config and generate_config.get("stream"): if generate_config and generate_config.get("stream"):
combined_text_output = "" combined_text_output = ""
for token in self._stream_generate( for token in self._stream_generate(
model=model, model=model,
prompt=prompt, prompt=prompt,
run_manager=run_manager, run_manager=run_manager,
generate_config=generate_config, generate_config=generate_config,
): ):
combined_text_output += token combined_text_output += token
completion = combined_text_output completion = combined_text_output
@ -89,13 +89,13 @@ class XinferenceLLM(Xinference):
return completion return completion
def _stream_generate( def _stream_generate(
self, self,
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"], model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
prompt: str, prompt: str,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None, generate_config: Optional[
Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
""" """
Args: Args:
@ -123,6 +123,10 @@ class XinferenceLLM(Xinference):
if choices: if choices:
choice = choices[0] choice = choices[0]
if isinstance(choice, dict): if isinstance(choice, dict):
if 'finish_reason' in choice and choice['finish_reason'] \
and choice['finish_reason'] in ['stop', 'length']:
break
if 'text' in choice: if 'text' in choice:
token = choice.get("text", "") token = choice.get("text", "")
elif 'delta' in choice and 'content' in choice['delta']: elif 'delta' in choice and 'content' in choice['delta']: