From 9ae91a2ec3921ee1cfcbc91dd9ed6191849f8601 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 24 Aug 2023 18:11:15 +0800 Subject: [PATCH] feat: optimize xinference request max token key and stop reason (#998) --- .../providers/xinference_provider.py | 3 +- .../langchain/llms/xinference_llm.py | 42 ++++++++++--------- 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/api/core/model_providers/providers/xinference_provider.py b/api/core/model_providers/providers/xinference_provider.py index a2412220b..0f3243d80 100644 --- a/api/core/model_providers/providers/xinference_provider.py +++ b/api/core/model_providers/providers/xinference_provider.py @@ -2,7 +2,6 @@ import json from typing import Type import requests -from xinference.client import RESTfulGenerateModelHandle, RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle from core.helper import encrypter from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding @@ -73,7 +72,7 @@ class XinferenceProvider(BaseModelProvider): top_p=KwargRule[float](min=0, max=1, default=0.7), presence_penalty=KwargRule[float](enabled=False), frequency_penalty=KwargRule[float](enabled=False), - max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=4000, default=256), + max_tokens=KwargRule[int](min=10, max=4000, default=256), ) diff --git a/api/core/third_party/langchain/llms/xinference_llm.py b/api/core/third_party/langchain/llms/xinference_llm.py index 1a01ee55b..7010e56d2 100644 --- a/api/core/third_party/langchain/llms/xinference_llm.py +++ b/api/core/third_party/langchain/llms/xinference_llm.py @@ -9,11 +9,11 @@ from xinference.client import RESTfulChatglmCppChatModelHandle, \ class XinferenceLLM(Xinference): def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call the xinference model and return the output. @@ -56,10 +56,10 @@ class XinferenceLLM(Xinference): if generate_config and generate_config.get("stream"): combined_text_output = "" for token in self._stream_generate( - model=model, - prompt=prompt, - run_manager=run_manager, - generate_config=generate_config, + model=model, + prompt=prompt, + run_manager=run_manager, + generate_config=generate_config, ): combined_text_output += token return combined_text_output @@ -73,10 +73,10 @@ class XinferenceLLM(Xinference): if generate_config and generate_config.get("stream"): combined_text_output = "" for token in self._stream_generate( - model=model, - prompt=prompt, - run_manager=run_manager, - generate_config=generate_config, + model=model, + prompt=prompt, + run_manager=run_manager, + generate_config=generate_config, ): combined_text_output += token completion = combined_text_output @@ -89,13 +89,13 @@ class XinferenceLLM(Xinference): return completion - def _stream_generate( - self, - model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"], - prompt: str, - run_manager: Optional[CallbackManagerForLLMRun] = None, - generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None, + self, + model: Union["RESTfulGenerateModelHandle", "RESTfulChatModelHandle", "RESTfulChatglmCppChatModelHandle"], + prompt: str, + run_manager: Optional[CallbackManagerForLLMRun] = None, + generate_config: Optional[ + Union["LlamaCppGenerateConfig", "PytorchGenerateConfig", "ChatglmCppGenerateConfig"]] = None, ) -> Generator[str, None, None]: """ Args: @@ -123,6 +123,10 @@ class XinferenceLLM(Xinference): if choices: choice = choices[0] if isinstance(choice, dict): + if 'finish_reason' in choice and choice['finish_reason'] \ + and choice['finish_reason'] in ['stop', 'length']: + break + if 'text' in choice: token = choice.get("text", "") elif 'delta' in choice and 'content' in choice['delta']: