mirror of
https://gitee.com/agents-flex/agents-flex.git
synced 2024-11-30 02:47:46 +08:00
feat: 完善 星火大模型 的 Function Calling 测试
This commit is contained in:
parent
39dc6e8f5b
commit
625cfd3dcf
@ -18,4 +18,14 @@ package com.agentsflex.llm;
|
||||
import java.io.Serializable;
|
||||
|
||||
public class LlmConfig implements Serializable {
|
||||
|
||||
private boolean debug;
|
||||
|
||||
public boolean isDebug() {
|
||||
return debug;
|
||||
}
|
||||
|
||||
public void setDebug(boolean debug) {
|
||||
this.debug = debug;
|
||||
}
|
||||
}
|
||||
|
@ -15,11 +15,13 @@
|
||||
*/
|
||||
package com.agentsflex.llm.client;
|
||||
|
||||
import com.agentsflex.llm.LlmConfig;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public interface LlmClient {
|
||||
|
||||
void start(String url, Map<String, String> headers, String payload, LlmClientListener listener);
|
||||
void start(String url, Map<String, String> headers, String payload, LlmClientListener listener, LlmConfig config);
|
||||
|
||||
void stop();
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
*/
|
||||
package com.agentsflex.llm.client.impl;
|
||||
|
||||
import com.agentsflex.llm.LlmConfig;
|
||||
import com.agentsflex.llm.client.LlmClient;
|
||||
import com.agentsflex.llm.client.LlmClientListener;
|
||||
import okhttp3.*;
|
||||
@ -28,11 +29,13 @@ public class AsyncHttpClient implements LlmClient {
|
||||
private static final MediaType JSON_TYPE = MediaType.parse("application/json; charset=utf-8");
|
||||
private OkHttpClient client;
|
||||
private LlmClientListener listener;
|
||||
private LlmConfig config;
|
||||
private boolean isStop = false;
|
||||
|
||||
@Override
|
||||
public void start(String url, Map<String, String> headers, String payload, LlmClientListener listener) {
|
||||
public void start(String url, Map<String, String> headers, String payload, LlmClientListener listener, LlmConfig config) {
|
||||
this.listener = listener;
|
||||
this.config = config;
|
||||
this.isStop = false;
|
||||
|
||||
Request.Builder rBuilder = new Request.Builder()
|
||||
@ -50,6 +53,9 @@ public class AsyncHttpClient implements LlmClient {
|
||||
.readTimeout(3, TimeUnit.MINUTES)
|
||||
.build();
|
||||
|
||||
if (this.config.isDebug()){
|
||||
System.out.println(">>>>send payload:" + payload);
|
||||
}
|
||||
|
||||
this.listener.onStart(this);
|
||||
this.client.newCall(rBuilder.build()).enqueue(new Callback() {
|
||||
@ -60,6 +66,9 @@ public class AsyncHttpClient implements LlmClient {
|
||||
|
||||
@Override
|
||||
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
|
||||
if (config.isDebug()){
|
||||
System.out.println(">>>>receive payload:" + response.message());
|
||||
}
|
||||
AsyncHttpClient.this.listener.onMessage(AsyncHttpClient.this, response.message());
|
||||
if (!isStop) {
|
||||
AsyncHttpClient.this.isStop = true;
|
||||
|
@ -15,6 +15,7 @@
|
||||
*/
|
||||
package com.agentsflex.llm.client.impl;
|
||||
|
||||
import com.agentsflex.llm.LlmConfig;
|
||||
import com.agentsflex.llm.client.LlmClient;
|
||||
import com.agentsflex.llm.client.LlmClientListener;
|
||||
import okhttp3.*;
|
||||
@ -32,11 +33,13 @@ public class SseClient extends EventSourceListener implements LlmClient {
|
||||
private OkHttpClient client;
|
||||
private EventSource eventSource;
|
||||
private LlmClientListener listener;
|
||||
private LlmConfig config;
|
||||
private boolean isStop = false;
|
||||
|
||||
@Override
|
||||
public void start(String url, Map<String, String> headers, String payload, LlmClientListener listener) {
|
||||
public void start(String url, Map<String, String> headers, String payload, LlmClientListener listener, LlmConfig config) {
|
||||
this.listener = listener;
|
||||
this.config = config;
|
||||
this.isStop = false;
|
||||
|
||||
Request.Builder builder = new Request.Builder()
|
||||
@ -60,6 +63,10 @@ public class SseClient extends EventSourceListener implements LlmClient {
|
||||
EventSource.Factory factory = EventSources.createFactory(this.client);
|
||||
this.eventSource = factory.newEventSource(request, this);
|
||||
|
||||
if (this.config.isDebug()){
|
||||
System.out.println(">>>>send payload:" + payload);
|
||||
}
|
||||
|
||||
this.listener.onStart(this);
|
||||
}
|
||||
|
||||
@ -84,6 +91,9 @@ public class SseClient extends EventSourceListener implements LlmClient {
|
||||
|
||||
@Override
|
||||
public void onEvent(@NotNull EventSource eventSource, @Nullable String id, @Nullable String type, @NotNull String data) {
|
||||
if (this.config.isDebug()){
|
||||
System.out.println(">>>>receive payload:" + data);
|
||||
}
|
||||
this.listener.onMessage(this, data);
|
||||
}
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
*/
|
||||
package com.agentsflex.llm.client.impl;
|
||||
|
||||
import com.agentsflex.llm.LlmConfig;
|
||||
import com.agentsflex.llm.client.LlmClient;
|
||||
import com.agentsflex.llm.client.LlmClientListener;
|
||||
import okhttp3.*;
|
||||
@ -28,13 +29,15 @@ public class WebSocketClient extends WebSocketListener implements LlmClient {
|
||||
|
||||
private WebSocket webSocket;
|
||||
private LlmClientListener listener;
|
||||
private LlmConfig config;
|
||||
private boolean isStop = false;
|
||||
private String payload;
|
||||
|
||||
@Override
|
||||
public void start(String url, Map<String, String> headers, String payload, LlmClientListener listener) {
|
||||
public void start(String url, Map<String, String> headers, String payload, LlmClientListener listener, LlmConfig config) {
|
||||
this.listener = listener;
|
||||
this.payload = payload;
|
||||
this.config = config;
|
||||
|
||||
OkHttpClient client = new OkHttpClient.Builder()
|
||||
.readTimeout(0, TimeUnit.MILLISECONDS)
|
||||
@ -46,6 +49,10 @@ public class WebSocketClient extends WebSocketListener implements LlmClient {
|
||||
|
||||
this.webSocket = client.newWebSocket(request, this);
|
||||
this.isStop = false;
|
||||
|
||||
if (this.config.isDebug()){
|
||||
System.out.println(">>>>send payload:" + payload);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -69,6 +76,9 @@ public class WebSocketClient extends WebSocketListener implements LlmClient {
|
||||
|
||||
@Override
|
||||
public void onMessage(WebSocket webSocket, String text) {
|
||||
if (this.config.isDebug()){
|
||||
System.out.println(">>>>receive payload:" + text);
|
||||
}
|
||||
this.listener.onMessage(this, text);
|
||||
}
|
||||
|
||||
|
@ -19,6 +19,7 @@ import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.message.MessageStatus;
|
||||
import com.agentsflex.parser.AiMessageParser;
|
||||
import com.agentsflex.parser.Parser;
|
||||
import com.agentsflex.util.StringUtil;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.alibaba.fastjson.JSONPath;
|
||||
@ -76,13 +77,24 @@ public class BaseAiMessageParser implements AiMessageParser {
|
||||
public AiMessage parse(String content) {
|
||||
AiMessage aiMessage = new AiMessage();
|
||||
JSONObject rootJson = JSON.parseObject(content);
|
||||
aiMessage.setContent((String) JSONPath.eval(rootJson, this.contentPath));
|
||||
aiMessage.setIndex((Integer) JSONPath.eval(rootJson, this.indexPath));
|
||||
aiMessage.setTotalTokens((Integer) JSONPath.eval(rootJson, this.totalTokensPath));
|
||||
|
||||
String statusString = (String) JSONPath.eval(rootJson, this.statusPath);
|
||||
if (this.statusParser != null) {
|
||||
aiMessage.setStatus(this.statusParser.parse(statusString));
|
||||
if (StringUtil.hasText(this.contentPath)) {
|
||||
aiMessage.setContent((String) JSONPath.eval(rootJson, this.contentPath));
|
||||
}
|
||||
|
||||
if (StringUtil.hasText(this.indexPath)) {
|
||||
aiMessage.setIndex((Integer) JSONPath.eval(rootJson, this.indexPath));
|
||||
}
|
||||
|
||||
if (StringUtil.hasText(this.totalTokensPath)) {
|
||||
aiMessage.setTotalTokens((Integer) JSONPath.eval(rootJson, this.totalTokensPath));
|
||||
}
|
||||
|
||||
if (StringUtil.hasText(this.statusPath)) {
|
||||
Object statusString = JSONPath.eval(rootJson, this.statusPath);
|
||||
if (this.statusParser != null) {
|
||||
aiMessage.setStatus(this.statusParser.parse(statusString));
|
||||
}
|
||||
}
|
||||
|
||||
return aiMessage;
|
||||
|
@ -40,7 +40,13 @@ public class DefaultPromptFormat implements PromptFormat {
|
||||
return null;
|
||||
}
|
||||
|
||||
List<Map<String, String>> messageArray = new ArrayList<>(messages.size());
|
||||
List<Map<String, String>> messageJsonArray = new ArrayList<>(messages.size());
|
||||
buildMessageJsonArray(messageJsonArray, messages);
|
||||
|
||||
return messageJsonArray;
|
||||
}
|
||||
|
||||
protected void buildMessageJsonArray(List<Map<String, String>> messageJsonArray, List<Message> messages) {
|
||||
messages.forEach(message -> {
|
||||
Map<String, String> map = new HashMap<>(2);
|
||||
if (message instanceof HumanMessage) {
|
||||
@ -53,10 +59,8 @@ public class DefaultPromptFormat implements PromptFormat {
|
||||
map.put("role", "system");
|
||||
map.put("content", ((SystemMessage) message).getContent());
|
||||
}
|
||||
messageArray.add(map);
|
||||
messageJsonArray.add(map);
|
||||
});
|
||||
|
||||
return messageArray;
|
||||
}
|
||||
|
||||
|
||||
@ -66,12 +70,18 @@ public class DefaultPromptFormat implements PromptFormat {
|
||||
return null;
|
||||
}
|
||||
|
||||
List<Function<?>> functions = ((FunctionPrompt)prompt).getFunctions();
|
||||
List<Function<?>> functions = ((FunctionPrompt) prompt).getFunctions();
|
||||
if (functions == null || functions.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
List<Map<String, Object>> functionsJsonArray = new ArrayList<>();
|
||||
buildFunctionJsonArray(functionsJsonArray, functions);
|
||||
|
||||
return functionsJsonArray;
|
||||
}
|
||||
|
||||
protected void buildFunctionJsonArray(List<Map<String, Object>> functionsJsonArray, List<Function<?>> functions) {
|
||||
for (Function<?> function : functions) {
|
||||
Map<String, Object> functionRoot = new HashMap<>();
|
||||
functionRoot.put("type", "function");
|
||||
@ -100,8 +110,6 @@ public class DefaultPromptFormat implements PromptFormat {
|
||||
}
|
||||
functionsJsonArray.add(functionRoot);
|
||||
}
|
||||
|
||||
return functionsJsonArray;
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -26,6 +26,10 @@ public class Maps {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static Builder of(String key, Builder value) {
|
||||
return of(key, value.build());
|
||||
}
|
||||
|
||||
public static Builder of(String key, Object value) {
|
||||
return new Builder().put(key, value);
|
||||
}
|
||||
|
@ -80,7 +80,7 @@ public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> {
|
||||
String payload = OpenAiLLmUtil.promptToPayload(prompt, config);
|
||||
|
||||
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, aiMessageParser, functionMessageParser);
|
||||
llmClient.start("https://api.openai.com/v1/chat/completions", headers, payload, clientListener);
|
||||
llmClient.start("https://api.openai.com/v1/chat/completions", headers, payload, clientListener,config);
|
||||
}
|
||||
|
||||
|
||||
|
@ -94,7 +94,7 @@ public class QwenLlm extends BaseLlm<QwenLlmConfig> {
|
||||
return aiMessage;
|
||||
}
|
||||
}, functionMessageParser);
|
||||
llmClient.start("https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation", headers, payload, clientListener);
|
||||
llmClient.start("https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation", headers, payload, clientListener,config);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -43,7 +43,6 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
public FunctionMessageParser functionMessageParser = SparkLlmUtil.getFunctionMessageParser();
|
||||
|
||||
|
||||
|
||||
public SparkLlm(SparkLlmConfig config) {
|
||||
super(config);
|
||||
}
|
||||
@ -54,7 +53,6 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
}
|
||||
|
||||
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public <R extends MessageResponse<M>, M extends Message> R chat(Prompt<M> prompt) {
|
||||
@ -67,7 +65,7 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
if (messages[0] == null) {
|
||||
messages[0] = response.getMessage();
|
||||
} else {
|
||||
((AiMessage)messages[0]).setContent(((AiMessage) response.getMessage()).getFullContent());
|
||||
((AiMessage) messages[0]).setContent(((AiMessage) response.getMessage()).getFullContent());
|
||||
}
|
||||
|
||||
} else if (response.getMessage() instanceof FunctionMessage) {
|
||||
@ -95,7 +93,6 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public <R extends MessageResponse<M>, M extends Message> void chatAsync(Prompt<M> prompt, MessageListener<R, M> listener) {
|
||||
LlmClient llmClient = new WebSocketClient();
|
||||
@ -104,7 +101,7 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
String payload = SparkLlmUtil.promptToPayload(prompt, config);
|
||||
|
||||
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, aiMessageParser, functionMessageParser);
|
||||
llmClient.start(url, null, payload, clientListener);
|
||||
llmClient.start(url, null, payload, clientListener, config);
|
||||
}
|
||||
|
||||
|
||||
|
@ -15,13 +15,14 @@
|
||||
*/
|
||||
package com.agentsflex.llm.spark;
|
||||
|
||||
import com.agentsflex.functions.Function;
|
||||
import com.agentsflex.functions.Parameter;
|
||||
import com.agentsflex.message.MessageStatus;
|
||||
import com.agentsflex.parser.AiMessageParser;
|
||||
import com.agentsflex.parser.FunctionMessageParser;
|
||||
import com.agentsflex.parser.impl.BaseAiMessageParser;
|
||||
import com.agentsflex.parser.impl.BaseFunctionMessageParser;
|
||||
import com.agentsflex.prompt.DefaultPromptFormat;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
import com.agentsflex.prompt.PromptFormat;
|
||||
import com.agentsflex.util.HashUtil;
|
||||
@ -31,14 +32,34 @@ import com.alibaba.fastjson.JSON;
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.net.URLEncoder;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.Base64;
|
||||
import java.util.Date;
|
||||
import java.util.Locale;
|
||||
import java.util.UUID;
|
||||
import java.util.*;
|
||||
|
||||
public class SparkLlmUtil {
|
||||
|
||||
private static final PromptFormat promptFormat = new DefaultPromptFormat();
|
||||
private static final PromptFormat promptFormat = new DefaultPromptFormat() {
|
||||
@Override
|
||||
protected void buildFunctionJsonArray(List<Map<String, Object>> functionsJsonArray, List<Function<?>> functions) {
|
||||
for (Function<?> function : functions) {
|
||||
Map<String, Object> propertiesMap = new HashMap<>();
|
||||
List<String> requiredProperties = new ArrayList<>();
|
||||
|
||||
Parameter[] parameters = function.getParameters();
|
||||
if (parameters != null) {
|
||||
for (Parameter parameter : parameters) {
|
||||
if (parameter.isRequired()) {
|
||||
requiredProperties.add(parameter.getName());
|
||||
}
|
||||
propertiesMap.put(parameter.getName(), Maps.of("type", parameter.getType()).put("description", parameter.getDescription()).build());
|
||||
}
|
||||
}
|
||||
|
||||
Maps.Builder builder = Maps.of("name", function.getName()).put("description", function.getDescription())
|
||||
.put("parameters", Maps.of("type", "object").put("properties", propertiesMap).put("required", requiredProperties));
|
||||
functionsJsonArray.add(builder.build());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
public static AiMessageParser getAiMessageParser() {
|
||||
BaseAiMessageParser aiMessageParser = new BaseAiMessageParser();
|
||||
@ -62,14 +83,17 @@ public class SparkLlmUtil {
|
||||
public static String promptToPayload(Prompt prompt, SparkLlmConfig config) {
|
||||
// https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
|
||||
Maps.Builder root = Maps.of("header", Maps.of("app_id", config.getAppId()).put("uid", UUID.randomUUID()));
|
||||
root.put("parameter", Maps.of("chat", Maps.of("domain", "generalv3").put("temperature", 0.5).put("max_tokens", 1024)));
|
||||
root.put("parameter", Maps.of("chat", Maps.of("domain", getDomain(config.getVersion())).put("temperature", 0.5).put("max_tokens", 1024)));
|
||||
root.put("payload", Maps.of("message", Maps.of("text", promptFormat.toMessagesJsonKey(prompt)))
|
||||
.putIfNotEmpty("functions", Maps.ofNotNull("text", promptFormat.toFunctionsJsonKey((FunctionPrompt) prompt)))
|
||||
.putIfNotEmpty("functions", Maps.ofNotNull("text", promptFormat.toFunctionsJsonKey(prompt)))
|
||||
);
|
||||
return JSON.toJSONString(root.build());
|
||||
}
|
||||
|
||||
public static MessageStatus parseMessageStatus(Integer status) {
|
||||
if (status == null) {
|
||||
return MessageStatus.UNKNOW;
|
||||
}
|
||||
switch (status) {
|
||||
case 0:
|
||||
return MessageStatus.START;
|
||||
@ -106,4 +130,18 @@ public class SparkLlmUtil {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private static String getDomain(String version) {
|
||||
switch (version) {
|
||||
case "v3.5":
|
||||
return "generalv3.5";
|
||||
case "v3.1":
|
||||
return "generalv3";
|
||||
case "v2.1":
|
||||
return "generalv2";
|
||||
default:
|
||||
return "general";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ public class SparkLlmTest {
|
||||
config.setApiSecret("****");
|
||||
|
||||
Llm llm = new SparkLlm(config);
|
||||
String result = llm.chat("你好");
|
||||
String result = llm.chat("你好,请问你是谁?");
|
||||
System.out.println(result);
|
||||
}
|
||||
|
||||
@ -32,6 +32,10 @@ public class SparkLlmTest {
|
||||
config.setAppId("****");
|
||||
config.setApiKey("****");
|
||||
config.setApiSecret("****");
|
||||
config.setDebug(true);
|
||||
|
||||
//只有 v3.5 版本支持 function calling
|
||||
config.setVersion("v3.5");
|
||||
|
||||
Llm llm = new SparkLlm(config);
|
||||
|
||||
|
@ -5,9 +5,9 @@ import com.agentsflex.functions.annotation.FunctionParam;
|
||||
|
||||
public class WeatherFunctions {
|
||||
|
||||
@FunctionDef(name = "weather", description = "获取天气信息")
|
||||
public static String weather(
|
||||
@FunctionParam(name = "location", description = "城市名称,比如: 北京, 上海") String name
|
||||
@FunctionDef(name = "天气查询", description = "获取天气信息的方法")
|
||||
public static String getWeatherInfo(
|
||||
@FunctionParam(name = "location", description = "城市的名称,比如: 北京, 上海") String name
|
||||
) {
|
||||
//此处应该通过 api 去第三方获取
|
||||
return name + "今天的天气阴转多云";
|
||||
|
Loading…
Reference in New Issue
Block a user