mirror of
https://gitee.com/agents-flex/agents-flex.git
synced 2024-11-30 02:47:46 +08:00
feat: add function calling for "spark" llm
This commit is contained in:
parent
4b4fb420c5
commit
4b6267f62d
@ -17,10 +17,8 @@ package com.agentsflex.message;
|
||||
|
||||
public class SystemMessage extends Message {
|
||||
|
||||
|
||||
public SystemMessage(String content) {
|
||||
this.content = content;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
@ -29,18 +29,27 @@ import java.util.List;
|
||||
public class FunctionPrompt extends Prompt<FunctionMessage> {
|
||||
private final ChatMemory memory = new DefaultChatMemory();
|
||||
|
||||
private List<Function<?>> functions = new ArrayList<>();
|
||||
private final List<Function<?>> functions = new ArrayList<>();
|
||||
|
||||
public FunctionPrompt(String prompt, Class<?> funcClass) {
|
||||
memory.addMessage(new HumanMessage(prompt));
|
||||
functions.addAll(Functions.from(funcClass));
|
||||
}
|
||||
|
||||
public FunctionPrompt(String prompt, Class<?> funcClass, String... methodNames) {
|
||||
memory.addMessage(new HumanMessage(prompt));
|
||||
functions.addAll(Functions.from(funcClass, methodNames));
|
||||
}
|
||||
|
||||
public FunctionPrompt(List<Message> messages, Class<?> funcClass) {
|
||||
memory.addMessages(messages);
|
||||
functions.addAll(Functions.from(funcClass));
|
||||
}
|
||||
|
||||
public FunctionPrompt(List<Message> messages, Class<?> funcClass, String... methodNames) {
|
||||
memory.addMessages(messages);
|
||||
functions.addAll(Functions.from(funcClass, methodNames));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Message> toMessages() {
|
||||
|
@ -51,7 +51,7 @@ public class OpenAiLLmUtil {
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/making-requests
|
||||
String payload = "{\n" +
|
||||
" \"input\": \""+text.getContent()+"\",\n" +
|
||||
" \"input\": \"" + text.getContent() + "\",\n" +
|
||||
" \"model\": \"text-embedding-ada-002\",\n" +
|
||||
" \"encoding_format\": \"float\"\n" +
|
||||
"}";
|
||||
@ -60,7 +60,6 @@ public class OpenAiLLmUtil {
|
||||
}
|
||||
|
||||
|
||||
|
||||
public static String promptToPayload(Prompt prompt, OpenAiLlmConfig config) {
|
||||
|
||||
List<Message> messages = prompt.toMessages();
|
||||
@ -86,7 +85,7 @@ public class OpenAiLLmUtil {
|
||||
return "{\n" +
|
||||
// " \"model\": \"gpt-3.5-turbo\",\n" +
|
||||
" \"model\": \"" + config.getModel() + "\",\n" +
|
||||
" \"messages\": "+messageText+",\n" +
|
||||
" \"messages\": " + messageText + ",\n" +
|
||||
" \"temperature\": 0.7\n" +
|
||||
"}";
|
||||
}
|
||||
@ -113,32 +112,32 @@ public class OpenAiLLmUtil {
|
||||
String messageText = JSON.toJSONString(messageArray);
|
||||
|
||||
|
||||
List<Map<String,Object>> toolsArray = new ArrayList<>();
|
||||
List<Map<String, Object>> toolsArray = new ArrayList<>();
|
||||
for (Function<?> function : functions) {
|
||||
Map<String,Object> functionRoot = new HashMap<>();
|
||||
functionRoot.put("type","function");
|
||||
Map<String, Object> functionRoot = new HashMap<>();
|
||||
functionRoot.put("type", "function");
|
||||
|
||||
Map<String,Object> functionObj = new HashMap<>();
|
||||
functionRoot.put("function",functionObj);
|
||||
Map<String, Object> functionObj = new HashMap<>();
|
||||
functionRoot.put("function", functionObj);
|
||||
|
||||
functionObj.put("name",function.getName());
|
||||
functionObj.put("description",function.getDescription());
|
||||
functionObj.put("name", function.getName());
|
||||
functionObj.put("description", function.getDescription());
|
||||
|
||||
|
||||
Map<String,Object> parametersObj = new HashMap<>();
|
||||
functionObj.put("parameters",parametersObj);
|
||||
Map<String, Object> parametersObj = new HashMap<>();
|
||||
functionObj.put("parameters", parametersObj);
|
||||
|
||||
parametersObj.put("type","object");
|
||||
parametersObj.put("type", "object");
|
||||
|
||||
Map<String,Object> propertiesObj = new HashMap<>();
|
||||
parametersObj.put("properties",propertiesObj);
|
||||
Map<String, Object> propertiesObj = new HashMap<>();
|
||||
parametersObj.put("properties", propertiesObj);
|
||||
|
||||
for (Parameter parameter : function.getParameters()) {
|
||||
Map<String,Object> parameterObj = new HashMap<>();
|
||||
parameterObj.put("type",parameter.getType());
|
||||
parameterObj.put("description",parameter.getDescription());
|
||||
parameterObj.put("enum",parameter.getEnums());
|
||||
propertiesObj.put(parameter.getName(),parameterObj);
|
||||
Map<String, Object> parameterObj = new HashMap<>();
|
||||
parameterObj.put("type", parameter.getType());
|
||||
parameterObj.put("description", parameter.getDescription());
|
||||
parameterObj.put("enum", parameter.getEnums());
|
||||
propertiesObj.put(parameter.getName(), parameterObj);
|
||||
}
|
||||
|
||||
toolsArray.add(functionRoot);
|
||||
@ -149,12 +148,11 @@ public class OpenAiLLmUtil {
|
||||
return "{\n" +
|
||||
// " \"model\": \"gpt-3.5-turbo\",\n" +
|
||||
" \"model\": \"" + config.getModel() + "\",\n" +
|
||||
" \"messages\": "+messageText+",\n" +
|
||||
" \"tools\": "+toolsText+",\n" +
|
||||
" \"messages\": " + messageText + ",\n" +
|
||||
" \"tools\": " + toolsText + ",\n" +
|
||||
" \"tool_choice\": \"auto\"\n" +
|
||||
"}";
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
@ -16,6 +16,7 @@
|
||||
package com.agentsflex.llm.spark;
|
||||
|
||||
import com.agentsflex.document.Document;
|
||||
import com.agentsflex.functions.Function;
|
||||
import com.agentsflex.llm.BaseLlm;
|
||||
import com.agentsflex.llm.ChatContext;
|
||||
import com.agentsflex.llm.MessageListener;
|
||||
@ -25,11 +26,18 @@ import com.agentsflex.llm.client.LlmClient;
|
||||
import com.agentsflex.llm.client.LlmClientListener;
|
||||
import com.agentsflex.llm.client.impl.WebSocketClient;
|
||||
import com.agentsflex.llm.response.AiMessageResponse;
|
||||
import com.agentsflex.llm.response.FunctionMessageResponse;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.message.FunctionMessage;
|
||||
import com.agentsflex.message.Message;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
import com.agentsflex.store.VectorData;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.alibaba.fastjson.JSONPath;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
@ -47,11 +55,20 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
@Override
|
||||
public <R extends MessageResponse<M>, M extends Message> R chat(Prompt<M> prompt) {
|
||||
CountDownLatch latch = new CountDownLatch(1);
|
||||
AiMessage aiMessage = new AiMessage();
|
||||
Message[] messages = new Message[1];
|
||||
chatAsync(prompt, new MessageListener<MessageResponse<M>, M>() {
|
||||
@Override
|
||||
public void onMessage(ChatContext context, MessageResponse<M> response) {
|
||||
aiMessage.setContent(((AiMessage) response.getMessage()).getFullContent());
|
||||
if (response.getMessage() instanceof AiMessage) {
|
||||
if (messages[0] == null) {
|
||||
messages[0] = response.getMessage();
|
||||
} else {
|
||||
messages[0].setContent(((AiMessage) response.getMessage()).getFullContent());
|
||||
}
|
||||
|
||||
} else if (response.getMessage() instanceof FunctionMessage) {
|
||||
messages[0] = response.getMessage();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -65,7 +82,12 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return (R) new AiMessageResponse(aiMessage);
|
||||
|
||||
if (prompt instanceof FunctionPrompt) {
|
||||
return (R) new FunctionMessageResponse(((FunctionPrompt) prompt).getFunctions(), (FunctionMessage) messages[0]);
|
||||
} else {
|
||||
return (R) new AiMessageResponse((AiMessage) messages[0]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -76,7 +98,20 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
|
||||
String payload = SparkLlmUtil.promptToPayload(prompt, config);
|
||||
|
||||
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, SparkLlmUtil::parseAiMessage, null);
|
||||
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, SparkLlmUtil::parseAiMessage, new BaseLlmClientListener.FunctionMessageParser() {
|
||||
@Override
|
||||
public FunctionMessage parseMessage(String response) {
|
||||
JSONObject jsonObject = JSON.parseObject(response);
|
||||
String callFunctionName = (String) JSONPath.eval(jsonObject, "$.payload.choices.text[0].function_call.name");
|
||||
String callFunctionArgsString = (String) JSONPath.eval(jsonObject, "$.payload.choices.text[0].function_call.arguments");
|
||||
JSONObject callFunctionArgs = JSON.parseObject(callFunctionArgsString);
|
||||
|
||||
FunctionMessage functionMessage = new FunctionMessage();
|
||||
functionMessage.setFunctionName(callFunctionName);
|
||||
functionMessage.setArgs(callFunctionArgs);
|
||||
return functionMessage;
|
||||
}
|
||||
});
|
||||
llmClient.start(url, null, payload, clientListener);
|
||||
}
|
||||
|
||||
|
@ -15,10 +15,10 @@
|
||||
*/
|
||||
package com.agentsflex.llm.spark;
|
||||
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.message.HumanMessage;
|
||||
import com.agentsflex.message.Message;
|
||||
import com.agentsflex.message.MessageStatus;
|
||||
import com.agentsflex.functions.Function;
|
||||
import com.agentsflex.functions.Parameter;
|
||||
import com.agentsflex.message.*;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
import com.agentsflex.util.HashUtil;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
@ -32,14 +32,14 @@ import java.util.*;
|
||||
|
||||
public class SparkLlmUtil {
|
||||
|
||||
public static AiMessage parseAiMessage(String json){
|
||||
public static AiMessage parseAiMessage(String json) {
|
||||
AiMessage aiMessage = new AiMessage();
|
||||
JSONObject jsonObject = JSON.parseObject(json);
|
||||
Object status = JSONPath.eval(jsonObject, "$.payload.choices.status");
|
||||
MessageStatus messageStatus = SparkLlmUtil.parseMessageStatus((Integer) status);
|
||||
aiMessage.setStatus(messageStatus);
|
||||
aiMessage.setIndex((Integer) JSONPath.eval(jsonObject,"$.payload.choices.text[0].index"));
|
||||
aiMessage.setContent((String) JSONPath.eval(jsonObject,"$.payload.choices.text[0].content"));
|
||||
aiMessage.setIndex((Integer) JSONPath.eval(jsonObject, "$.payload.choices.text[0].index"));
|
||||
aiMessage.setContent((String) JSONPath.eval(jsonObject, "$.payload.choices.text[0].content"));
|
||||
return aiMessage;
|
||||
}
|
||||
|
||||
@ -64,7 +64,8 @@ public class SparkLlmUtil {
|
||||
" \"payload\": {\n" +
|
||||
" \"message\": {\n" +
|
||||
" \"text\": messageJsonString" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" \"functions\":functionsJsonString" +
|
||||
" }\n" +
|
||||
"}";
|
||||
|
||||
@ -78,13 +79,60 @@ public class SparkLlmUtil {
|
||||
} else if (message instanceof AiMessage) {
|
||||
map.put("role", "assistant");
|
||||
map.put("content", ((AiMessage) message).getFullContent());
|
||||
} else if (message instanceof SystemMessage) {
|
||||
map.put("role", "system");
|
||||
map.put("content", message.getContent());
|
||||
}
|
||||
|
||||
messageArray.add(map);
|
||||
});
|
||||
|
||||
|
||||
|
||||
String functionsJsonString = "\"\"";
|
||||
if (prompt instanceof FunctionPrompt) {
|
||||
List<Function<?>> functions = ((FunctionPrompt) prompt).getFunctions();
|
||||
|
||||
List<Map<String, Object>> functionsArray = new ArrayList<>();
|
||||
for (Function<?> function : functions) {
|
||||
Map<String, Object> functionRoot = new HashMap<>();
|
||||
functionRoot.put("type", "function");
|
||||
|
||||
Map<String, Object> functionObj = new HashMap<>();
|
||||
functionRoot.put("function", functionObj);
|
||||
|
||||
functionObj.put("name", function.getName());
|
||||
functionObj.put("description", function.getDescription());
|
||||
|
||||
|
||||
Map<String, Object> parametersObj = new HashMap<>();
|
||||
functionObj.put("parameters", parametersObj);
|
||||
|
||||
parametersObj.put("type", "object");
|
||||
|
||||
Map<String, Object> propertiesObj = new HashMap<>();
|
||||
parametersObj.put("properties", propertiesObj);
|
||||
|
||||
for (Parameter parameter : function.getParameters()) {
|
||||
Map<String, Object> parameterObj = new HashMap<>();
|
||||
parameterObj.put("type", parameter.getType());
|
||||
parameterObj.put("description", parameter.getDescription());
|
||||
parameterObj.put("enum", parameter.getEnums());
|
||||
propertiesObj.put(parameter.getName(), parameterObj);
|
||||
}
|
||||
|
||||
functionsArray.add(functionRoot);
|
||||
}
|
||||
Map<String, Object> functionsJsonMap = new HashMap<>();
|
||||
functionsJsonMap.put("text", functionsArray);
|
||||
|
||||
// Map<String, Object> functionsJsonRoot = new HashMap<>();
|
||||
// functionsJsonRoot.put("functions", functionsJsonMap);
|
||||
|
||||
functionsJsonString = JSON.toJSONString(functionsJsonMap);
|
||||
}
|
||||
|
||||
String messageText = JSON.toJSONString(messageArray);
|
||||
return payload.replace("messageJsonString", messageText);
|
||||
return payload.replace("messageJsonString", messageText).replace("functionsJsonString", functionsJsonString);
|
||||
}
|
||||
|
||||
public static MessageStatus parseMessageStatus(Integer status) {
|
||||
|
@ -1,9 +1,11 @@
|
||||
package com.agentsflex.llm.spark.test;
|
||||
|
||||
import com.agentsflex.llm.Llm;
|
||||
import com.agentsflex.llm.response.FunctionMessageResponse;
|
||||
import com.agentsflex.llm.spark.SparkLlm;
|
||||
import com.agentsflex.llm.spark.SparkLlmConfig;
|
||||
import com.agentsflex.message.HumanMessage;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
import com.agentsflex.prompt.HistoriesPrompt;
|
||||
import org.junit.Test;
|
||||
|
||||
@ -24,6 +26,24 @@ public class SparkLlmTest {
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testFunctionCalling() throws InterruptedException {
|
||||
SparkLlmConfig config = new SparkLlmConfig();
|
||||
config.setAppId("****");
|
||||
config.setApiKey("****");
|
||||
config.setApiSecret("****");
|
||||
|
||||
Llm llm = new SparkLlm(config);
|
||||
|
||||
FunctionPrompt prompt = new FunctionPrompt("今天北京的天气怎么样", WeatherFunctions.class);
|
||||
FunctionMessageResponse response = llm.chat(prompt);
|
||||
|
||||
Object result = response.invoke();
|
||||
|
||||
System.out.println(result);
|
||||
}
|
||||
|
||||
|
||||
public static void main(String[] args) {
|
||||
SparkLlmConfig config = new SparkLlmConfig();
|
||||
config.setAppId("****");
|
||||
|
@ -0,0 +1,15 @@
|
||||
package com.agentsflex.llm.spark.test;
|
||||
|
||||
import com.agentsflex.functions.annotation.FunctionDef;
|
||||
import com.agentsflex.functions.annotation.FunctionParam;
|
||||
|
||||
public class WeatherFunctions {
|
||||
|
||||
@FunctionDef(name = "weather", description = "获取天气信息")
|
||||
public static String weather(
|
||||
@FunctionParam(name = "location", description = "城市名称,比如: 北京, 上海") String name
|
||||
) {
|
||||
//此处应该通过 api 去第三方获取
|
||||
return name + "今天的天气阴转多云";
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user