refactor: refactor chain

This commit is contained in:
Michael Yang 2024-10-30 16:03:14 +08:00
parent f317ce650e
commit 1756d5782c

View File

@ -20,15 +20,16 @@ import com.agentsflex.core.chain.OutputKey;
import com.agentsflex.core.llm.ChatOptions;
import com.agentsflex.core.llm.Llm;
import com.agentsflex.core.llm.response.AiMessageResponse;
import com.agentsflex.core.message.AiMessage;
import com.agentsflex.core.message.SystemMessage;
import com.agentsflex.core.prompt.TextPrompt;
import com.agentsflex.core.prompt.template.TextPromptTemplate;
import com.agentsflex.core.util.Maps;
import com.agentsflex.core.util.StringUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import java.util.Collections;
import java.util.List;
import java.util.HashMap;
import java.util.Map;
public class LLMNode extends BaseNode {
@ -40,6 +41,7 @@ public class LLMNode extends BaseNode {
protected String systemPrompt;
protected TextPromptTemplate systemPromptTemplate;
protected String outType = "text"; //text markdown json
public LLMNode() {
}
@ -107,32 +109,33 @@ public class LLMNode extends BaseNode {
}
AiMessageResponse response = llm.chat(userPrompt, chatOptions);
chain.output(this, response);
if (chain != null) {
chain.output(this, response);
}
return response.isError()
? onError(response, chain)
: onMessage(response.getMessage());
}
protected Map<String, Object> onError(AiMessageResponse response, Chain chain) {
if (chain != null) {
if (response.isError()) {
chain.stopError(response.getErrorMessage());
}
return null;
}
protected Map<String, Object> onMessage(AiMessage aiMessage) {
List<OutputKey> outputKeys = getOutputKeys();
if (outputKeys != null && outputKeys.size() == 1) {
return Maps.of("content", aiMessage.getContent()).build();
return Collections.emptyMap();
}
return Maps.of("content", aiMessage.getContent()).build();
if (outType == null || outType.equalsIgnoreCase("text") || outType.equalsIgnoreCase("markdown")) {
return Maps.of("output", response.getMessage().getContent()).build();
} else {
if (this.outputKeys != null) {
JSONObject jsonObject;
try {
jsonObject = JSON.parseObject(response.getResponse());
} catch (Exception e) {
chain.stopError("Can not parse json: " + response.getResponse() + " " + e.getMessage());
return Collections.emptyMap();
}
Map<String, Object> map = new HashMap<>();
for (OutputKey outputKey : this.outputKeys) {
map.put(outputKey.getKey(), jsonObject.get(outputKey.getKey()));
}
return map;
}
return Collections.emptyMap();
}
}