mirror of
https://gitee.com/agents-flex/agents-flex.git
synced 2024-11-29 18:38:17 +08:00
refactor: refactor chain
This commit is contained in:
parent
ebafb3eae4
commit
b54b3ff1b8
@ -18,15 +18,16 @@ package com.agentsflex.core.chain;
|
||||
import com.agentsflex.core.chain.event.*;
|
||||
import com.agentsflex.core.chain.node.BaseNode;
|
||||
import com.agentsflex.core.util.CollectionUtil;
|
||||
import com.agentsflex.core.util.MapUtil;
|
||||
import com.agentsflex.core.util.NamedThreadPools;
|
||||
import com.agentsflex.core.util.StringUtil;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
|
||||
|
||||
public class Chain extends ChainNode {
|
||||
public static final String CTX_EXEC_COUNT = "_exec_count";
|
||||
protected Map<Class<?>, List<ChainEventListener>> eventListeners = new HashMap<>(0);
|
||||
protected Map<String, Object> executeResult = null;
|
||||
protected List<ChainOutputListener> outputListeners = new ArrayList<>();
|
||||
@ -37,6 +38,7 @@ public class Chain extends ChainNode {
|
||||
protected Chain parent;
|
||||
protected List<Chain> children;
|
||||
protected ExecutorService asyncNodeExecutors = NamedThreadPools.newFixedThreadPool("chain-executor");
|
||||
protected Map<String, NodeContext> nodeContexts = new ConcurrentHashMap<>();
|
||||
|
||||
|
||||
public Chain() {
|
||||
@ -92,7 +94,6 @@ public class Chain extends ChainNode {
|
||||
this.outputListeners.add(outputListener);
|
||||
}
|
||||
|
||||
|
||||
public List<ChainNode> getNodes() {
|
||||
return nodes;
|
||||
}
|
||||
@ -215,16 +216,30 @@ public class Chain extends ChainNode {
|
||||
return inputParameters;
|
||||
}
|
||||
|
||||
public NodeContext getNodeContext(String nodeId) {
|
||||
return MapUtil.computeIfAbsent(nodeContexts, nodeId, k -> new NodeContext());
|
||||
}
|
||||
|
||||
protected void executeInternal() {
|
||||
List<ChainNode> currentNodes = getStartNodes();
|
||||
while (CollectionUtil.hasItems(currentNodes)) {
|
||||
ChainNode currentNode = currentNodes.remove(0);
|
||||
if (currentNodes == null || currentNodes.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
Integer execCount = (Integer) currentNode.getMemory().get(CTX_EXEC_COUNT);
|
||||
if (execCount == null) execCount = 0;
|
||||
List<ExecuteNode> waitingExecuteNodes = new ArrayList<>();
|
||||
for (ChainNode currentNode : currentNodes) {
|
||||
waitingExecuteNodes.add(new ExecuteNode(currentNode, null, ""));
|
||||
}
|
||||
|
||||
while (CollectionUtil.hasItems(waitingExecuteNodes)) {
|
||||
ExecuteNode executeNode = waitingExecuteNodes.remove(0);
|
||||
ChainNode currentNode = executeNode.currentNode;
|
||||
|
||||
NodeContext nodeContext = getNodeContext(currentNode.getId());
|
||||
nodeContext.recordTrigger(executeNode);
|
||||
|
||||
NodeCondition nodeCondition = currentNode.getCondition();
|
||||
if (nodeCondition != null && !nodeCondition.check(this, currentNode)) {
|
||||
if (nodeCondition != null && !nodeCondition.check(this, nodeContext)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -235,11 +250,11 @@ public class Chain extends ChainNode {
|
||||
if (this.getStatus() != ChainStatus.RUNNING) {
|
||||
break;
|
||||
}
|
||||
nodeContext.recordExecute(executeNode);
|
||||
executeResult = executeNode(currentNode);
|
||||
this.executeResult = executeResult;
|
||||
} finally {
|
||||
ChainContext.clearNode();
|
||||
currentNode.getMemory().put(CTX_EXEC_COUNT, execCount + 1);
|
||||
notifyEvent(new OnNodeFinishedEvent(currentNode, executeResult));
|
||||
}
|
||||
|
||||
@ -263,9 +278,9 @@ public class Chain extends ChainNode {
|
||||
}
|
||||
EdgeCondition condition = chainEdge.getCondition();
|
||||
if (condition == null) {
|
||||
currentNodes.add(nextNode);
|
||||
waitingExecuteNodes.add(new ExecuteNode(nextNode, currentNode, chainEdge.getId()));
|
||||
} else if (condition.check(this, chainEdge)) {
|
||||
currentNodes.add(nextNode);
|
||||
waitingExecuteNodes.add(new ExecuteNode(nextNode, currentNode, chainEdge.getId()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -410,6 +425,19 @@ public class Chain extends ChainNode {
|
||||
this.asyncNodeExecutors = asyncNodeExecutors;
|
||||
}
|
||||
|
||||
public static class ExecuteNode {
|
||||
|
||||
final ChainNode currentNode;
|
||||
final ChainNode prevNode;
|
||||
final String fromEdgeId;
|
||||
|
||||
public ExecuteNode(ChainNode currentNode, ChainNode prevNode, String fromEdgeId) {
|
||||
this.currentNode = currentNode;
|
||||
this.prevNode = prevNode;
|
||||
this.fromEdgeId = fromEdgeId;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Chain{" +
|
||||
|
@ -16,12 +16,21 @@
|
||||
package com.agentsflex.core.chain;
|
||||
|
||||
public class ChainEdge {
|
||||
private String id;
|
||||
private String source;
|
||||
private String target;
|
||||
private EdgeCondition condition;
|
||||
private int weight;
|
||||
private boolean isDefault;
|
||||
|
||||
public String getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
public void setId(String id) {
|
||||
this.id = id;
|
||||
}
|
||||
|
||||
public String getSource() {
|
||||
return source;
|
||||
}
|
||||
|
@ -17,6 +17,6 @@ package com.agentsflex.core.chain;
|
||||
|
||||
public interface NodeCondition {
|
||||
|
||||
boolean check(Chain chain, ChainNode node);
|
||||
boolean check(Chain chain, NodeContext context);
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,76 @@
|
||||
/*
|
||||
* Copyright (c) 2023-2025, Agents-Flex (fuhai999@gmail.com).
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.agentsflex.core.chain;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public class NodeContext {
|
||||
|
||||
public ChainNode currentNode;
|
||||
public ChainNode prevNode;
|
||||
public String fromEdgeId;
|
||||
|
||||
private AtomicInteger triggerCount = new AtomicInteger(0);
|
||||
private List<String> triggerEdgeIds = new ArrayList<>();
|
||||
|
||||
private AtomicInteger executeCount = new AtomicInteger(0);
|
||||
private List<String> executeEdgeIds = new ArrayList<>();
|
||||
|
||||
|
||||
public ChainNode getCurrentNode() {
|
||||
return currentNode;
|
||||
}
|
||||
|
||||
public ChainNode getPrevNode() {
|
||||
return prevNode;
|
||||
}
|
||||
|
||||
public String getFromEdgeId() {
|
||||
return fromEdgeId;
|
||||
}
|
||||
|
||||
public int getTriggerCount() {
|
||||
return triggerCount.get();
|
||||
}
|
||||
|
||||
public List<String> getTriggerEdgeIds() {
|
||||
return triggerEdgeIds;
|
||||
}
|
||||
|
||||
public int getExecuteCount() {
|
||||
return executeCount.get();
|
||||
}
|
||||
|
||||
public List<String> getExecuteEdgeIds() {
|
||||
return executeEdgeIds;
|
||||
}
|
||||
|
||||
public synchronized void recordTrigger(Chain.ExecuteNode executeNode) {
|
||||
this.currentNode = executeNode.currentNode;
|
||||
this.prevNode = executeNode.prevNode;
|
||||
this.fromEdgeId = executeNode.fromEdgeId;
|
||||
|
||||
triggerCount.incrementAndGet();
|
||||
triggerEdgeIds.add(executeNode.fromEdgeId);
|
||||
}
|
||||
|
||||
public void recordExecute(Chain.ExecuteNode executeNode) {
|
||||
executeCount.incrementAndGet();
|
||||
executeEdgeIds.add(executeNode.fromEdgeId);
|
||||
}
|
||||
}
|
@ -0,0 +1,70 @@
|
||||
/*
|
||||
* Copyright (c) 2023-2025, Agents-Flex (fuhai999@gmail.com).
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.agentsflex.core.util;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.function.Function;
|
||||
|
||||
public class MapUtil {
|
||||
private static final boolean IS_JDK8 = (8 == getJvmVersion0());
|
||||
|
||||
private MapUtil() {
|
||||
}
|
||||
|
||||
private static String tryTrim(String string) {
|
||||
return string != null ? string.trim() : "";
|
||||
}
|
||||
|
||||
private static int getJvmVersion0() {
|
||||
int jvmVersion = -1;
|
||||
try {
|
||||
String javaSpecVer = tryTrim(System.getProperty("java.specification.version"));
|
||||
if (StringUtil.hasText(javaSpecVer)) {
|
||||
if (javaSpecVer.startsWith("1.")) {
|
||||
javaSpecVer = javaSpecVer.substring(2);
|
||||
}
|
||||
if (javaSpecVer.indexOf('.') == -1) {
|
||||
jvmVersion = Integer.parseInt(javaSpecVer);
|
||||
}
|
||||
}
|
||||
} catch (Throwable ignore) {
|
||||
// ignore
|
||||
}
|
||||
// default is jdk8
|
||||
if (jvmVersion == -1) {
|
||||
jvmVersion = 8;
|
||||
}
|
||||
return jvmVersion;
|
||||
}
|
||||
|
||||
/**
|
||||
* A temporary workaround for Java 8 specific performance issue JDK-8161372 .<br>
|
||||
* This class should be removed once we drop Java 8 support.
|
||||
*
|
||||
* @see <a href=
|
||||
* "https://bugs.openjdk.java.net/browse/JDK-8161372">https://bugs.openjdk.java.net/browse/JDK-8161372</a>
|
||||
*/
|
||||
public static <K, V> V computeIfAbsent(Map<K, V> map, K key, Function<K, V> mappingFunction) {
|
||||
if (IS_JDK8) {
|
||||
V value = map.get(key);
|
||||
if (value != null) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
return map.computeIfAbsent(key, mappingFunction);
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user