refactor: refactor chain

This commit is contained in:
Michael Yang 2024-10-30 09:50:19 +08:00
parent 9962002fdb
commit 76ddf203f8
5 changed files with 62 additions and 62 deletions

View File

@ -205,8 +205,8 @@ public class Chain extends ChainNode {
List<InputParameter> inputParameters = new ArrayList<>();
for (ChainNode node : startNodes) {
if (node instanceof BaseNode) {
List<InputParameter> nodeInputInputParameters = ((BaseNode) node).getInputInputParameters();
if (nodeInputInputParameters != null) inputParameters.addAll(nodeInputInputParameters);
List<InputParameter> nodeInputParameters = ((BaseNode) node).getInputParameters();
if (nodeInputParameters != null) inputParameters.addAll(nodeInputParameters);
} else if (node instanceof Chain) {
List<InputParameter> chainInputParameters = ((Chain) node).getInputParameters();
if (chainInputParameters != null) inputParameters.addAll(chainInputParameters);

View File

@ -29,7 +29,7 @@ import java.util.Map;
public abstract class BaseNode extends ChainNode {
protected String description;
protected List<InputParameter> inputInputParameters;
protected List<InputParameter> inputParameters;
protected List<OutputKey> outputKeys;
@ -41,19 +41,19 @@ public abstract class BaseNode extends ChainNode {
this.description = description;
}
public List<InputParameter> getInputInputParameters() {
return inputInputParameters;
public List<InputParameter> getInputParameters() {
return inputParameters;
}
public void setInputInputParameters(List<InputParameter> inputInputParameters) {
this.inputInputParameters = inputInputParameters;
public void setInputParameters(List<InputParameter> inputParameters) {
this.inputParameters = inputParameters;
}
public void addInputParameter(InputParameter inputParameter) {
if (inputInputParameters == null) {
inputInputParameters = new java.util.ArrayList<>();
if (inputParameters == null) {
inputParameters = new java.util.ArrayList<>();
}
inputInputParameters.add(inputParameter);
inputParameters.add(inputParameter);
}
public List<OutputKey> getOutputKeys() {
@ -75,24 +75,24 @@ public abstract class BaseNode extends ChainNode {
public Map<String, Object> getParameters(Chain chain) {
Map<String, Object> variables = new HashMap<>();
if (this.inputInputParameters != null) {
for (InputParameter inputInputParameter : this.inputInputParameters) {
RefType refType = inputInputParameter.getRefType();
if (this.inputParameters != null) {
for (InputParameter parameter : this.inputParameters) {
RefType refType = parameter.getRefType();
Object value = null;
if (refType == RefType.INPUT) {
value = inputInputParameter.getRef();
value = parameter.getRef();
} else if (refType == RefType.REF) {
value = chain.get(inputInputParameter.getRef());
value = chain.get(parameter.getRef());
} else {
value = chain.get(inputInputParameter.getName());
value = chain.get(parameter.getName());
}
if (inputInputParameter.isRequired() &&
if (parameter.isRequired() &&
(value == null || (value instanceof String && StringUtil.noText((String) value)))) {
chain.stopError(this.getName() + " Missing required parameter:" + inputInputParameter.getName());
chain.stopError(this.getName() + " Missing required parameter:" + parameter.getName());
}
variables.put(inputInputParameter.getName(), value);
variables.put(parameter.getName(), value);
}
}

View File

@ -60,7 +60,7 @@ public class EndNode extends BaseNode {
"normal=" + normal +
", message='" + message + '\'' +
", description='" + description + '\'' +
", inputInputParameters=" + inputInputParameters +
", inputParameters=" + inputParameters +
", outputKeys=" + outputKeys +
", id='" + id + '\'' +
", name='" + name + '\'' +

View File

@ -15,9 +15,8 @@
*/
package com.agentsflex.core.chain.node;
import com.agentsflex.core.chain.InputParameter;
import com.agentsflex.core.chain.OutputKey;
import com.agentsflex.core.chain.Chain;
import com.agentsflex.core.chain.OutputKey;
import com.agentsflex.core.llm.ChatOptions;
import com.agentsflex.core.llm.Llm;
import com.agentsflex.core.llm.response.AiMessageResponse;
@ -25,28 +24,30 @@ import com.agentsflex.core.message.AiMessage;
import com.agentsflex.core.prompt.TextPrompt;
import com.agentsflex.core.prompt.template.TextPromptTemplate;
import com.agentsflex.core.util.Maps;
import com.agentsflex.core.util.StringUtil;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
public class LLMNode extends BaseNode {
protected Llm llm;
protected ChatOptions chatOptions = ChatOptions.DEFAULT;
protected String prompt;
protected TextPromptTemplate promptTemplate;
protected String userPrompt;
protected TextPromptTemplate userPromptTemplate;
protected String systemPrompt;
protected TextPromptTemplate systemPromptTemplate;
public LLMNode() {
}
public LLMNode(Llm llm, String prompt) {
public LLMNode(Llm llm, String userPrompt) {
this.llm = llm;
this.prompt = prompt;
this.promptTemplate = new TextPromptTemplate(prompt);
this.initInputParameter();
this.userPrompt = userPrompt;
this.userPromptTemplate = new TextPromptTemplate(userPrompt);
}
@ -58,14 +59,22 @@ public class LLMNode extends BaseNode {
this.llm = llm;
}
public String getPrompt() {
return prompt;
public String getUserPrompt() {
return userPrompt;
}
public void setPrompt(String prompt) {
this.prompt = prompt;
this.promptTemplate = new TextPromptTemplate(prompt);
this.initInputParameter();
public void setUserPrompt(String userPrompt) {
this.userPrompt = userPrompt;
this.userPromptTemplate = new TextPromptTemplate(userPrompt);
}
public String getSystemPrompt() {
return systemPrompt;
}
public void setSystemPrompt(String systemPrompt) {
this.systemPrompt = systemPrompt;
this.systemPromptTemplate = StringUtil.hasText(systemPrompt) ? new TextPromptTemplate(systemPrompt) : null;
}
public ChatOptions getChatOptions() {
@ -80,30 +89,21 @@ public class LLMNode extends BaseNode {
}
public void initInputParameter() {
if (this.promptTemplate == null) {
return;
}
Set<String> keys = this.promptTemplate.getKeys();
if (keys == null || keys.isEmpty()) {
return;
}
List<InputParameter> inputParameters = new ArrayList<>(keys.size());
for (String key : keys) {
InputParameter inputParameter = new InputParameter(key, true);
inputParameters.add(inputParameter);
}
this.inputInputParameters = inputParameters;
}
@Override
protected Map<String, Object> execute(Chain chain) {
Map<String, Object> parameters = getParameters(chain);
TextPrompt textPrompt = promptTemplate.format(parameters);
AiMessageResponse response = llm.chat(textPrompt, chatOptions);
if (userPromptTemplate == null){
chain.stopError("user prompt is null or empty");
return Collections.emptyMap();
}
TextPrompt userPrompt = userPromptTemplate.format(parameters);
AiMessageResponse response = llm.chat(userPrompt, chatOptions);
if (chain != null) {
chain.output(this, response);
@ -138,10 +138,10 @@ public class LLMNode extends BaseNode {
return "LLMNode{" +
"llm=" + llm +
", chatOptions=" + chatOptions +
", prompt='" + prompt + '\'' +
", promptTemplate=" + promptTemplate +
", prompt='" + userPrompt + '\'' +
", promptTemplate=" + userPromptTemplate +
", description='" + description + '\'' +
", inputInputParameters=" + inputInputParameters +
", inputParameters=" + inputParameters +
", outputKeys=" + outputKeys +
", id='" + id + '\'' +
", name='" + name + '\'' +

View File

@ -23,7 +23,7 @@ import java.util.Map;
public class StartNode extends BaseNode {
@Override
protected Map<String, Object> execute(Chain chain) {
if (inputInputParameters != null) {
if (inputParameters != null) {
return getParameters(chain);
}
return Collections.emptyMap();
@ -33,7 +33,7 @@ public class StartNode extends BaseNode {
public String toString() {
return "StartNode{" +
"description='" + description + '\'' +
", inputInputParameters=" + inputInputParameters +
", inputParameters=" + inputParameters +
", outputKeys=" + outputKeys +
", id='" + id + '\'' +
", name='" + name + '\'' +