diff --git a/agents-flex-core/src/main/java/com/agentsflex/message/SystemMessage.java b/agents-flex-core/src/main/java/com/agentsflex/message/SystemMessage.java index 3e4baf6..bb1bd60 100644 --- a/agents-flex-core/src/main/java/com/agentsflex/message/SystemMessage.java +++ b/agents-flex-core/src/main/java/com/agentsflex/message/SystemMessage.java @@ -17,10 +17,8 @@ package com.agentsflex.message; public class SystemMessage extends Message { - public SystemMessage(String content) { this.content = content; } - } diff --git a/agents-flex-core/src/main/java/com/agentsflex/prompt/FunctionPrompt.java b/agents-flex-core/src/main/java/com/agentsflex/prompt/FunctionPrompt.java index e22fa35..9642043 100644 --- a/agents-flex-core/src/main/java/com/agentsflex/prompt/FunctionPrompt.java +++ b/agents-flex-core/src/main/java/com/agentsflex/prompt/FunctionPrompt.java @@ -29,18 +29,27 @@ import java.util.List; public class FunctionPrompt extends Prompt { private final ChatMemory memory = new DefaultChatMemory(); - private List> functions = new ArrayList<>(); + private final List> functions = new ArrayList<>(); public FunctionPrompt(String prompt, Class funcClass) { memory.addMessage(new HumanMessage(prompt)); functions.addAll(Functions.from(funcClass)); } + public FunctionPrompt(String prompt, Class funcClass, String... methodNames) { + memory.addMessage(new HumanMessage(prompt)); + functions.addAll(Functions.from(funcClass, methodNames)); + } + public FunctionPrompt(List messages, Class funcClass) { memory.addMessages(messages); functions.addAll(Functions.from(funcClass)); } + public FunctionPrompt(List messages, Class funcClass, String... methodNames) { + memory.addMessages(messages); + functions.addAll(Functions.from(funcClass, methodNames)); + } @Override public List toMessages() { diff --git a/agents-flex-llm/agents-flex-llm-openai/src/main/java/com/agentsflex/llm/openai/OpenAiLLmUtil.java b/agents-flex-llm/agents-flex-llm-openai/src/main/java/com/agentsflex/llm/openai/OpenAiLLmUtil.java index 7b77748..1ad72c5 100644 --- a/agents-flex-llm/agents-flex-llm-openai/src/main/java/com/agentsflex/llm/openai/OpenAiLLmUtil.java +++ b/agents-flex-llm/agents-flex-llm-openai/src/main/java/com/agentsflex/llm/openai/OpenAiLLmUtil.java @@ -51,7 +51,7 @@ public class OpenAiLLmUtil { // https://platform.openai.com/docs/api-reference/making-requests String payload = "{\n" + - " \"input\": \""+text.getContent()+"\",\n" + + " \"input\": \"" + text.getContent() + "\",\n" + " \"model\": \"text-embedding-ada-002\",\n" + " \"encoding_format\": \"float\"\n" + "}"; @@ -60,7 +60,6 @@ public class OpenAiLLmUtil { } - public static String promptToPayload(Prompt prompt, OpenAiLlmConfig config) { List messages = prompt.toMessages(); @@ -86,7 +85,7 @@ public class OpenAiLLmUtil { return "{\n" + // " \"model\": \"gpt-3.5-turbo\",\n" + " \"model\": \"" + config.getModel() + "\",\n" + - " \"messages\": "+messageText+",\n" + + " \"messages\": " + messageText + ",\n" + " \"temperature\": 0.7\n" + "}"; } @@ -113,32 +112,32 @@ public class OpenAiLLmUtil { String messageText = JSON.toJSONString(messageArray); - List> toolsArray = new ArrayList<>(); + List> toolsArray = new ArrayList<>(); for (Function function : functions) { - Map functionRoot = new HashMap<>(); - functionRoot.put("type","function"); + Map functionRoot = new HashMap<>(); + functionRoot.put("type", "function"); - Map functionObj = new HashMap<>(); - functionRoot.put("function",functionObj); + Map functionObj = new HashMap<>(); + functionRoot.put("function", functionObj); - functionObj.put("name",function.getName()); - functionObj.put("description",function.getDescription()); + functionObj.put("name", function.getName()); + functionObj.put("description", function.getDescription()); - Map parametersObj = new HashMap<>(); - functionObj.put("parameters",parametersObj); + Map parametersObj = new HashMap<>(); + functionObj.put("parameters", parametersObj); - parametersObj.put("type","object"); + parametersObj.put("type", "object"); - Map propertiesObj = new HashMap<>(); - parametersObj.put("properties",propertiesObj); + Map propertiesObj = new HashMap<>(); + parametersObj.put("properties", propertiesObj); for (Parameter parameter : function.getParameters()) { - Map parameterObj = new HashMap<>(); - parameterObj.put("type",parameter.getType()); - parameterObj.put("description",parameter.getDescription()); - parameterObj.put("enum",parameter.getEnums()); - propertiesObj.put(parameter.getName(),parameterObj); + Map parameterObj = new HashMap<>(); + parameterObj.put("type", parameter.getType()); + parameterObj.put("description", parameter.getDescription()); + parameterObj.put("enum", parameter.getEnums()); + propertiesObj.put(parameter.getName(), parameterObj); } toolsArray.add(functionRoot); @@ -149,12 +148,11 @@ public class OpenAiLLmUtil { return "{\n" + // " \"model\": \"gpt-3.5-turbo\",\n" + " \"model\": \"" + config.getModel() + "\",\n" + - " \"messages\": "+messageText+",\n" + - " \"tools\": "+toolsText+",\n" + + " \"messages\": " + messageText + ",\n" + + " \"tools\": " + toolsText + ",\n" + " \"tool_choice\": \"auto\"\n" + "}"; } - } diff --git a/agents-flex-llm/agents-flex-llm-spark/src/main/java/com/agentsflex/llm/spark/SparkLlm.java b/agents-flex-llm/agents-flex-llm-spark/src/main/java/com/agentsflex/llm/spark/SparkLlm.java index 1f1c344..82b3913 100644 --- a/agents-flex-llm/agents-flex-llm-spark/src/main/java/com/agentsflex/llm/spark/SparkLlm.java +++ b/agents-flex-llm/agents-flex-llm-spark/src/main/java/com/agentsflex/llm/spark/SparkLlm.java @@ -16,6 +16,7 @@ package com.agentsflex.llm.spark; import com.agentsflex.document.Document; +import com.agentsflex.functions.Function; import com.agentsflex.llm.BaseLlm; import com.agentsflex.llm.ChatContext; import com.agentsflex.llm.MessageListener; @@ -25,11 +26,18 @@ import com.agentsflex.llm.client.LlmClient; import com.agentsflex.llm.client.LlmClientListener; import com.agentsflex.llm.client.impl.WebSocketClient; import com.agentsflex.llm.response.AiMessageResponse; +import com.agentsflex.llm.response.FunctionMessageResponse; import com.agentsflex.message.AiMessage; +import com.agentsflex.message.FunctionMessage; import com.agentsflex.message.Message; +import com.agentsflex.prompt.FunctionPrompt; import com.agentsflex.prompt.Prompt; import com.agentsflex.store.VectorData; +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONObject; +import com.alibaba.fastjson.JSONPath; +import java.util.List; import java.util.concurrent.CountDownLatch; public class SparkLlm extends BaseLlm { @@ -47,11 +55,20 @@ public class SparkLlm extends BaseLlm { @Override public , M extends Message> R chat(Prompt prompt) { CountDownLatch latch = new CountDownLatch(1); - AiMessage aiMessage = new AiMessage(); + Message[] messages = new Message[1]; chatAsync(prompt, new MessageListener, M>() { @Override public void onMessage(ChatContext context, MessageResponse response) { - aiMessage.setContent(((AiMessage) response.getMessage()).getFullContent()); + if (response.getMessage() instanceof AiMessage) { + if (messages[0] == null) { + messages[0] = response.getMessage(); + } else { + messages[0].setContent(((AiMessage) response.getMessage()).getFullContent()); + } + + } else if (response.getMessage() instanceof FunctionMessage) { + messages[0] = response.getMessage(); + } } @Override @@ -65,7 +82,12 @@ public class SparkLlm extends BaseLlm { } catch (InterruptedException e) { throw new RuntimeException(e); } - return (R) new AiMessageResponse(aiMessage); + + if (prompt instanceof FunctionPrompt) { + return (R) new FunctionMessageResponse(((FunctionPrompt) prompt).getFunctions(), (FunctionMessage) messages[0]); + } else { + return (R) new AiMessageResponse((AiMessage) messages[0]); + } } @@ -76,7 +98,20 @@ public class SparkLlm extends BaseLlm { String payload = SparkLlmUtil.promptToPayload(prompt, config); - LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, SparkLlmUtil::parseAiMessage, null); + LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, SparkLlmUtil::parseAiMessage, new BaseLlmClientListener.FunctionMessageParser() { + @Override + public FunctionMessage parseMessage(String response) { + JSONObject jsonObject = JSON.parseObject(response); + String callFunctionName = (String) JSONPath.eval(jsonObject, "$.payload.choices.text[0].function_call.name"); + String callFunctionArgsString = (String) JSONPath.eval(jsonObject, "$.payload.choices.text[0].function_call.arguments"); + JSONObject callFunctionArgs = JSON.parseObject(callFunctionArgsString); + + FunctionMessage functionMessage = new FunctionMessage(); + functionMessage.setFunctionName(callFunctionName); + functionMessage.setArgs(callFunctionArgs); + return functionMessage; + } + }); llmClient.start(url, null, payload, clientListener); } diff --git a/agents-flex-llm/agents-flex-llm-spark/src/main/java/com/agentsflex/llm/spark/SparkLlmUtil.java b/agents-flex-llm/agents-flex-llm-spark/src/main/java/com/agentsflex/llm/spark/SparkLlmUtil.java index fb53115..bf5587d 100644 --- a/agents-flex-llm/agents-flex-llm-spark/src/main/java/com/agentsflex/llm/spark/SparkLlmUtil.java +++ b/agents-flex-llm/agents-flex-llm-spark/src/main/java/com/agentsflex/llm/spark/SparkLlmUtil.java @@ -15,10 +15,10 @@ */ package com.agentsflex.llm.spark; -import com.agentsflex.message.AiMessage; -import com.agentsflex.message.HumanMessage; -import com.agentsflex.message.Message; -import com.agentsflex.message.MessageStatus; +import com.agentsflex.functions.Function; +import com.agentsflex.functions.Parameter; +import com.agentsflex.message.*; +import com.agentsflex.prompt.FunctionPrompt; import com.agentsflex.prompt.Prompt; import com.agentsflex.util.HashUtil; import com.alibaba.fastjson.JSON; @@ -32,14 +32,14 @@ import java.util.*; public class SparkLlmUtil { - public static AiMessage parseAiMessage(String json){ + public static AiMessage parseAiMessage(String json) { AiMessage aiMessage = new AiMessage(); JSONObject jsonObject = JSON.parseObject(json); Object status = JSONPath.eval(jsonObject, "$.payload.choices.status"); MessageStatus messageStatus = SparkLlmUtil.parseMessageStatus((Integer) status); aiMessage.setStatus(messageStatus); - aiMessage.setIndex((Integer) JSONPath.eval(jsonObject,"$.payload.choices.text[0].index")); - aiMessage.setContent((String) JSONPath.eval(jsonObject,"$.payload.choices.text[0].content")); + aiMessage.setIndex((Integer) JSONPath.eval(jsonObject, "$.payload.choices.text[0].index")); + aiMessage.setContent((String) JSONPath.eval(jsonObject, "$.payload.choices.text[0].content")); return aiMessage; } @@ -64,7 +64,8 @@ public class SparkLlmUtil { " \"payload\": {\n" + " \"message\": {\n" + " \"text\": messageJsonString" + - " }\n" + + " },\n" + + " \"functions\":functionsJsonString" + " }\n" + "}"; @@ -78,13 +79,60 @@ public class SparkLlmUtil { } else if (message instanceof AiMessage) { map.put("role", "assistant"); map.put("content", ((AiMessage) message).getFullContent()); + } else if (message instanceof SystemMessage) { + map.put("role", "system"); + map.put("content", message.getContent()); } - messageArray.add(map); }); + + + String functionsJsonString = "\"\""; + if (prompt instanceof FunctionPrompt) { + List> functions = ((FunctionPrompt) prompt).getFunctions(); + + List> functionsArray = new ArrayList<>(); + for (Function function : functions) { + Map functionRoot = new HashMap<>(); + functionRoot.put("type", "function"); + + Map functionObj = new HashMap<>(); + functionRoot.put("function", functionObj); + + functionObj.put("name", function.getName()); + functionObj.put("description", function.getDescription()); + + + Map parametersObj = new HashMap<>(); + functionObj.put("parameters", parametersObj); + + parametersObj.put("type", "object"); + + Map propertiesObj = new HashMap<>(); + parametersObj.put("properties", propertiesObj); + + for (Parameter parameter : function.getParameters()) { + Map parameterObj = new HashMap<>(); + parameterObj.put("type", parameter.getType()); + parameterObj.put("description", parameter.getDescription()); + parameterObj.put("enum", parameter.getEnums()); + propertiesObj.put(parameter.getName(), parameterObj); + } + + functionsArray.add(functionRoot); + } + Map functionsJsonMap = new HashMap<>(); + functionsJsonMap.put("text", functionsArray); + +// Map functionsJsonRoot = new HashMap<>(); +// functionsJsonRoot.put("functions", functionsJsonMap); + + functionsJsonString = JSON.toJSONString(functionsJsonMap); + } + String messageText = JSON.toJSONString(messageArray); - return payload.replace("messageJsonString", messageText); + return payload.replace("messageJsonString", messageText).replace("functionsJsonString", functionsJsonString); } public static MessageStatus parseMessageStatus(Integer status) { diff --git a/agents-flex-llm/agents-flex-llm-spark/src/test/java/com/agentsflex/llm/spark/test/SparkLlmTest.java b/agents-flex-llm/agents-flex-llm-spark/src/test/java/com/agentsflex/llm/spark/test/SparkLlmTest.java index ac1ef98..61f2b87 100644 --- a/agents-flex-llm/agents-flex-llm-spark/src/test/java/com/agentsflex/llm/spark/test/SparkLlmTest.java +++ b/agents-flex-llm/agents-flex-llm-spark/src/test/java/com/agentsflex/llm/spark/test/SparkLlmTest.java @@ -1,9 +1,11 @@ package com.agentsflex.llm.spark.test; import com.agentsflex.llm.Llm; +import com.agentsflex.llm.response.FunctionMessageResponse; import com.agentsflex.llm.spark.SparkLlm; import com.agentsflex.llm.spark.SparkLlmConfig; import com.agentsflex.message.HumanMessage; +import com.agentsflex.prompt.FunctionPrompt; import com.agentsflex.prompt.HistoriesPrompt; import org.junit.Test; @@ -24,6 +26,24 @@ public class SparkLlmTest { } + @Test + public void testFunctionCalling() throws InterruptedException { + SparkLlmConfig config = new SparkLlmConfig(); + config.setAppId("****"); + config.setApiKey("****"); + config.setApiSecret("****"); + + Llm llm = new SparkLlm(config); + + FunctionPrompt prompt = new FunctionPrompt("今天北京的天气怎么样", WeatherFunctions.class); + FunctionMessageResponse response = llm.chat(prompt); + + Object result = response.invoke(); + + System.out.println(result); + } + + public static void main(String[] args) { SparkLlmConfig config = new SparkLlmConfig(); config.setAppId("****"); diff --git a/agents-flex-llm/agents-flex-llm-spark/src/test/java/com/agentsflex/llm/spark/test/WeatherFunctions.java b/agents-flex-llm/agents-flex-llm-spark/src/test/java/com/agentsflex/llm/spark/test/WeatherFunctions.java new file mode 100644 index 0000000..a3ebe39 --- /dev/null +++ b/agents-flex-llm/agents-flex-llm-spark/src/test/java/com/agentsflex/llm/spark/test/WeatherFunctions.java @@ -0,0 +1,15 @@ +package com.agentsflex.llm.spark.test; + +import com.agentsflex.functions.annotation.FunctionDef; +import com.agentsflex.functions.annotation.FunctionParam; + +public class WeatherFunctions { + + @FunctionDef(name = "weather", description = "获取天气信息") + public static String weather( + @FunctionParam(name = "location", description = "城市名称,比如: 北京, 上海") String name + ) { + //此处应该通过 api 去第三方获取 + return name + "今天的天气阴转多云"; + } +}