mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-04 12:17:52 +08:00
chore: add create_json_message api for tools (#5440)
This commit is contained in:
parent
ba67206bb9
commit
1e28a8c033
@ -32,7 +32,6 @@ from core.model_runtime.entities.model_entities import ModelFeature
|
|||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
ToolInvokeMessage,
|
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntimeVariablePool,
|
ToolRuntimeVariablePool,
|
||||||
)
|
)
|
||||||
@ -141,24 +140,6 @@ class BaseAgentRunner(AppRunner):
|
|||||||
app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
|
app_generate_entity.app_config.prompt_template.simple_prompt_template = ''
|
||||||
|
|
||||||
return app_generate_entity
|
return app_generate_entity
|
||||||
|
|
||||||
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
|
|
||||||
"""
|
|
||||||
Handle tool response
|
|
||||||
"""
|
|
||||||
result = ''
|
|
||||||
for response in tool_response:
|
|
||||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
|
||||||
result += response.message
|
|
||||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
|
||||||
result += f"result link: {response.message}. please tell user to check it."
|
|
||||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
|
||||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
|
||||||
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
|
|
||||||
else:
|
|
||||||
result += f"tool response: {response.message}."
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
|
def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:
|
||||||
"""
|
"""
|
||||||
|
@ -95,6 +95,7 @@ class ToolInvokeMessage(BaseModel):
|
|||||||
IMAGE = "image"
|
IMAGE = "image"
|
||||||
LINK = "link"
|
LINK = "link"
|
||||||
BLOB = "blob"
|
BLOB = "blob"
|
||||||
|
JSON = "json"
|
||||||
IMAGE_LINK = "image_link"
|
IMAGE_LINK = "image_link"
|
||||||
FILE_VAR = "file_var"
|
FILE_VAR = "file_var"
|
||||||
|
|
||||||
@ -102,7 +103,7 @@ class ToolInvokeMessage(BaseModel):
|
|||||||
"""
|
"""
|
||||||
plain text, image url or link url
|
plain text, image url or link url
|
||||||
"""
|
"""
|
||||||
message: Union[str, bytes] = None
|
message: Union[str, bytes, dict] = None
|
||||||
meta: dict[str, Any] = None
|
meta: dict[str, Any] = None
|
||||||
save_as: str = ''
|
save_as: str = ''
|
||||||
|
|
||||||
|
@ -8,99 +8,36 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
|||||||
SERP_API_URL = "https://serpapi.com/search"
|
SERP_API_URL = "https://serpapi.com/search"
|
||||||
|
|
||||||
|
|
||||||
class SerpAPI:
|
class GoogleSearchTool(BuiltinTool):
|
||||||
"""
|
|
||||||
SerpAPI tool provider.
|
|
||||||
"""
|
|
||||||
def __init__(self, api_key: str) -> None:
|
|
||||||
"""Initialize SerpAPI tool provider."""
|
|
||||||
self.serpapi_api_key = api_key
|
|
||||||
|
|
||||||
def run(self, query: str, **kwargs: Any) -> str:
|
def _parse_response(self, response: dict) -> dict:
|
||||||
"""Run query through SerpAPI and parse result."""
|
result = {}
|
||||||
typ = kwargs.get("result_type", "text")
|
if "knowledge_graph" in response:
|
||||||
return self._process_response(self.results(query), typ=typ)
|
result["title"] = response["knowledge_graph"].get("title", "")
|
||||||
|
result["description"] = response["knowledge_graph"].get("description", "")
|
||||||
def results(self, query: str) -> dict:
|
if "organic_results" in response:
|
||||||
"""Run query through SerpAPI and return the raw result."""
|
result["organic_results"] = [
|
||||||
params = self.get_params(query)
|
{
|
||||||
response = requests.get(url=SERP_API_URL, params=params)
|
"title": item.get("title", ""),
|
||||||
response.raise_for_status()
|
"link": item.get("link", ""),
|
||||||
return response.json()
|
"snippet": item.get("snippet", "")
|
||||||
|
}
|
||||||
def get_params(self, query: str) -> dict[str, str]:
|
for item in response["organic_results"]
|
||||||
"""Get parameters for SerpAPI."""
|
]
|
||||||
|
return result
|
||||||
|
def _invoke(self,
|
||||||
|
user_id: str,
|
||||||
|
tool_parameters: dict[str, Any],
|
||||||
|
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||||
params = {
|
params = {
|
||||||
"api_key": self.serpapi_api_key,
|
"api_key": self.runtime.credentials['serpapi_api_key'],
|
||||||
"q": query,
|
"q": tool_parameters['query'],
|
||||||
"engine": "google",
|
"engine": "google",
|
||||||
"google_domain": "google.com",
|
"google_domain": "google.com",
|
||||||
"gl": "us",
|
"gl": "us",
|
||||||
"hl": "en"
|
"hl": "en"
|
||||||
}
|
}
|
||||||
return params
|
response = requests.get(url=SERP_API_URL, params=params)
|
||||||
|
response.raise_for_status()
|
||||||
@staticmethod
|
valuable_res = self._parse_response(response.json())
|
||||||
def _process_response(res: dict, typ: str) -> str:
|
return self.create_json_message(valuable_res)
|
||||||
"""
|
|
||||||
Process response from SerpAPI.
|
|
||||||
SerpAPI doc: https://serpapi.com/search-api
|
|
||||||
Google search main results are called organic results
|
|
||||||
"""
|
|
||||||
if "error" in res:
|
|
||||||
raise ValueError(f"Got error from SerpAPI: {res['error']}")
|
|
||||||
toret = ""
|
|
||||||
if typ == "text":
|
|
||||||
if "knowledge_graph" in res and "description" in res["knowledge_graph"]:
|
|
||||||
toret += res["knowledge_graph"]["description"] + "\n"
|
|
||||||
if "organic_results" in res:
|
|
||||||
snippets = [
|
|
||||||
f"content: {item.get('snippet')}\nlink: {item.get('link')}"
|
|
||||||
for item in res["organic_results"]
|
|
||||||
if "snippet" in item
|
|
||||||
]
|
|
||||||
toret += "\n".join(snippets)
|
|
||||||
elif typ == "link":
|
|
||||||
if "knowledge_graph" in res and "source" in res["knowledge_graph"]:
|
|
||||||
toret += res["knowledge_graph"]["source"]["link"]
|
|
||||||
elif "organic_results" in res:
|
|
||||||
links = [
|
|
||||||
f"[{item['title']}]({item['link']})\n"
|
|
||||||
for item in res["organic_results"]
|
|
||||||
if "title" in item and "link" in item
|
|
||||||
]
|
|
||||||
toret += "\n".join(links)
|
|
||||||
elif "related_questions" in res:
|
|
||||||
questions = [
|
|
||||||
f"[{item['question']}]({item['link']})\n"
|
|
||||||
for item in res["related_questions"]
|
|
||||||
if "question" in item and "link" in item
|
|
||||||
]
|
|
||||||
toret += "\n".join(questions)
|
|
||||||
elif "related_searches" in res:
|
|
||||||
searches = [
|
|
||||||
f"[{item['query']}]({item['link']})\n"
|
|
||||||
for item in res["related_searches"]
|
|
||||||
if "query" in item and "link" in item
|
|
||||||
]
|
|
||||||
toret += "\n".join(searches)
|
|
||||||
if not toret:
|
|
||||||
toret = "No good search result found"
|
|
||||||
return toret
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleSearchTool(BuiltinTool):
|
|
||||||
def _invoke(self,
|
|
||||||
user_id: str,
|
|
||||||
tool_parameters: dict[str, Any],
|
|
||||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
|
||||||
"""
|
|
||||||
invoke tools
|
|
||||||
"""
|
|
||||||
query = tool_parameters['query']
|
|
||||||
result_type = tool_parameters['result_type']
|
|
||||||
api_key = self.runtime.credentials['serpapi_api_key']
|
|
||||||
result = SerpAPI(api_key).run(query, result_type=result_type)
|
|
||||||
if result_type == 'text':
|
|
||||||
return self.create_text_message(text=result)
|
|
||||||
return self.create_link_message(link=result)
|
|
||||||
|
@ -25,27 +25,3 @@ parameters:
|
|||||||
pt_BR: used for searching
|
pt_BR: used for searching
|
||||||
llm_description: key words for searching
|
llm_description: key words for searching
|
||||||
form: llm
|
form: llm
|
||||||
- name: result_type
|
|
||||||
type: select
|
|
||||||
required: true
|
|
||||||
options:
|
|
||||||
- value: text
|
|
||||||
label:
|
|
||||||
en_US: text
|
|
||||||
zh_Hans: 文本
|
|
||||||
pt_BR: texto
|
|
||||||
- value: link
|
|
||||||
label:
|
|
||||||
en_US: link
|
|
||||||
zh_Hans: 链接
|
|
||||||
pt_BR: link
|
|
||||||
default: link
|
|
||||||
label:
|
|
||||||
en_US: Result type
|
|
||||||
zh_Hans: 结果类型
|
|
||||||
pt_BR: Result type
|
|
||||||
human_description:
|
|
||||||
en_US: used for selecting the result type, text or link
|
|
||||||
zh_Hans: 用于选择结果类型,使用文本还是链接进行展示
|
|
||||||
pt_BR: used for selecting the result type, text or link
|
|
||||||
form: form
|
|
||||||
|
@ -207,30 +207,7 @@ class Tool(BaseModel, ABC):
|
|||||||
result = [result]
|
result = [result]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
|
|
||||||
"""
|
|
||||||
Handle tool response
|
|
||||||
"""
|
|
||||||
result = ''
|
|
||||||
for response in tool_response:
|
|
||||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
|
||||||
result += response.message
|
|
||||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
|
||||||
result += f"result link: {response.message}. please tell user to check it. \n"
|
|
||||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
|
||||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
|
||||||
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now. \n"
|
|
||||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
|
||||||
if len(response.message) > 114:
|
|
||||||
result += str(response.message[:114]) + '...'
|
|
||||||
else:
|
|
||||||
result += str(response.message)
|
|
||||||
else:
|
|
||||||
result += f"tool response: {response.message}. \n"
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
|
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Transform tool parameters type
|
Transform tool parameters type
|
||||||
@ -355,3 +332,12 @@ class Tool(BaseModel, ABC):
|
|||||||
message=blob, meta=meta,
|
message=blob, meta=meta,
|
||||||
save_as=save_as
|
save_as=save_as
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_json_message(self, object: dict) -> ToolInvokeMessage:
|
||||||
|
"""
|
||||||
|
create a json message
|
||||||
|
"""
|
||||||
|
return ToolInvokeMessage(
|
||||||
|
type=ToolInvokeMessage.MessageType.JSON,
|
||||||
|
message=object
|
||||||
|
)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from mimetypes import guess_type
|
from mimetypes import guess_type
|
||||||
@ -188,6 +189,8 @@ class ToolEngine:
|
|||||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||||
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
|
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
|
||||||
|
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||||
|
result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}."
|
||||||
else:
|
else:
|
||||||
result += f"tool response: {response.message}."
|
result += f"tool response: {response.message}."
|
||||||
|
|
||||||
|
@ -74,13 +74,14 @@ class ToolNode(BaseNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# convert tool messages
|
# convert tool messages
|
||||||
plain_text, files = self._convert_tool_messages(messages)
|
plain_text, files, json = self._convert_tool_messages(messages)
|
||||||
|
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
outputs={
|
outputs={
|
||||||
'text': plain_text,
|
'text': plain_text,
|
||||||
'files': files
|
'files': files,
|
||||||
|
'json': json
|
||||||
},
|
},
|
||||||
metadata={
|
metadata={
|
||||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||||
@ -149,8 +150,9 @@ class ToolNode(BaseNode):
|
|||||||
# extract plain text and files
|
# extract plain text and files
|
||||||
files = self._extract_tool_response_binary(messages)
|
files = self._extract_tool_response_binary(messages)
|
||||||
plain_text = self._extract_tool_response_text(messages)
|
plain_text = self._extract_tool_response_text(messages)
|
||||||
|
json = self._extract_tool_response_json(messages)
|
||||||
|
|
||||||
return plain_text, files
|
return plain_text, files, json
|
||||||
|
|
||||||
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]:
|
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]:
|
||||||
"""
|
"""
|
||||||
@ -203,7 +205,9 @@ class ToolNode(BaseNode):
|
|||||||
f'Link: {message.message}' if message.type == ToolInvokeMessage.MessageType.LINK else ''
|
f'Link: {message.message}' if message.type == ToolInvokeMessage.MessageType.LINK else ''
|
||||||
for message in tool_response
|
for message in tool_response
|
||||||
])
|
])
|
||||||
|
|
||||||
|
def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]:
|
||||||
|
return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
|
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
|
||||||
|
Loading…
Reference in New Issue
Block a user