From 971defbbbd71cf1f63344619044069b37a87ec75 Mon Sep 17 00:00:00 2001 From: guogeer <1500065870@qq.com> Date: Mon, 4 Nov 2024 18:46:39 +0800 Subject: [PATCH] fix: buitin tool aippt (#10234) Co-authored-by: jinqi.guo --- .../provider/builtin/aippt/tools/aippt.py | 78 ++++++++++++------- api/core/workflow/nodes/tool/tool_node.py | 2 +- 2 files changed, 50 insertions(+), 30 deletions(-) diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index dd9371f70..38123f125 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -4,7 +4,7 @@ from hmac import new as hmac_new from json import loads as json_loads from threading import Lock from time import sleep, time -from typing import Any, Optional +from typing import Any from httpx import get, post from requests import get as requests_get @@ -15,27 +15,27 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, from core.tools.tool.builtin_tool import BuiltinTool -class AIPPTGenerateTool(BuiltinTool): +class AIPPTGenerateToolAdapter: """ A tool for generating a ppt """ _api_base_url = URL("https://co.aippt.cn/api") _api_token_cache = {} - _api_token_cache_lock: Optional[Lock] = None _style_cache = {} - _style_cache_lock: Optional[Lock] = None + + _api_token_cache_lock = Lock() + _style_cache_lock = Lock() _task = {} _task_type_map = { "auto": 1, "markdown": 7, } + _tool: BuiltinTool - def __init__(self, **kwargs: Any): - super().__init__(**kwargs) - self._api_token_cache_lock = Lock() - self._style_cache_lock = Lock() + def __init__(self, tool: BuiltinTool = None): + self._tool = tool def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ @@ -51,11 +51,11 @@ class AIPPTGenerateTool(BuiltinTool): """ title = tool_parameters.get("title", "") if not title: - return self.create_text_message("Please provide a title for the ppt") + return self._tool.create_text_message("Please provide a title for the ppt") model = tool_parameters.get("model", "aippt") if not model: - return self.create_text_message("Please provide a model for the ppt") + return self._tool.create_text_message("Please provide a model for the ppt") outline = tool_parameters.get("outline", "") @@ -68,8 +68,8 @@ class AIPPTGenerateTool(BuiltinTool): ) # get suit - color = tool_parameters.get("color") - style = tool_parameters.get("style") + color: str = tool_parameters.get("color") + style: str = tool_parameters.get("style") if color == "__default__": color_id = "" @@ -93,9 +93,9 @@ class AIPPTGenerateTool(BuiltinTool): # generate ppt _, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id) - return self.create_text_message( + return self._tool.create_text_message( """the ppt has been created successfully,""" - f"""the ppt url is {ppt_url}""" + f"""the ppt url is {ppt_url} .""" """please give the ppt url to user and direct user to download it.""" ) @@ -111,8 +111,8 @@ class AIPPTGenerateTool(BuiltinTool): """ headers = { "x-channel": "", - "x-api-key": self.runtime.credentials["aippt_access_key"], - "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id), } response = post( str(self._api_base_url / "ai" / "chat" / "v2" / "task"), @@ -139,8 +139,8 @@ class AIPPTGenerateTool(BuiltinTool): headers = { "x-channel": "", - "x-api-key": self.runtime.credentials["aippt_access_key"], - "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id), } response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60)) @@ -183,8 +183,8 @@ class AIPPTGenerateTool(BuiltinTool): headers = { "x-channel": "", - "x-api-key": self.runtime.credentials["aippt_access_key"], - "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id), } response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60)) @@ -236,14 +236,15 @@ class AIPPTGenerateTool(BuiltinTool): """ headers = { "x-channel": "", - "x-api-key": self.runtime.credentials["aippt_access_key"], - "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id), } response = post( str(self._api_base_url / "design" / "v2" / "save"), headers=headers, data={"task_id": task_id, "template_id": suit_id}, + timeout=(10, 60), ) if response.status_code != 200: @@ -350,11 +351,13 @@ class AIPPTGenerateTool(BuiltinTool): return token - @classmethod - def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str: + @staticmethod + def _calculate_sign(access_key: str, secret_key: str, timestamp: int) -> str: return b64encode( hmac_new( - key=secret_key.encode("utf-8"), msg=f"GET@/api/grant/token/@{timestamp}".encode(), digestmod=sha1 + key=secret_key.encode("utf-8"), + msg=f"GET@/api/grant/token/@{timestamp}".encode(), + digestmod=sha1, ).digest() ).decode("utf-8") @@ -419,10 +422,12 @@ class AIPPTGenerateTool(BuiltinTool): :param credentials: the credentials :return: Tuple[list[dict[id, color]], list[dict[id, style]] """ - if not self.runtime.credentials.get("aippt_access_key") or not self.runtime.credentials.get("aippt_secret_key"): + if not self._tool.runtime.credentials.get("aippt_access_key") or not self._tool.runtime.credentials.get( + "aippt_secret_key" + ): raise Exception("Please provide aippt credentials") - return self._get_styles(credentials=self.runtime.credentials, user_id=user_id) + return self._get_styles(credentials=self._tool.runtime.credentials, user_id=user_id) def _get_suit(self, style_id: int, colour_id: int) -> int: """ @@ -430,8 +435,8 @@ class AIPPTGenerateTool(BuiltinTool): """ headers = { "x-channel": "", - "x-api-key": self.runtime.credentials["aippt_access_key"], - "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id="__dify_system__"), + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id="__dify_system__"), } response = get( str(self._api_base_url / "template_component" / "suit" / "search"), @@ -496,3 +501,18 @@ class AIPPTGenerateTool(BuiltinTool): ], ), ] + + +class AIPPTGenerateTool(BuiltinTool): + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters) + + def get_runtime_parameters(self) -> list[ToolParameter]: + return AIPPTGenerateToolAdapter(self).get_runtime_parameters() + + @classmethod + def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str: + return AIPPTGenerateToolAdapter()._get_api_token(credentials, user_id) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index df22130d6..0994ccaed 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -53,7 +53,7 @@ class ToolNode(BaseNode[ToolNodeData]): ) # get parameters - tool_parameters = tool_runtime.get_runtime_parameters() or [] + tool_parameters = tool_runtime.parameters or [] parameters = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool,