feat: add function calling for "spark" llm

This commit is contained in:
Michael Yang 2024-03-08 17:04:38 +08:00
parent 4b4fb420c5
commit 4b6267f62d
7 changed files with 163 additions and 40 deletions

View File

@ -17,10 +17,8 @@ package com.agentsflex.message;
public class SystemMessage extends Message {
public SystemMessage(String content) {
this.content = content;
}
}

View File

@ -29,18 +29,27 @@ import java.util.List;
public class FunctionPrompt extends Prompt<FunctionMessage> {
private final ChatMemory memory = new DefaultChatMemory();
private List<Function<?>> functions = new ArrayList<>();
private final List<Function<?>> functions = new ArrayList<>();
public FunctionPrompt(String prompt, Class<?> funcClass) {
memory.addMessage(new HumanMessage(prompt));
functions.addAll(Functions.from(funcClass));
}
public FunctionPrompt(String prompt, Class<?> funcClass, String... methodNames) {
memory.addMessage(new HumanMessage(prompt));
functions.addAll(Functions.from(funcClass, methodNames));
}
public FunctionPrompt(List<Message> messages, Class<?> funcClass) {
memory.addMessages(messages);
functions.addAll(Functions.from(funcClass));
}
public FunctionPrompt(List<Message> messages, Class<?> funcClass, String... methodNames) {
memory.addMessages(messages);
functions.addAll(Functions.from(funcClass, methodNames));
}
@Override
public List<Message> toMessages() {

View File

@ -51,7 +51,7 @@ public class OpenAiLLmUtil {
// https://platform.openai.com/docs/api-reference/making-requests
String payload = "{\n" +
" \"input\": \""+text.getContent()+"\",\n" +
" \"input\": \"" + text.getContent() + "\",\n" +
" \"model\": \"text-embedding-ada-002\",\n" +
" \"encoding_format\": \"float\"\n" +
"}";
@ -60,7 +60,6 @@ public class OpenAiLLmUtil {
}
public static String promptToPayload(Prompt prompt, OpenAiLlmConfig config) {
List<Message> messages = prompt.toMessages();
@ -86,7 +85,7 @@ public class OpenAiLLmUtil {
return "{\n" +
// " \"model\": \"gpt-3.5-turbo\",\n" +
" \"model\": \"" + config.getModel() + "\",\n" +
" \"messages\": "+messageText+",\n" +
" \"messages\": " + messageText + ",\n" +
" \"temperature\": 0.7\n" +
"}";
}
@ -113,32 +112,32 @@ public class OpenAiLLmUtil {
String messageText = JSON.toJSONString(messageArray);
List<Map<String,Object>> toolsArray = new ArrayList<>();
List<Map<String, Object>> toolsArray = new ArrayList<>();
for (Function<?> function : functions) {
Map<String,Object> functionRoot = new HashMap<>();
functionRoot.put("type","function");
Map<String, Object> functionRoot = new HashMap<>();
functionRoot.put("type", "function");
Map<String,Object> functionObj = new HashMap<>();
functionRoot.put("function",functionObj);
Map<String, Object> functionObj = new HashMap<>();
functionRoot.put("function", functionObj);
functionObj.put("name",function.getName());
functionObj.put("description",function.getDescription());
functionObj.put("name", function.getName());
functionObj.put("description", function.getDescription());
Map<String,Object> parametersObj = new HashMap<>();
functionObj.put("parameters",parametersObj);
Map<String, Object> parametersObj = new HashMap<>();
functionObj.put("parameters", parametersObj);
parametersObj.put("type","object");
parametersObj.put("type", "object");
Map<String,Object> propertiesObj = new HashMap<>();
parametersObj.put("properties",propertiesObj);
Map<String, Object> propertiesObj = new HashMap<>();
parametersObj.put("properties", propertiesObj);
for (Parameter parameter : function.getParameters()) {
Map<String,Object> parameterObj = new HashMap<>();
parameterObj.put("type",parameter.getType());
parameterObj.put("description",parameter.getDescription());
parameterObj.put("enum",parameter.getEnums());
propertiesObj.put(parameter.getName(),parameterObj);
Map<String, Object> parameterObj = new HashMap<>();
parameterObj.put("type", parameter.getType());
parameterObj.put("description", parameter.getDescription());
parameterObj.put("enum", parameter.getEnums());
propertiesObj.put(parameter.getName(), parameterObj);
}
toolsArray.add(functionRoot);
@ -149,12 +148,11 @@ public class OpenAiLLmUtil {
return "{\n" +
// " \"model\": \"gpt-3.5-turbo\",\n" +
" \"model\": \"" + config.getModel() + "\",\n" +
" \"messages\": "+messageText+",\n" +
" \"tools\": "+toolsText+",\n" +
" \"messages\": " + messageText + ",\n" +
" \"tools\": " + toolsText + ",\n" +
" \"tool_choice\": \"auto\"\n" +
"}";
}
}

View File

@ -16,6 +16,7 @@
package com.agentsflex.llm.spark;
import com.agentsflex.document.Document;
import com.agentsflex.functions.Function;
import com.agentsflex.llm.BaseLlm;
import com.agentsflex.llm.ChatContext;
import com.agentsflex.llm.MessageListener;
@ -25,11 +26,18 @@ import com.agentsflex.llm.client.LlmClient;
import com.agentsflex.llm.client.LlmClientListener;
import com.agentsflex.llm.client.impl.WebSocketClient;
import com.agentsflex.llm.response.AiMessageResponse;
import com.agentsflex.llm.response.FunctionMessageResponse;
import com.agentsflex.message.AiMessage;
import com.agentsflex.message.FunctionMessage;
import com.agentsflex.message.Message;
import com.agentsflex.prompt.FunctionPrompt;
import com.agentsflex.prompt.Prompt;
import com.agentsflex.store.VectorData;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.JSONPath;
import java.util.List;
import java.util.concurrent.CountDownLatch;
public class SparkLlm extends BaseLlm<SparkLlmConfig> {
@ -47,11 +55,20 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
@Override
public <R extends MessageResponse<M>, M extends Message> R chat(Prompt<M> prompt) {
CountDownLatch latch = new CountDownLatch(1);
AiMessage aiMessage = new AiMessage();
Message[] messages = new Message[1];
chatAsync(prompt, new MessageListener<MessageResponse<M>, M>() {
@Override
public void onMessage(ChatContext context, MessageResponse<M> response) {
aiMessage.setContent(((AiMessage) response.getMessage()).getFullContent());
if (response.getMessage() instanceof AiMessage) {
if (messages[0] == null) {
messages[0] = response.getMessage();
} else {
messages[0].setContent(((AiMessage) response.getMessage()).getFullContent());
}
} else if (response.getMessage() instanceof FunctionMessage) {
messages[0] = response.getMessage();
}
}
@Override
@ -65,7 +82,12 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
return (R) new AiMessageResponse(aiMessage);
if (prompt instanceof FunctionPrompt) {
return (R) new FunctionMessageResponse(((FunctionPrompt) prompt).getFunctions(), (FunctionMessage) messages[0]);
} else {
return (R) new AiMessageResponse((AiMessage) messages[0]);
}
}
@ -76,7 +98,20 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
String payload = SparkLlmUtil.promptToPayload(prompt, config);
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, SparkLlmUtil::parseAiMessage, null);
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, SparkLlmUtil::parseAiMessage, new BaseLlmClientListener.FunctionMessageParser() {
@Override
public FunctionMessage parseMessage(String response) {
JSONObject jsonObject = JSON.parseObject(response);
String callFunctionName = (String) JSONPath.eval(jsonObject, "$.payload.choices.text[0].function_call.name");
String callFunctionArgsString = (String) JSONPath.eval(jsonObject, "$.payload.choices.text[0].function_call.arguments");
JSONObject callFunctionArgs = JSON.parseObject(callFunctionArgsString);
FunctionMessage functionMessage = new FunctionMessage();
functionMessage.setFunctionName(callFunctionName);
functionMessage.setArgs(callFunctionArgs);
return functionMessage;
}
});
llmClient.start(url, null, payload, clientListener);
}

View File

@ -15,10 +15,10 @@
*/
package com.agentsflex.llm.spark;
import com.agentsflex.message.AiMessage;
import com.agentsflex.message.HumanMessage;
import com.agentsflex.message.Message;
import com.agentsflex.message.MessageStatus;
import com.agentsflex.functions.Function;
import com.agentsflex.functions.Parameter;
import com.agentsflex.message.*;
import com.agentsflex.prompt.FunctionPrompt;
import com.agentsflex.prompt.Prompt;
import com.agentsflex.util.HashUtil;
import com.alibaba.fastjson.JSON;
@ -32,14 +32,14 @@ import java.util.*;
public class SparkLlmUtil {
public static AiMessage parseAiMessage(String json){
public static AiMessage parseAiMessage(String json) {
AiMessage aiMessage = new AiMessage();
JSONObject jsonObject = JSON.parseObject(json);
Object status = JSONPath.eval(jsonObject, "$.payload.choices.status");
MessageStatus messageStatus = SparkLlmUtil.parseMessageStatus((Integer) status);
aiMessage.setStatus(messageStatus);
aiMessage.setIndex((Integer) JSONPath.eval(jsonObject,"$.payload.choices.text[0].index"));
aiMessage.setContent((String) JSONPath.eval(jsonObject,"$.payload.choices.text[0].content"));
aiMessage.setIndex((Integer) JSONPath.eval(jsonObject, "$.payload.choices.text[0].index"));
aiMessage.setContent((String) JSONPath.eval(jsonObject, "$.payload.choices.text[0].content"));
return aiMessage;
}
@ -64,7 +64,8 @@ public class SparkLlmUtil {
" \"payload\": {\n" +
" \"message\": {\n" +
" \"text\": messageJsonString" +
" }\n" +
" },\n" +
" \"functions\":functionsJsonString" +
" }\n" +
"}";
@ -78,13 +79,60 @@ public class SparkLlmUtil {
} else if (message instanceof AiMessage) {
map.put("role", "assistant");
map.put("content", ((AiMessage) message).getFullContent());
} else if (message instanceof SystemMessage) {
map.put("role", "system");
map.put("content", message.getContent());
}
messageArray.add(map);
});
String functionsJsonString = "\"\"";
if (prompt instanceof FunctionPrompt) {
List<Function<?>> functions = ((FunctionPrompt) prompt).getFunctions();
List<Map<String, Object>> functionsArray = new ArrayList<>();
for (Function<?> function : functions) {
Map<String, Object> functionRoot = new HashMap<>();
functionRoot.put("type", "function");
Map<String, Object> functionObj = new HashMap<>();
functionRoot.put("function", functionObj);
functionObj.put("name", function.getName());
functionObj.put("description", function.getDescription());
Map<String, Object> parametersObj = new HashMap<>();
functionObj.put("parameters", parametersObj);
parametersObj.put("type", "object");
Map<String, Object> propertiesObj = new HashMap<>();
parametersObj.put("properties", propertiesObj);
for (Parameter parameter : function.getParameters()) {
Map<String, Object> parameterObj = new HashMap<>();
parameterObj.put("type", parameter.getType());
parameterObj.put("description", parameter.getDescription());
parameterObj.put("enum", parameter.getEnums());
propertiesObj.put(parameter.getName(), parameterObj);
}
functionsArray.add(functionRoot);
}
Map<String, Object> functionsJsonMap = new HashMap<>();
functionsJsonMap.put("text", functionsArray);
// Map<String, Object> functionsJsonRoot = new HashMap<>();
// functionsJsonRoot.put("functions", functionsJsonMap);
functionsJsonString = JSON.toJSONString(functionsJsonMap);
}
String messageText = JSON.toJSONString(messageArray);
return payload.replace("messageJsonString", messageText);
return payload.replace("messageJsonString", messageText).replace("functionsJsonString", functionsJsonString);
}
public static MessageStatus parseMessageStatus(Integer status) {

View File

@ -1,9 +1,11 @@
package com.agentsflex.llm.spark.test;
import com.agentsflex.llm.Llm;
import com.agentsflex.llm.response.FunctionMessageResponse;
import com.agentsflex.llm.spark.SparkLlm;
import com.agentsflex.llm.spark.SparkLlmConfig;
import com.agentsflex.message.HumanMessage;
import com.agentsflex.prompt.FunctionPrompt;
import com.agentsflex.prompt.HistoriesPrompt;
import org.junit.Test;
@ -24,6 +26,24 @@ public class SparkLlmTest {
}
@Test
public void testFunctionCalling() throws InterruptedException {
SparkLlmConfig config = new SparkLlmConfig();
config.setAppId("****");
config.setApiKey("****");
config.setApiSecret("****");
Llm llm = new SparkLlm(config);
FunctionPrompt prompt = new FunctionPrompt("今天北京的天气怎么样", WeatherFunctions.class);
FunctionMessageResponse response = llm.chat(prompt);
Object result = response.invoke();
System.out.println(result);
}
public static void main(String[] args) {
SparkLlmConfig config = new SparkLlmConfig();
config.setAppId("****");

View File

@ -0,0 +1,15 @@
package com.agentsflex.llm.spark.test;
import com.agentsflex.functions.annotation.FunctionDef;
import com.agentsflex.functions.annotation.FunctionParam;
public class WeatherFunctions {
@FunctionDef(name = "weather", description = "获取天气信息")
public static String weather(
@FunctionParam(name = "location", description = "城市名称,比如: 北京, 上海") String name
) {
//此处应该通过 api 去第三方获取
return name + "今天的天气阴转多云";
}
}