mirror of
https://gitee.com/agents-flex/agents-flex.git
synced 2024-11-30 02:47:46 +08:00
refactor: optimize llms
This commit is contained in:
parent
f7ab096796
commit
39dc6e8f5b
@ -15,26 +15,21 @@
|
||||
<maven.compiler.source>8</maven.compiler.source>
|
||||
<maven.compiler.target>8</maven.compiler.target>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<okhttp.version>4.9.3</okhttp.version>
|
||||
<fastjson.version>2.0.45</fastjson.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.squareup.okhttp3</groupId>
|
||||
<artifactId>okhttp</artifactId>
|
||||
<version>${okhttp.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.squareup.okhttp3</groupId>
|
||||
<artifactId>okhttp-sse</artifactId>
|
||||
<version>${okhttp.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.alibaba</groupId>
|
||||
<artifactId>fastjson</artifactId>
|
||||
<version>${fastjson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
|
@ -40,7 +40,7 @@ public class SimpleSplitter implements Splitter {
|
||||
List<Document> texts = new ArrayList<>(textArray.length);
|
||||
for (String textString : textArray) {
|
||||
Document newText = new Document();
|
||||
newText.setMetadataMap(text.getMetadataMap());
|
||||
newText.setMetadatas(text.getMetadatas());
|
||||
newText.setContent(textString);
|
||||
texts.add(newText);
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ import com.agentsflex.prompt.SimplePrompt;
|
||||
public interface Llm extends Embeddings {
|
||||
|
||||
default String chat(String prompt) {
|
||||
MessageResponse<?> chat = chat(new SimplePrompt(prompt));
|
||||
MessageResponse<AiMessage> chat = chat(new SimplePrompt(prompt));
|
||||
return chat != null ? chat.getMessage().getContent() : null;
|
||||
}
|
||||
|
||||
|
@ -17,13 +17,15 @@ package com.agentsflex.llm.client;
|
||||
|
||||
import com.agentsflex.functions.Function;
|
||||
import com.agentsflex.llm.ChatContext;
|
||||
import com.agentsflex.llm.Llm;
|
||||
import com.agentsflex.llm.MessageListener;
|
||||
import com.agentsflex.llm.MessageResponse;
|
||||
import com.agentsflex.llm.Llm;
|
||||
import com.agentsflex.llm.response.AiMessageResponse;
|
||||
import com.agentsflex.llm.response.FunctionMessageResponse;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.message.FunctionMessage;
|
||||
import com.agentsflex.parser.AiMessageParser;
|
||||
import com.agentsflex.parser.FunctionMessageParser;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
import com.agentsflex.prompt.HistoriesPrompt;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
@ -72,12 +74,12 @@ public class BaseLlmClientListener implements LlmClientListener {
|
||||
@Override
|
||||
public void onMessage(LlmClient client, String response) {
|
||||
if (isFunctionCalling) {
|
||||
FunctionMessage functionInfo = functionInfoParser.parseMessage(response);
|
||||
FunctionMessage functionInfo = functionInfoParser.parse(response);
|
||||
List<Function<?>> functions = ((FunctionPrompt) prompt).getFunctions();
|
||||
MessageResponse<?> r = new FunctionMessageResponse(functions, functionInfo);
|
||||
messageListener.onMessage(context, r);
|
||||
} else {
|
||||
lastAiMessage = messageParser.parseMessage(response);
|
||||
lastAiMessage = messageParser.parse(response);
|
||||
fullMessage.append(lastAiMessage.getContent());
|
||||
lastAiMessage.setFullContent(fullMessage.toString());
|
||||
MessageResponse<?> r = new AiMessageResponse(lastAiMessage);
|
||||
@ -102,11 +104,11 @@ public class BaseLlmClientListener implements LlmClientListener {
|
||||
}
|
||||
|
||||
|
||||
public interface AiMessageParser {
|
||||
AiMessage parseMessage(String response);
|
||||
}
|
||||
|
||||
public interface FunctionMessageParser {
|
||||
FunctionMessage parseMessage(String response);
|
||||
}
|
||||
// public interface AiMessageParser {
|
||||
// AiMessage parseMessage(String response);
|
||||
// }
|
||||
//
|
||||
// public interface FunctionMessageParser {
|
||||
// FunctionMessage parseMessage(String response);
|
||||
// }
|
||||
}
|
||||
|
@ -67,12 +67,7 @@ public class HttpClient {
|
||||
try {
|
||||
Response response = okHttpClient.newCall(request).execute();
|
||||
return response.body().string();
|
||||
// if (response.isSuccessful()) {
|
||||
// return response.message();
|
||||
// } else {
|
||||
// return response.body().string();
|
||||
// }
|
||||
} catch (IOException e) {
|
||||
} catch (Exception e) {
|
||||
LOG.error(e.toString(), e);
|
||||
}
|
||||
return null;
|
||||
|
@ -18,5 +18,4 @@ package com.agentsflex.parser;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
|
||||
public interface AiMessageParser extends TextParser<AiMessage> {
|
||||
|
||||
}
|
||||
|
@ -18,5 +18,4 @@ package com.agentsflex.parser;
|
||||
import com.agentsflex.message.FunctionMessage;
|
||||
|
||||
public interface FunctionMessageParser extends TextParser<FunctionMessage> {
|
||||
|
||||
}
|
||||
|
@ -0,0 +1,90 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, Agents-Flex (fuhai999@gmail.com).
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.agentsflex.parser.impl;
|
||||
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.message.MessageStatus;
|
||||
import com.agentsflex.parser.AiMessageParser;
|
||||
import com.agentsflex.parser.Parser;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.alibaba.fastjson.JSONPath;
|
||||
|
||||
|
||||
public class BaseAiMessageParser implements AiMessageParser {
|
||||
|
||||
private String contentPath;
|
||||
private String indexPath;
|
||||
private String statusPath;
|
||||
private String totalTokensPath;
|
||||
private Parser<Object, MessageStatus> statusParser;
|
||||
|
||||
public String getContentPath() {
|
||||
return contentPath;
|
||||
}
|
||||
|
||||
public void setContentPath(String contentPath) {
|
||||
this.contentPath = contentPath;
|
||||
}
|
||||
|
||||
public String getIndexPath() {
|
||||
return indexPath;
|
||||
}
|
||||
|
||||
public void setIndexPath(String indexPath) {
|
||||
this.indexPath = indexPath;
|
||||
}
|
||||
|
||||
public String getStatusPath() {
|
||||
return statusPath;
|
||||
}
|
||||
|
||||
public void setStatusPath(String statusPath) {
|
||||
this.statusPath = statusPath;
|
||||
}
|
||||
|
||||
public String getTotalTokensPath() {
|
||||
return totalTokensPath;
|
||||
}
|
||||
|
||||
public void setTotalTokensPath(String totalTokensPath) {
|
||||
this.totalTokensPath = totalTokensPath;
|
||||
}
|
||||
|
||||
public Parser<Object, MessageStatus> getStatusParser() {
|
||||
return statusParser;
|
||||
}
|
||||
|
||||
public void setStatusParser(Parser<Object, MessageStatus> statusParser) {
|
||||
this.statusParser = statusParser;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AiMessage parse(String content) {
|
||||
AiMessage aiMessage = new AiMessage();
|
||||
JSONObject rootJson = JSON.parseObject(content);
|
||||
aiMessage.setContent((String) JSONPath.eval(rootJson, this.contentPath));
|
||||
aiMessage.setIndex((Integer) JSONPath.eval(rootJson, this.indexPath));
|
||||
aiMessage.setTotalTokens((Integer) JSONPath.eval(rootJson, this.totalTokensPath));
|
||||
|
||||
String statusString = (String) JSONPath.eval(rootJson, this.statusPath);
|
||||
if (this.statusParser != null) {
|
||||
aiMessage.setStatus(this.statusParser.parse(statusString));
|
||||
}
|
||||
|
||||
return aiMessage;
|
||||
}
|
||||
}
|
@ -0,0 +1,69 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, Agents-Flex (fuhai999@gmail.com).
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.agentsflex.parser.impl;
|
||||
|
||||
import com.agentsflex.message.FunctionMessage;
|
||||
import com.agentsflex.parser.FunctionMessageParser;
|
||||
import com.agentsflex.parser.Parser;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.alibaba.fastjson.JSONPath;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public class BaseFunctionMessageParser implements FunctionMessageParser {
|
||||
|
||||
private String functionNamePath;
|
||||
private String functionArgsPath;
|
||||
private Parser<String, Map<String, Object>> functionArgsParser;
|
||||
|
||||
public String getFunctionNamePath() {
|
||||
return functionNamePath;
|
||||
}
|
||||
|
||||
public void setFunctionNamePath(String functionNamePath) {
|
||||
this.functionNamePath = functionNamePath;
|
||||
}
|
||||
|
||||
public String getFunctionArgsPath() {
|
||||
return functionArgsPath;
|
||||
}
|
||||
|
||||
public void setFunctionArgsPath(String functionArgsPath) {
|
||||
this.functionArgsPath = functionArgsPath;
|
||||
}
|
||||
|
||||
public Parser<String, Map<String, Object>> getFunctionArgsParser() {
|
||||
return functionArgsParser;
|
||||
}
|
||||
|
||||
public void setFunctionArgsParser(Parser<String, Map<String, Object>> functionArgsParser) {
|
||||
this.functionArgsParser = functionArgsParser;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FunctionMessage parse(String content) {
|
||||
FunctionMessage functionMessage = new FunctionMessage();
|
||||
JSONObject jsonObject = JSON.parseObject(content);
|
||||
String functionName = (String) JSONPath.eval(jsonObject, this.functionNamePath);
|
||||
functionMessage.setFunctionName(functionName);
|
||||
String functionArgsString = (String) JSONPath.eval(jsonObject, this.functionArgsPath);
|
||||
if (functionArgsString != null) {
|
||||
functionMessage.setArgs(this.functionArgsParser.parse(functionArgsString));
|
||||
}
|
||||
return functionMessage;
|
||||
}
|
||||
}
|
@ -0,0 +1,107 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, Agents-Flex (fuhai999@gmail.com).
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.agentsflex.prompt;
|
||||
|
||||
import com.agentsflex.functions.Function;
|
||||
import com.agentsflex.functions.Parameter;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.message.HumanMessage;
|
||||
import com.agentsflex.message.Message;
|
||||
import com.agentsflex.message.SystemMessage;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class DefaultPromptFormat implements PromptFormat {
|
||||
|
||||
@Override
|
||||
public Object toMessagesJsonKey(Prompt<?> prompt) {
|
||||
if (prompt == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
List<Message> messages = prompt.toMessages();
|
||||
if (messages == null || messages.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
List<Map<String, String>> messageArray = new ArrayList<>(messages.size());
|
||||
messages.forEach(message -> {
|
||||
Map<String, String> map = new HashMap<>(2);
|
||||
if (message instanceof HumanMessage) {
|
||||
map.put("role", "user");
|
||||
map.put("content", ((HumanMessage) message).getContent());
|
||||
} else if (message instanceof AiMessage) {
|
||||
map.put("role", "assistant");
|
||||
map.put("content", ((AiMessage) message).getFullContent());
|
||||
} else if (message instanceof SystemMessage) {
|
||||
map.put("role", "system");
|
||||
map.put("content", ((SystemMessage) message).getContent());
|
||||
}
|
||||
messageArray.add(map);
|
||||
});
|
||||
|
||||
return messageArray;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public Object toFunctionsJsonKey(Prompt<?> prompt) {
|
||||
if (!(prompt instanceof FunctionPrompt)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
List<Function<?>> functions = ((FunctionPrompt)prompt).getFunctions();
|
||||
if (functions == null || functions.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
List<Map<String, Object>> functionsJsonArray = new ArrayList<>();
|
||||
for (Function<?> function : functions) {
|
||||
Map<String, Object> functionRoot = new HashMap<>();
|
||||
functionRoot.put("type", "function");
|
||||
|
||||
Map<String, Object> functionObj = new HashMap<>();
|
||||
functionRoot.put("function", functionObj);
|
||||
|
||||
functionObj.put("name", function.getName());
|
||||
functionObj.put("description", function.getDescription());
|
||||
|
||||
|
||||
Map<String, Object> parametersObj = new HashMap<>();
|
||||
functionObj.put("parameters", parametersObj);
|
||||
|
||||
parametersObj.put("type", "object");
|
||||
|
||||
Map<String, Object> propertiesObj = new HashMap<>();
|
||||
parametersObj.put("properties", propertiesObj);
|
||||
|
||||
for (Parameter parameter : function.getParameters()) {
|
||||
Map<String, Object> parameterObj = new HashMap<>();
|
||||
parameterObj.put("type", parameter.getType());
|
||||
parameterObj.put("description", parameter.getDescription());
|
||||
parameterObj.put("enum", parameter.getEnums());
|
||||
propertiesObj.put(parameter.getName(), parameterObj);
|
||||
}
|
||||
functionsJsonArray.add(functionRoot);
|
||||
}
|
||||
|
||||
return functionsJsonArray;
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,23 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, Agents-Flex (fuhai999@gmail.com).
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.agentsflex.prompt;
|
||||
|
||||
public interface PromptFormat {
|
||||
|
||||
Object toMessagesJsonKey(Prompt<?> prompt);
|
||||
|
||||
Object toFunctionsJsonKey(Prompt<?> prompt);
|
||||
}
|
@ -26,7 +26,7 @@ public class VectorDocument extends VectorData{
|
||||
|
||||
public VectorDocument(VectorData vectorData) {
|
||||
this.setVector(vectorData.getVector());
|
||||
this.setMetadataMap(vectorData.getMetadataMap());
|
||||
this.setMetadatas(vectorData.getMetadatas());
|
||||
}
|
||||
|
||||
public String getId() {
|
||||
|
@ -1,35 +0,0 @@
|
||||
package com.agentsflex.util;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class MapUtil {
|
||||
|
||||
public static Map<String, Object> ofSingleKey(String key, Object value) {
|
||||
Map<String, Object> map = new HashMap<>();
|
||||
map.put(key, value);
|
||||
return map;
|
||||
}
|
||||
|
||||
public static MapBuilder of() {
|
||||
return new MapBuilder();
|
||||
}
|
||||
|
||||
public static MapBuilder of(String key, Object value) {
|
||||
return new MapBuilder().put(key, value);
|
||||
}
|
||||
|
||||
public static class MapBuilder {
|
||||
private Map<String, Object> map = new HashMap<>();
|
||||
|
||||
public MapBuilder put(String key, Object value) {
|
||||
map.put(key, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Map<String, Object> build() {
|
||||
return map;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
147
agents-flex-core/src/main/java/com/agentsflex/util/Maps.java
Normal file
147
agents-flex-core/src/main/java/com/agentsflex/util/Maps.java
Normal file
@ -0,0 +1,147 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, Agents-Flex (fuhai999@gmail.com).
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.agentsflex.util;
|
||||
|
||||
import java.lang.reflect.Array;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class Maps {
|
||||
|
||||
public static Builder of() {
|
||||
return new Builder();
|
||||
}
|
||||
|
||||
public static Builder of(String key, Object value) {
|
||||
return new Builder().put(key, value);
|
||||
}
|
||||
|
||||
public static Builder ofNotNull(String key, Object value) {
|
||||
return new Builder().putIfNotNull(key, value);
|
||||
}
|
||||
|
||||
public static Builder ofNotEmpty(String key, Object value) {
|
||||
return new Builder().putIfNotEmpty(key, value);
|
||||
}
|
||||
|
||||
public static Builder ofNotEmpty(String key, Builder value) {
|
||||
return new Builder().putIfNotEmpty(key, value);
|
||||
}
|
||||
|
||||
public static class Builder {
|
||||
private Map<String, Object> map = new HashMap<>();
|
||||
|
||||
public Builder put(String key, Object value) {
|
||||
map.put(key, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder put(String key, Builder value) {
|
||||
map.put(key, value.build());
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder putIf(boolean condition, String key, Builder value) {
|
||||
if (condition) put(key, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder putIf(boolean condition, String key, Object value) {
|
||||
if (condition) put(key, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder putIfNotNull(String key, Object value) {
|
||||
if (value != null) put(key, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder putIfNotEmpty(String key, Builder value) {
|
||||
Map<String, Object> map = value.build();
|
||||
if (map != null && !map.isEmpty()) put(key, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder putIfNotEmpty(String key, Object value) {
|
||||
if (value == null) {
|
||||
return this;
|
||||
}
|
||||
|
||||
if (value instanceof Collection && ((Collection<?>) value).isEmpty()) {
|
||||
return this;
|
||||
}
|
||||
|
||||
if (value instanceof Map && ((Map<?, ?>) value).isEmpty()) {
|
||||
return this;
|
||||
}
|
||||
|
||||
if (value.getClass().isArray() && Array.getLength(value) == 0) {
|
||||
return this;
|
||||
}
|
||||
|
||||
if (value instanceof String && ((String) value).isEmpty()) {
|
||||
return this;
|
||||
}
|
||||
|
||||
put(key, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
public Builder putIfContainsKey(String checkKey, String key, Object value) {
|
||||
if (map.containsKey(checkKey)) {
|
||||
this.put(key, value);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder putIfContainsKey(String checkKey, String key, Builder value) {
|
||||
if (map.containsKey(checkKey)) {
|
||||
this.put(key, value);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder putIfNotContainsKey(String checkKey, String key, Object value) {
|
||||
if (!map.containsKey(checkKey)) {
|
||||
this.put(key, value);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public Builder putIfNotContainsKey(String checkKey, String key, Builder value) {
|
||||
if (!map.containsKey(checkKey)) {
|
||||
this.put(key, value);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public Map<String, Object> build() {
|
||||
return map;
|
||||
}
|
||||
|
||||
public Object get(String key) {
|
||||
return map.get(key);
|
||||
}
|
||||
|
||||
public Map getAsMap(String key) {
|
||||
return (Map) map.get(key);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
@ -21,34 +21,41 @@ import java.util.Map;
|
||||
|
||||
public class Metadata implements Serializable {
|
||||
|
||||
protected Map<String, Object> metadataMap;
|
||||
protected Map<String, Object> metadatas;
|
||||
|
||||
public Object getMetadata(String key) {
|
||||
return metadataMap != null ? metadataMap.get(key) : null;
|
||||
return metadatas != null ? metadatas.get(key) : null;
|
||||
}
|
||||
|
||||
public void addMetadata(String key, Object value) {
|
||||
if (metadataMap == null) {
|
||||
metadataMap = new HashMap<>();
|
||||
if (metadatas == null) {
|
||||
metadatas = new HashMap<>();
|
||||
}
|
||||
metadataMap.put(key, value);
|
||||
metadatas.put(key, value);
|
||||
}
|
||||
|
||||
public void addMetadata(Map<String, Object> metadata) {
|
||||
if (metadata == null || metadata.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
if (metadataMap == null) {
|
||||
metadataMap = new HashMap<>();
|
||||
if (metadatas == null) {
|
||||
metadatas = new HashMap<>();
|
||||
}
|
||||
metadataMap.putAll(metadata);
|
||||
metadatas.putAll(metadata);
|
||||
}
|
||||
|
||||
public Map<String, Object> getMetadataMap() {
|
||||
return metadataMap;
|
||||
public Object removeMetadata(String key) {
|
||||
if (this.metadatas == null) {
|
||||
return null;
|
||||
}
|
||||
return this.metadatas.remove(key);
|
||||
}
|
||||
|
||||
public void setMetadataMap(Map<String, Object> metadataMap) {
|
||||
this.metadataMap = metadataMap;
|
||||
public Map<String, Object> getMetadatas() {
|
||||
return metadatas;
|
||||
}
|
||||
|
||||
public void setMetadatas(Map<String, Object> metadatas) {
|
||||
this.metadatas = metadatas;
|
||||
}
|
||||
}
|
||||
|
@ -15,40 +15,47 @@
|
||||
*/
|
||||
package com.agentsflex.llm.openai;
|
||||
|
||||
import com.agentsflex.functions.Function;
|
||||
import com.agentsflex.functions.Parameter;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.message.HumanMessage;
|
||||
import com.agentsflex.message.Message;
|
||||
import com.agentsflex.message.MessageStatus;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
import com.agentsflex.document.Document;
|
||||
import com.agentsflex.message.MessageStatus;
|
||||
import com.agentsflex.parser.AiMessageParser;
|
||||
import com.agentsflex.parser.FunctionMessageParser;
|
||||
import com.agentsflex.parser.impl.BaseAiMessageParser;
|
||||
import com.agentsflex.parser.impl.BaseFunctionMessageParser;
|
||||
import com.agentsflex.prompt.DefaultPromptFormat;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
import com.agentsflex.prompt.PromptFormat;
|
||||
import com.agentsflex.util.Maps;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.alibaba.fastjson.JSONPath;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class OpenAiLLmUtil {
|
||||
|
||||
public static AiMessage parseAiMessage(String json) {
|
||||
AiMessage aiMessage = new AiMessage();
|
||||
JSONObject jsonObject = JSON.parseObject(json);
|
||||
Object status = JSONPath.eval(jsonObject, "$.choices[0].finish_reason");
|
||||
MessageStatus messageStatus = parseMessageStatus((String) status);
|
||||
aiMessage.setStatus(messageStatus);
|
||||
aiMessage.setIndex((Integer) JSONPath.eval(jsonObject, "$.choices[0].index"));
|
||||
aiMessage.setContent((String) JSONPath.eval(jsonObject, "$.choices[0].delta.content"));
|
||||
return aiMessage;
|
||||
private static final PromptFormat promptFormat = new DefaultPromptFormat();
|
||||
|
||||
public static AiMessageParser getAiMessageParser() {
|
||||
BaseAiMessageParser aiMessageParser = new BaseAiMessageParser();
|
||||
aiMessageParser.setContentPath("$.choices[0].delta.content");
|
||||
aiMessageParser.setIndexPath("$.choices[0].index");
|
||||
aiMessageParser.setStatusPath("$.choices[0].finish_reason");
|
||||
aiMessageParser.setStatusParser(content -> parseMessageStatus((String) content));
|
||||
return aiMessageParser;
|
||||
}
|
||||
|
||||
|
||||
public static FunctionMessageParser getFunctionMessageParser() {
|
||||
BaseFunctionMessageParser functionMessageParser = new BaseFunctionMessageParser();
|
||||
functionMessageParser.setFunctionNamePath("$.choices[0].message.tool_calls[0].function.name");
|
||||
functionMessageParser.setFunctionArgsPath("$.choices[0].message.tool_calls[0].function.arguments");
|
||||
functionMessageParser.setFunctionArgsParser(JSON::parseObject);
|
||||
return functionMessageParser;
|
||||
}
|
||||
|
||||
|
||||
public static MessageStatus parseMessageStatus(String status) {
|
||||
return "stop".equals(status) ? MessageStatus.END : MessageStatus.MIDDLE;
|
||||
}
|
||||
|
||||
|
||||
public static String promptToEmbeddingsPayload(Document text) {
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/making-requests
|
||||
String payload = "{\n" +
|
||||
" \"input\": \"" + text.getContent() + "\",\n" +
|
||||
@ -61,96 +68,13 @@ public class OpenAiLLmUtil {
|
||||
|
||||
|
||||
public static String promptToPayload(Prompt prompt, OpenAiLlmConfig config) {
|
||||
Maps.Builder builder = Maps.of("model", config.getModel())
|
||||
.put("messages", promptFormat.toMessagesJsonKey(prompt))
|
||||
.putIfNotEmpty("tools", promptFormat.toFunctionsJsonKey(prompt))
|
||||
.putIfContainsKey("tools", "tool_choice", "auto")
|
||||
.putIfNotContainsKey("tools", "temperature", 0.7);
|
||||
|
||||
List<Message> messages = prompt.toMessages();
|
||||
|
||||
List<Map<String, String>> messageArray = new ArrayList<>();
|
||||
messages.forEach(message -> {
|
||||
Map<String, String> map = new HashMap<>(2);
|
||||
if (message instanceof HumanMessage) {
|
||||
map.put("role", "user");
|
||||
map.put("content", ((HumanMessage) message).getContent());
|
||||
} else if (message instanceof AiMessage) {
|
||||
map.put("role", "assistant");
|
||||
map.put("content", ((AiMessage) message).getFullContent());
|
||||
}
|
||||
messageArray.add(map);
|
||||
});
|
||||
|
||||
String messageText = JSON.toJSONString(messageArray);
|
||||
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/making-requests
|
||||
return "{\n" +
|
||||
// " \"model\": \"gpt-3.5-turbo\",\n" +
|
||||
" \"model\": \"" + config.getModel() + "\",\n" +
|
||||
" \"messages\": " + messageText + ",\n" +
|
||||
" \"temperature\": 0.7\n" +
|
||||
"}";
|
||||
}
|
||||
|
||||
|
||||
public static <R> String promptToFunctionCallingPayload(Prompt prompt, OpenAiLlmConfig config, List<Function<R>> functions) {
|
||||
|
||||
List<Message> messages = prompt.toMessages();
|
||||
|
||||
List<Map<String, String>> messageArray = new ArrayList<>();
|
||||
messages.forEach(message -> {
|
||||
Map<String, String> map = new HashMap<>(2);
|
||||
if (message instanceof HumanMessage) {
|
||||
map.put("role", "user");
|
||||
map.put("content", ((HumanMessage) message).getContent());
|
||||
} else if (message instanceof AiMessage) {
|
||||
map.put("role", "assistant");
|
||||
map.put("content", ((AiMessage) message).getFullContent());
|
||||
}
|
||||
|
||||
messageArray.add(map);
|
||||
});
|
||||
|
||||
String messageText = JSON.toJSONString(messageArray);
|
||||
|
||||
|
||||
List<Map<String, Object>> toolsArray = new ArrayList<>();
|
||||
for (Function<?> function : functions) {
|
||||
Map<String, Object> functionRoot = new HashMap<>();
|
||||
functionRoot.put("type", "function");
|
||||
|
||||
Map<String, Object> functionObj = new HashMap<>();
|
||||
functionRoot.put("function", functionObj);
|
||||
|
||||
functionObj.put("name", function.getName());
|
||||
functionObj.put("description", function.getDescription());
|
||||
|
||||
|
||||
Map<String, Object> parametersObj = new HashMap<>();
|
||||
functionObj.put("parameters", parametersObj);
|
||||
|
||||
parametersObj.put("type", "object");
|
||||
|
||||
Map<String, Object> propertiesObj = new HashMap<>();
|
||||
parametersObj.put("properties", propertiesObj);
|
||||
|
||||
for (Parameter parameter : function.getParameters()) {
|
||||
Map<String, Object> parameterObj = new HashMap<>();
|
||||
parameterObj.put("type", parameter.getType());
|
||||
parameterObj.put("description", parameter.getDescription());
|
||||
parameterObj.put("enum", parameter.getEnums());
|
||||
propertiesObj.put(parameter.getName(), parameterObj);
|
||||
}
|
||||
|
||||
toolsArray.add(functionRoot);
|
||||
}
|
||||
|
||||
String toolsText = JSON.toJSONString(toolsArray);
|
||||
// https://platform.openai.com/docs/api-reference/making-requests
|
||||
return "{\n" +
|
||||
// " \"model\": \"gpt-3.5-turbo\",\n" +
|
||||
" \"model\": \"" + config.getModel() + "\",\n" +
|
||||
" \"messages\": " + messageText + ",\n" +
|
||||
" \"tools\": " + toolsText + ",\n" +
|
||||
" \"tool_choice\": \"auto\"\n" +
|
||||
"}";
|
||||
return JSON.toJSONString(builder.build());
|
||||
}
|
||||
|
||||
|
||||
|
@ -16,7 +16,6 @@
|
||||
package com.agentsflex.llm.openai;
|
||||
|
||||
import com.agentsflex.document.Document;
|
||||
import com.agentsflex.functions.Function;
|
||||
import com.agentsflex.llm.BaseLlm;
|
||||
import com.agentsflex.llm.MessageListener;
|
||||
import com.agentsflex.llm.MessageResponse;
|
||||
@ -27,24 +26,24 @@ import com.agentsflex.llm.client.LlmClientListener;
|
||||
import com.agentsflex.llm.client.impl.SseClient;
|
||||
import com.agentsflex.llm.response.AiMessageResponse;
|
||||
import com.agentsflex.llm.response.FunctionMessageResponse;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.message.FunctionMessage;
|
||||
import com.agentsflex.message.Message;
|
||||
import com.agentsflex.parser.AiMessageParser;
|
||||
import com.agentsflex.parser.FunctionMessageParser;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
import com.agentsflex.util.StringUtil;
|
||||
import com.agentsflex.store.VectorData;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.agentsflex.util.StringUtil;
|
||||
import com.alibaba.fastjson.JSONPath;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> {
|
||||
|
||||
private final HttpClient httpClient = new HttpClient();
|
||||
public AiMessageParser aiMessageParser = OpenAiLLmUtil.getAiMessageParser();
|
||||
public FunctionMessageParser functionMessageParser = OpenAiLLmUtil.getFunctionMessageParser();
|
||||
|
||||
|
||||
public OpenAiLlm(OpenAiLlmConfig config) {
|
||||
super(config);
|
||||
@ -64,20 +63,10 @@ public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> {
|
||||
}
|
||||
|
||||
if (prompt instanceof FunctionPrompt) {
|
||||
List<Function<?>> functions = ((FunctionPrompt) prompt).getFunctions();
|
||||
|
||||
JSONObject jsonObject = JSON.parseObject(responseString);
|
||||
String callFunctionName = (String) JSONPath.eval(jsonObject, "$.choices[0].tool_calls[0].function.name");
|
||||
String callFunctionArgsString = (String) JSONPath.eval(jsonObject, "$.choices[0].tool_calls[0].function.arguments");
|
||||
JSONObject callFunctionArgs = JSON.parseObject(callFunctionArgsString);
|
||||
|
||||
FunctionMessage functionMessage = new FunctionMessage();
|
||||
functionMessage.setFunctionName(callFunctionName);
|
||||
functionMessage.setArgs(callFunctionArgs);
|
||||
return (R) new FunctionMessageResponse(functions, functionMessage);
|
||||
return (R) new FunctionMessageResponse(((FunctionPrompt) prompt).getFunctions()
|
||||
, functionMessageParser.parse(responseString));
|
||||
} else {
|
||||
AiMessage aiMessage = OpenAiLLmUtil.parseAiMessage(responseString);
|
||||
return (R) new AiMessageResponse(aiMessage);
|
||||
return (R) new AiMessageResponse(aiMessageParser.parse(responseString));
|
||||
}
|
||||
}
|
||||
|
||||
@ -90,7 +79,7 @@ public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> {
|
||||
|
||||
String payload = OpenAiLLmUtil.promptToPayload(prompt, config);
|
||||
|
||||
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, OpenAiLLmUtil::parseAiMessage, null);
|
||||
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, aiMessageParser, functionMessageParser);
|
||||
llmClient.start("https://api.openai.com/v1/chat/completions", headers, payload, clientListener);
|
||||
}
|
||||
|
||||
|
@ -25,8 +25,12 @@ import com.agentsflex.llm.client.LlmClient;
|
||||
import com.agentsflex.llm.client.LlmClientListener;
|
||||
import com.agentsflex.llm.client.impl.SseClient;
|
||||
import com.agentsflex.llm.response.AiMessageResponse;
|
||||
import com.agentsflex.llm.response.FunctionMessageResponse;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.message.Message;
|
||||
import com.agentsflex.parser.AiMessageParser;
|
||||
import com.agentsflex.parser.FunctionMessageParser;
|
||||
import com.agentsflex.parser.impl.BaseAiMessageParser;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
import com.agentsflex.store.VectorData;
|
||||
@ -37,12 +41,16 @@ import java.util.Map;
|
||||
|
||||
public class QwenLlm extends BaseLlm<QwenLlmConfig> {
|
||||
|
||||
|
||||
HttpClient httpClient = new HttpClient();
|
||||
|
||||
public AiMessageParser aiMessageParser = QwenLlmUtil.getAiMessageParser();
|
||||
public FunctionMessageParser functionMessageParser = QwenLlmUtil.getFunctionMessageParser();
|
||||
|
||||
public QwenLlm(QwenLlmConfig config) {
|
||||
super(config);
|
||||
}
|
||||
|
||||
HttpClient httpClient = new HttpClient();
|
||||
|
||||
|
||||
@Override
|
||||
public <R extends MessageResponse<M>, M extends Message> R chat(Prompt<M> prompt) {
|
||||
@ -58,13 +66,10 @@ public class QwenLlm extends BaseLlm<QwenLlmConfig> {
|
||||
}
|
||||
|
||||
if (prompt instanceof FunctionPrompt) {
|
||||
|
||||
return (R) new FunctionMessageResponse(((FunctionPrompt) prompt).getFunctions(), functionMessageParser.parse(responseString));
|
||||
} else {
|
||||
AiMessage aiMessage = QwenLlmUtil.parseAiMessage(responseString, 0);
|
||||
return (R) new AiMessageResponse(aiMessage);
|
||||
return (R) new AiMessageResponse(aiMessageParser.parse(responseString));
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@ -77,16 +82,18 @@ public class QwenLlm extends BaseLlm<QwenLlmConfig> {
|
||||
|
||||
String payload = QwenLlmUtil.promptToPayload(prompt, config);
|
||||
|
||||
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, new BaseLlmClientListener.AiMessageParser() {
|
||||
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, new BaseAiMessageParser() {
|
||||
int prevMessageLength = 0;
|
||||
|
||||
@Override
|
||||
public AiMessage parseMessage(String response) {
|
||||
AiMessage aiMessage = QwenLlmUtil.parseAiMessage(response, prevMessageLength);
|
||||
prevMessageLength += aiMessage.getContent().length();
|
||||
public AiMessage parse(String content) {
|
||||
AiMessage aiMessage = aiMessageParser.parse(content);
|
||||
String messageContent = aiMessage.getContent();
|
||||
aiMessage.setContent(messageContent.substring(prevMessageLength));
|
||||
prevMessageLength = messageContent.length();
|
||||
return aiMessage;
|
||||
}
|
||||
}, null);
|
||||
}, functionMessageParser);
|
||||
llmClient.start("https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation", headers, payload, clientListener);
|
||||
}
|
||||
|
||||
|
@ -15,33 +15,37 @@
|
||||
*/
|
||||
package com.agentsflex.llm.qwen;
|
||||
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.message.HumanMessage;
|
||||
import com.agentsflex.message.Message;
|
||||
import com.agentsflex.message.MessageStatus;
|
||||
import com.agentsflex.parser.AiMessageParser;
|
||||
import com.agentsflex.parser.FunctionMessageParser;
|
||||
import com.agentsflex.parser.impl.BaseAiMessageParser;
|
||||
import com.agentsflex.parser.impl.BaseFunctionMessageParser;
|
||||
import com.agentsflex.prompt.DefaultPromptFormat;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
import com.agentsflex.prompt.PromptFormat;
|
||||
import com.agentsflex.util.Maps;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.alibaba.fastjson.JSONPath;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class QwenLlmUtil {
|
||||
|
||||
private static final PromptFormat promptFormat = new DefaultPromptFormat();
|
||||
|
||||
public static AiMessage parseAiMessage(String json, int length) {
|
||||
AiMessage aiMessage = new AiMessage();
|
||||
JSONObject jsonObject = JSON.parseObject(json);
|
||||
MessageStatus messageStatus = parseMessageStatus((String) JSONPath.eval(json, "$.output.finish_reason"));
|
||||
aiMessage.setStatus(messageStatus);
|
||||
public static AiMessageParser getAiMessageParser() {
|
||||
BaseAiMessageParser aiMessageParser = new BaseAiMessageParser();
|
||||
aiMessageParser.setContentPath("$.output.text");
|
||||
aiMessageParser.setStatusPath("$.output.finish_reason");
|
||||
aiMessageParser.setTotalTokensPath("$.usage.total_tokens");
|
||||
aiMessageParser.setStatusParser(content -> parseMessageStatus((String) content));
|
||||
return aiMessageParser;
|
||||
}
|
||||
|
||||
String text = (String) JSONPath.eval(jsonObject, "$.output.text");
|
||||
aiMessage.setContent(text.substring(length));
|
||||
aiMessage.setTotalTokens((Integer) JSONPath.eval(jsonObject, "$.usage.total_tokens"));
|
||||
return aiMessage;
|
||||
|
||||
public static FunctionMessageParser getFunctionMessageParser() {
|
||||
BaseFunctionMessageParser functionMessageParser = new BaseFunctionMessageParser();
|
||||
functionMessageParser.setFunctionNamePath("$.choices[0].message.tool_calls[0].function.name");
|
||||
functionMessageParser.setFunctionArgsPath("$.choices[0].message.tool_calls[0].function.arguments");
|
||||
functionMessageParser.setFunctionArgsParser(JSON::parseObject);
|
||||
return functionMessageParser;
|
||||
}
|
||||
|
||||
|
||||
@ -50,34 +54,9 @@ public class QwenLlmUtil {
|
||||
}
|
||||
|
||||
|
||||
public static String promptToPayload(Prompt prompt, QwenLlmConfig config) {
|
||||
|
||||
List<Message> messages = prompt.toMessages();
|
||||
|
||||
public static String promptToPayload(Prompt<?> prompt, QwenLlmConfig config) {
|
||||
// https://help.aliyun.com/zh/dashscope/developer-reference/api-details?spm=a2c4g.11186623.0.0.1ff6fa70jCgGRc#b8ebf6b25eul6
|
||||
String payload = "{\n" +
|
||||
" \"model\": \"" + config.getModel() + "\",\n" +
|
||||
" \"input\": {\n" +
|
||||
" \"messages\": messageJsonString\n" +
|
||||
" }\n" +
|
||||
"}";
|
||||
|
||||
|
||||
List<Map<String, String>> messageArray = new ArrayList<>();
|
||||
messages.forEach(message -> {
|
||||
Map<String, String> map = new HashMap<>(2);
|
||||
if (message instanceof HumanMessage) {
|
||||
map.put("role", "user");
|
||||
map.put("content", ((HumanMessage) message).getContent());
|
||||
} else if (message instanceof AiMessage) {
|
||||
map.put("role", "assistant");
|
||||
map.put("content", ((AiMessage) message).getFullContent());
|
||||
}
|
||||
|
||||
messageArray.add(map);
|
||||
});
|
||||
|
||||
String messageText = JSON.toJSONString(messageArray);
|
||||
return payload.replace("messageJsonString", messageText);
|
||||
Maps.Builder root = Maps.of("model", config.getModel()).put("input", Maps.of("messages", promptFormat.toMessagesJsonKey(prompt)));
|
||||
return JSON.toJSONString(root.build());
|
||||
}
|
||||
}
|
||||
|
@ -16,7 +16,6 @@
|
||||
package com.agentsflex.llm.spark;
|
||||
|
||||
import com.agentsflex.document.Document;
|
||||
import com.agentsflex.functions.Function;
|
||||
import com.agentsflex.llm.BaseLlm;
|
||||
import com.agentsflex.llm.ChatContext;
|
||||
import com.agentsflex.llm.MessageListener;
|
||||
@ -30,18 +29,21 @@ import com.agentsflex.llm.response.FunctionMessageResponse;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.message.FunctionMessage;
|
||||
import com.agentsflex.message.Message;
|
||||
import com.agentsflex.parser.AiMessageParser;
|
||||
import com.agentsflex.parser.FunctionMessageParser;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
import com.agentsflex.store.VectorData;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.alibaba.fastjson.JSONPath;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
|
||||
public AiMessageParser aiMessageParser = SparkLlmUtil.getAiMessageParser();
|
||||
public FunctionMessageParser functionMessageParser = SparkLlmUtil.getFunctionMessageParser();
|
||||
|
||||
|
||||
|
||||
public SparkLlm(SparkLlmConfig config) {
|
||||
super(config);
|
||||
}
|
||||
@ -51,6 +53,8 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public <R extends MessageResponse<M>, M extends Message> R chat(Prompt<M> prompt) {
|
||||
@ -91,6 +95,7 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Override
|
||||
public <R extends MessageResponse<M>, M extends Message> void chatAsync(Prompt<M> prompt, MessageListener<R, M> listener) {
|
||||
LlmClient llmClient = new WebSocketClient();
|
||||
@ -98,20 +103,7 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
|
||||
String payload = SparkLlmUtil.promptToPayload(prompt, config);
|
||||
|
||||
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, SparkLlmUtil::parseAiMessage, new BaseLlmClientListener.FunctionMessageParser() {
|
||||
@Override
|
||||
public FunctionMessage parseMessage(String response) {
|
||||
JSONObject jsonObject = JSON.parseObject(response);
|
||||
String callFunctionName = (String) JSONPath.eval(jsonObject, "$.payload.choices.text[0].function_call.name");
|
||||
String callFunctionArgsString = (String) JSONPath.eval(jsonObject, "$.payload.choices.text[0].function_call.arguments");
|
||||
JSONObject callFunctionArgs = JSON.parseObject(callFunctionArgsString);
|
||||
|
||||
FunctionMessage functionMessage = new FunctionMessage();
|
||||
functionMessage.setFunctionName(callFunctionName);
|
||||
functionMessage.setArgs(callFunctionArgs);
|
||||
return functionMessage;
|
||||
}
|
||||
});
|
||||
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, aiMessageParser, functionMessageParser);
|
||||
llmClient.start(url, null, payload, clientListener);
|
||||
}
|
||||
|
||||
|
@ -15,119 +15,58 @@
|
||||
*/
|
||||
package com.agentsflex.llm.spark;
|
||||
|
||||
import com.agentsflex.functions.Function;
|
||||
import com.agentsflex.functions.Parameter;
|
||||
import com.agentsflex.message.*;
|
||||
import com.agentsflex.message.MessageStatus;
|
||||
import com.agentsflex.parser.AiMessageParser;
|
||||
import com.agentsflex.parser.FunctionMessageParser;
|
||||
import com.agentsflex.parser.impl.BaseAiMessageParser;
|
||||
import com.agentsflex.parser.impl.BaseFunctionMessageParser;
|
||||
import com.agentsflex.prompt.DefaultPromptFormat;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
import com.agentsflex.prompt.PromptFormat;
|
||||
import com.agentsflex.util.HashUtil;
|
||||
import com.agentsflex.util.Maps;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.alibaba.fastjson.JSONPath;
|
||||
|
||||
import java.io.UnsupportedEncodingException;
|
||||
import java.net.URLEncoder;
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.*;
|
||||
import java.util.Base64;
|
||||
import java.util.Date;
|
||||
import java.util.Locale;
|
||||
import java.util.UUID;
|
||||
|
||||
public class SparkLlmUtil {
|
||||
|
||||
public static AiMessage parseAiMessage(String json) {
|
||||
AiMessage aiMessage = new AiMessage();
|
||||
JSONObject jsonObject = JSON.parseObject(json);
|
||||
Object status = JSONPath.eval(jsonObject, "$.payload.choices.status");
|
||||
MessageStatus messageStatus = SparkLlmUtil.parseMessageStatus((Integer) status);
|
||||
aiMessage.setStatus(messageStatus);
|
||||
aiMessage.setIndex((Integer) JSONPath.eval(jsonObject, "$.payload.choices.text[0].index"));
|
||||
aiMessage.setContent((String) JSONPath.eval(jsonObject, "$.payload.choices.text[0].content"));
|
||||
return aiMessage;
|
||||
private static final PromptFormat promptFormat = new DefaultPromptFormat();
|
||||
|
||||
public static AiMessageParser getAiMessageParser() {
|
||||
BaseAiMessageParser aiMessageParser = new BaseAiMessageParser();
|
||||
aiMessageParser.setContentPath("$.payload.choices.text[0].content");
|
||||
aiMessageParser.setIndexPath("$.payload.choices.text[0].index");
|
||||
aiMessageParser.setStatusPath("$.payload.choices.status");
|
||||
aiMessageParser.setStatusParser(content -> parseMessageStatus((Integer) content));
|
||||
return aiMessageParser;
|
||||
}
|
||||
|
||||
|
||||
public static FunctionMessageParser getFunctionMessageParser() {
|
||||
BaseFunctionMessageParser functionMessageParser = new BaseFunctionMessageParser();
|
||||
functionMessageParser.setFunctionNamePath("$.payload.choices.text[0].function_call.name");
|
||||
functionMessageParser.setFunctionArgsPath("$.payload.choices.text[0].function_call.arguments");
|
||||
functionMessageParser.setFunctionArgsParser(JSON::parseObject);
|
||||
return functionMessageParser;
|
||||
}
|
||||
|
||||
|
||||
public static String promptToPayload(Prompt prompt, SparkLlmConfig config) {
|
||||
|
||||
|
||||
// https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E
|
||||
String payload = "{\n" +
|
||||
" \"header\": {\n" +
|
||||
" \"app_id\": \"" + config.getAppId() + "\",\n" +
|
||||
" \"uid\": \"" + UUID.randomUUID() + "\"\n" +
|
||||
" },\n" +
|
||||
" \"parameter\": {\n" +
|
||||
" \"chat\": {\n" +
|
||||
" \"domain\": \"generalv3\",\n" +
|
||||
" \"temperature\": 0.5,\n" +
|
||||
" \"max_tokens\": 1024 \n" +
|
||||
" }\n" +
|
||||
" },\n" +
|
||||
" \"payload\": {\n" +
|
||||
" \"message\": {\n" +
|
||||
" \"text\": messageJsonString" +
|
||||
" },\n" +
|
||||
" \"functions\":functionsJsonString" +
|
||||
" }\n" +
|
||||
"}";
|
||||
|
||||
List<Message> messages = prompt.toMessages();
|
||||
List<Map<String, String>> messageArray = new ArrayList<>();
|
||||
messages.forEach(message -> {
|
||||
Map<String, String> map = new HashMap<>(2);
|
||||
if (message instanceof HumanMessage) {
|
||||
map.put("role", "user");
|
||||
map.put("content", ((HumanMessage) message).getContent());
|
||||
} else if (message instanceof AiMessage) {
|
||||
map.put("role", "assistant");
|
||||
map.put("content", ((AiMessage) message).getFullContent());
|
||||
} else if (message instanceof SystemMessage) {
|
||||
map.put("role", "system");
|
||||
map.put("content", ((SystemMessage) message).getContent());
|
||||
}
|
||||
messageArray.add(map);
|
||||
});
|
||||
|
||||
|
||||
|
||||
String functionsJsonString = "\"\"";
|
||||
if (prompt instanceof FunctionPrompt) {
|
||||
List<Function<?>> functions = ((FunctionPrompt) prompt).getFunctions();
|
||||
|
||||
List<Map<String, Object>> functionsArray = new ArrayList<>();
|
||||
for (Function<?> function : functions) {
|
||||
Map<String, Object> functionRoot = new HashMap<>();
|
||||
functionRoot.put("type", "function");
|
||||
|
||||
Map<String, Object> functionObj = new HashMap<>();
|
||||
functionRoot.put("function", functionObj);
|
||||
|
||||
functionObj.put("name", function.getName());
|
||||
functionObj.put("description", function.getDescription());
|
||||
|
||||
|
||||
Map<String, Object> parametersObj = new HashMap<>();
|
||||
functionObj.put("parameters", parametersObj);
|
||||
|
||||
parametersObj.put("type", "object");
|
||||
|
||||
Map<String, Object> propertiesObj = new HashMap<>();
|
||||
parametersObj.put("properties", propertiesObj);
|
||||
|
||||
for (Parameter parameter : function.getParameters()) {
|
||||
Map<String, Object> parameterObj = new HashMap<>();
|
||||
parameterObj.put("type", parameter.getType());
|
||||
parameterObj.put("description", parameter.getDescription());
|
||||
parameterObj.put("enum", parameter.getEnums());
|
||||
propertiesObj.put(parameter.getName(), parameterObj);
|
||||
}
|
||||
|
||||
functionsArray.add(functionRoot);
|
||||
}
|
||||
Map<String, Object> functionsJsonMap = new HashMap<>();
|
||||
functionsJsonMap.put("text", functionsArray);
|
||||
functionsJsonString = JSON.toJSONString(functionsJsonMap);
|
||||
}
|
||||
|
||||
String messageText = JSON.toJSONString(messageArray);
|
||||
return payload.replace("messageJsonString", messageText).replace("functionsJsonString", functionsJsonString);
|
||||
Maps.Builder root = Maps.of("header", Maps.of("app_id", config.getAppId()).put("uid", UUID.randomUUID()));
|
||||
root.put("parameter", Maps.of("chat", Maps.of("domain", "generalv3").put("temperature", 0.5).put("max_tokens", 1024)));
|
||||
root.put("payload", Maps.of("message", Maps.of("text", promptFormat.toMessagesJsonKey(prompt)))
|
||||
.putIfNotEmpty("functions", Maps.ofNotNull("text", promptFormat.toFunctionsJsonKey((FunctionPrompt) prompt)))
|
||||
);
|
||||
return JSON.toJSONString(root.build());
|
||||
}
|
||||
|
||||
public static MessageStatus parseMessageStatus(Integer status) {
|
||||
@ -139,7 +78,6 @@ public class SparkLlmUtil {
|
||||
case 2:
|
||||
return MessageStatus.END;
|
||||
}
|
||||
|
||||
return MessageStatus.UNKNOW;
|
||||
}
|
||||
|
||||
|
@ -31,11 +31,11 @@ import java.util.*;
|
||||
*/
|
||||
public class AliyunVectorStore extends VectorStore<VectorDocument> {
|
||||
|
||||
private AliyunVectorStorageConfig config;
|
||||
private AliyunVectorStoreConfig config;
|
||||
|
||||
private final HttpClient httpUtil = new HttpClient();
|
||||
|
||||
public AliyunVectorStore(AliyunVectorStorageConfig config) {
|
||||
public AliyunVectorStore(AliyunVectorStoreConfig config) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
@ -50,8 +50,8 @@ public class AliyunVectorStore extends VectorStore<VectorDocument> {
|
||||
List<Map<String,Object>> payloadDocs = new ArrayList<>();
|
||||
for (VectorDocument vectorDocument : documents) {
|
||||
Map<String, Object> document = new HashMap<>();
|
||||
if (vectorDocument.getMetadataMap() != null) {
|
||||
document.put("fields", vectorDocument.getMetadataMap());
|
||||
if (vectorDocument.getMetadatas() != null) {
|
||||
document.put("fields", vectorDocument.getMetadatas());
|
||||
}
|
||||
document.put("vector", vectorDocument.getVector());
|
||||
document.put("id", vectorDocument.getId());
|
||||
@ -79,6 +79,7 @@ public class AliyunVectorStore extends VectorStore<VectorDocument> {
|
||||
httpUtil.delete("https://" + config.getEndpoint() + "/v1/collections/" + config.getCollection() + "/docs", headers, payload);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void update(List<VectorDocument> documents) {
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
@ -90,8 +91,8 @@ public class AliyunVectorStore extends VectorStore<VectorDocument> {
|
||||
List<Map<String,Object>> payloadDocs = new ArrayList<>();
|
||||
for (VectorDocument vectorDocument : documents) {
|
||||
Map<String, Object> document = new HashMap<>();
|
||||
if (vectorDocument.getMetadataMap() != null) {
|
||||
document.put("fields", vectorDocument.getMetadataMap());
|
||||
if (vectorDocument.getMetadatas() != null) {
|
||||
document.put("fields", vectorDocument.getMetadatas());
|
||||
}
|
||||
document.put("vector", vectorDocument.getVector());
|
||||
document.put("id", vectorDocument.getId());
|
||||
|
@ -20,7 +20,7 @@ import java.io.Serializable;
|
||||
/**
|
||||
* https://help.aliyun.com/document_detail/2510317.html
|
||||
*/
|
||||
public class AliyunVectorStorageConfig implements Serializable {
|
||||
public class AliyunVectorStoreConfig implements Serializable {
|
||||
private String endpoint;
|
||||
private String apiKey;
|
||||
private String database;
|
@ -31,12 +31,12 @@ import java.util.*;
|
||||
*/
|
||||
public class QCloudVectorStore extends VectorStore<VectorDocument> {
|
||||
|
||||
private QCloudVectorStorageConfig config;
|
||||
private QCloudVectorStoreConfig config;
|
||||
|
||||
private final HttpClient httpUtil = new HttpClient();
|
||||
|
||||
|
||||
public QCloudVectorStore(QCloudVectorStorageConfig config) {
|
||||
public QCloudVectorStore(QCloudVectorStoreConfig config) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
@ -55,8 +55,8 @@ public class QCloudVectorStore extends VectorStore<VectorDocument> {
|
||||
List<Map<String, Object>> payloadDocs = new ArrayList<>();
|
||||
for (VectorDocument vectorDocument : documents) {
|
||||
Map<String, Object> document = new HashMap<>();
|
||||
if (vectorDocument.getMetadataMap() != null) {
|
||||
document.putAll(vectorDocument.getMetadataMap());
|
||||
if (vectorDocument.getMetadatas() != null) {
|
||||
document.putAll(vectorDocument.getMetadatas());
|
||||
}
|
||||
document.put("vector", vectorDocument.getVector());
|
||||
document.put("id", vectorDocument.getId());
|
||||
@ -103,7 +103,7 @@ public class QCloudVectorStore extends VectorStore<VectorDocument> {
|
||||
Map<String, Object> documentIdsObj = new HashMap<>();
|
||||
documentIdsObj.put("documentIds", Collections.singletonList(document.getId()));
|
||||
payloadMap.put("query", documentIdsObj);
|
||||
payloadMap.put("update", document.getMetadataMap());
|
||||
payloadMap.put("update", document.getMetadatas());
|
||||
String payload = JSON.toJSONString(payloadMap);
|
||||
httpUtil.post(config.getHost() + "/document/update", headers, payload);
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ package com.agentsflex.store.qcloud;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
public class QCloudVectorStorageConfig implements Serializable {
|
||||
public class QCloudVectorStoreConfig implements Serializable {
|
||||
|
||||
private String host;
|
||||
private String apiKey;
|
@ -16,5 +16,11 @@
|
||||
<maven.compiler.target>8</maven.compiler.target>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
</properties>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.alibaba</groupId>
|
||||
<artifactId>fastjson</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
||||
|
@ -0,0 +1,7 @@
|
||||
package com.agentsflex;
|
||||
|
||||
public class JavaMain {
|
||||
|
||||
public static void main(String[] args) {
|
||||
}
|
||||
}
|
19
pom.xml
19
pom.xml
@ -56,11 +56,30 @@
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<slf4j.version>1.7.29</slf4j.version>
|
||||
<junit.version>4.13.2</junit.version>
|
||||
<okhttp.version>4.9.3</okhttp.version>
|
||||
<fastjson.version>2.0.45</fastjson.version>
|
||||
</properties>
|
||||
|
||||
|
||||
<dependencyManagement>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.squareup.okhttp3</groupId>
|
||||
<artifactId>okhttp</artifactId>
|
||||
<version>${okhttp.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.squareup.okhttp3</groupId>
|
||||
<artifactId>okhttp-sse</artifactId>
|
||||
<version>${okhttp.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.alibaba</groupId>
|
||||
<artifactId>fastjson</artifactId>
|
||||
<version>${fastjson.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
|
Loading…
Reference in New Issue
Block a user