feat: 完善 星火大模型 的 Function Calling 测试

This commit is contained in:
Michael Yang 2024-03-12 20:12:13 +08:00
parent 39dc6e8f5b
commit 625cfd3dcf
15 changed files with 141 additions and 37 deletions

View File

@ -18,4 +18,14 @@ package com.agentsflex.llm;
import java.io.Serializable;
public class LlmConfig implements Serializable {
private boolean debug;
public boolean isDebug() {
return debug;
}
public void setDebug(boolean debug) {
this.debug = debug;
}
}

View File

@ -15,11 +15,13 @@
*/
package com.agentsflex.llm.client;
import com.agentsflex.llm.LlmConfig;
import java.util.Map;
public interface LlmClient {
void start(String url, Map<String, String> headers, String payload, LlmClientListener listener);
void start(String url, Map<String, String> headers, String payload, LlmClientListener listener, LlmConfig config);
void stop();
}

View File

@ -15,6 +15,7 @@
*/
package com.agentsflex.llm.client.impl;
import com.agentsflex.llm.LlmConfig;
import com.agentsflex.llm.client.LlmClient;
import com.agentsflex.llm.client.LlmClientListener;
import okhttp3.*;
@ -28,11 +29,13 @@ public class AsyncHttpClient implements LlmClient {
private static final MediaType JSON_TYPE = MediaType.parse("application/json; charset=utf-8");
private OkHttpClient client;
private LlmClientListener listener;
private LlmConfig config;
private boolean isStop = false;
@Override
public void start(String url, Map<String, String> headers, String payload, LlmClientListener listener) {
public void start(String url, Map<String, String> headers, String payload, LlmClientListener listener, LlmConfig config) {
this.listener = listener;
this.config = config;
this.isStop = false;
Request.Builder rBuilder = new Request.Builder()
@ -50,6 +53,9 @@ public class AsyncHttpClient implements LlmClient {
.readTimeout(3, TimeUnit.MINUTES)
.build();
if (this.config.isDebug()){
System.out.println(">>>>send payload:" + payload);
}
this.listener.onStart(this);
this.client.newCall(rBuilder.build()).enqueue(new Callback() {
@ -60,6 +66,9 @@ public class AsyncHttpClient implements LlmClient {
@Override
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
if (config.isDebug()){
System.out.println(">>>>receive payload:" + response.message());
}
AsyncHttpClient.this.listener.onMessage(AsyncHttpClient.this, response.message());
if (!isStop) {
AsyncHttpClient.this.isStop = true;

View File

@ -15,6 +15,7 @@
*/
package com.agentsflex.llm.client.impl;
import com.agentsflex.llm.LlmConfig;
import com.agentsflex.llm.client.LlmClient;
import com.agentsflex.llm.client.LlmClientListener;
import okhttp3.*;
@ -32,11 +33,13 @@ public class SseClient extends EventSourceListener implements LlmClient {
private OkHttpClient client;
private EventSource eventSource;
private LlmClientListener listener;
private LlmConfig config;
private boolean isStop = false;
@Override
public void start(String url, Map<String, String> headers, String payload, LlmClientListener listener) {
public void start(String url, Map<String, String> headers, String payload, LlmClientListener listener, LlmConfig config) {
this.listener = listener;
this.config = config;
this.isStop = false;
Request.Builder builder = new Request.Builder()
@ -60,6 +63,10 @@ public class SseClient extends EventSourceListener implements LlmClient {
EventSource.Factory factory = EventSources.createFactory(this.client);
this.eventSource = factory.newEventSource(request, this);
if (this.config.isDebug()){
System.out.println(">>>>send payload:" + payload);
}
this.listener.onStart(this);
}
@ -84,6 +91,9 @@ public class SseClient extends EventSourceListener implements LlmClient {
@Override
public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) {
if (this.config.isDebug()){
System.out.println(">>>>receive payload:" + data);
}
this.listener.onMessage(this, data);
}

View File

@ -15,6 +15,7 @@
*/
package com.agentsflex.llm.client.impl;
import com.agentsflex.llm.LlmConfig;
import com.agentsflex.llm.client.LlmClient;
import com.agentsflex.llm.client.LlmClientListener;
import okhttp3.*;
@ -28,13 +29,15 @@ public class WebSocketClient extends WebSocketListener implements LlmClient {
private WebSocket webSocket;
private LlmClientListener listener;
private LlmConfig config;
private boolean isStop = false;
private String payload;
@Override
public void start(String url, Map<String, String> headers, String payload, LlmClientListener listener) {
public void start(String url, Map<String, String> headers, String payload, LlmClientListener listener, LlmConfig config) {
this.listener = listener;
this.payload = payload;
this.config = config;
OkHttpClient client = new OkHttpClient.Builder()
.readTimeout(0, TimeUnit.MILLISECONDS)
@ -46,6 +49,10 @@ public class WebSocketClient extends WebSocketListener implements LlmClient {
this.webSocket = client.newWebSocket(request, this);
this.isStop = false;
if (this.config.isDebug()){
System.out.println(">>>>send payload:" + payload);
}
}
@Override
@ -69,6 +76,9 @@ public class WebSocketClient extends WebSocketListener implements LlmClient {
@Override
public void onMessage(WebSocket webSocket, String text) {
if (this.config.isDebug()){
System.out.println(">>>>receive payload:" + text);
}
this.listener.onMessage(this, text);
}

View File

@ -19,6 +19,7 @@ import com.agentsflex.message.AiMessage;
import com.agentsflex.message.MessageStatus;
import com.agentsflex.parser.AiMessageParser;
import com.agentsflex.parser.Parser;
import com.agentsflex.util.StringUtil;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.JSONPath;
@ -76,13 +77,24 @@ public class BaseAiMessageParser implements AiMessageParser {
public AiMessage parse(String content) {
AiMessage aiMessage = new AiMessage();
JSONObject rootJson = JSON.parseObject(content);
aiMessage.setContent((String) JSONPath.eval(rootJson, this.contentPath));
aiMessage.setIndex((Integer) JSONPath.eval(rootJson, this.indexPath));
aiMessage.setTotalTokens((Integer) JSONPath.eval(rootJson, this.totalTokensPath));
String statusString = (String) JSONPath.eval(rootJson, this.statusPath);
if (this.statusParser != null) {
aiMessage.setStatus(this.statusParser.parse(statusString));
if (StringUtil.hasText(this.contentPath)) {
aiMessage.setContent((String) JSONPath.eval(rootJson, this.contentPath));
}
if (StringUtil.hasText(this.indexPath)) {
aiMessage.setIndex((Integer) JSONPath.eval(rootJson, this.indexPath));
}
if (StringUtil.hasText(this.totalTokensPath)) {
aiMessage.setTotalTokens((Integer) JSONPath.eval(rootJson, this.totalTokensPath));
}
if (StringUtil.hasText(this.statusPath)) {
Object statusString = JSONPath.eval(rootJson, this.statusPath);
if (this.statusParser != null) {
aiMessage.setStatus(this.statusParser.parse(statusString));
}
}
return aiMessage;

View File

@ -40,7 +40,13 @@ public class DefaultPromptFormat implements PromptFormat {
return null;
}
List<Map<String, String>> messageArray = new ArrayList<>(messages.size());
List<Map<String, String>> messageJsonArray = new ArrayList<>(messages.size());
buildMessageJsonArray(messageJsonArray, messages);
return messageJsonArray;
}
protected void buildMessageJsonArray(List<Map<String, String>> messageJsonArray, List<Message> messages) {
messages.forEach(message -> {
Map<String, String> map = new HashMap<>(2);
if (message instanceof HumanMessage) {
@ -53,10 +59,8 @@ public class DefaultPromptFormat implements PromptFormat {
map.put("role", "system");
map.put("content", ((SystemMessage) message).getContent());
}
messageArray.add(map);
messageJsonArray.add(map);
});
return messageArray;
}
@ -66,12 +70,18 @@ public class DefaultPromptFormat implements PromptFormat {
return null;
}
List<Function<?>> functions = ((FunctionPrompt)prompt).getFunctions();
List<Function<?>> functions = ((FunctionPrompt) prompt).getFunctions();
if (functions == null || functions.isEmpty()) {
return null;
}
List<Map<String, Object>> functionsJsonArray = new ArrayList<>();
buildFunctionJsonArray(functionsJsonArray, functions);
return functionsJsonArray;
}
protected void buildFunctionJsonArray(List<Map<String, Object>> functionsJsonArray, List<Function<?>> functions) {
for (Function<?> function : functions) {
Map<String, Object> functionRoot = new HashMap<>();
functionRoot.put("type", "function");
@ -100,8 +110,6 @@ public class DefaultPromptFormat implements PromptFormat {
}
functionsJsonArray.add(functionRoot);
}
return functionsJsonArray;
}
}

View File

@ -26,6 +26,10 @@ public class Maps {
return new Builder();
}
public static Builder of(String key, Builder value) {
return of(key, value.build());
}
public static Builder of(String key, Object value) {
return new Builder().put(key, value);
}

View File

@ -80,7 +80,7 @@ public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> {
String payload = OpenAiLLmUtil.promptToPayload(prompt, config);
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, aiMessageParser, functionMessageParser);
llmClient.start("https://api.openai.com/v1/chat/completions", headers, payload, clientListener);
llmClient.start("https://api.openai.com/v1/chat/completions", headers, payload, clientListener,config);
}

View File

@ -94,7 +94,7 @@ public class QwenLlm extends BaseLlm<QwenLlmConfig> {
return aiMessage;
}
}, functionMessageParser);
llmClient.start("https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation", headers, payload, clientListener);
llmClient.start("https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation", headers, payload, clientListener,config);
}
@Override

View File

@ -43,7 +43,6 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
public FunctionMessageParser functionMessageParser = SparkLlmUtil.getFunctionMessageParser();
public SparkLlm(SparkLlmConfig config) {
super(config);
}
@ -54,7 +53,6 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
}
@SuppressWarnings("unchecked")
@Override
public <R extends MessageResponse<M>, M extends Message> R chat(Prompt<M> prompt) {
@ -67,7 +65,7 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
if (messages[0] == null) {
messages[0] = response.getMessage();
} else {
((AiMessage)messages[0]).setContent(((AiMessage) response.getMessage()).getFullContent());
((AiMessage) messages[0]).setContent(((AiMessage) response.getMessage()).getFullContent());
}
} else if (response.getMessage() instanceof FunctionMessage) {
@ -95,7 +93,6 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
}
@Override
public <R extends MessageResponse<M>, M extends Message> void chatAsync(Prompt<M> prompt, MessageListener<R, M> listener) {
LlmClient llmClient = new WebSocketClient();
@ -104,7 +101,7 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
String payload = SparkLlmUtil.promptToPayload(prompt, config);
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, aiMessageParser, functionMessageParser);
llmClient.start(url, null, payload, clientListener);
llmClient.start(url, null, payload, clientListener, config);
}

View File

@ -15,13 +15,14 @@
*/
package com.agentsflex.llm.spark;
import com.agentsflex.functions.Function;
import com.agentsflex.functions.Parameter;
import com.agentsflex.message.MessageStatus;
import com.agentsflex.parser.AiMessageParser;
import com.agentsflex.parser.FunctionMessageParser;
import com.agentsflex.parser.impl.BaseAiMessageParser;
import com.agentsflex.parser.impl.BaseFunctionMessageParser;
import com.agentsflex.prompt.DefaultPromptFormat;
import com.agentsflex.prompt.FunctionPrompt;
import com.agentsflex.prompt.Prompt;
import com.agentsflex.prompt.PromptFormat;
import com.agentsflex.util.HashUtil;
@ -31,14 +32,34 @@ import com.alibaba.fastjson.JSON;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.text.SimpleDateFormat;
import java.util.Base64;
import java.util.Date;
import java.util.Locale;
import java.util.UUID;
import java.util.*;
public class SparkLlmUtil {
private static final PromptFormat promptFormat = new DefaultPromptFormat();
private static final PromptFormat promptFormat = new DefaultPromptFormat() {
@Override
protected void buildFunctionJsonArray(List<Map<String, Object>> functionsJsonArray, List<Function<?>> functions) {
for (Function<?> function : functions) {
Map<String, Object> propertiesMap = new HashMap<>();
List<String> requiredProperties = new ArrayList<>();
Parameter[] parameters = function.getParameters();
if (parameters != null) {
for (Parameter parameter : parameters) {
if (parameter.isRequired()) {
requiredProperties.add(parameter.getName());
}
propertiesMap.put(parameter.getName(), Maps.of("type", parameter.getType()).put("description", parameter.getDescription()).build());
}
}
Maps.Builder builder = Maps.of("name", function.getName()).put("description", function.getDescription())
.put("parameters", Maps.of("type", "object").put("properties", propertiesMap).put("required", requiredProperties));
functionsJsonArray.add(builder.build());
}
}
};
public static AiMessageParser getAiMessageParser() {
BaseAiMessageParser aiMessageParser = new BaseAiMessageParser();
@ -62,14 +83,17 @@ public class SparkLlmUtil {
public static String promptToPayload(Prompt prompt, SparkLlmConfig config) {
// https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
Maps.Builder root = Maps.of("header", Maps.of("app_id", config.getAppId()).put("uid", UUID.randomUUID()));
root.put("parameter", Maps.of("chat", Maps.of("domain", "generalv3").put("temperature", 0.5).put("max_tokens", 1024)));
root.put("parameter", Maps.of("chat", Maps.of("domain", getDomain(config.getVersion())).put("temperature", 0.5).put("max_tokens", 1024)));
root.put("payload", Maps.of("message", Maps.of("text", promptFormat.toMessagesJsonKey(prompt)))
.putIfNotEmpty("functions", Maps.ofNotNull("text", promptFormat.toFunctionsJsonKey((FunctionPrompt) prompt)))
.putIfNotEmpty("functions", Maps.ofNotNull("text", promptFormat.toFunctionsJsonKey(prompt)))
);
return JSON.toJSONString(root.build());
}
public static MessageStatus parseMessageStatus(Integer status) {
if (status == null) {
return MessageStatus.UNKNOW;
}
switch (status) {
case 0:
return MessageStatus.START;
@ -106,4 +130,18 @@ public class SparkLlmUtil {
throw new RuntimeException(e);
}
}
private static String getDomain(String version) {
switch (version) {
case "v3.5":
return "generalv3.5";
case "v3.1":
return "generalv3";
case "v2.1":
return "generalv2";
default:
return "general";
}
}
}

View File

@ -21,7 +21,7 @@ public class SparkLlmTest {
config.setApiSecret("****");
Llm llm = new SparkLlm(config);
String result = llm.chat("你好");
String result = llm.chat("你好,请问你是谁?");
System.out.println(result);
}
@ -32,6 +32,10 @@ public class SparkLlmTest {
config.setAppId("****");
config.setApiKey("****");
config.setApiSecret("****");
config.setDebug(true);
//只有 v3.5 版本支持 function calling
config.setVersion("v3.5");
Llm llm = new SparkLlm(config);

View File

@ -5,9 +5,9 @@ import com.agentsflex.functions.annotation.FunctionParam;
public class WeatherFunctions {
@FunctionDef(name = "weather", description = "获取天气信息")
public static String weather(
@FunctionParam(name = "location", description = "城市名称,比如: 北京, 上海") String name
@FunctionDef(name = "天气查询", description = "获取天气信息的方法")
public static String getWeatherInfo(
@FunctionParam(name = "location", description = "城市名称,比如: 北京, 上海") String name
) {
//此处应该通过 api 去第三方获取
return name + "今天的天气阴转多云";

View File

@ -57,7 +57,7 @@
<slf4j.version>1.7.29</slf4j.version>
<junit.version>4.13.2</junit.version>
<okhttp.version>4.9.3</okhttp.version>
<fastjson.version>2.0.45</fastjson.version>
<fastjson.version>2.0.47</fastjson.version>
</properties>