Feature: HumanMessage add toolChoice config

This commit is contained in:
Michael Yang 2024-11-13 11:31:34 +08:00
parent cf2f0c2897
commit 9f2ab83e06
5 changed files with 30 additions and 8 deletions

View File

@ -23,6 +23,7 @@ import java.util.*;
public class HumanMessage extends AbstractTextMessage {
private List<Function> functions;
private String toolChoice;
public HumanMessage() {
}
@ -69,4 +70,16 @@ public class HumanMessage extends AbstractTextMessage {
}
return map;
}
public void setFunctions(List<Function> functions) {
this.functions = functions;
}
public String getToolChoice() {
return toolChoice;
}
public void setToolChoice(String toolChoice) {
this.toolChoice = toolChoice;
}
}

View File

@ -11,7 +11,7 @@ public class MapsTest {
@Test
public void testMaps() {
Map<String, Object> map1 = Maps.of("key", "value")
.put("options.aaa", 1);
.putChild("options.aaa", 1);
Assert.assertEquals(1, ((Map<?, ?>) map1.get("options")).get("aaa"));
System.out.println(map1);

View File

@ -16,6 +16,7 @@
package com.agentsflex.llm.chatglm;
import com.agentsflex.core.llm.ChatOptions;
import com.agentsflex.core.message.HumanMessage;
import com.agentsflex.core.message.Message;
import com.agentsflex.core.message.MessageStatus;
import com.agentsflex.core.parser.AiMessageParser;
@ -23,6 +24,7 @@ import com.agentsflex.core.parser.impl.DefaultAiMessageParser;
import com.agentsflex.core.prompt.DefaultPromptFormat;
import com.agentsflex.core.prompt.Prompt;
import com.agentsflex.core.prompt.PromptFormat;
import com.agentsflex.core.util.CollectionUtil;
import com.agentsflex.core.util.Maps;
import com.alibaba.fastjson.JSON;
import io.jsonwebtoken.JwtBuilder;
@ -91,12 +93,12 @@ public class ChatglmLlmUtil {
public static String promptToPayload(Prompt prompt, ChatglmLlmConfig config, boolean withStream, ChatOptions options) {
List<Message> messages = prompt.toMessages();
HumanMessage humanMessage = (HumanMessage) CollectionUtil.lastItem(messages);
return Maps.of("model", config.getModel())
.put("messages", promptFormat.toMessagesJsonObject(messages))
.putIf(withStream, "stream", true)
.putIfNotEmpty("tools", promptFormat.toFunctionsJsonObject(messages.get(messages.size() - 1)))
.putIfContainsKey("tools", "tool_choice", "auto")
.putIfNotEmpty("tools", promptFormat.toFunctionsJsonObject(humanMessage))
.putIfContainsKey("tools", "tool_choice", humanMessage.getToolChoice())
.putIfNotNull("top_p", options.getTopP())
.putIfNotEmpty("stop", options.getStop())
.putIf(map -> !map.containsKey("tools") && options.getTemperature() > 0, "temperature", options.getTemperature())

View File

@ -16,12 +16,14 @@
package com.agentsflex.llm.gitee;
import com.agentsflex.core.llm.ChatOptions;
import com.agentsflex.core.message.HumanMessage;
import com.agentsflex.core.message.Message;
import com.agentsflex.core.parser.AiMessageParser;
import com.agentsflex.core.parser.impl.DefaultAiMessageParser;
import com.agentsflex.core.prompt.DefaultPromptFormat;
import com.agentsflex.core.prompt.Prompt;
import com.agentsflex.core.prompt.PromptFormat;
import com.agentsflex.core.util.CollectionUtil;
import com.agentsflex.core.util.Maps;
import java.util.List;
@ -36,12 +38,14 @@ public class GiteeAiLLmUtil {
public static String promptToPayload(Prompt prompt, GiteeAiLlmConfig config, ChatOptions options, boolean withStream) {
List<Message> messages = prompt.toMessages();
HumanMessage humanMessage = (HumanMessage) CollectionUtil.lastItem(messages);
return Maps.of()
.put("messages", promptFormat.toMessagesJsonObject(messages))
.putIf(withStream, "stream", withStream)
.putIfNotNull("max_tokens", options.getMaxTokens())
.putIfNotNull("temperature", options.getTemperature())
.putIfNotEmpty("tools", promptFormat.toFunctionsJsonObject(messages.get(messages.size() - 1)))
.putIfNotEmpty("tools", promptFormat.toFunctionsJsonObject(humanMessage))
.putIfContainsKey("tools", "tool_choice", humanMessage.getToolChoice())
.putIfNotNull("top_p", options.getTopP())
.putIfNotNull("top_k", options.getTopK())
.putIfNotEmpty("stop", options.getStop())

View File

@ -18,12 +18,14 @@ package com.agentsflex.llm.openai;
import com.agentsflex.core.document.Document;
import com.agentsflex.core.llm.ChatOptions;
import com.agentsflex.core.llm.embedding.EmbeddingOptions;
import com.agentsflex.core.message.HumanMessage;
import com.agentsflex.core.message.Message;
import com.agentsflex.core.parser.AiMessageParser;
import com.agentsflex.core.parser.impl.DefaultAiMessageParser;
import com.agentsflex.core.prompt.DefaultPromptFormat;
import com.agentsflex.core.prompt.Prompt;
import com.agentsflex.core.prompt.PromptFormat;
import com.agentsflex.core.util.CollectionUtil;
import com.agentsflex.core.util.Maps;
import java.util.List;
@ -48,11 +50,12 @@ public class OpenAiLLmUtil {
public static String promptToPayload(Prompt prompt, OpenAiLlmConfig config, ChatOptions options, boolean withStream) {
List<Message> messages = prompt.toMessages();
HumanMessage humanMessage = (HumanMessage) CollectionUtil.lastItem(messages);
return Maps.of("model", config.getModel())
.put("messages", promptFormat.toMessagesJsonObject(prompt.toMessages()))
.put("messages", promptFormat.toMessagesJsonObject(messages))
.putIf(withStream, "stream", true)
.putIfNotEmpty("tools", promptFormat.toFunctionsJsonObject(messages.get(messages.size() - 1)))
.putIfContainsKey("tools", "tool_choice", "auto")
.putIfNotEmpty("tools", promptFormat.toFunctionsJsonObject(humanMessage))
.putIfContainsKey("tools", "tool_choice", humanMessage.getToolChoice())
.putIfNotNull("top_p", options.getTopP())
.putIfNotEmpty("stop", options.getStop())
.putIf(map -> !map.containsKey("tools") && options.getTemperature() > 0, "temperature", options.getTemperature())