mirror of
https://gitee.com/agents-flex/agents-flex.git
synced 2024-12-01 19:37:50 +08:00
refactor: refactor chain
This commit is contained in:
parent
9962002fdb
commit
76ddf203f8
@ -205,8 +205,8 @@ public class Chain extends ChainNode {
|
||||
List<InputParameter> inputParameters = new ArrayList<>();
|
||||
for (ChainNode node : startNodes) {
|
||||
if (node instanceof BaseNode) {
|
||||
List<InputParameter> nodeInputInputParameters = ((BaseNode) node).getInputInputParameters();
|
||||
if (nodeInputInputParameters != null) inputParameters.addAll(nodeInputInputParameters);
|
||||
List<InputParameter> nodeInputParameters = ((BaseNode) node).getInputParameters();
|
||||
if (nodeInputParameters != null) inputParameters.addAll(nodeInputParameters);
|
||||
} else if (node instanceof Chain) {
|
||||
List<InputParameter> chainInputParameters = ((Chain) node).getInputParameters();
|
||||
if (chainInputParameters != null) inputParameters.addAll(chainInputParameters);
|
||||
|
@ -29,7 +29,7 @@ import java.util.Map;
|
||||
public abstract class BaseNode extends ChainNode {
|
||||
|
||||
protected String description;
|
||||
protected List<InputParameter> inputInputParameters;
|
||||
protected List<InputParameter> inputParameters;
|
||||
protected List<OutputKey> outputKeys;
|
||||
|
||||
|
||||
@ -41,19 +41,19 @@ public abstract class BaseNode extends ChainNode {
|
||||
this.description = description;
|
||||
}
|
||||
|
||||
public List<InputParameter> getInputInputParameters() {
|
||||
return inputInputParameters;
|
||||
public List<InputParameter> getInputParameters() {
|
||||
return inputParameters;
|
||||
}
|
||||
|
||||
public void setInputInputParameters(List<InputParameter> inputInputParameters) {
|
||||
this.inputInputParameters = inputInputParameters;
|
||||
public void setInputParameters(List<InputParameter> inputParameters) {
|
||||
this.inputParameters = inputParameters;
|
||||
}
|
||||
|
||||
public void addInputParameter(InputParameter inputParameter) {
|
||||
if (inputInputParameters == null) {
|
||||
inputInputParameters = new java.util.ArrayList<>();
|
||||
if (inputParameters == null) {
|
||||
inputParameters = new java.util.ArrayList<>();
|
||||
}
|
||||
inputInputParameters.add(inputParameter);
|
||||
inputParameters.add(inputParameter);
|
||||
}
|
||||
|
||||
public List<OutputKey> getOutputKeys() {
|
||||
@ -75,24 +75,24 @@ public abstract class BaseNode extends ChainNode {
|
||||
public Map<String, Object> getParameters(Chain chain) {
|
||||
Map<String, Object> variables = new HashMap<>();
|
||||
|
||||
if (this.inputInputParameters != null) {
|
||||
for (InputParameter inputInputParameter : this.inputInputParameters) {
|
||||
RefType refType = inputInputParameter.getRefType();
|
||||
if (this.inputParameters != null) {
|
||||
for (InputParameter parameter : this.inputParameters) {
|
||||
RefType refType = parameter.getRefType();
|
||||
Object value = null;
|
||||
if (refType == RefType.INPUT) {
|
||||
value = inputInputParameter.getRef();
|
||||
value = parameter.getRef();
|
||||
} else if (refType == RefType.REF) {
|
||||
value = chain.get(inputInputParameter.getRef());
|
||||
value = chain.get(parameter.getRef());
|
||||
} else {
|
||||
value = chain.get(inputInputParameter.getName());
|
||||
value = chain.get(parameter.getName());
|
||||
}
|
||||
|
||||
if (inputInputParameter.isRequired() &&
|
||||
if (parameter.isRequired() &&
|
||||
(value == null || (value instanceof String && StringUtil.noText((String) value)))) {
|
||||
chain.stopError(this.getName() + " Missing required parameter:" + inputInputParameter.getName());
|
||||
chain.stopError(this.getName() + " Missing required parameter:" + parameter.getName());
|
||||
}
|
||||
|
||||
variables.put(inputInputParameter.getName(), value);
|
||||
variables.put(parameter.getName(), value);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -60,7 +60,7 @@ public class EndNode extends BaseNode {
|
||||
"normal=" + normal +
|
||||
", message='" + message + '\'' +
|
||||
", description='" + description + '\'' +
|
||||
", inputInputParameters=" + inputInputParameters +
|
||||
", inputParameters=" + inputParameters +
|
||||
", outputKeys=" + outputKeys +
|
||||
", id='" + id + '\'' +
|
||||
", name='" + name + '\'' +
|
||||
|
@ -15,9 +15,8 @@
|
||||
*/
|
||||
package com.agentsflex.core.chain.node;
|
||||
|
||||
import com.agentsflex.core.chain.InputParameter;
|
||||
import com.agentsflex.core.chain.OutputKey;
|
||||
import com.agentsflex.core.chain.Chain;
|
||||
import com.agentsflex.core.chain.OutputKey;
|
||||
import com.agentsflex.core.llm.ChatOptions;
|
||||
import com.agentsflex.core.llm.Llm;
|
||||
import com.agentsflex.core.llm.response.AiMessageResponse;
|
||||
@ -25,28 +24,30 @@ import com.agentsflex.core.message.AiMessage;
|
||||
import com.agentsflex.core.prompt.TextPrompt;
|
||||
import com.agentsflex.core.prompt.template.TextPromptTemplate;
|
||||
import com.agentsflex.core.util.Maps;
|
||||
import com.agentsflex.core.util.StringUtil;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
public class LLMNode extends BaseNode {
|
||||
|
||||
protected Llm llm;
|
||||
protected ChatOptions chatOptions = ChatOptions.DEFAULT;
|
||||
protected String prompt;
|
||||
protected TextPromptTemplate promptTemplate;
|
||||
protected String userPrompt;
|
||||
protected TextPromptTemplate userPromptTemplate;
|
||||
|
||||
protected String systemPrompt;
|
||||
protected TextPromptTemplate systemPromptTemplate;
|
||||
|
||||
public LLMNode() {
|
||||
}
|
||||
|
||||
|
||||
public LLMNode(Llm llm, String prompt) {
|
||||
public LLMNode(Llm llm, String userPrompt) {
|
||||
this.llm = llm;
|
||||
this.prompt = prompt;
|
||||
this.promptTemplate = new TextPromptTemplate(prompt);
|
||||
this.initInputParameter();
|
||||
this.userPrompt = userPrompt;
|
||||
this.userPromptTemplate = new TextPromptTemplate(userPrompt);
|
||||
}
|
||||
|
||||
|
||||
@ -58,14 +59,22 @@ public class LLMNode extends BaseNode {
|
||||
this.llm = llm;
|
||||
}
|
||||
|
||||
public String getPrompt() {
|
||||
return prompt;
|
||||
public String getUserPrompt() {
|
||||
return userPrompt;
|
||||
}
|
||||
|
||||
public void setPrompt(String prompt) {
|
||||
this.prompt = prompt;
|
||||
this.promptTemplate = new TextPromptTemplate(prompt);
|
||||
this.initInputParameter();
|
||||
public void setUserPrompt(String userPrompt) {
|
||||
this.userPrompt = userPrompt;
|
||||
this.userPromptTemplate = new TextPromptTemplate(userPrompt);
|
||||
}
|
||||
|
||||
public String getSystemPrompt() {
|
||||
return systemPrompt;
|
||||
}
|
||||
|
||||
public void setSystemPrompt(String systemPrompt) {
|
||||
this.systemPrompt = systemPrompt;
|
||||
this.systemPromptTemplate = StringUtil.hasText(systemPrompt) ? new TextPromptTemplate(systemPrompt) : null;
|
||||
}
|
||||
|
||||
public ChatOptions getChatOptions() {
|
||||
@ -80,30 +89,21 @@ public class LLMNode extends BaseNode {
|
||||
}
|
||||
|
||||
|
||||
public void initInputParameter() {
|
||||
if (this.promptTemplate == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
Set<String> keys = this.promptTemplate.getKeys();
|
||||
if (keys == null || keys.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
List<InputParameter> inputParameters = new ArrayList<>(keys.size());
|
||||
for (String key : keys) {
|
||||
InputParameter inputParameter = new InputParameter(key, true);
|
||||
inputParameters.add(inputParameter);
|
||||
}
|
||||
this.inputInputParameters = inputParameters;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected Map<String, Object> execute(Chain chain) {
|
||||
Map<String, Object> parameters = getParameters(chain);
|
||||
TextPrompt textPrompt = promptTemplate.format(parameters);
|
||||
AiMessageResponse response = llm.chat(textPrompt, chatOptions);
|
||||
|
||||
if (userPromptTemplate == null){
|
||||
chain.stopError("user prompt is null or empty");
|
||||
return Collections.emptyMap();
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
TextPrompt userPrompt = userPromptTemplate.format(parameters);
|
||||
AiMessageResponse response = llm.chat(userPrompt, chatOptions);
|
||||
|
||||
if (chain != null) {
|
||||
chain.output(this, response);
|
||||
@ -138,10 +138,10 @@ public class LLMNode extends BaseNode {
|
||||
return "LLMNode{" +
|
||||
"llm=" + llm +
|
||||
", chatOptions=" + chatOptions +
|
||||
", prompt='" + prompt + '\'' +
|
||||
", promptTemplate=" + promptTemplate +
|
||||
", prompt='" + userPrompt + '\'' +
|
||||
", promptTemplate=" + userPromptTemplate +
|
||||
", description='" + description + '\'' +
|
||||
", inputInputParameters=" + inputInputParameters +
|
||||
", inputParameters=" + inputParameters +
|
||||
", outputKeys=" + outputKeys +
|
||||
", id='" + id + '\'' +
|
||||
", name='" + name + '\'' +
|
||||
|
@ -23,7 +23,7 @@ import java.util.Map;
|
||||
public class StartNode extends BaseNode {
|
||||
@Override
|
||||
protected Map<String, Object> execute(Chain chain) {
|
||||
if (inputInputParameters != null) {
|
||||
if (inputParameters != null) {
|
||||
return getParameters(chain);
|
||||
}
|
||||
return Collections.emptyMap();
|
||||
@ -33,7 +33,7 @@ public class StartNode extends BaseNode {
|
||||
public String toString() {
|
||||
return "StartNode{" +
|
||||
"description='" + description + '\'' +
|
||||
", inputInputParameters=" + inputInputParameters +
|
||||
", inputParameters=" + inputParameters +
|
||||
", outputKeys=" + outputKeys +
|
||||
", id='" + id + '\'' +
|
||||
", name='" + name + '\'' +
|
||||
|
Loading…
Reference in New Issue
Block a user