refactor: refactor Agent.java

This commit is contained in:
Michael Yang 2024-06-21 16:43:32 +08:00
parent df2691d4ac
commit f3d29ba2eb
7 changed files with 104 additions and 13 deletions

View File

@ -28,7 +28,8 @@ public abstract class Agent {
protected String name;
protected String description;
private ContextMemory memory;
private List<String> outputKeys;
private List<Parameter> inputParameters;
private List<OutputKey> outputKeys;
public Agent() {
this.id = UUID.randomUUID().toString();
@ -76,16 +77,21 @@ public abstract class Agent {
}
public List<Parameter> getInputParameters() {
List<Parameter> parameters = defineInputParameter();
return parameters == null ? Collections.emptyList() : parameters;
if (this.inputParameters == null) {
this.inputParameters = defineInputParameter();
if (this.inputParameters == null) {
this.inputParameters = Collections.emptyList();
}
}
return this.inputParameters;
}
public List<String> getOutputKeys() {
public List<OutputKey> getOutputKeys() {
return outputKeys;
}
public void setOutputKeys(List<String> outputKeys) {
public void setOutputKeys(List<OutputKey> outputKeys) {
this.outputKeys = outputKeys;
}
@ -94,7 +100,10 @@ public abstract class Agent {
this.outputKeys = new ArrayList<>();
}
this.outputKeys.addAll(Arrays.asList(keys));
for (String key : keys) {
this.outputKeys.add(new OutputKey(key));
}
return this;
}
@ -102,7 +111,9 @@ public abstract class Agent {
return execute(variables, null);
}
protected abstract List<Parameter> defineInputParameter();
protected List<Parameter> defineInputParameter() {
return Collections.emptyList();
}
public abstract Output execute(Map<String, Object> variables, Chain chain);

View File

@ -53,7 +53,7 @@ public abstract class DefaultAgent extends Agent {
String key = variables.keySet().iterator().next();
value = variables.get(key);
}
List<String> outputKeys = getOutputKeys();
List<OutputKey> outputKeys = getOutputKeys();
if (outputKeys != null && outputKeys.size() == 1) {
return Output.of(outputKeys.get(0), execute(value, chain));
}

View File

@ -119,7 +119,7 @@ public class LLMAgent extends Agent {
protected Output onMessage(AiMessage aiMessage) {
List<String> outputKeys = getOutputKeys();
List<OutputKey> outputKeys = getOutputKeys();
if (outputKeys != null && outputKeys.size() == 1) {
return Output.of(outputKeys.get(0), aiMessage.getContent());
}

View File

@ -47,6 +47,10 @@ public class Output extends HashMap<String, Object> {
return this;
}
public static Output of(OutputKey key, Object value) {
return of(key.getKey(), value);
}
public static Output of(String key, Object value) {
Output output = new Output();
output.put(key, value);

View File

@ -0,0 +1,65 @@
/*
* Copyright (c) 2023-2025, Agents-Flex (fuhai999@gmail.com).
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.agentsflex.agent;
public class OutputKey {
private String key;
private String type;
private String description;
public OutputKey() {
}
public OutputKey(String key) {
this.key = key;
}
public OutputKey(String key, String type) {
this.key = key;
this.type = type;
}
public OutputKey(String key, String type, String description) {
this.key = key;
this.type = type;
this.description = description;
}
public String getKey() {
return key;
}
public void setKey(String key) {
this.key = key;
}
public String getType() {
return type;
}
public void setType(String type) {
this.type = type;
}
public String getDescription() {
return description;
}
public void setDescription(String description) {
this.description = description;
}
}

View File

@ -17,6 +17,7 @@ package com.agentsflex.agent;
public class Parameter {
private String name;
private String description;
private String type;
private boolean required;
private boolean isDefault;
@ -59,6 +60,14 @@ public class Parameter {
this.name = name;
}
public String getDescription() {
return description;
}
public void setDescription(String description) {
this.description = description;
}
public String getType() {
return type;
}

View File

@ -17,6 +17,7 @@ package com.agentsflex.chain.node;
import com.agentsflex.agent.Agent;
import com.agentsflex.agent.Output;
import com.agentsflex.agent.OutputKey;
import com.agentsflex.agent.Parameter;
import com.agentsflex.chain.Chain;
@ -95,16 +96,17 @@ public class AgentNode extends AbstractBaseNode {
}
Output output = agent.execute(variables, chain);
List<String> outputKeys = agent.getOutputKeys();
List<OutputKey> outputKeys = agent.getOutputKeys();
if (outputKeys == null || outputKeys.isEmpty()
|| outputMapping == null || outputMapping.isEmpty()) {
return output;
}
Map<String, Object> newResult = new HashMap<>(outputKeys.size());
for (String outputKey : outputKeys) {
String newKey = outputMapping.getOrDefault(outputKey, outputKey);
newResult.put(newKey, output.get(outputKey));
for (OutputKey outputKey : outputKeys) {
String oldKey = outputKey.getKey();
String newKey = outputMapping.getOrDefault(oldKey, oldKey);
newResult.put(newKey, output.get(oldKey));
}
return newResult;