refactor: refactor chain

This commit is contained in:
Michael Yang 2024-10-30 12:45:50 +08:00
parent ebafb3eae4
commit b54b3ff1b8
5 changed files with 194 additions and 11 deletions

View File

@ -18,15 +18,16 @@ package com.agentsflex.core.chain;
import com.agentsflex.core.chain.event.*;
import com.agentsflex.core.chain.node.BaseNode;
import com.agentsflex.core.util.CollectionUtil;
import com.agentsflex.core.util.MapUtil;
import com.agentsflex.core.util.NamedThreadPools;
import com.agentsflex.core.util.StringUtil;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
public class Chain extends ChainNode {
public static final String CTX_EXEC_COUNT = "_exec_count";
protected Map<Class<?>, List<ChainEventListener>> eventListeners = new HashMap<>(0);
protected Map<String, Object> executeResult = null;
protected List<ChainOutputListener> outputListeners = new ArrayList<>();
@ -37,6 +38,7 @@ public class Chain extends ChainNode {
protected Chain parent;
protected List<Chain> children;
protected ExecutorService asyncNodeExecutors = NamedThreadPools.newFixedThreadPool("chain-executor");
protected Map<String, NodeContext> nodeContexts = new ConcurrentHashMap<>();
public Chain() {
@ -92,7 +94,6 @@ public class Chain extends ChainNode {
this.outputListeners.add(outputListener);
}
public List<ChainNode> getNodes() {
return nodes;
}
@ -215,16 +216,30 @@ public class Chain extends ChainNode {
return inputParameters;
}
public NodeContext getNodeContext(String nodeId) {
return MapUtil.computeIfAbsent(nodeContexts, nodeId, k -> new NodeContext());
}
protected void executeInternal() {
List<ChainNode> currentNodes = getStartNodes();
while (CollectionUtil.hasItems(currentNodes)) {
ChainNode currentNode = currentNodes.remove(0);
if (currentNodes == null || currentNodes.isEmpty()) {
return;
}
Integer execCount = (Integer) currentNode.getMemory().get(CTX_EXEC_COUNT);
if (execCount == null) execCount = 0;
List<ExecuteNode> waitingExecuteNodes = new ArrayList<>();
for (ChainNode currentNode : currentNodes) {
waitingExecuteNodes.add(new ExecuteNode(currentNode, null, ""));
}
while (CollectionUtil.hasItems(waitingExecuteNodes)) {
ExecuteNode executeNode = waitingExecuteNodes.remove(0);
ChainNode currentNode = executeNode.currentNode;
NodeContext nodeContext = getNodeContext(currentNode.getId());
nodeContext.recordTrigger(executeNode);
NodeCondition nodeCondition = currentNode.getCondition();
if (nodeCondition != null && !nodeCondition.check(this, currentNode)) {
if (nodeCondition != null && !nodeCondition.check(this, nodeContext)) {
continue;
}
@ -235,11 +250,11 @@ public class Chain extends ChainNode {
if (this.getStatus() != ChainStatus.RUNNING) {
break;
}
nodeContext.recordExecute(executeNode);
executeResult = executeNode(currentNode);
this.executeResult = executeResult;
} finally {
ChainContext.clearNode();
currentNode.getMemory().put(CTX_EXEC_COUNT, execCount + 1);
notifyEvent(new OnNodeFinishedEvent(currentNode, executeResult));
}
@ -263,9 +278,9 @@ public class Chain extends ChainNode {
}
EdgeCondition condition = chainEdge.getCondition();
if (condition == null) {
currentNodes.add(nextNode);
waitingExecuteNodes.add(new ExecuteNode(nextNode, currentNode, chainEdge.getId()));
} else if (condition.check(this, chainEdge)) {
currentNodes.add(nextNode);
waitingExecuteNodes.add(new ExecuteNode(nextNode, currentNode, chainEdge.getId()));
}
}
}
@ -410,6 +425,19 @@ public class Chain extends ChainNode {
this.asyncNodeExecutors = asyncNodeExecutors;
}
public static class ExecuteNode {
final ChainNode currentNode;
final ChainNode prevNode;
final String fromEdgeId;
public ExecuteNode(ChainNode currentNode, ChainNode prevNode, String fromEdgeId) {
this.currentNode = currentNode;
this.prevNode = prevNode;
this.fromEdgeId = fromEdgeId;
}
}
@Override
public String toString() {
return "Chain{" +

View File

@ -16,12 +16,21 @@
package com.agentsflex.core.chain;
public class ChainEdge {
private String id;
private String source;
private String target;
private EdgeCondition condition;
private int weight;
private boolean isDefault;
public String getId() {
return id;
}
public void setId(String id) {
this.id = id;
}
public String getSource() {
return source;
}

View File

@ -17,6 +17,6 @@ package com.agentsflex.core.chain;
public interface NodeCondition {
boolean check(Chain chain, ChainNode node);
boolean check(Chain chain, NodeContext context);
}

View File

@ -0,0 +1,76 @@
/*
* Copyright (c) 2023-2025, Agents-Flex (fuhai999@gmail.com).
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.agentsflex.core.chain;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
public class NodeContext {
public ChainNode currentNode;
public ChainNode prevNode;
public String fromEdgeId;
private AtomicInteger triggerCount = new AtomicInteger(0);
private List<String> triggerEdgeIds = new ArrayList<>();
private AtomicInteger executeCount = new AtomicInteger(0);
private List<String> executeEdgeIds = new ArrayList<>();
public ChainNode getCurrentNode() {
return currentNode;
}
public ChainNode getPrevNode() {
return prevNode;
}
public String getFromEdgeId() {
return fromEdgeId;
}
public int getTriggerCount() {
return triggerCount.get();
}
public List<String> getTriggerEdgeIds() {
return triggerEdgeIds;
}
public int getExecuteCount() {
return executeCount.get();
}
public List<String> getExecuteEdgeIds() {
return executeEdgeIds;
}
public synchronized void recordTrigger(Chain.ExecuteNode executeNode) {
this.currentNode = executeNode.currentNode;
this.prevNode = executeNode.prevNode;
this.fromEdgeId = executeNode.fromEdgeId;
triggerCount.incrementAndGet();
triggerEdgeIds.add(executeNode.fromEdgeId);
}
public void recordExecute(Chain.ExecuteNode executeNode) {
executeCount.incrementAndGet();
executeEdgeIds.add(executeNode.fromEdgeId);
}
}

View File

@ -0,0 +1,70 @@
/*
* Copyright (c) 2023-2025, Agents-Flex (fuhai999@gmail.com).
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.agentsflex.core.util;
import java.util.Map;
import java.util.function.Function;
public class MapUtil {
private static final boolean IS_JDK8 = (8 == getJvmVersion0());
private MapUtil() {
}
private static String tryTrim(String string) {
return string != null ? string.trim() : "";
}
private static int getJvmVersion0() {
int jvmVersion = -1;
try {
String javaSpecVer = tryTrim(System.getProperty("java.specification.version"));
if (StringUtil.hasText(javaSpecVer)) {
if (javaSpecVer.startsWith("1.")) {
javaSpecVer = javaSpecVer.substring(2);
}
if (javaSpecVer.indexOf('.') == -1) {
jvmVersion = Integer.parseInt(javaSpecVer);
}
}
} catch (Throwable ignore) {
// ignore
}
// default is jdk8
if (jvmVersion == -1) {
jvmVersion = 8;
}
return jvmVersion;
}
/**
* A temporary workaround for Java 8 specific performance issue JDK-8161372 .<br>
* This class should be removed once we drop Java 8 support.
*
* @see <a href=
* "https://bugs.openjdk.java.net/browse/JDK-8161372">https://bugs.openjdk.java.net/browse/JDK-8161372</a>
*/
public static <K, V> V computeIfAbsent(Map<K, V> map, K key, Function<K, V> mappingFunction) {
if (IS_JDK8) {
V value = map.get(key);
if (value != null) {
return value;
}
}
return map.computeIfAbsent(key, mappingFunction);
}
}