feat: add tool labels (#2178)

This commit is contained in:
Yeuoly 2024-01-24 20:14:45 +08:00 committed by GitHub
parent 0940084fd2
commit 7cb75cb2e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 90 additions and 3 deletions

View File

@ -44,6 +44,7 @@ class MessageListApi(AppApiResource):
'position': fields.Integer,
'thought': fields.String,
'tool': fields.String,
'tool_labels': fields.Raw,
'tool_input': fields.String,
'created_at': TimestampField,
'observation': fields.String,

View File

@ -18,6 +18,7 @@ from core.model_runtime.entities.message_entities import (AssistantPromptMessage
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool_file_manager import ToolFileManager
from core.tools.tool_manager import ToolManager
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.prompt_template import PromptTemplateParser
from events.message_event import message_was_created
@ -281,7 +282,7 @@ class GenerateTaskPipeline:
self._task_state.llm_result.message.content = annotation.content
elif isinstance(event, QueueAgentThoughtEvent):
agent_thought = (
agent_thought: MessageAgentThought = (
db.session.query(MessageAgentThought)
.filter(MessageAgentThought.id == event.agent_thought_id)
.first()
@ -298,6 +299,7 @@ class GenerateTaskPipeline:
'thought': agent_thought.thought,
'observation': agent_thought.observation,
'tool': agent_thought.tool,
'tool_labels': agent_thought.tool_labels,
'tool_input': agent_thought.tool_input,
'created_at': int(self._message.created_at.timestamp()),
'message_files': agent_thought.files

View File

@ -396,6 +396,7 @@ class BaseAssistantApplicationRunner(AppRunner):
message_chain_id=None,
thought='',
tool=tool_name,
tool_labels_str='{}',
tool_input=tool_input,
message=message,
message_token=0,
@ -469,6 +470,21 @@ class BaseAssistantApplicationRunner(AppRunner):
agent_thought.tokens = llm_usage.total_tokens
agent_thought.total_price = llm_usage.total_price
# check if tool labels is not empty
labels = agent_thought.tool_labels or {}
tools = agent_thought.tool.split(';') if agent_thought.tool else []
for tool in tools:
if not tool:
continue
if tool not in labels:
tool_label = ToolManager.get_tool_label(tool)
if tool_label:
labels[tool] = tool_label.to_dict()
else:
labels[tool] = {'en_US': tool, 'zh_Hans': tool}
agent_thought.tool_labels_str = json.dumps(labels)
db.session.commit()
def get_history_prompt_messages(self) -> List[PromptMessage]:

View File

@ -31,6 +31,7 @@ import mimetypes
logger = logging.getLogger(__name__)
_builtin_providers = {}
_builtin_tools_labels = {}
class ToolManager:
@staticmethod
@ -233,7 +234,7 @@ class ToolManager:
if len(_builtin_providers) > 0:
return list(_builtin_providers.values())
builtin_providers = []
builtin_providers: List[BuiltinToolProviderController] = []
for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
if provider.startswith('__'):
continue
@ -264,8 +265,30 @@ class ToolManager:
# cache the builtin providers
for provider in builtin_providers:
_builtin_providers[provider.identity.name] = provider
for tool in provider.get_tools():
_builtin_tools_labels[tool.identity.name] = tool.identity.label
return builtin_providers
@staticmethod
def get_tool_label(tool_name: str) -> Union[I18nObject, None]:
"""
get the tool label
:param tool_name: the name of the tool
:return: the label of the tool
"""
global _builtin_tools_labels
if len(_builtin_tools_labels) == 0:
# init the builtin providers
ToolManager.list_builtin_providers()
if tool_name not in _builtin_tools_labels:
return None
return _builtin_tools_labels[tool_name]
@staticmethod
def user_list_providers(
user_id: str,

View File

@ -49,10 +49,11 @@ agent_thought_fields = {
'position': fields.Integer,
'thought': fields.String,
'tool': fields.String,
'tool_labels': fields.Raw,
'tool_input': fields.String,
'created_at': TimestampField,
'observation': fields.String,
'files': fields.List(fields.String)
'files': fields.List(fields.String),
}
message_detail_fields = {

View File

@ -36,6 +36,7 @@ agent_thought_fields = {
'position': fields.Integer,
'thought': fields.String,
'tool': fields.String,
'tool_labels': fields.Raw,
'tool_input': fields.String,
'created_at': TimestampField,
'observation': fields.String,

View File

@ -0,0 +1,32 @@
"""add tool labels to agent thought
Revision ID: 380c6aa5a70d
Revises: dfb3b7f477da
Create Date: 2024-01-24 10:58:15.644445
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '380c6aa5a70d'
down_revision = 'dfb3b7f477da'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.add_column(sa.Column('tool_labels_str', sa.Text(), server_default=sa.text("'{}'::text"), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('message_agent_thoughts', schema=None) as batch_op:
batch_op.drop_column('tool_labels_str')
# ### end Alembic commands ###

View File

@ -1003,6 +1003,7 @@ class MessageAgentThought(db.Model):
position = db.Column(db.Integer, nullable=False)
thought = db.Column(db.Text, nullable=True)
tool = db.Column(db.Text, nullable=True)
tool_labels_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
tool_input = db.Column(db.Text, nullable=True)
observation = db.Column(db.Text, nullable=True)
# plugin_id = db.Column(UUID, nullable=True) ## for future design
@ -1030,6 +1031,16 @@ class MessageAgentThought(db.Model):
return json.loads(self.message_files)
else:
return []
@property
def tool_labels(self) -> dict:
try:
if self.tool_labels_str:
return json.loads(self.tool_labels_str)
else:
return {}
except Exception as e:
return {}
class DatasetRetrieverResource(db.Model):
__tablename__ = 'dataset_retriever_resources'