From 625cfd3dcf69612a593762fbdd2efe389503962b Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 12 Mar 2024 20:12:13 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=8C=E5=96=84=20=E6=98=9F=E7=81=AB?= =?UTF-8?q?=E5=A4=A7=E6=A8=A1=E5=9E=8B=20=E7=9A=84=20Function=20Calling=20?= =?UTF-8?q?=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../java/com/agentsflex/llm/LlmConfig.java | 10 ++++ .../com/agentsflex/llm/client/LlmClient.java | 4 +- .../llm/client/impl/AsyncHttpClient.java | 11 +++- .../agentsflex/llm/client/impl/SseClient.java | 12 ++++- .../llm/client/impl/WebSocketClient.java | 12 ++++- .../parser/impl/BaseAiMessageParser.java | 24 ++++++--- .../prompt/DefaultPromptFormat.java | 22 +++++--- .../main/java/com/agentsflex/util/Maps.java | 4 ++ .../com/agentsflex/llm/openai/OpenAiLlm.java | 2 +- .../java/com/agentsflex/llm/qwen/QwenLlm.java | 2 +- .../com/agentsflex/llm/spark/SparkLlm.java | 7 +-- .../agentsflex/llm/spark/SparkLlmUtil.java | 54 ++++++++++++++++--- .../llm/spark/test/SparkLlmTest.java | 6 ++- .../llm/spark/test/WeatherFunctions.java | 6 +-- pom.xml | 2 +- 15 files changed, 141 insertions(+), 37 deletions(-) diff --git a/agents-flex-core/src/main/java/com/agentsflex/llm/LlmConfig.java b/agents-flex-core/src/main/java/com/agentsflex/llm/LlmConfig.java index e235232..7c36d15 100644 --- a/agents-flex-core/src/main/java/com/agentsflex/llm/LlmConfig.java +++ b/agents-flex-core/src/main/java/com/agentsflex/llm/LlmConfig.java @@ -18,4 +18,14 @@ package com.agentsflex.llm; import java.io.Serializable; public class LlmConfig implements Serializable { + + private boolean debug; + + public boolean isDebug() { + return debug; + } + + public void setDebug(boolean debug) { + this.debug = debug; + } } diff --git a/agents-flex-core/src/main/java/com/agentsflex/llm/client/LlmClient.java b/agents-flex-core/src/main/java/com/agentsflex/llm/client/LlmClient.java index 404d0dc..012c5f6 100644 --- a/agents-flex-core/src/main/java/com/agentsflex/llm/client/LlmClient.java +++ b/agents-flex-core/src/main/java/com/agentsflex/llm/client/LlmClient.java @@ -15,11 +15,13 @@ */ package com.agentsflex.llm.client; +import com.agentsflex.llm.LlmConfig; + import java.util.Map; public interface LlmClient { - void start(String url, Map headers, String payload, LlmClientListener listener); + void start(String url, Map headers, String payload, LlmClientListener listener, LlmConfig config); void stop(); } diff --git a/agents-flex-core/src/main/java/com/agentsflex/llm/client/impl/AsyncHttpClient.java b/agents-flex-core/src/main/java/com/agentsflex/llm/client/impl/AsyncHttpClient.java index 981cffe..787ce3b 100644 --- a/agents-flex-core/src/main/java/com/agentsflex/llm/client/impl/AsyncHttpClient.java +++ b/agents-flex-core/src/main/java/com/agentsflex/llm/client/impl/AsyncHttpClient.java @@ -15,6 +15,7 @@ */ package com.agentsflex.llm.client.impl; +import com.agentsflex.llm.LlmConfig; import com.agentsflex.llm.client.LlmClient; import com.agentsflex.llm.client.LlmClientListener; import okhttp3.*; @@ -28,11 +29,13 @@ public class AsyncHttpClient implements LlmClient { private static final MediaType JSON_TYPE = MediaType.parse("application/json; charset=utf-8"); private OkHttpClient client; private LlmClientListener listener; + private LlmConfig config; private boolean isStop = false; @Override - public void start(String url, Map headers, String payload, LlmClientListener listener) { + public void start(String url, Map headers, String payload, LlmClientListener listener, LlmConfig config) { this.listener = listener; + this.config = config; this.isStop = false; Request.Builder rBuilder = new Request.Builder() @@ -50,6 +53,9 @@ public class AsyncHttpClient implements LlmClient { .readTimeout(3, TimeUnit.MINUTES) .build(); + if (this.config.isDebug()){ + System.out.println(">>>>send payload:" + payload); + } this.listener.onStart(this); this.client.newCall(rBuilder.build()).enqueue(new Callback() { @@ -60,6 +66,9 @@ public class AsyncHttpClient implements LlmClient { @Override public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException { + if (config.isDebug()){ + System.out.println(">>>>receive payload:" + response.message()); + } AsyncHttpClient.this.listener.onMessage(AsyncHttpClient.this, response.message()); if (!isStop) { AsyncHttpClient.this.isStop = true; diff --git a/agents-flex-core/src/main/java/com/agentsflex/llm/client/impl/SseClient.java b/agents-flex-core/src/main/java/com/agentsflex/llm/client/impl/SseClient.java index 05c2be8..2f1d43e 100644 --- a/agents-flex-core/src/main/java/com/agentsflex/llm/client/impl/SseClient.java +++ b/agents-flex-core/src/main/java/com/agentsflex/llm/client/impl/SseClient.java @@ -15,6 +15,7 @@ */ package com.agentsflex.llm.client.impl; +import com.agentsflex.llm.LlmConfig; import com.agentsflex.llm.client.LlmClient; import com.agentsflex.llm.client.LlmClientListener; import okhttp3.*; @@ -32,11 +33,13 @@ public class SseClient extends EventSourceListener implements LlmClient { private OkHttpClient client; private EventSource eventSource; private LlmClientListener listener; + private LlmConfig config; private boolean isStop = false; @Override - public void start(String url, Map headers, String payload, LlmClientListener listener) { + public void start(String url, Map headers, String payload, LlmClientListener listener, LlmConfig config) { this.listener = listener; + this.config = config; this.isStop = false; Request.Builder builder = new Request.Builder() @@ -60,6 +63,10 @@ public class SseClient extends EventSourceListener implements LlmClient { EventSource.Factory factory = EventSources.createFactory(this.client); this.eventSource = factory.newEventSource(request, this); + if (this.config.isDebug()){ + System.out.println(">>>>send payload:" + payload); + } + this.listener.onStart(this); } @@ -84,6 +91,9 @@ public class SseClient extends EventSourceListener implements LlmClient { @Override public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) { + if (this.config.isDebug()){ + System.out.println(">>>>receive payload:" + data); + } this.listener.onMessage(this, data); } diff --git a/agents-flex-core/src/main/java/com/agentsflex/llm/client/impl/WebSocketClient.java b/agents-flex-core/src/main/java/com/agentsflex/llm/client/impl/WebSocketClient.java index 9a8d394..bff9fd6 100644 --- a/agents-flex-core/src/main/java/com/agentsflex/llm/client/impl/WebSocketClient.java +++ b/agents-flex-core/src/main/java/com/agentsflex/llm/client/impl/WebSocketClient.java @@ -15,6 +15,7 @@ */ package com.agentsflex.llm.client.impl; +import com.agentsflex.llm.LlmConfig; import com.agentsflex.llm.client.LlmClient; import com.agentsflex.llm.client.LlmClientListener; import okhttp3.*; @@ -28,13 +29,15 @@ public class WebSocketClient extends WebSocketListener implements LlmClient { private WebSocket webSocket; private LlmClientListener listener; + private LlmConfig config; private boolean isStop = false; private String payload; @Override - public void start(String url, Map headers, String payload, LlmClientListener listener) { + public void start(String url, Map headers, String payload, LlmClientListener listener, LlmConfig config) { this.listener = listener; this.payload = payload; + this.config = config; OkHttpClient client = new OkHttpClient.Builder() .readTimeout(0, TimeUnit.MILLISECONDS) @@ -46,6 +49,10 @@ public class WebSocketClient extends WebSocketListener implements LlmClient { this.webSocket = client.newWebSocket(request, this); this.isStop = false; + + if (this.config.isDebug()){ + System.out.println(">>>>send payload:" + payload); + } } @Override @@ -69,6 +76,9 @@ public class WebSocketClient extends WebSocketListener implements LlmClient { @Override public void onMessage(WebSocket webSocket, String text) { + if (this.config.isDebug()){ + System.out.println(">>>>receive payload:" + text); + } this.listener.onMessage(this, text); } diff --git a/agents-flex-core/src/main/java/com/agentsflex/parser/impl/BaseAiMessageParser.java b/agents-flex-core/src/main/java/com/agentsflex/parser/impl/BaseAiMessageParser.java index bef4d7f..40cdf74 100644 --- a/agents-flex-core/src/main/java/com/agentsflex/parser/impl/BaseAiMessageParser.java +++ b/agents-flex-core/src/main/java/com/agentsflex/parser/impl/BaseAiMessageParser.java @@ -19,6 +19,7 @@ import com.agentsflex.message.AiMessage; import com.agentsflex.message.MessageStatus; import com.agentsflex.parser.AiMessageParser; import com.agentsflex.parser.Parser; +import com.agentsflex.util.StringUtil; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONPath; @@ -76,13 +77,24 @@ public class BaseAiMessageParser implements AiMessageParser { public AiMessage parse(String content) { AiMessage aiMessage = new AiMessage(); JSONObject rootJson = JSON.parseObject(content); - aiMessage.setContent((String) JSONPath.eval(rootJson, this.contentPath)); - aiMessage.setIndex((Integer) JSONPath.eval(rootJson, this.indexPath)); - aiMessage.setTotalTokens((Integer) JSONPath.eval(rootJson, this.totalTokensPath)); - String statusString = (String) JSONPath.eval(rootJson, this.statusPath); - if (this.statusParser != null) { - aiMessage.setStatus(this.statusParser.parse(statusString)); + if (StringUtil.hasText(this.contentPath)) { + aiMessage.setContent((String) JSONPath.eval(rootJson, this.contentPath)); + } + + if (StringUtil.hasText(this.indexPath)) { + aiMessage.setIndex((Integer) JSONPath.eval(rootJson, this.indexPath)); + } + + if (StringUtil.hasText(this.totalTokensPath)) { + aiMessage.setTotalTokens((Integer) JSONPath.eval(rootJson, this.totalTokensPath)); + } + + if (StringUtil.hasText(this.statusPath)) { + Object statusString = JSONPath.eval(rootJson, this.statusPath); + if (this.statusParser != null) { + aiMessage.setStatus(this.statusParser.parse(statusString)); + } } return aiMessage; diff --git a/agents-flex-core/src/main/java/com/agentsflex/prompt/DefaultPromptFormat.java b/agents-flex-core/src/main/java/com/agentsflex/prompt/DefaultPromptFormat.java index 10b12a6..5766427 100644 --- a/agents-flex-core/src/main/java/com/agentsflex/prompt/DefaultPromptFormat.java +++ b/agents-flex-core/src/main/java/com/agentsflex/prompt/DefaultPromptFormat.java @@ -40,7 +40,13 @@ public class DefaultPromptFormat implements PromptFormat { return null; } - List> messageArray = new ArrayList<>(messages.size()); + List> messageJsonArray = new ArrayList<>(messages.size()); + buildMessageJsonArray(messageJsonArray, messages); + + return messageJsonArray; + } + + protected void buildMessageJsonArray(List> messageJsonArray, List messages) { messages.forEach(message -> { Map map = new HashMap<>(2); if (message instanceof HumanMessage) { @@ -53,10 +59,8 @@ public class DefaultPromptFormat implements PromptFormat { map.put("role", "system"); map.put("content", ((SystemMessage) message).getContent()); } - messageArray.add(map); + messageJsonArray.add(map); }); - - return messageArray; } @@ -66,12 +70,18 @@ public class DefaultPromptFormat implements PromptFormat { return null; } - List> functions = ((FunctionPrompt)prompt).getFunctions(); + List> functions = ((FunctionPrompt) prompt).getFunctions(); if (functions == null || functions.isEmpty()) { return null; } List> functionsJsonArray = new ArrayList<>(); + buildFunctionJsonArray(functionsJsonArray, functions); + + return functionsJsonArray; + } + + protected void buildFunctionJsonArray(List> functionsJsonArray, List> functions) { for (Function function : functions) { Map functionRoot = new HashMap<>(); functionRoot.put("type", "function"); @@ -100,8 +110,6 @@ public class DefaultPromptFormat implements PromptFormat { } functionsJsonArray.add(functionRoot); } - - return functionsJsonArray; } } diff --git a/agents-flex-core/src/main/java/com/agentsflex/util/Maps.java b/agents-flex-core/src/main/java/com/agentsflex/util/Maps.java index fb4de22..c6f3204 100644 --- a/agents-flex-core/src/main/java/com/agentsflex/util/Maps.java +++ b/agents-flex-core/src/main/java/com/agentsflex/util/Maps.java @@ -26,6 +26,10 @@ public class Maps { return new Builder(); } + public static Builder of(String key, Builder value) { + return of(key, value.build()); + } + public static Builder of(String key, Object value) { return new Builder().put(key, value); } diff --git a/agents-flex-llm/agents-flex-llm-openai/src/main/java/com/agentsflex/llm/openai/OpenAiLlm.java b/agents-flex-llm/agents-flex-llm-openai/src/main/java/com/agentsflex/llm/openai/OpenAiLlm.java index 47d1941..558cb61 100644 --- a/agents-flex-llm/agents-flex-llm-openai/src/main/java/com/agentsflex/llm/openai/OpenAiLlm.java +++ b/agents-flex-llm/agents-flex-llm-openai/src/main/java/com/agentsflex/llm/openai/OpenAiLlm.java @@ -80,7 +80,7 @@ public class OpenAiLlm extends BaseLlm { String payload = OpenAiLLmUtil.promptToPayload(prompt, config); LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, aiMessageParser, functionMessageParser); - llmClient.start("https://api.openai.com/v1/chat/completions", headers, payload, clientListener); + llmClient.start("https://api.openai.com/v1/chat/completions", headers, payload, clientListener,config); } diff --git a/agents-flex-llm/agents-flex-llm-qwen/src/main/java/com/agentsflex/llm/qwen/QwenLlm.java b/agents-flex-llm/agents-flex-llm-qwen/src/main/java/com/agentsflex/llm/qwen/QwenLlm.java index 7d49a10..7ba17ec 100644 --- a/agents-flex-llm/agents-flex-llm-qwen/src/main/java/com/agentsflex/llm/qwen/QwenLlm.java +++ b/agents-flex-llm/agents-flex-llm-qwen/src/main/java/com/agentsflex/llm/qwen/QwenLlm.java @@ -94,7 +94,7 @@ public class QwenLlm extends BaseLlm { return aiMessage; } }, functionMessageParser); - llmClient.start("https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation", headers, payload, clientListener); + llmClient.start("https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation", headers, payload, clientListener,config); } @Override diff --git a/agents-flex-llm/agents-flex-llm-spark/src/main/java/com/agentsflex/llm/spark/SparkLlm.java b/agents-flex-llm/agents-flex-llm-spark/src/main/java/com/agentsflex/llm/spark/SparkLlm.java index 5f1158c..b187515 100644 --- a/agents-flex-llm/agents-flex-llm-spark/src/main/java/com/agentsflex/llm/spark/SparkLlm.java +++ b/agents-flex-llm/agents-flex-llm-spark/src/main/java/com/agentsflex/llm/spark/SparkLlm.java @@ -43,7 +43,6 @@ public class SparkLlm extends BaseLlm { public FunctionMessageParser functionMessageParser = SparkLlmUtil.getFunctionMessageParser(); - public SparkLlm(SparkLlmConfig config) { super(config); } @@ -54,7 +53,6 @@ public class SparkLlm extends BaseLlm { } - @SuppressWarnings("unchecked") @Override public , M extends Message> R chat(Prompt prompt) { @@ -67,7 +65,7 @@ public class SparkLlm extends BaseLlm { if (messages[0] == null) { messages[0] = response.getMessage(); } else { - ((AiMessage)messages[0]).setContent(((AiMessage) response.getMessage()).getFullContent()); + ((AiMessage) messages[0]).setContent(((AiMessage) response.getMessage()).getFullContent()); } } else if (response.getMessage() instanceof FunctionMessage) { @@ -95,7 +93,6 @@ public class SparkLlm extends BaseLlm { } - @Override public , M extends Message> void chatAsync(Prompt prompt, MessageListener listener) { LlmClient llmClient = new WebSocketClient(); @@ -104,7 +101,7 @@ public class SparkLlm extends BaseLlm { String payload = SparkLlmUtil.promptToPayload(prompt, config); LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, aiMessageParser, functionMessageParser); - llmClient.start(url, null, payload, clientListener); + llmClient.start(url, null, payload, clientListener, config); } diff --git a/agents-flex-llm/agents-flex-llm-spark/src/main/java/com/agentsflex/llm/spark/SparkLlmUtil.java b/agents-flex-llm/agents-flex-llm-spark/src/main/java/com/agentsflex/llm/spark/SparkLlmUtil.java index 76a558b..098628e 100644 --- a/agents-flex-llm/agents-flex-llm-spark/src/main/java/com/agentsflex/llm/spark/SparkLlmUtil.java +++ b/agents-flex-llm/agents-flex-llm-spark/src/main/java/com/agentsflex/llm/spark/SparkLlmUtil.java @@ -15,13 +15,14 @@ */ package com.agentsflex.llm.spark; +import com.agentsflex.functions.Function; +import com.agentsflex.functions.Parameter; import com.agentsflex.message.MessageStatus; import com.agentsflex.parser.AiMessageParser; import com.agentsflex.parser.FunctionMessageParser; import com.agentsflex.parser.impl.BaseAiMessageParser; import com.agentsflex.parser.impl.BaseFunctionMessageParser; import com.agentsflex.prompt.DefaultPromptFormat; -import com.agentsflex.prompt.FunctionPrompt; import com.agentsflex.prompt.Prompt; import com.agentsflex.prompt.PromptFormat; import com.agentsflex.util.HashUtil; @@ -31,14 +32,34 @@ import com.alibaba.fastjson.JSON; import java.io.UnsupportedEncodingException; import java.net.URLEncoder; import java.text.SimpleDateFormat; -import java.util.Base64; -import java.util.Date; -import java.util.Locale; -import java.util.UUID; +import java.util.*; public class SparkLlmUtil { - private static final PromptFormat promptFormat = new DefaultPromptFormat(); + private static final PromptFormat promptFormat = new DefaultPromptFormat() { + @Override + protected void buildFunctionJsonArray(List> functionsJsonArray, List> functions) { + for (Function function : functions) { + Map propertiesMap = new HashMap<>(); + List requiredProperties = new ArrayList<>(); + + Parameter[] parameters = function.getParameters(); + if (parameters != null) { + for (Parameter parameter : parameters) { + if (parameter.isRequired()) { + requiredProperties.add(parameter.getName()); + } + propertiesMap.put(parameter.getName(), Maps.of("type", parameter.getType()).put("description", parameter.getDescription()).build()); + } + } + + Maps.Builder builder = Maps.of("name", function.getName()).put("description", function.getDescription()) + .put("parameters", Maps.of("type", "object").put("properties", propertiesMap).put("required", requiredProperties)); + functionsJsonArray.add(builder.build()); + } + } + }; + public static AiMessageParser getAiMessageParser() { BaseAiMessageParser aiMessageParser = new BaseAiMessageParser(); @@ -62,14 +83,17 @@ public class SparkLlmUtil { public static String promptToPayload(Prompt prompt, SparkLlmConfig config) { // https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E Maps.Builder root = Maps.of("header", Maps.of("app_id", config.getAppId()).put("uid", UUID.randomUUID())); - root.put("parameter", Maps.of("chat", Maps.of("domain", "generalv3").put("temperature", 0.5).put("max_tokens", 1024))); + root.put("parameter", Maps.of("chat", Maps.of("domain", getDomain(config.getVersion())).put("temperature", 0.5).put("max_tokens", 1024))); root.put("payload", Maps.of("message", Maps.of("text", promptFormat.toMessagesJsonKey(prompt))) - .putIfNotEmpty("functions", Maps.ofNotNull("text", promptFormat.toFunctionsJsonKey((FunctionPrompt) prompt))) + .putIfNotEmpty("functions", Maps.ofNotNull("text", promptFormat.toFunctionsJsonKey(prompt))) ); return JSON.toJSONString(root.build()); } public static MessageStatus parseMessageStatus(Integer status) { + if (status == null) { + return MessageStatus.UNKNOW; + } switch (status) { case 0: return MessageStatus.START; @@ -106,4 +130,18 @@ public class SparkLlmUtil { throw new RuntimeException(e); } } + + + private static String getDomain(String version) { + switch (version) { + case "v3.5": + return "generalv3.5"; + case "v3.1": + return "generalv3"; + case "v2.1": + return "generalv2"; + default: + return "general"; + } + } } diff --git a/agents-flex-llm/agents-flex-llm-spark/src/test/java/com/agentsflex/llm/spark/test/SparkLlmTest.java b/agents-flex-llm/agents-flex-llm-spark/src/test/java/com/agentsflex/llm/spark/test/SparkLlmTest.java index 61f2b87..da3c9fa 100644 --- a/agents-flex-llm/agents-flex-llm-spark/src/test/java/com/agentsflex/llm/spark/test/SparkLlmTest.java +++ b/agents-flex-llm/agents-flex-llm-spark/src/test/java/com/agentsflex/llm/spark/test/SparkLlmTest.java @@ -21,7 +21,7 @@ public class SparkLlmTest { config.setApiSecret("****"); Llm llm = new SparkLlm(config); - String result = llm.chat("你好"); + String result = llm.chat("你好,请问你是谁?"); System.out.println(result); } @@ -32,6 +32,10 @@ public class SparkLlmTest { config.setAppId("****"); config.setApiKey("****"); config.setApiSecret("****"); + config.setDebug(true); + + //只有 v3.5 版本支持 function calling + config.setVersion("v3.5"); Llm llm = new SparkLlm(config); diff --git a/agents-flex-llm/agents-flex-llm-spark/src/test/java/com/agentsflex/llm/spark/test/WeatherFunctions.java b/agents-flex-llm/agents-flex-llm-spark/src/test/java/com/agentsflex/llm/spark/test/WeatherFunctions.java index a3ebe39..3d7ffd0 100644 --- a/agents-flex-llm/agents-flex-llm-spark/src/test/java/com/agentsflex/llm/spark/test/WeatherFunctions.java +++ b/agents-flex-llm/agents-flex-llm-spark/src/test/java/com/agentsflex/llm/spark/test/WeatherFunctions.java @@ -5,9 +5,9 @@ import com.agentsflex.functions.annotation.FunctionParam; public class WeatherFunctions { - @FunctionDef(name = "weather", description = "获取天气信息") - public static String weather( - @FunctionParam(name = "location", description = "城市名称,比如: 北京, 上海") String name + @FunctionDef(name = "天气查询", description = "获取天气信息的方法") + public static String getWeatherInfo( + @FunctionParam(name = "location", description = "城市的名称,比如: 北京, 上海") String name ) { //此处应该通过 api 去第三方获取 return name + "今天的天气阴转多云"; diff --git a/pom.xml b/pom.xml index e6c8cd2..f04b21f 100644 --- a/pom.xml +++ b/pom.xml @@ -57,7 +57,7 @@ 1.7.29 4.13.2 4.9.3 - 2.0.45 + 2.0.47