mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-11-30 02:08:37 +08:00
feat: support basic feature of OpenAI new models (#1476)
This commit is contained in:
parent
7b26c9e2ef
commit
d7ae86799c
@ -8,3 +8,4 @@ class ProviderQuotaUnit(Enum):
|
||||
|
||||
class ModelFeature(Enum):
|
||||
AGENT_THOUGHT = 'agent_thought'
|
||||
VISION = 'vision'
|
||||
|
@ -19,6 +19,13 @@ from core.model_providers.models.entity.model_params import ModelMode, ModelKwar
|
||||
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
|
||||
|
||||
|
||||
FUNCTION_CALL_MODELS = [
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-35-turbo',
|
||||
'gpt-35-turbo-16k'
|
||||
]
|
||||
|
||||
class AzureOpenAIModel(BaseLLM):
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
name: str,
|
||||
@ -157,3 +164,7 @@ class AzureOpenAIModel(BaseLLM):
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def support_function_call(self):
|
||||
return self.base_model_name in FUNCTION_CALL_MODELS
|
||||
|
@ -310,6 +310,10 @@ class BaseLLM(BaseProviderModel):
|
||||
def support_streaming(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def support_function_call(self):
|
||||
return False
|
||||
|
||||
def _get_prompt_from_messages(self, messages: List[PromptMessage],
|
||||
model_mode: Optional[ModelMode] = None) -> Union[str , List[BaseMessage]]:
|
||||
if not model_mode:
|
||||
|
@ -23,21 +23,36 @@ COMPLETION_MODELS = [
|
||||
]
|
||||
|
||||
CHAT_MODELS = [
|
||||
'gpt-4-1106-preview', # 128,000 tokens
|
||||
'gpt-4-vision-preview', # 128,000 tokens
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo-1106', # 16,384 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k', # 16,384 tokens
|
||||
]
|
||||
|
||||
MODEL_MAX_TOKENS = {
|
||||
'gpt-4-1106-preview': 128000,
|
||||
'gpt-4-vision-preview': 128000,
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo-1106': 16384,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-instruct': 4097,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
}
|
||||
|
||||
FUNCTION_CALL_MODELS = [
|
||||
'gpt-4-1106-preview',
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-3.5-turbo-1106',
|
||||
'gpt-3.5-turbo',
|
||||
'gpt-3.5-turbo-16k'
|
||||
]
|
||||
|
||||
|
||||
class OpenAIModel(BaseLLM):
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
@ -50,7 +65,6 @@ class OpenAIModel(BaseLLM):
|
||||
else:
|
||||
self.model_mode = ModelMode.CHAT
|
||||
|
||||
# TODO load price config from configs(db)
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
@ -100,7 +114,7 @@ class OpenAIModel(BaseLLM):
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
if self.name == 'gpt-4' \
|
||||
if self.name.startswith('gpt-4') \
|
||||
and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \
|
||||
and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value:
|
||||
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
|
||||
@ -175,6 +189,10 @@ class OpenAIModel(BaseLLM):
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def support_function_call(self):
|
||||
return self.name in FUNCTION_CALL_MODELS
|
||||
|
||||
# def is_model_valid_or_raise(self):
|
||||
# """
|
||||
# check is a valid model.
|
||||
|
@ -41,9 +41,17 @@ class OpenAIProvider(BaseModelProvider):
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-3.5-turbo-1106',
|
||||
'name': 'gpt-3.5-turbo-1106',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-3.5-turbo-instruct',
|
||||
'name': 'GPT-3.5-Turbo-Instruct',
|
||||
'name': 'gpt-3.5-turbo-instruct',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
},
|
||||
{
|
||||
@ -62,6 +70,22 @@ class OpenAIProvider(BaseModelProvider):
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4-1106-preview',
|
||||
'name': 'gpt-4-1106-preview',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4-vision-preview',
|
||||
'name': 'gpt-4-vision-preview',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.VISION.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4-32k',
|
||||
'name': 'gpt-4-32k',
|
||||
@ -79,7 +103,7 @@ class OpenAIProvider(BaseModelProvider):
|
||||
|
||||
if self.provider.provider_type == ProviderType.SYSTEM.value \
|
||||
and self.provider.quota_type == ProviderQuotaType.TRIAL.value:
|
||||
models = [item for item in models if item['id'] not in ['gpt-4', 'gpt-4-32k']]
|
||||
models = [item for item in models if not item['id'].startswith('gpt-4')]
|
||||
|
||||
return models
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
@ -141,8 +165,11 @@ class OpenAIProvider(BaseModelProvider):
|
||||
:return:
|
||||
"""
|
||||
model_max_tokens = {
|
||||
'gpt-4-1106-preview': 128000,
|
||||
'gpt-4-vision-preview': 128000,
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo-1106': 16384,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-instruct': 4097,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
|
@ -24,12 +24,30 @@
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-4-1106-preview": {
|
||||
"prompt": "0.01",
|
||||
"completion": "0.03",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-4-vision-preview": {
|
||||
"prompt": "0.01",
|
||||
"completion": "0.03",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo": {
|
||||
"prompt": "0.0015",
|
||||
"completion": "0.002",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo-1106": {
|
||||
"prompt": "0.0010",
|
||||
"completion": "0.002",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo-instruct": {
|
||||
"prompt": "0.0015",
|
||||
"completion": "0.002",
|
||||
|
@ -73,8 +73,7 @@ class OrchestratorRuleParser:
|
||||
planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
|
||||
|
||||
# only OpenAI chat model (include Azure) support function call, use ReACT instead
|
||||
if agent_model_instance.model_mode != ModelMode.CHAT \
|
||||
or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
|
||||
if not agent_model_instance.support_function_call:
|
||||
if planning_strategy == PlanningStrategy.FUNCTION_CALL:
|
||||
planning_strategy = PlanningStrategy.REACT
|
||||
elif planning_strategy == PlanningStrategy.ROUTER:
|
||||
|
Loading…
Reference in New Issue
Block a user