feat: make SparkLlm to support embedding

This commit is contained in:
Michael Yang 2024-05-30 19:19:48 +08:00
parent 9a35621171
commit 6e6fbb01cd
3 changed files with 85 additions and 1 deletions

View File

@ -18,6 +18,7 @@ package com.agentsflex.llm.spark;
import com.agentsflex.document.Document;
import com.agentsflex.llm.*;
import com.agentsflex.llm.client.BaseLlmClientListener;
import com.agentsflex.llm.client.HttpClient;
import com.agentsflex.llm.client.LlmClient;
import com.agentsflex.llm.client.LlmClientListener;
import com.agentsflex.llm.client.impl.WebSocketClient;
@ -32,14 +33,24 @@ import com.agentsflex.parser.FunctionMessageParser;
import com.agentsflex.prompt.FunctionPrompt;
import com.agentsflex.prompt.Prompt;
import com.agentsflex.store.VectorData;
import com.agentsflex.util.StringUtil;
import com.alibaba.fastjson.JSONPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Base64;
import java.util.concurrent.CountDownLatch;
public class SparkLlm extends BaseLlm<SparkLlmConfig> {
private static final Logger logger = LoggerFactory.getLogger(SparkLlm.class);
public AiMessageParser aiMessageParser = SparkLlmUtil.getAiMessageParser();
public FunctionMessageParser functionMessageParser = SparkLlmUtil.getFunctionMessageParser();
private final HttpClient httpClient = new HttpClient();
public SparkLlm(SparkLlmConfig config) {
super(config);
@ -47,7 +58,33 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
@Override
public VectorData embed(Document document, EmbeddingOptions options) {
return null;
String payload = SparkLlmUtil.embedPayload(config, document);
String resp = httpClient.post(SparkLlmUtil.createEmbedURL(config), null, payload);
if (StringUtil.noText(resp)) {
return null;
}
Integer code = JSONPath.read(resp, "$.header.code", Integer.class);
if (code == null || code != 0) {
logger.error(resp);
return null;
}
String text = JSONPath.read(resp, "$.payload.feature.text", String.class);
if (StringUtil.noText(text)) {
return null;
}
byte[] buffer = Base64.getDecoder().decode(text);
double[] vector = new double[buffer.length / 4];
for (int i = 0; i < vector.length; i++) {
int n = i * 4;
vector[i] = ByteBuffer.wrap(buffer, n, 4).order(ByteOrder.LITTLE_ENDIAN).getFloat();
}
VectorData vectorData = new VectorData();
vectorData.setVector(vector);
return vectorData;
}

View File

@ -15,6 +15,7 @@
*/
package com.agentsflex.llm.spark;
import com.agentsflex.document.Document;
import com.agentsflex.functions.Function;
import com.agentsflex.functions.Parameter;
import com.agentsflex.llm.ChatOptions;
@ -150,4 +151,35 @@ public class SparkLlmUtil {
return "general";
}
}
public static String embedPayload(SparkLlmConfig config, Document document) {
String text = Maps.of("messages",Collections.singletonList(Maps.of("content",document.getContent()).put("role","user").build())).toJSON();
String textBase64 = Base64.getEncoder().encodeToString(text.getBytes());
return Maps.of("header", Maps.of("app_id", config.getAppId()).put("uid", UUID.randomUUID()).put("status", 3))
.put("parameter", Maps.of("emb", Maps.of("domain", "para").put("feature", Maps.of("encoding", "utf8").put("compress", "raw").put("format", "plain"))))
.put("payload", Maps.of("messages", Maps.of("encoding", "utf8").put("compress", "raw").put("format", "json").put("status", 3).put("text", textBase64)))
.toJSON();
}
/// http://emb-cn-huabei-1.xf-yun.com/
public static String createEmbedURL(SparkLlmConfig config) {
SimpleDateFormat sdf = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss '+0000'", Locale.US);
sdf.setTimeZone(java.util.TimeZone.getTimeZone("UTC"));
String date = sdf.format(new Date());
String header = "host: emb-cn-huabei-1.xf-yun.com\n";
header += "date: " + date + "\n";
header += "POST / HTTP/1.1";
String base64 = HashUtil.hmacSHA256ToBase64(header, config.getApiSecret());
String authorization_origin = "api_key=\"" + config.getApiKey()
+ "\", algorithm=\"hmac-sha256\", headers=\"host date request-line\", signature=\"" + base64 + "\"";
String authorization = Base64.getEncoder().encodeToString(authorization_origin.getBytes());
return "http://emb-cn-huabei-1.xf-yun.com/?authorization=" + authorization
+ "&date=" + urlEncode(date) + "&host=emb-cn-huabei-1.xf-yun.com";
}
}

View File

@ -1,5 +1,6 @@
package com.agentsflex.llm.spark.test;
import com.agentsflex.document.Document;
import com.agentsflex.llm.Llm;
import com.agentsflex.llm.response.FunctionMessageResponse;
import com.agentsflex.llm.spark.SparkLlm;
@ -7,6 +8,7 @@ import com.agentsflex.llm.spark.SparkLlmConfig;
import com.agentsflex.message.HumanMessage;
import com.agentsflex.prompt.FunctionPrompt;
import com.agentsflex.prompt.HistoriesPrompt;
import com.agentsflex.store.VectorData;
import org.junit.Test;
import java.util.Scanner;
@ -26,6 +28,19 @@ public class SparkLlmTest {
System.out.println(result);
}
@Test
public void testEmbedding() {
SparkLlmConfig config = new SparkLlmConfig();
config.setAppId("****");
config.setApiKey("****");
config.setApiSecret("****");
config.setVersion("v3.5");
Llm llm = new SparkLlm(config);
VectorData vectorData = llm.embed(Document.of("你好,请问你是谁?"));
System.out.println(vectorData);
}
@Test
public void testFunctionCalling() throws InterruptedException {