mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-02 11:18:19 +08:00
feat: optimize xinference request max token key and stop reason (#998)
This commit is contained in:
parent
276d3d10a0
commit
9ae91a2ec3
@ -2,7 +2,6 @@ import json
|
|||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle
|
|
||||||
|
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
|
||||||
@ -73,7 +72,7 @@ class XinferenceProvider(BaseModelProvider):
|
|||||||
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
top_p=KwargRule[float](min=0, max=1, default=0.7),
|
||||||
presence_penalty=KwargRule[float](enabled=False),
|
presence_penalty=KwargRule[float](enabled=False),
|
||||||
frequency_penalty=KwargRule[float](enabled=False),
|
frequency_penalty=KwargRule[float](enabled=False),
|
||||||
max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256),
|
max_tokens=KwargRule[int](min=10, max=4000, default=256),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,11 +9,11 @@ from xinference.client import RESTfulChatglmCppChatModelHandle, \
|
|||||||
|
|
||||||
class XinferenceLLM(Xinference):
|
class XinferenceLLM(Xinference):
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call the xinference model and return the output.
|
"""Call the xinference model and return the output.
|
||||||
|
|
||||||
@ -56,10 +56,10 @@ class XinferenceLLM(Xinference):
|
|||||||
if generate_config and generate_config.get("stream"):
|
if generate_config and generate_config.get("stream"):
|
||||||
combined_text_output = ""
|
combined_text_output = ""
|
||||||
for token in self._stream_generate(
|
for token in self._stream_generate(
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
run_manager=run_manager,
|
run_manager=run_manager,
|
||||||
generate_config=generate_config,
|
generate_config=generate_config,
|
||||||
):
|
):
|
||||||
combined_text_output += token
|
combined_text_output += token
|
||||||
return combined_text_output
|
return combined_text_output
|
||||||
@ -73,10 +73,10 @@ class XinferenceLLM(Xinference):
|
|||||||
if generate_config and generate_config.get("stream"):
|
if generate_config and generate_config.get("stream"):
|
||||||
combined_text_output = ""
|
combined_text_output = ""
|
||||||
for token in self._stream_generate(
|
for token in self._stream_generate(
|
||||||
model=model,
|
model=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
run_manager=run_manager,
|
run_manager=run_manager,
|
||||||
generate_config=generate_config,
|
generate_config=generate_config,
|
||||||
):
|
):
|
||||||
combined_text_output += token
|
combined_text_output += token
|
||||||
completion = combined_text_output
|
completion = combined_text_output
|
||||||
@ -89,13 +89,13 @@ class XinferenceLLM(Xinference):
|
|||||||
|
|
||||||
return completion
|
return completion
|
||||||
|
|
||||||
|
|
||||||
def _stream_generate(
|
def _stream_generate(
|
||||||
self,
|
self,
|
||||||
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
|
model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"],
|
||||||
prompt: str,
|
prompt: str,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
|
generate_config: Optional[
|
||||||
|
Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -123,6 +123,10 @@ class XinferenceLLM(Xinference):
|
|||||||
if choices:
|
if choices:
|
||||||
choice = choices[0]
|
choice = choices[0]
|
||||||
if isinstance(choice, dict):
|
if isinstance(choice, dict):
|
||||||
|
if 'finish_reason' in choice and choice['finish_reason'] \
|
||||||
|
and choice['finish_reason'] in ['stop', 'length']:
|
||||||
|
break
|
||||||
|
|
||||||
if 'text' in choice:
|
if 'text' in choice:
|
||||||
token = choice.get("text", "")
|
token = choice.get("text", "")
|
||||||
elif 'delta' in choice and 'content' in choice['delta']:
|
elif 'delta' in choice and 'content' in choice['delta']:
|
||||||
|
Loading…
Reference in New Issue
Block a user