refactor: optimize chain nodes

This commit is contained in:
Michael Yang 2024-05-11 16:08:59 +08:00
parent 362bdbdaa5
commit 00712f1c7c
9 changed files with 125 additions and 17 deletions

View File

@ -66,12 +66,12 @@
<dependency>
<groupId>com.agentsflex</groupId>
<artifactId>agents-flex-chain-qlexpress</artifactId>
<version>1.0.0-beta.2</version>
<version>${agents-flex.version}</version>
</dependency>
<dependency>
<groupId>com.agentsflex</groupId>
<artifactId>agents-flex-chain-groovy</artifactId>
<version>1.0.0-beta.2</version>
<version>${agents-flex.version}</version>
</dependency>
<!--chains end-->

View File

@ -16,15 +16,36 @@
package com.agentsflex.chain.node;
import com.agentsflex.chain.Chain;
import com.agentsflex.chain.ChainNode;
import groovy.lang.Binding;
import groovy.lang.GroovyShell;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
public class GroovyRouterNode extends RouterNode {
private String express;
public GroovyRouterNode() {
}
public GroovyRouterNode(String express) {
this.express = express;
}
public GroovyRouterNode(String express, ChainNode... nodes) {
this.express = express;
this.setNodes(Arrays.asList(nodes));
}
public GroovyRouterNode(List<ChainNode> nodes, String express) {
super(nodes);
this.express = express;
}
@Override
protected String route(Chain chain) {
Binding binding = new Binding();

View File

@ -16,13 +16,34 @@
package com.agentsflex.chain.node;
import com.agentsflex.chain.Chain;
import com.agentsflex.chain.ChainNode;
import com.ql.util.express.DefaultContext;
import com.ql.util.express.ExpressRunner;
public class QLExpressRouterNode extends RouterNode {
import java.util.Arrays;
import java.util.List;
public class QLExpressRouterNode extends RouterNode {
private String express;
public QLExpressRouterNode() {
}
public QLExpressRouterNode(String express) {
this.express = express;
}
public QLExpressRouterNode(String express, ChainNode... nodes) {
this.express = express;
this.setNodes(Arrays.asList(nodes));
}
public QLExpressRouterNode(List<ChainNode> nodes, String express) {
super(nodes);
this.express = express;
}
@Override
protected String route(Chain chain) {
ExpressRunner runner = new ExpressRunner();

View File

@ -21,6 +21,10 @@ public interface ChainNode {
Object getId();
default String getName() {
return null;
}
boolean isSkip();
Map<String, Object> execute(Chain chain);

View File

@ -22,6 +22,7 @@ import java.util.UUID;
public abstract class AbstractBaseNode implements ChainNode {
protected Object id;
protected String name;
protected boolean skip;
public AbstractBaseNode() {
@ -37,6 +38,19 @@ public abstract class AbstractBaseNode implements ChainNode {
this.id = id;
}
@Override
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public void setSkip(boolean skip) {
this.skip = skip;
}
@Override
public boolean isSkip() {
return skip;

View File

@ -26,7 +26,10 @@ import java.util.List;
import java.util.Map;
public class AgentNode extends AbstractBaseNode {
private final Agent agent;
private Agent agent;
public AgentNode() {
}
public AgentNode(Agent agent) {
this.agent = agent;
@ -41,6 +44,10 @@ public class AgentNode extends AbstractBaseNode {
return agent;
}
public void setAgent(Agent agent) {
this.agent = agent;
}
@Override
public Map<String, Object> execute(Chain chain) {
List<Parameter> inputParameters = agent.getInputParameters();

View File

@ -21,6 +21,10 @@ import java.util.Map;
public class EndNode extends AbstractBaseNode {
public EndNode() {
this.name = "end";
}
@Override
public Map<String, Object> execute(Chain chain) {
chain.stop();

View File

@ -28,10 +28,40 @@ public class LLMRouterNode extends RouterNode {
private Llm llm;
private String prompt;
public LLMRouterNode() {
}
public LLMRouterNode(Llm llm, String prompt) {
this.llm = llm;
this.prompt = prompt;
}
public LLMRouterNode(List<ChainNode> nodes, Llm llm, String prompt) {
super(nodes);
this.llm = llm;
this.prompt = prompt;
}
public LLMRouterNode(List<ChainNode> nodes) {
super(nodes);
}
public Llm getLlm() {
return llm;
}
public void setLlm(Llm llm) {
this.llm = llm;
}
public String getPrompt() {
return prompt;
}
public void setPrompt(String prompt) {
this.prompt = prompt;
}
@Override
protected String route(Chain chain) {
SimplePromptTemplate promptTemplate = SimplePromptTemplate.create(prompt);

View File

@ -41,10 +41,17 @@ public abstract class RouterNode extends AbstractBaseNode {
}
List<ChainNode> matchNodes = new ArrayList<>();
String[] ids = routeKeys.split(",");
for (String id : ids) {
String[] idOrNames = routeKeys.split(",");
for (String idOrName : idOrNames) {
if (StringUtil.noText(idOrName)) {
continue;
} else {
idOrName = idOrName.trim();
}
for (ChainNode node : this.nodes) {
if (Objects.equals(id, String.valueOf(node.getId()))) {
if (Objects.equals(idOrName, String.valueOf(node.getId()))
|| Objects.equals(idOrName, node.getName())
) {
matchNodes.add(node);
}
}
@ -57,30 +64,30 @@ public abstract class RouterNode extends AbstractBaseNode {
return executeNode(chain, matchNodes.get(0));
}
return onMatchMultiNodes(matchNodes,chain);
return onMatchMultiNodes(matchNodes, chain);
}
protected Map<String,Object> onMatchMultiNodes(List<ChainNode> nodes, Chain chain) {
protected Map<String, Object> onMatchMultiNodes(List<ChainNode> nodes, Chain chain) {
switch (this.multiMatchStrategy) {
case ALL:
return buildMultiResult(nodes, chain);
case FIRST:
return executeNode(chain,nodes.get(0));
return executeNode(chain, nodes.get(0));
case LAST:
return executeNode(chain,nodes.get(nodes.size() - 1));
return executeNode(chain, nodes.get(nodes.size() - 1));
case RANDOM:
return executeNode(chain,nodes.get(ThreadLocalRandom.current().nextInt(nodes.size())));
return executeNode(chain, nodes.get(ThreadLocalRandom.current().nextInt(nodes.size())));
default:
return null;
}
}
private Map<String,Object> buildMultiResult(List<ChainNode> nodes, Chain chain) {
Map<String,Object> results = new HashMap<>();
private Map<String, Object> buildMultiResult(List<ChainNode> nodes, Chain chain) {
Map<String, Object> results = new HashMap<>();
for (ChainNode matchNode : nodes) {
Map<String, Object> result = executeNode(chain, matchNode);
if (result != null){
if (result != null) {
results.putAll(result);
}
}
@ -109,8 +116,8 @@ public abstract class RouterNode extends AbstractBaseNode {
this.nodes = nodes;
}
public void addNode(ChainNode node){
if (nodes == null){
public void addNode(ChainNode node) {
if (nodes == null) {
nodes = new ArrayList<>();
}
nodes.add(node);