refactor: optimize and refactor LLM module

This commit is contained in:
Michael Yang 2024-01-26 17:02:29 +08:00
parent d1f7c05c54
commit 1fc00e4563
29 changed files with 456 additions and 148 deletions

View File

@ -26,15 +26,15 @@ import java.util.List;
public class Functions<T> extends ArrayList<Function<T>> {
public static <R> Functions<R> from(Class<?> clazz, Class<R> resultType, String... methodNames) {
public static <R> Functions<R> from(Class<?> clazz, String... methodNames) {
List<Method> methodList = ClassUtil.getAllMethods(clazz, method -> {
if (Modifier.isStatic(method.getModifiers())) {
return false;
}
if (!resultType.isAssignableFrom(method.getReturnType())) {
return false;
}
// if (!resultType.isAssignableFrom(method.getReturnType())) {
// return false;
// }
if (method.getAnnotation(FunctionDef.class) == null) {
return false;

View File

@ -15,7 +15,7 @@
*/
package com.agentsflex.llm;
public abstract class BaseLlm<T extends LlmConfig> extends Llm{
public abstract class BaseLlm<T extends LlmConfig> implements Llm {
protected T config;

View File

@ -0,0 +1,44 @@
/*
* Copyright (c) 2022-2023, Agents-Flex (fuhai999@gmail.com).
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.agentsflex.llm;
import com.agentsflex.llm.client.LlmClient;
public class ChatContext {
private Llm llm;
private LlmClient client;
public ChatContext(Llm llm, LlmClient client) {
this.llm = llm;
this.client = client;
}
public Llm getLlm() {
return llm;
}
public void setLlm(Llm llm) {
this.llm = llm;
}
public LlmClient getClient() {
return client;
}
public void setClient(LlmClient client) {
this.client = client;
}
}

View File

@ -15,15 +15,16 @@
*/
package com.agentsflex.llm;
import com.agentsflex.message.AiMessage;
public interface ChatListener {
default void onStart(Llm llm){}
default void onStart(ChatContext context) {
}
void onMessage(Llm llm, AiMessage aiMessage);
void onMessage(ChatContext context, ChatResponse<?> response);
default void onStop(Llm llm){}
default void onStop(ChatContext context) {
}
default void onFailure(Llm llm, Throwable throwable){}
default void onFailure(ChatContext context, Throwable throwable) {
}
}

View File

@ -15,13 +15,8 @@
*/
package com.agentsflex.llm;
import com.agentsflex.functions.Function;
import com.agentsflex.prompt.Prompt;
import java.util.List;
public interface FunctionCalling {
<R> R call(Prompt prompt, List<Function<R>> functions);
import com.agentsflex.message.Message;
public interface ChatResponse<T extends Message> {
T getMessage();
}

View File

@ -20,6 +20,6 @@ import com.agentsflex.vector.VectorData;
public interface Embeddings {
VectorData embeddings(Document prompt);
VectorData embeddings(Document document);
}

View File

@ -15,16 +15,23 @@
*/
package com.agentsflex.llm;
import com.agentsflex.llm.client.LlmClient;
import com.agentsflex.prompt.Prompt;
import com.agentsflex.prompt.SimplePrompt;
public abstract class Llm implements Embeddings {
public interface Llm extends Embeddings {
public LlmClient chat(String prompt, ChatListener listener) {
return chat(new SimplePrompt(prompt), listener);
default String chat(String prompt) {
ChatResponse<?> chat = chat(new SimplePrompt(prompt));
return chat != null ? chat.getMessage().getContent() : null;
}
public abstract LlmClient chat(Prompt prompt, ChatListener listener);
<T extends ChatResponse<?>> T chat(Prompt<T> prompt);
default void chatAsync(String prompt, ChatListener listener) {
this.chatAsync(new SimplePrompt(prompt), listener);
}
void chatAsync(Prompt prompt, ChatListener listener);
}

View File

@ -15,66 +15,96 @@
*/
package com.agentsflex.llm.client;
import com.agentsflex.functions.Function;
import com.agentsflex.llm.ChatContext;
import com.agentsflex.llm.ChatListener;
import com.agentsflex.llm.ChatResponse;
import com.agentsflex.llm.Llm;
import com.agentsflex.llm.response.FunctionResultResponse;
import com.agentsflex.llm.response.MessageResponse;
import com.agentsflex.message.AiMessage;
import com.agentsflex.message.FunctionMessage;
import com.agentsflex.prompt.FunctionPrompt;
import com.agentsflex.prompt.HistoriesPrompt;
import com.agentsflex.prompt.Prompt;
import java.util.List;
public class BaseLlmClientListener implements LlmClientListener {
private final Llm llm;
private final LlmClient client;
private final ChatListener chatListener;
private final Prompt prompt;
private final MessageParser messageParser;
private final AiMessageParser messageParser;
private final FunctionMessageParser functionInfoParser;
private final StringBuilder fullMessage = new StringBuilder();
private AiMessage lastAiMessage;
private boolean isFunctionCalling = false;
public BaseLlmClientListener(Llm llm, ChatListener chatListener, Prompt prompt, MessageParser messageParser) {
public BaseLlmClientListener(Llm llm, LlmClient client, ChatListener chatListener, Prompt prompt
, AiMessageParser messageParser
, FunctionMessageParser functionInfoParser) {
this.llm = llm;
this.client = client;
this.chatListener = chatListener;
this.prompt = prompt;
this.messageParser = messageParser;
this.functionInfoParser = functionInfoParser;
if (prompt instanceof FunctionPrompt) {
if (functionInfoParser == null) {
throw new IllegalArgumentException("Can not support Function Calling");
} else {
isFunctionCalling = true;
}
}
}
@Override
public void onStart(LlmClient client) {
chatListener.onStart(llm);
chatListener.onStart(new ChatContext(llm, client));
}
@Override
public void onMessage(LlmClient client, String response) {
lastAiMessage = messageParser.parseMessage(response);
fullMessage.append(lastAiMessage.getContent());
chatListener.onMessage(llm, lastAiMessage);
if (isFunctionCalling) {
FunctionMessage functionInfo = functionInfoParser.parseMessage(response);
List<Function<?>> functions = ((FunctionPrompt) prompt).getFunctions();
ChatResponse<?> r = new FunctionResultResponse(functions, functionInfo);
chatListener.onMessage(new ChatContext(llm, client), r);
} else {
lastAiMessage = messageParser.parseMessage(response);
fullMessage.append(lastAiMessage.getContent());
lastAiMessage.setFullContent(fullMessage.toString());
ChatResponse<?> r = new MessageResponse(lastAiMessage);
chatListener.onMessage(new ChatContext(llm, client), r);
}
}
@Override
public void onStop(LlmClient client) {
if (lastAiMessage != null){
lastAiMessage.setFullContent(fullMessage.toString());
if (this.prompt instanceof HistoriesPrompt){
if (lastAiMessage != null) {
if (this.prompt instanceof HistoriesPrompt) {
((HistoriesPrompt) this.prompt).addMessage(lastAiMessage);
}
}
chatListener.onStop(llm);
chatListener.onStop(new ChatContext(llm, client));
}
@Override
public void onFailure(LlmClient client, Throwable throwable) {
chatListener.onFailure(llm, throwable);
chatListener.onFailure(new ChatContext(llm, client), throwable);
}
public interface MessageParser{
public interface AiMessageParser {
AiMessage parseMessage(String response);
}
public interface FunctionMessageParser {
FunctionMessage parseMessage(String response);
}
}

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.agentsflex.util;
package com.agentsflex.llm.client;
import okhttp3.*;
import org.slf4j.Logger;
@ -23,13 +23,13 @@ import java.io.IOException;
import java.util.Map;
import java.util.concurrent.TimeUnit;
public class OKHttpUtil {
private static final Logger LOG = LoggerFactory.getLogger(OKHttpUtil.class);
public class HttpClient {
private static final Logger LOG = LoggerFactory.getLogger(HttpClient.class);
private static final MediaType JSON_TYPE = MediaType.parse("application/json; charset=utf-8");
private final OkHttpClient okHttpClient;
public OKHttpUtil() {
public HttpClient() {
this.okHttpClient = new OkHttpClient.Builder()
.connectTimeout(3, TimeUnit.MINUTES)
.readTimeout(3, TimeUnit.MINUTES)
@ -37,7 +37,7 @@ public class OKHttpUtil {
}
public OKHttpUtil(OkHttpClient okHttpClient) {
public HttpClient(OkHttpClient okHttpClient) {
this.okHttpClient = okHttpClient;
}

View File

@ -24,7 +24,7 @@ import java.io.IOException;
import java.util.Map;
import java.util.concurrent.TimeUnit;
public class HttpClient implements LlmClient {
public class AsyncHttpClient implements LlmClient {
private static final MediaType JSON_TYPE = MediaType.parse("application/json; charset=utf-8");
private OkHttpClient client;
private LlmClientListener listener;
@ -55,15 +55,15 @@ public class HttpClient implements LlmClient {
this.client.newCall(rBuilder.build()).enqueue(new Callback() {
@Override
public void onFailure(@NotNull Call call, @NotNull IOException e) {
HttpClient.this.listener.onFailure(HttpClient.this, e);
AsyncHttpClient.this.listener.onFailure(AsyncHttpClient.this, e);
}
@Override
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
HttpClient.this.listener.onMessage(HttpClient.this, response.message());
AsyncHttpClient.this.listener.onMessage(AsyncHttpClient.this, response.message());
if (!isStop) {
HttpClient.this.isStop = true;
HttpClient.this.listener.onStop(HttpClient.this);
AsyncHttpClient.this.isStop = true;
AsyncHttpClient.this.listener.onStop(AsyncHttpClient.this);
}
}
});

View File

@ -79,6 +79,10 @@ public class WebSocketClient extends WebSocketListener implements LlmClient {
@Override
public void onClosing(WebSocket webSocket, int code, String reason) {
if (!isStop) {
this.isStop = true;
this.listener.onStop(this);
}
}
@Override

View File

@ -0,0 +1,47 @@
/*
* Copyright (c) 2022-2023, Agents-Flex (fuhai999@gmail.com).
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.agentsflex.llm.response;
import com.agentsflex.functions.Function;
import com.agentsflex.llm.ChatResponse;
import com.agentsflex.message.FunctionMessage;
import java.util.List;
public class FunctionResultResponse implements ChatResponse<FunctionMessage> {
private final List<Function<?>> functions;
private final FunctionMessage functionMessage;
public FunctionResultResponse(List<Function<?>> functions, FunctionMessage functionMessage) {
this.functions = functions;
this.functionMessage = functionMessage;
}
public Object invoke() {
for (Function<?> function : functions) {
if (function.getName().equals(functionMessage.getFunctionName())) {
return function.invoke(functionMessage.getArgs());
}
}
return null;
}
@Override
public FunctionMessage getMessage() {
return functionMessage;
}
}

View File

@ -0,0 +1,32 @@
/*
* Copyright (c) 2022-2023, Agents-Flex (fuhai999@gmail.com).
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.agentsflex.llm.response;
import com.agentsflex.llm.ChatResponse;
import com.agentsflex.message.AiMessage;
public class MessageResponse implements ChatResponse<AiMessage> {
private final AiMessage aiMessage;
public MessageResponse(AiMessage aiMessage) {
this.aiMessage = aiMessage;
}
@Override
public AiMessage getMessage() {
return aiMessage;
}
}

View File

@ -17,10 +17,17 @@ package com.agentsflex.memory;
import com.agentsflex.message.Message;
import java.util.Collection;
import java.util.List;
public interface MessageMemory extends Memory {
List<Message> getMessages();
void addMessage(Message message);
default void addMessages(Collection<Message> messages){
for (Message message : messages) {
addMessage(message);
}
}
}

View File

@ -21,7 +21,6 @@ public class AiMessage extends Message {
private Integer index;
private MessageStatus status;
private int totalTokens;
private String fullContent;

View File

@ -15,5 +15,26 @@
*/
package com.agentsflex.message;
import java.util.Map;
public class FunctionMessage extends Message{
private String functionName;
private Map<String,Object> args;
public String getFunctionName() {
return functionName;
}
public void setFunctionName(String functionName) {
this.functionName = functionName;
}
public Map<String, Object> getArgs() {
return args;
}
public void setArgs(Map<String, Object> args) {
this.args = args;
}
}

View File

@ -0,0 +1,53 @@
/*
* Copyright (c) 2022-2023, Agents-Flex (fuhai999@gmail.com).
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.agentsflex.prompt;
import com.agentsflex.functions.Function;
import com.agentsflex.functions.Functions;
import com.agentsflex.llm.response.FunctionResultResponse;
import com.agentsflex.memory.DefaultMessageMemory;
import com.agentsflex.memory.MessageMemory;
import com.agentsflex.message.HumanMessage;
import com.agentsflex.message.Message;
import java.util.ArrayList;
import java.util.List;
public class FunctionPrompt extends Prompt<FunctionResultResponse> {
private MessageMemory memory = new DefaultMessageMemory();
private List<Function<?>> functions = new ArrayList<>();
public FunctionPrompt(String prompt, Class<?> funcClass) {
memory.addMessage(new HumanMessage(prompt));
functions.addAll(Functions.from(funcClass));
}
public FunctionPrompt(List<Message> messages, Class<?> funcClass) {
memory.addMessages(messages);
functions.addAll(Functions.from(funcClass));
}
@Override
public List<Message> getMessages() {
return memory.getMessages();
}
public List<Function<?>> getFunctions() {
return functions;
}
}

View File

@ -15,13 +15,14 @@
*/
package com.agentsflex.prompt;
import com.agentsflex.llm.response.MessageResponse;
import com.agentsflex.memory.DefaultMessageMemory;
import com.agentsflex.memory.MessageMemory;
import com.agentsflex.message.Message;
import java.util.List;
public class HistoriesPrompt extends Prompt{
public class HistoriesPrompt extends Prompt<MessageResponse>{
private MessageMemory memory = new DefaultMessageMemory();

View File

@ -21,7 +21,7 @@ import com.agentsflex.util.Metadata;
import java.util.List;
public abstract class Prompt extends Metadata {
public abstract class Prompt<ChatResponse> extends Metadata {
public abstract List<Message> getMessages();

View File

@ -15,13 +15,14 @@
*/
package com.agentsflex.prompt;
import com.agentsflex.llm.response.MessageResponse;
import com.agentsflex.message.HumanMessage;
import com.agentsflex.message.Message;
import java.util.Collections;
import java.util.List;
public class SimplePrompt extends Prompt{
public class SimplePrompt extends Prompt<MessageResponse> {
private final String content;

View File

@ -15,17 +15,22 @@
*/
package com.agentsflex.llm.openai;
import com.agentsflex.llm.client.BaseLlmClientListener;
import com.agentsflex.llm.client.LlmClient;
import com.agentsflex.llm.client.LlmClientListener;
import com.agentsflex.llm.client.impl.SseClient;
import com.agentsflex.document.Document;
import com.agentsflex.functions.Function;
import com.agentsflex.llm.BaseLlm;
import com.agentsflex.llm.ChatListener;
import com.agentsflex.llm.FunctionCalling;
import com.agentsflex.llm.ChatResponse;
import com.agentsflex.llm.client.BaseLlmClientListener;
import com.agentsflex.llm.client.HttpClient;
import com.agentsflex.llm.client.LlmClient;
import com.agentsflex.llm.client.LlmClientListener;
import com.agentsflex.llm.client.impl.SseClient;
import com.agentsflex.llm.response.FunctionResultResponse;
import com.agentsflex.llm.response.MessageResponse;
import com.agentsflex.message.AiMessage;
import com.agentsflex.message.FunctionMessage;
import com.agentsflex.prompt.FunctionPrompt;
import com.agentsflex.prompt.Prompt;
import com.agentsflex.document.Document;
import com.agentsflex.util.OKHttpUtil;
import com.agentsflex.util.StringUtil;
import com.agentsflex.vector.VectorData;
import com.alibaba.fastjson.JSON;
@ -36,17 +41,46 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> implements FunctionCalling {
public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> {
private final OKHttpUtil httpUtil = new OKHttpUtil();
private final HttpClient httpClient = new HttpClient();
public OpenAiLlm(OpenAiLlmConfig config) {
super(config);
}
@Override
public <T extends ChatResponse<?>> T chat(Prompt<T> prompt){
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
headers.put("Authorization", "Bearer " + getConfig().getApiKey());
String payload = OpenAiLLmUtil.promptToPayload(prompt, config);
String responseString = httpClient.post("https://api.openai.com/v1/chat/completions", headers, payload);
if (StringUtil.noText(responseString)) {
return null;
}
if (prompt instanceof FunctionPrompt) {
List<Function<?>> functions = ((FunctionPrompt) prompt).getFunctions();
JSONObject jsonObject = JSON.parseObject(responseString);
String callFunctionName = (String) JSONPath.eval(jsonObject, "$.choices[0].tool_calls[0].function.name");
String callFunctionArgsString = (String) JSONPath.eval(jsonObject, "$.choices[0].tool_calls[0].function.arguments");
JSONObject callFunctionArgs = JSON.parseObject(callFunctionArgsString);
FunctionMessage functionMessage = new FunctionMessage();
functionMessage.setFunctionName(callFunctionName);
functionMessage.setArgs(callFunctionArgs);
return (T) new FunctionResultResponse(functions, functionMessage);
} else {
AiMessage aiMessage = OpenAiLLmUtil.parseAiMessage(responseString);
return (T) new MessageResponse(aiMessage);
}
}
@Override
public LlmClient chat(Prompt prompt, ChatListener listener) {
public void chatAsync(Prompt prompt, ChatListener listener) {
LlmClient llmClient = new SseClient();
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
@ -54,22 +88,21 @@ public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> implements FunctionCalli
String payload = OpenAiLLmUtil.promptToPayload(prompt, config);
LlmClientListener clientListener = new BaseLlmClientListener(this, listener, prompt, OpenAiLLmUtil::parseAiMessage);
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, OpenAiLLmUtil::parseAiMessage, null);
llmClient.start("https://api.openai.com/v1/chat/completions", headers, payload, clientListener);
return llmClient;
}
@Override
public VectorData embeddings(Document text) {
public VectorData embeddings(Document document) {
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
headers.put("Authorization", "Bearer " + getConfig().getApiKey());
String payload = OpenAiLLmUtil.promptToEmbeddingsPayload(text);
String payload = OpenAiLLmUtil.promptToEmbeddingsPayload(document);
// https://platform.openai.com/docs/api-reference/embeddings/create
String response = httpUtil.post("https://api.openai.com/v1/embeddings", headers, payload);
String response = httpClient.post("https://api.openai.com/v1/embeddings", headers, payload);
if (StringUtil.noText(response)) {
return null;
}
@ -81,30 +114,5 @@ public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> implements FunctionCalli
return vectorData;
}
@Override
public <R> R call(Prompt prompt, List<Function<R>> functions) {
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
headers.put("Authorization", "Bearer " + getConfig().getApiKey());
String payload = OpenAiLLmUtil.promptToFunctionCallingPayload(prompt, config, functions);
// https://platform.openai.com/docs/api-reference/embeddings/create
String response = httpUtil.post("https://api.openai.com/v1/embeddings", headers, payload);
if (StringUtil.noText(response)) {
return null;
}
JSONObject jsonObject = JSON.parseObject(response);
String callFunctionName = (String) JSONPath.eval(jsonObject, "$.choices[0].tool_calls[0].function.name");
String callFunctionArgsString = (String) JSONPath.eval(jsonObject, "$.choices[0].tool_calls[0].function.arguments");
JSONObject callFunctionArgs = JSON.parseObject(callFunctionArgsString);
for (Function<R> function : functions) {
if (function.getName().equals(callFunctionName)) {
return function.invoke(callFunctionArgs);
}
}
return null;
}
}

View File

@ -1,43 +1,36 @@
package com.agentsflex.llm.openai;
import com.agentsflex.functions.Functions;
import com.agentsflex.llm.Llm;
import com.agentsflex.prompt.SimplePrompt;
import com.agentsflex.llm.response.FunctionResultResponse;
import com.agentsflex.prompt.FunctionPrompt;
import org.junit.Test;
public class OpenAiLlmTest {
@Test
public void testChat() throws InterruptedException {
public void testChat() {
OpenAiLlmConfig config = new OpenAiLlmConfig();
config.setApiKey("sk-rts5NF6n*******");
Llm llm = new OpenAiLlm(config);
String response = llm.chat("请问你叫什么名字");
llm.chat(new SimplePrompt("请写一个小兔子战胜大灰狼的故事"), (instance, message) -> {
System.out.println("--->" + message.getContent());
});
Thread.sleep(10000);
System.out.println(response);
}
@Test
public void testFunctionCalling() throws InterruptedException {
public void testFunctionCalling() throws InterruptedException {
OpenAiLlmConfig config = new OpenAiLlmConfig();
config.setApiKey("sk-rts5NF6n*******");
OpenAiLlm llm = new OpenAiLlm(config);
Functions<String> functions = Functions.from(WeatherUtil.class, String.class);
String result = llm.call(new SimplePrompt("今天的天气怎么样"), functions);
FunctionPrompt prompt = new FunctionPrompt("今天北京的天气怎么样", WeatherUtil.class);
FunctionResultResponse response = llm.chat(prompt);
Object result = response.invoke();
System.out.println(result);
Thread.sleep(10000);
// "Today it will be dull and overcast in 北京"
}
}

View File

@ -9,6 +9,6 @@ public class WeatherUtil {
public static String getWeatherInfo(
@FunctionParam(name = "city", description = "the city name") String name
) {
return "Today it will be dull and overcast in city " + name;
return "Today it will be dull and overcast in " + name;
}
}

View File

@ -15,15 +15,20 @@
*/
package com.agentsflex.llm.qwen;
import com.agentsflex.llm.ChatResponse;
import com.agentsflex.llm.client.BaseLlmClientListener;
import com.agentsflex.llm.client.HttpClient;
import com.agentsflex.llm.client.LlmClient;
import com.agentsflex.llm.client.LlmClientListener;
import com.agentsflex.llm.client.impl.SseClient;
import com.agentsflex.llm.BaseLlm;
import com.agentsflex.llm.ChatListener;
import com.agentsflex.llm.response.MessageResponse;
import com.agentsflex.message.AiMessage;
import com.agentsflex.prompt.FunctionPrompt;
import com.agentsflex.prompt.Prompt;
import com.agentsflex.document.Document;
import com.agentsflex.util.StringUtil;
import com.agentsflex.vector.VectorData;
import java.util.HashMap;
@ -35,9 +40,34 @@ public class QwenLlm extends BaseLlm<QwenLlmConfig> {
super(config);
}
HttpClient httpClient = new HttpClient();
@Override
public LlmClient chat(Prompt prompt, ChatListener listener) {
public ChatResponse<?> chat(Prompt prompt) {
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
headers.put("Authorization", "Bearer " + getConfig().getApiKey());
String payload = QwenLlmUtil.promptToPayload(prompt, config);
String responseString = httpClient.post("https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation", headers, payload);
if (StringUtil.noText(responseString)){
return null;
}
if (prompt instanceof FunctionPrompt){
}else {
AiMessage aiMessage = QwenLlmUtil.parseAiMessage(responseString, 0);
return new MessageResponse(aiMessage);
}
return null;
}
@Override
public void chatAsync(Prompt prompt, ChatListener listener) {
LlmClient llmClient = new SseClient();
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
@ -45,24 +75,21 @@ public class QwenLlm extends BaseLlm<QwenLlmConfig> {
String payload = QwenLlmUtil.promptToPayload(prompt, config);
LlmClientListener clientListener = new BaseLlmClientListener(this, listener, prompt, new BaseLlmClientListener.MessageParser() {
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient,listener, prompt, new BaseLlmClientListener.AiMessageParser() {
int prevMessageLength = 0;
@Override
public AiMessage parseMessage(String response) {
AiMessage aiMessage = QwenLlmUtil.parseAiMessage(response, prevMessageLength);
prevMessageLength += aiMessage.getContent().length();
return aiMessage;
}
});
},null);
llmClient.start("https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation", headers, payload, clientListener);
return llmClient;
}
@Override
public VectorData embeddings(Document text) {
public VectorData embeddings(Document document) {
return null;
}
}

View File

@ -13,7 +13,7 @@ public class QwenTest {
config.setModel("qwen-turbo");
Llm llm = new QwenLlm(config);
llm.chat(new SimplePrompt("请写一个小兔子战胜大灰狼的故事"), (llm1, aiMessage) -> {
llm.chatAsync(new SimplePrompt("请写一个小兔子战胜大灰狼的故事"), (llm1, aiMessage) -> {
System.out.println(">>>>" + aiMessage.getContent());
});

View File

@ -15,16 +15,22 @@
*/
package com.agentsflex.llm.spark;
import com.agentsflex.document.Document;
import com.agentsflex.llm.BaseLlm;
import com.agentsflex.llm.ChatContext;
import com.agentsflex.llm.ChatListener;
import com.agentsflex.llm.ChatResponse;
import com.agentsflex.llm.client.BaseLlmClientListener;
import com.agentsflex.llm.client.LlmClient;
import com.agentsflex.llm.client.LlmClientListener;
import com.agentsflex.llm.client.impl.WebSocketClient;
import com.agentsflex.llm.BaseLlm;
import com.agentsflex.llm.ChatListener;
import com.agentsflex.llm.response.MessageResponse;
import com.agentsflex.message.AiMessage;
import com.agentsflex.prompt.Prompt;
import com.agentsflex.document.Document;
import com.agentsflex.vector.VectorData;
import java.util.concurrent.CountDownLatch;
public class SparkLlm extends BaseLlm<SparkLlmConfig> {
public SparkLlm(SparkLlmConfig config) {
@ -32,21 +38,44 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
}
@Override
public LlmClient chat(Prompt prompt, ChatListener listener) {
public VectorData embeddings(Document document) {
return null;
}
@Override
public MessageResponse chat(Prompt prompt) {
CountDownLatch latch = new CountDownLatch(1);
AiMessage aiMessage = new AiMessage();
chatAsync(prompt, new ChatListener() {
@Override
public void onMessage(ChatContext context, ChatResponse<?> response) {
aiMessage.setContent(((AiMessage) response.getMessage()).getFullContent());
}
@Override
public void onStop(ChatContext context) {
ChatListener.super.onStop(context);
latch.countDown();
}
});
try {
latch.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
return new MessageResponse(aiMessage);
}
@Override
public void chatAsync(Prompt prompt, ChatListener listener) {
LlmClient llmClient = new WebSocketClient();
String url = SparkLlmUtil.createURL(config);
String payload = SparkLlmUtil.promptToPayload(prompt, config);
LlmClientListener clientListener = new BaseLlmClientListener(this, listener, prompt, SparkLlmUtil::parseAiMessage);
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, SparkLlmUtil::parseAiMessage, null);
llmClient.start(url, null, payload, clientListener);
return llmClient;
}
@Override
public VectorData embeddings(Document text) {
return null;
}
}

View File

@ -5,14 +5,26 @@ import com.agentsflex.llm.spark.SparkLlm;
import com.agentsflex.llm.spark.SparkLlmConfig;
import com.agentsflex.message.HumanMessage;
import com.agentsflex.prompt.HistoriesPrompt;
import org.junit.Test;
import java.util.Scanner;
public class SparkLlmTest {
public static void main(String[] args) throws InterruptedException {
@Test
public void testSimple() {
SparkLlmConfig config = new SparkLlmConfig();
config.setAppId("****");
config.setApiKey("****");
config.setApiSecret("****");
Llm llm = new SparkLlm(config);
String result = llm.chat("你好");
System.out.println(result);
}
public static void main(String[] args) {
SparkLlmConfig config = new SparkLlmConfig();
config.setAppId("****");
config.setApiKey("****");
@ -26,18 +38,15 @@ public class SparkLlmTest {
Scanner scanner = new Scanner(System.in);
String userInput = scanner.nextLine();
while (userInput != null){
while (userInput != null) {
prompt.addMessage(new HumanMessage(userInput));
llm.chat(prompt, (instance, message) -> {
System.out.println(">>>> " + message.getContent());
llm.chatAsync(prompt, (context, response) -> {
System.out.println(">>>> " + response.getMessage().getContent());
});
userInput = scanner.nextLine();
}
}
}

View File

@ -15,7 +15,7 @@
*/
package com.agentsflex.vector.aliyun;
import com.agentsflex.util.OKHttpUtil;
import com.agentsflex.llm.client.HttpClient;
import com.agentsflex.util.StringUtil;
import com.agentsflex.vector.RetrieveWrapper;
import com.agentsflex.vector.VectorDocument;
@ -33,7 +33,7 @@ public class AliyunVectorStorage extends VectorStorage<VectorDocument> {
private AliyunVectorStorageConfig config;
private final OKHttpUtil httpUtil = new OKHttpUtil();
private final HttpClient httpUtil = new HttpClient();
public AliyunVectorStorage(AliyunVectorStorageConfig config) {
this.config = config;

View File

@ -15,7 +15,7 @@
*/
package com.agentsflex.vector.qcloud;
import com.agentsflex.util.OKHttpUtil;
import com.agentsflex.llm.client.HttpClient;
import com.agentsflex.util.StringUtil;
import com.agentsflex.vector.RetrieveWrapper;
import com.agentsflex.vector.VectorDocument;
@ -33,7 +33,7 @@ public class QCloudVectorStorage extends VectorStorage<VectorDocument> {
private QCloudVectorStorageConfig config;
private final OKHttpUtil httpUtil = new OKHttpUtil();
private final HttpClient httpUtil = new HttpClient();
public QCloudVectorStorage(QCloudVectorStorageConfig config) {