mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-02 19:27:48 +08:00
feat: fix azure completion choices return empty (#708)
This commit is contained in:
parent
a856ef387b
commit
e18211ffea
@ -1,5 +1,7 @@
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.callbacks.manager import Callbacks, CallbackManagerForLLMRun
|
||||
from langchain.llms import AzureOpenAI
|
||||
from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \
|
||||
update_token_usage
|
||||
from langchain.schema import LLMResult
|
||||
from typing import Optional, List, Dict, Mapping, Any, Union, Tuple
|
||||
|
||||
@ -67,3 +69,58 @@ class StreamableAzureOpenAI(AzureOpenAI):
|
||||
@classmethod
|
||||
def get_kwargs_from_model_params(cls, params: dict):
|
||||
return params
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Call out to OpenAI's endpoint with k unique prompts.
|
||||
|
||||
Args:
|
||||
prompts: The prompts to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The full LLM output.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = openai.generate(["Tell me a joke."])
|
||||
"""
|
||||
params = self._invocation_params
|
||||
params = {**params, **kwargs}
|
||||
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
||||
choices = []
|
||||
token_usage: Dict[str, int] = {}
|
||||
# Get the token usage from the response.
|
||||
# Includes prompt, completion, and total tokens used.
|
||||
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
||||
for _prompts in sub_prompts:
|
||||
if self.streaming:
|
||||
if len(_prompts) > 1:
|
||||
raise ValueError("Cannot stream results with multiple prompts.")
|
||||
params["stream"] = True
|
||||
response = _streaming_response_template()
|
||||
for stream_resp in completion_with_retry(
|
||||
self, prompt=_prompts, **params
|
||||
):
|
||||
if len(stream_resp["choices"]) > 0:
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
stream_resp["choices"][0]["text"],
|
||||
verbose=self.verbose,
|
||||
logprobs=stream_resp["choices"][0]["logprobs"],
|
||||
)
|
||||
_update_response(response, stream_resp)
|
||||
choices.extend(response["choices"])
|
||||
else:
|
||||
response = completion_with_retry(self, prompt=_prompts, **params)
|
||||
choices.extend(response["choices"])
|
||||
if not self.streaming:
|
||||
# Can't update token usage if streaming
|
||||
update_token_usage(_keys, response, token_usage)
|
||||
return self.create_llm_result(choices, prompts, token_usage)
|
Loading…
Reference in New Issue
Block a user