mirror of
https://gitee.com/agents-flex/agents-flex.git
synced 2024-12-02 03:48:11 +08:00
refactor: optimize Agent and Chain
This commit is contained in:
parent
6acf994d8a
commit
0082ffb21c
@ -21,18 +21,16 @@ import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* input and output agent
|
||||
*/
|
||||
public abstract class IOAgent extends Agent {
|
||||
public IOAgent() {
|
||||
|
||||
public abstract class DefaultAgent extends Agent {
|
||||
public DefaultAgent() {
|
||||
}
|
||||
|
||||
public IOAgent(Object id) {
|
||||
public DefaultAgent(Object id) {
|
||||
super(id);
|
||||
}
|
||||
|
||||
public IOAgent(Object id, String name) {
|
||||
public DefaultAgent(Object id, String name) {
|
||||
super(id, name);
|
||||
}
|
||||
|
||||
@ -58,7 +56,7 @@ public abstract class IOAgent extends Agent {
|
||||
return Output.ofValue(execute(value, chain));
|
||||
}
|
||||
|
||||
public abstract Object execute(Object param, Chain chain);
|
||||
public abstract Object execute(Object parameter, Chain chain);
|
||||
|
||||
|
||||
}
|
@ -97,7 +97,7 @@ public class LLMAgent extends Agent {
|
||||
|
||||
@Override
|
||||
public Output execute(Map<String, Object> variables, Chain chain) {
|
||||
SimplePrompt simplePrompt = promptTemplate.format(chain == null ? variables : chain.getMemory().getAll());
|
||||
SimplePrompt simplePrompt = promptTemplate.format(variables);
|
||||
AiMessageResponse response = llm.chat(simplePrompt, chatOptions);
|
||||
|
||||
if (chain != null) {
|
||||
|
@ -16,6 +16,7 @@
|
||||
package com.agentsflex.chain;
|
||||
|
||||
import com.agentsflex.agent.Agent;
|
||||
import com.agentsflex.agent.Output;
|
||||
import com.agentsflex.agent.Parameter;
|
||||
import com.agentsflex.chain.event.OnErrorEvent;
|
||||
import com.agentsflex.chain.event.OnFinishedEvent;
|
||||
@ -40,7 +41,7 @@ public abstract class Chain implements Serializable {
|
||||
private List<Chain> children;
|
||||
private ChainStatus status = ChainStatus.READY;
|
||||
|
||||
private String errorMessage;
|
||||
private String message;
|
||||
|
||||
//理论上是线程安全的,所有有多线程写入的情况,但是只有全部写入完成后才会去通知监听器
|
||||
private List<Parameter> waitInputParameters = new ArrayList<>();
|
||||
@ -230,6 +231,23 @@ public abstract class Chain implements Serializable {
|
||||
return null;
|
||||
}
|
||||
|
||||
public void execute(Object variable) {
|
||||
Map<String, Object> variables = new HashMap<>(1);
|
||||
variables.put(Output.DEFAULT_VALUE_KEY, variable);
|
||||
this.execute(variables);
|
||||
}
|
||||
|
||||
|
||||
public <T> T executeForResult(Object variable) {
|
||||
Map<String, Object> variables = new HashMap<>(1);
|
||||
variables.put(Output.DEFAULT_VALUE_KEY, variable);
|
||||
this.execute(variables);
|
||||
|
||||
//noinspection unchecked
|
||||
return (T) this.getMemory().get(Output.DEFAULT_VALUE_KEY);
|
||||
}
|
||||
|
||||
|
||||
public void execute(Map<String, Object> variables) {
|
||||
runInLifeCycle(variables, this::executeInternal);
|
||||
}
|
||||
@ -306,18 +324,19 @@ public abstract class Chain implements Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
public void stopNormal() {
|
||||
public void stopNormal(String message) {
|
||||
this.message = message;
|
||||
setStatus(ChainStatus.FINISHED_NORMAL);
|
||||
if (parent != null) {
|
||||
parent.stopNormal();
|
||||
parent.stopNormal(message);
|
||||
}
|
||||
}
|
||||
|
||||
public void stopError(String errorMessage) {
|
||||
this.errorMessage = errorMessage;
|
||||
public void stopError(String message) {
|
||||
this.message = message;
|
||||
setStatus(ChainStatus.FINISHED_ABNORMAL);
|
||||
if (parent != null) {
|
||||
parent.stopError(errorMessage);
|
||||
parent.stopError(message);
|
||||
}
|
||||
}
|
||||
|
||||
@ -328,8 +347,8 @@ public abstract class Chain implements Serializable {
|
||||
}
|
||||
}
|
||||
|
||||
public String getErrorMessage() {
|
||||
return errorMessage;
|
||||
public String getMessage() {
|
||||
return message;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1,37 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2023-2025, Agents-Flex (fuhai999@gmail.com).
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.agentsflex.chain;
|
||||
|
||||
import com.agentsflex.agent.Output;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* input and output chain
|
||||
*/
|
||||
public class IOChain extends SequentialChain {
|
||||
|
||||
public <T> T execute(Object input) {
|
||||
Map<String, Object> variables = new HashMap<>(1);
|
||||
variables.put(Output.DEFAULT_VALUE_KEY, input);
|
||||
|
||||
super.execute(variables);
|
||||
|
||||
//noinspection unchecked
|
||||
return (T) this.getMemory().get(Output.DEFAULT_VALUE_KEY);
|
||||
}
|
||||
}
|
@ -20,6 +20,24 @@ import com.agentsflex.chain.Chain;
|
||||
import java.util.Map;
|
||||
|
||||
public class EndNode extends AbstractBaseNode {
|
||||
private boolean isNormal = true;
|
||||
private String message;
|
||||
|
||||
public String getMessage() {
|
||||
return message;
|
||||
}
|
||||
|
||||
public void setMessage(String message) {
|
||||
this.message = message;
|
||||
}
|
||||
|
||||
public boolean isNormal() {
|
||||
return isNormal;
|
||||
}
|
||||
|
||||
public void setNormal(boolean normal) {
|
||||
isNormal = normal;
|
||||
}
|
||||
|
||||
public EndNode() {
|
||||
this.name = "end";
|
||||
@ -27,7 +45,11 @@ public class EndNode extends AbstractBaseNode {
|
||||
|
||||
@Override
|
||||
public Map<String, Object> execute(Chain chain) {
|
||||
chain.stopNormal();
|
||||
if (isNormal) {
|
||||
chain.stopNormal(message);
|
||||
} else {
|
||||
chain.stopError(message);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,15 @@
|
||||
package com.agentsflex.core.test.io;
|
||||
|
||||
import com.agentsflex.agent.DefaultAgent;
|
||||
import com.agentsflex.chain.Chain;
|
||||
|
||||
public class Agent1 extends DefaultAgent {
|
||||
public Agent1(Object id) {
|
||||
super(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object execute(Object parameter, Chain chain) {
|
||||
return "001:" + parameter;
|
||||
}
|
||||
}
|
@ -0,0 +1,16 @@
|
||||
package com.agentsflex.core.test.io;
|
||||
|
||||
import com.agentsflex.agent.DefaultAgent;
|
||||
import com.agentsflex.chain.Chain;
|
||||
|
||||
public class Agent2 extends DefaultAgent {
|
||||
|
||||
public Agent2(Object id) {
|
||||
super(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object execute(Object parameter, Chain chain) {
|
||||
return "002:" + parameter;
|
||||
}
|
||||
}
|
@ -0,0 +1,29 @@
|
||||
package com.agentsflex.core.test.io;
|
||||
|
||||
import com.agentsflex.chain.*;
|
||||
|
||||
public class AgentChainTest {
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
SequentialChain ioChain1 = new SequentialChain();
|
||||
ioChain1.addNode(new Agent1("agent1"));
|
||||
ioChain1.addNode(new Agent2("agent2"));
|
||||
|
||||
SequentialChain ioChain2 = new SequentialChain();
|
||||
ioChain2.addNode(new Agent1("agent3"));
|
||||
ioChain2.addNode(new Agent2("agent4"));
|
||||
ioChain2.addNode(ioChain1);
|
||||
|
||||
ioChain2.registerEventListener(new ChainEventListener() {
|
||||
@Override
|
||||
public void onEvent(ChainEvent event, Chain chain) {
|
||||
System.out.println(event);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
Object result = ioChain2.executeForResult("your params");
|
||||
System.out.println(result);
|
||||
}
|
||||
}
|
@ -1,15 +0,0 @@
|
||||
package com.agentsflex.core.test.io;
|
||||
|
||||
import com.agentsflex.agent.IOAgent;
|
||||
import com.agentsflex.chain.Chain;
|
||||
|
||||
public class IOAgent1 extends IOAgent {
|
||||
public IOAgent1(Object id) {
|
||||
super(id);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object execute(Object param, Chain chain) {
|
||||
return "001:" + param;
|
||||
}
|
||||
}
|
@ -1,15 +0,0 @@
|
||||
package com.agentsflex.core.test.io;
|
||||
|
||||
import com.agentsflex.agent.IOAgent;
|
||||
import com.agentsflex.chain.Chain;
|
||||
|
||||
public class IOAgent2 extends IOAgent {
|
||||
|
||||
public IOAgent2(Object id) {
|
||||
super(id);
|
||||
}
|
||||
@Override
|
||||
public Object execute(Object param, Chain chain) {
|
||||
return "002:" + param;
|
||||
}
|
||||
}
|
@ -1,32 +0,0 @@
|
||||
package com.agentsflex.core.test.io;
|
||||
|
||||
import com.agentsflex.chain.Chain;
|
||||
import com.agentsflex.chain.ChainEvent;
|
||||
import com.agentsflex.chain.ChainEventListener;
|
||||
import com.agentsflex.chain.IOChain;
|
||||
|
||||
public class IOChainTest {
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
IOChain ioChain1 = new IOChain();
|
||||
ioChain1.addNode(new IOAgent1("agent1"));
|
||||
ioChain1.addNode(new IOAgent2("agent2"));
|
||||
|
||||
IOChain ioChain2 = new IOChain();
|
||||
ioChain2.addNode(new IOAgent1("agent3"));
|
||||
ioChain2.addNode(new IOAgent2("agent4"));
|
||||
ioChain2.addNode(ioChain1);
|
||||
|
||||
ioChain2.registerEventListener(new ChainEventListener() {
|
||||
@Override
|
||||
public void onEvent(ChainEvent event, Chain chain) {
|
||||
System.out.println(event);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
Object result = ioChain2.execute("your params");
|
||||
System.out.println(result);
|
||||
}
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
# Agents-Flex ChangeLog
|
||||
|
||||
## v1.0.0-beta.2
|
||||
- 新增:新增 IOAgent 以及 IOChain,方便用于单一的输入和输出的场景
|
||||
- 新增:MilvusVectorStore 用于对 Milvus 向量数据库的支持,感谢 @xgc
|
||||
- 新增:执行链的路由节点新增对 QLExpress 和 Groovy 的规则支持
|
||||
- 优化:重构让 FunctionMessage 继承 AiMessage
|
||||
@ -11,5 +10,6 @@
|
||||
- 优化:优化执行链 Chain 以及 Node 节点,方便更加容易的创建和配置
|
||||
- 优化:重命名 BaseFunctionMessageParser 为 DefaultFunctionMessageParser
|
||||
- 优化:修改 TextParser 为 JSONObjectParser
|
||||
- 测试:多场景大量测试 Agent 以及 Chain,已初步具备编排能力
|
||||
- 文档:https://agentsflex.com 官网上线
|
||||
- 文档:完善基础文档
|
||||
|
Loading…
Reference in New Issue
Block a user