mirror of
https://gitee.com/agents-flex/agents-flex.git
synced 2024-11-29 18:38:17 +08:00
refactor: optimize chain nodes
This commit is contained in:
parent
362bdbdaa5
commit
00712f1c7c
@ -66,12 +66,12 @@
|
||||
<dependency>
|
||||
<groupId>com.agentsflex</groupId>
|
||||
<artifactId>agents-flex-chain-qlexpress</artifactId>
|
||||
<version>1.0.0-beta.2</version>
|
||||
<version>${agents-flex.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.agentsflex</groupId>
|
||||
<artifactId>agents-flex-chain-groovy</artifactId>
|
||||
<version>1.0.0-beta.2</version>
|
||||
<version>${agents-flex.version}</version>
|
||||
</dependency>
|
||||
<!--chains end-->
|
||||
|
||||
|
@ -16,15 +16,36 @@
|
||||
package com.agentsflex.chain.node;
|
||||
|
||||
import com.agentsflex.chain.Chain;
|
||||
import com.agentsflex.chain.ChainNode;
|
||||
import groovy.lang.Binding;
|
||||
import groovy.lang.GroovyShell;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class GroovyRouterNode extends RouterNode {
|
||||
|
||||
private String express;
|
||||
|
||||
public GroovyRouterNode() {
|
||||
|
||||
}
|
||||
|
||||
public GroovyRouterNode(String express) {
|
||||
this.express = express;
|
||||
}
|
||||
|
||||
public GroovyRouterNode(String express, ChainNode... nodes) {
|
||||
this.express = express;
|
||||
this.setNodes(Arrays.asList(nodes));
|
||||
}
|
||||
|
||||
public GroovyRouterNode(List<ChainNode> nodes, String express) {
|
||||
super(nodes);
|
||||
this.express = express;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String route(Chain chain) {
|
||||
Binding binding = new Binding();
|
||||
|
@ -16,13 +16,34 @@
|
||||
package com.agentsflex.chain.node;
|
||||
|
||||
import com.agentsflex.chain.Chain;
|
||||
import com.agentsflex.chain.ChainNode;
|
||||
import com.ql.util.express.DefaultContext;
|
||||
import com.ql.util.express.ExpressRunner;
|
||||
|
||||
public class QLExpressRouterNode extends RouterNode {
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class QLExpressRouterNode extends RouterNode {
|
||||
private String express;
|
||||
|
||||
public QLExpressRouterNode() {
|
||||
|
||||
}
|
||||
|
||||
public QLExpressRouterNode(String express) {
|
||||
this.express = express;
|
||||
}
|
||||
|
||||
public QLExpressRouterNode(String express, ChainNode... nodes) {
|
||||
this.express = express;
|
||||
this.setNodes(Arrays.asList(nodes));
|
||||
}
|
||||
|
||||
public QLExpressRouterNode(List<ChainNode> nodes, String express) {
|
||||
super(nodes);
|
||||
this.express = express;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String route(Chain chain) {
|
||||
ExpressRunner runner = new ExpressRunner();
|
||||
|
@ -21,6 +21,10 @@ public interface ChainNode {
|
||||
|
||||
Object getId();
|
||||
|
||||
default String getName() {
|
||||
return null;
|
||||
}
|
||||
|
||||
boolean isSkip();
|
||||
|
||||
Map<String, Object> execute(Chain chain);
|
||||
|
@ -22,6 +22,7 @@ import java.util.UUID;
|
||||
public abstract class AbstractBaseNode implements ChainNode {
|
||||
|
||||
protected Object id;
|
||||
protected String name;
|
||||
protected boolean skip;
|
||||
|
||||
public AbstractBaseNode() {
|
||||
@ -37,6 +38,19 @@ public abstract class AbstractBaseNode implements ChainNode {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getName() {
|
||||
return name;
|
||||
}
|
||||
|
||||
public void setName(String name) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
public void setSkip(boolean skip) {
|
||||
this.skip = skip;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isSkip() {
|
||||
return skip;
|
||||
|
@ -26,7 +26,10 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class AgentNode extends AbstractBaseNode {
|
||||
private final Agent agent;
|
||||
private Agent agent;
|
||||
|
||||
public AgentNode() {
|
||||
}
|
||||
|
||||
public AgentNode(Agent agent) {
|
||||
this.agent = agent;
|
||||
@ -41,6 +44,10 @@ public class AgentNode extends AbstractBaseNode {
|
||||
return agent;
|
||||
}
|
||||
|
||||
public void setAgent(Agent agent) {
|
||||
this.agent = agent;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> execute(Chain chain) {
|
||||
List<Parameter> inputParameters = agent.getInputParameters();
|
||||
|
@ -21,6 +21,10 @@ import java.util.Map;
|
||||
|
||||
public class EndNode extends AbstractBaseNode {
|
||||
|
||||
public EndNode() {
|
||||
this.name = "end";
|
||||
}
|
||||
|
||||
@Override
|
||||
public Map<String, Object> execute(Chain chain) {
|
||||
chain.stop();
|
||||
|
@ -28,10 +28,40 @@ public class LLMRouterNode extends RouterNode {
|
||||
private Llm llm;
|
||||
private String prompt;
|
||||
|
||||
public LLMRouterNode() {
|
||||
}
|
||||
|
||||
public LLMRouterNode(Llm llm, String prompt) {
|
||||
this.llm = llm;
|
||||
this.prompt = prompt;
|
||||
}
|
||||
|
||||
public LLMRouterNode(List<ChainNode> nodes, Llm llm, String prompt) {
|
||||
super(nodes);
|
||||
this.llm = llm;
|
||||
this.prompt = prompt;
|
||||
}
|
||||
|
||||
public LLMRouterNode(List<ChainNode> nodes) {
|
||||
super(nodes);
|
||||
}
|
||||
|
||||
public Llm getLlm() {
|
||||
return llm;
|
||||
}
|
||||
|
||||
public void setLlm(Llm llm) {
|
||||
this.llm = llm;
|
||||
}
|
||||
|
||||
public String getPrompt() {
|
||||
return prompt;
|
||||
}
|
||||
|
||||
public void setPrompt(String prompt) {
|
||||
this.prompt = prompt;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String route(Chain chain) {
|
||||
SimplePromptTemplate promptTemplate = SimplePromptTemplate.create(prompt);
|
||||
|
@ -41,10 +41,17 @@ public abstract class RouterNode extends AbstractBaseNode {
|
||||
}
|
||||
|
||||
List<ChainNode> matchNodes = new ArrayList<>();
|
||||
String[] ids = routeKeys.split(",");
|
||||
for (String id : ids) {
|
||||
String[] idOrNames = routeKeys.split(",");
|
||||
for (String idOrName : idOrNames) {
|
||||
if (StringUtil.noText(idOrName)) {
|
||||
continue;
|
||||
} else {
|
||||
idOrName = idOrName.trim();
|
||||
}
|
||||
for (ChainNode node : this.nodes) {
|
||||
if (Objects.equals(id, String.valueOf(node.getId()))) {
|
||||
if (Objects.equals(idOrName, String.valueOf(node.getId()))
|
||||
|| Objects.equals(idOrName, node.getName())
|
||||
) {
|
||||
matchNodes.add(node);
|
||||
}
|
||||
}
|
||||
@ -57,30 +64,30 @@ public abstract class RouterNode extends AbstractBaseNode {
|
||||
return executeNode(chain, matchNodes.get(0));
|
||||
}
|
||||
|
||||
return onMatchMultiNodes(matchNodes,chain);
|
||||
return onMatchMultiNodes(matchNodes, chain);
|
||||
}
|
||||
|
||||
protected Map<String,Object> onMatchMultiNodes(List<ChainNode> nodes, Chain chain) {
|
||||
protected Map<String, Object> onMatchMultiNodes(List<ChainNode> nodes, Chain chain) {
|
||||
switch (this.multiMatchStrategy) {
|
||||
case ALL:
|
||||
return buildMultiResult(nodes, chain);
|
||||
case FIRST:
|
||||
return executeNode(chain,nodes.get(0));
|
||||
return executeNode(chain, nodes.get(0));
|
||||
case LAST:
|
||||
return executeNode(chain,nodes.get(nodes.size() - 1));
|
||||
return executeNode(chain, nodes.get(nodes.size() - 1));
|
||||
case RANDOM:
|
||||
return executeNode(chain,nodes.get(ThreadLocalRandom.current().nextInt(nodes.size())));
|
||||
return executeNode(chain, nodes.get(ThreadLocalRandom.current().nextInt(nodes.size())));
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private Map<String,Object> buildMultiResult(List<ChainNode> nodes, Chain chain) {
|
||||
Map<String,Object> results = new HashMap<>();
|
||||
private Map<String, Object> buildMultiResult(List<ChainNode> nodes, Chain chain) {
|
||||
Map<String, Object> results = new HashMap<>();
|
||||
for (ChainNode matchNode : nodes) {
|
||||
Map<String, Object> result = executeNode(chain, matchNode);
|
||||
if (result != null){
|
||||
if (result != null) {
|
||||
results.putAll(result);
|
||||
}
|
||||
}
|
||||
@ -109,8 +116,8 @@ public abstract class RouterNode extends AbstractBaseNode {
|
||||
this.nodes = nodes;
|
||||
}
|
||||
|
||||
public void addNode(ChainNode node){
|
||||
if (nodes == null){
|
||||
public void addNode(ChainNode node) {
|
||||
if (nodes == null) {
|
||||
nodes = new ArrayList<>();
|
||||
}
|
||||
nodes.add(node);
|
||||
|
Loading…
Reference in New Issue
Block a user