mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-11-30 02:08:37 +08:00
feat: optimize xinference request max token key and stop reason (#998)
This commit is contained in:
parent
276d3d10a0
commit
9ae91a2ec3
@ -2,7 +2,6 @@ import json
|
||||
from typing import Type
|
||||
|
||||
import requests
|
||||
from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
||||
@ -73,7 +72,7 @@ class XinferenceProvider(BaseModelProvider):
|
||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
|
||||
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||
)
|
||||
|
||||
|
||||
|
@ -9,11 +9,11 @@ from xinference.client import RESTfulChatglmCppChatModelHandle, \
|
||||
|
||||
class XinferenceLLM(Xinference):
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call the xinference model and return the output.
|
||||
|
||||
@ -56,10 +56,10 @@ class XinferenceLLM(Xinference):
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
return combined_text_output
|
||||
@ -73,10 +73,10 @@ class XinferenceLLM(Xinference):
|
||||
if generate_config and generate_config.get("stream"):
|
||||
combined_text_output = ""
|
||||
for token in self._stream_generate(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
run_manager=run_manager,
|
||||
generate_config=generate_config,
|
||||
):
|
||||
combined_text_output += token
|
||||
completion = combined_text_output
|
||||
@ -89,13 +89,13 @@ class XinferenceLLM(Xinference):
|
||||
|
||||
return completion
|
||||
|
||||
|
||||
def _stream_generate(
|
||||
self,
|
||||
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
|
||||
prompt: str,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
|
||||
self,
|
||||
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
|
||||
prompt: str,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
generate_config: Optional[
|
||||
Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""
|
||||
Args:
|
||||
@ -123,6 +123,10 @@ class XinferenceLLM(Xinference):
|
||||
if choices:
|
||||
choice = choices[0]
|
||||
if isinstance(choice, dict):
|
||||
if 'finish_reason' in choice and choice['finish_reason'] \
|
||||
and choice['finish_reason'] in ['stop', 'length']:
|
||||
break
|
||||
|
||||
if 'text' in choice:
|
||||
token = choice.get("text", "")
|
||||
elif 'delta' in choice and 'content' in choice['delta']:
|
||||
|
Loading…
Reference in New Issue
Block a user