mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-12-04 04:07:47 +08:00
feat: add tool labels (#2178)
This commit is contained in:
parent
0940084fd2
commit
7cb75cb2e7
@ -44,6 +44,7 @@ class MessageListApi(AppApiResource):
|
|||||||
'position': fields.Integer,
|
'position': fields.Integer,
|
||||||
'thought': fields.String,
|
'thought': fields.String,
|
||||||
'tool': fields.String,
|
'tool': fields.String,
|
||||||
|
'tool_labels': fields.Raw,
|
||||||
'tool_input': fields.String,
|
'tool_input': fields.String,
|
||||||
'created_at': TimestampField,
|
'created_at': TimestampField,
|
||||||
'observation': fields.String,
|
'observation': fields.String,
|
||||||
|
@ -18,6 +18,7 @@ from core.model_runtime.entities.message_entities import (AssistantPromptMessage
|
|||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.prompt.prompt_template import PromptTemplateParser
|
from core.prompt.prompt_template import PromptTemplateParser
|
||||||
from events.message_event import message_was_created
|
from events.message_event import message_was_created
|
||||||
@ -281,7 +282,7 @@ class GenerateTaskPipeline:
|
|||||||
|
|
||||||
self._task_state.llm_result.message.content = annotation.content
|
self._task_state.llm_result.message.content = annotation.content
|
||||||
elif isinstance(event, QueueAgentThoughtEvent):
|
elif isinstance(event, QueueAgentThoughtEvent):
|
||||||
agent_thought = (
|
agent_thought: MessageAgentThought = (
|
||||||
db.session.query(MessageAgentThought)
|
db.session.query(MessageAgentThought)
|
||||||
.filter(MessageAgentThought.id == event.agent_thought_id)
|
.filter(MessageAgentThought.id == event.agent_thought_id)
|
||||||
.first()
|
.first()
|
||||||
@ -298,6 +299,7 @@ class GenerateTaskPipeline:
|
|||||||
'thought': agent_thought.thought,
|
'thought': agent_thought.thought,
|
||||||
'observation': agent_thought.observation,
|
'observation': agent_thought.observation,
|
||||||
'tool': agent_thought.tool,
|
'tool': agent_thought.tool,
|
||||||
|
'tool_labels': agent_thought.tool_labels,
|
||||||
'tool_input': agent_thought.tool_input,
|
'tool_input': agent_thought.tool_input,
|
||||||
'created_at': int(self._message.created_at.timestamp()),
|
'created_at': int(self._message.created_at.timestamp()),
|
||||||
'message_files': agent_thought.files
|
'message_files': agent_thought.files
|
||||||
|
@ -396,6 +396,7 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|||||||
message_chain_id=None,
|
message_chain_id=None,
|
||||||
thought='',
|
thought='',
|
||||||
tool=tool_name,
|
tool=tool_name,
|
||||||
|
tool_labels_str='{}',
|
||||||
tool_input=tool_input,
|
tool_input=tool_input,
|
||||||
message=message,
|
message=message,
|
||||||
message_token=0,
|
message_token=0,
|
||||||
@ -469,6 +470,21 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|||||||
agent_thought.tokens = llm_usage.total_tokens
|
agent_thought.tokens = llm_usage.total_tokens
|
||||||
agent_thought.total_price = llm_usage.total_price
|
agent_thought.total_price = llm_usage.total_price
|
||||||
|
|
||||||
|
# check if tool labels is not empty
|
||||||
|
labels = agent_thought.tool_labels or {}
|
||||||
|
tools = agent_thought.tool.split(';') if agent_thought.tool else []
|
||||||
|
for tool in tools:
|
||||||
|
if not tool:
|
||||||
|
continue
|
||||||
|
if tool not in labels:
|
||||||
|
tool_label = ToolManager.get_tool_label(tool)
|
||||||
|
if tool_label:
|
||||||
|
labels[tool] = tool_label.to_dict()
|
||||||
|
else:
|
||||||
|
labels[tool] = {'en_US': tool, 'zh_Hans': tool}
|
||||||
|
|
||||||
|
agent_thought.tool_labels_str = json.dumps(labels)
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
def get_history_prompt_messages(self) -> List[PromptMessage]:
|
def get_history_prompt_messages(self) -> List[PromptMessage]:
|
||||||
|
@ -31,6 +31,7 @@ import mimetypes
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_builtin_providers = {}
|
_builtin_providers = {}
|
||||||
|
_builtin_tools_labels = {}
|
||||||
|
|
||||||
class ToolManager:
|
class ToolManager:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -233,7 +234,7 @@ class ToolManager:
|
|||||||
if len(_builtin_providers) > 0:
|
if len(_builtin_providers) > 0:
|
||||||
return list(_builtin_providers.values())
|
return list(_builtin_providers.values())
|
||||||
|
|
||||||
builtin_providers = []
|
builtin_providers: List[BuiltinToolProviderController] = []
|
||||||
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
|
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
|
||||||
if provider.startswith('__'):
|
if provider.startswith('__'):
|
||||||
continue
|
continue
|
||||||
@ -264,8 +265,30 @@ class ToolManager:
|
|||||||
# cache the builtin providers
|
# cache the builtin providers
|
||||||
for provider in builtin_providers:
|
for provider in builtin_providers:
|
||||||
_builtin_providers[provider.identity.name] = provider
|
_builtin_providers[provider.identity.name] = provider
|
||||||
|
for tool in provider.get_tools():
|
||||||
|
_builtin_tools_labels[tool.identity.name] = tool.identity.label
|
||||||
|
|
||||||
return builtin_providers
|
return builtin_providers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tool_label(tool_name: str) -> Union[I18nObject, None]:
|
||||||
|
"""
|
||||||
|
get the tool label
|
||||||
|
|
||||||
|
:param tool_name: the name of the tool
|
||||||
|
|
||||||
|
:return: the label of the tool
|
||||||
|
"""
|
||||||
|
global _builtin_tools_labels
|
||||||
|
if len(_builtin_tools_labels) == 0:
|
||||||
|
# init the builtin providers
|
||||||
|
ToolManager.list_builtin_providers()
|
||||||
|
|
||||||
|
if tool_name not in _builtin_tools_labels:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return _builtin_tools_labels[tool_name]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def user_list_providers(
|
def user_list_providers(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
@ -49,10 +49,11 @@ agent_thought_fields = {
|
|||||||
'position': fields.Integer,
|
'position': fields.Integer,
|
||||||
'thought': fields.String,
|
'thought': fields.String,
|
||||||
'tool': fields.String,
|
'tool': fields.String,
|
||||||
|
'tool_labels': fields.Raw,
|
||||||
'tool_input': fields.String,
|
'tool_input': fields.String,
|
||||||
'created_at': TimestampField,
|
'created_at': TimestampField,
|
||||||
'observation': fields.String,
|
'observation': fields.String,
|
||||||
'files': fields.List(fields.String)
|
'files': fields.List(fields.String),
|
||||||
}
|
}
|
||||||
|
|
||||||
message_detail_fields = {
|
message_detail_fields = {
|
||||||
|
@ -36,6 +36,7 @@ agent_thought_fields = {
|
|||||||
'position': fields.Integer,
|
'position': fields.Integer,
|
||||||
'thought': fields.String,
|
'thought': fields.String,
|
||||||
'tool': fields.String,
|
'tool': fields.String,
|
||||||
|
'tool_labels': fields.Raw,
|
||||||
'tool_input': fields.String,
|
'tool_input': fields.String,
|
||||||
'created_at': TimestampField,
|
'created_at': TimestampField,
|
||||||
'observation': fields.String,
|
'observation': fields.String,
|
||||||
|
@ -0,0 +1,32 @@
|
|||||||
|
"""add tool labels to agent thought
|
||||||
|
|
||||||
|
Revision ID: 380c6aa5a70d
|
||||||
|
Revises: dfb3b7f477da
|
||||||
|
Create Date: 2024-01-24 10:58:15.644445
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '380c6aa5a70d'
|
||||||
|
down_revision = 'dfb3b7f477da'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('tool_labels_str', sa.Text(), server_default=sa.text("'{}'::text"), nullable=False))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('tool_labels_str')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
@ -1003,6 +1003,7 @@ class MessageAgentThought(db.Model):
|
|||||||
position = db.Column(db.Integer, nullable=False)
|
position = db.Column(db.Integer, nullable=False)
|
||||||
thought = db.Column(db.Text, nullable=True)
|
thought = db.Column(db.Text, nullable=True)
|
||||||
tool = db.Column(db.Text, nullable=True)
|
tool = db.Column(db.Text, nullable=True)
|
||||||
|
tool_labels_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
|
||||||
tool_input = db.Column(db.Text, nullable=True)
|
tool_input = db.Column(db.Text, nullable=True)
|
||||||
observation = db.Column(db.Text, nullable=True)
|
observation = db.Column(db.Text, nullable=True)
|
||||||
# plugin_id = db.Column(UUID, nullable=True) ## for future design
|
# plugin_id = db.Column(UUID, nullable=True) ## for future design
|
||||||
@ -1031,6 +1032,16 @@ class MessageAgentThought(db.Model):
|
|||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tool_labels(self) -> dict:
|
||||||
|
try:
|
||||||
|
if self.tool_labels_str:
|
||||||
|
return json.loads(self.tool_labels_str)
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
except Exception as e:
|
||||||
|
return {}
|
||||||
|
|
||||||
class DatasetRetrieverResource(db.Model):
|
class DatasetRetrieverResource(db.Model):
|
||||||
__tablename__ = 'dataset_retriever_resources'
|
__tablename__ = 'dataset_retriever_resources'
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
|
Loading…
Reference in New Issue
Block a user