mirror of
https://gitee.com/agents-flex/agents-flex.git
synced 2024-11-29 18:38:17 +08:00
feat: make SparkLlm to support embedding
This commit is contained in:
parent
9a35621171
commit
6e6fbb01cd
@ -18,6 +18,7 @@ package com.agentsflex.llm.spark;
|
||||
import com.agentsflex.document.Document;
|
||||
import com.agentsflex.llm.*;
|
||||
import com.agentsflex.llm.client.BaseLlmClientListener;
|
||||
import com.agentsflex.llm.client.HttpClient;
|
||||
import com.agentsflex.llm.client.LlmClient;
|
||||
import com.agentsflex.llm.client.LlmClientListener;
|
||||
import com.agentsflex.llm.client.impl.WebSocketClient;
|
||||
@ -32,14 +33,24 @@ import com.agentsflex.parser.FunctionMessageParser;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
import com.agentsflex.store.VectorData;
|
||||
import com.agentsflex.util.StringUtil;
|
||||
import com.alibaba.fastjson.JSONPath;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.Base64;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(SparkLlm.class);
|
||||
public AiMessageParser aiMessageParser = SparkLlmUtil.getAiMessageParser();
|
||||
public FunctionMessageParser functionMessageParser = SparkLlmUtil.getFunctionMessageParser();
|
||||
|
||||
private final HttpClient httpClient = new HttpClient();
|
||||
|
||||
|
||||
public SparkLlm(SparkLlmConfig config) {
|
||||
super(config);
|
||||
@ -47,7 +58,33 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
|
||||
@Override
|
||||
public VectorData embed(Document document, EmbeddingOptions options) {
|
||||
return null;
|
||||
String payload = SparkLlmUtil.embedPayload(config, document);
|
||||
String resp = httpClient.post(SparkLlmUtil.createEmbedURL(config), null, payload);
|
||||
if (StringUtil.noText(resp)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
Integer code = JSONPath.read(resp, "$.header.code", Integer.class);
|
||||
if (code == null || code != 0) {
|
||||
logger.error(resp);
|
||||
return null;
|
||||
}
|
||||
|
||||
String text = JSONPath.read(resp, "$.payload.feature.text", String.class);
|
||||
if (StringUtil.noText(text)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
byte[] buffer = Base64.getDecoder().decode(text);
|
||||
double[] vector = new double[buffer.length / 4];
|
||||
for (int i = 0; i < vector.length; i++) {
|
||||
int n = i * 4;
|
||||
vector[i] = ByteBuffer.wrap(buffer, n, 4).order(ByteOrder.LITTLE_ENDIAN).getFloat();
|
||||
}
|
||||
|
||||
VectorData vectorData = new VectorData();
|
||||
vectorData.setVector(vector);
|
||||
return vectorData;
|
||||
}
|
||||
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
*/
|
||||
package com.agentsflex.llm.spark;
|
||||
|
||||
import com.agentsflex.document.Document;
|
||||
import com.agentsflex.functions.Function;
|
||||
import com.agentsflex.functions.Parameter;
|
||||
import com.agentsflex.llm.ChatOptions;
|
||||
@ -150,4 +151,35 @@ public class SparkLlmUtil {
|
||||
return "general";
|
||||
}
|
||||
}
|
||||
|
||||
public static String embedPayload(SparkLlmConfig config, Document document) {
|
||||
String text = Maps.of("messages",Collections.singletonList(Maps.of("content",document.getContent()).put("role","user").build())).toJSON();
|
||||
String textBase64 = Base64.getEncoder().encodeToString(text.getBytes());
|
||||
|
||||
return Maps.of("header", Maps.of("app_id", config.getAppId()).put("uid", UUID.randomUUID()).put("status", 3))
|
||||
.put("parameter", Maps.of("emb", Maps.of("domain", "para").put("feature", Maps.of("encoding", "utf8").put("compress", "raw").put("format", "plain"))))
|
||||
.put("payload", Maps.of("messages", Maps.of("encoding", "utf8").put("compress", "raw").put("format", "json").put("status", 3).put("text", textBase64)))
|
||||
.toJSON();
|
||||
}
|
||||
|
||||
|
||||
/// http://emb-cn-huabei-1.xf-yun.com/
|
||||
|
||||
public static String createEmbedURL(SparkLlmConfig config) {
|
||||
SimpleDateFormat sdf = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss '+0000'", Locale.US);
|
||||
sdf.setTimeZone(java.util.TimeZone.getTimeZone("UTC"));
|
||||
String date = sdf.format(new Date());
|
||||
|
||||
String header = "host: emb-cn-huabei-1.xf-yun.com\n";
|
||||
header += "date: " + date + "\n";
|
||||
header += "POST / HTTP/1.1";
|
||||
|
||||
String base64 = HashUtil.hmacSHA256ToBase64(header, config.getApiSecret());
|
||||
String authorization_origin = "api_key=\"" + config.getApiKey()
|
||||
+ "\", algorithm=\"hmac-sha256\", headers=\"host date request-line\", signature=\"" + base64 + "\"";
|
||||
|
||||
String authorization = Base64.getEncoder().encodeToString(authorization_origin.getBytes());
|
||||
return "http://emb-cn-huabei-1.xf-yun.com/?authorization=" + authorization
|
||||
+ "&date=" + urlEncode(date) + "&host=emb-cn-huabei-1.xf-yun.com";
|
||||
}
|
||||
}
|
||||
|
@ -1,5 +1,6 @@
|
||||
package com.agentsflex.llm.spark.test;
|
||||
|
||||
import com.agentsflex.document.Document;
|
||||
import com.agentsflex.llm.Llm;
|
||||
import com.agentsflex.llm.response.FunctionMessageResponse;
|
||||
import com.agentsflex.llm.spark.SparkLlm;
|
||||
@ -7,6 +8,7 @@ import com.agentsflex.llm.spark.SparkLlmConfig;
|
||||
import com.agentsflex.message.HumanMessage;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
import com.agentsflex.prompt.HistoriesPrompt;
|
||||
import com.agentsflex.store.VectorData;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.Scanner;
|
||||
@ -26,6 +28,19 @@ public class SparkLlmTest {
|
||||
System.out.println(result);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testEmbedding() {
|
||||
SparkLlmConfig config = new SparkLlmConfig();
|
||||
config.setAppId("****");
|
||||
config.setApiKey("****");
|
||||
config.setApiSecret("****");
|
||||
config.setVersion("v3.5");
|
||||
|
||||
Llm llm = new SparkLlm(config);
|
||||
VectorData vectorData = llm.embed(Document.of("你好,请问你是谁?"));
|
||||
System.out.println(vectorData);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testFunctionCalling() throws InterruptedException {
|
||||
|
Loading…
Reference in New Issue
Block a user