refactor: optimize llm methods

This commit is contained in:
Michael Yang 2024-06-14 19:37:35 +08:00
parent 91d39001ed
commit ef3cfec7c5
12 changed files with 40 additions and 53 deletions

View File

@ -33,25 +33,25 @@ public interface Llm extends EmbeddingModel {
return chat != null && chat.getMessage() != null ? chat.getMessage().getContent() : null;
}
default <R extends MessageResponse<M>, M extends AiMessage> R chat(Prompt<M> prompt) {
default <R extends MessageResponse<?>> R chat(Prompt<R> prompt) {
return chat(prompt, ChatOptions.DEFAULT);
}
<R extends MessageResponse<M>, M extends AiMessage> R chat(Prompt<M> prompt, ChatOptions options);
<R extends MessageResponse<?>> R chat(Prompt<R> prompt, ChatOptions options);
default void chatStream(String prompt, StreamResponseListener<AiMessageResponse, AiMessage> listener) {
default void chatStream(String prompt, StreamResponseListener<AiMessageResponse> listener) {
this.chatStream(new TextPrompt(prompt), listener, ChatOptions.DEFAULT);
}
default void chatStream(String prompt, StreamResponseListener<AiMessageResponse, AiMessage> listener, ChatOptions options) {
default void chatStream(String prompt, StreamResponseListener<AiMessageResponse> listener, ChatOptions options) {
this.chatStream(new TextPrompt(prompt), listener, options);
}
//chatStream
default <R extends MessageResponse<M>, M extends AiMessage> void chatStream(Prompt<M> prompt, StreamResponseListener<R, M> listener) {
default <R extends MessageResponse<?>> void chatStream(Prompt<R> prompt, StreamResponseListener<R> listener) {
this.chatStream(prompt, listener, ChatOptions.DEFAULT);
}
<R extends MessageResponse<M>, M extends AiMessage> void chatStream(Prompt<M> prompt, StreamResponseListener<R, M> listener, ChatOptions options);
<R extends MessageResponse<?>> void chatStream(Prompt<R> prompt, StreamResponseListener<R> listener, ChatOptions options);
}

View File

@ -15,9 +15,7 @@
*/
package com.agentsflex.llm;
import com.agentsflex.message.AiMessage;
public interface StreamResponseListener<R extends MessageResponse<M>, M extends AiMessage> {
public interface StreamResponseListener<R extends MessageResponse<?>> {
default void onStart(ChatContext context) {
}

View File

@ -17,9 +17,9 @@ package com.agentsflex.prompt;
import com.agentsflex.functions.Function;
import com.agentsflex.functions.Functions;
import com.agentsflex.llm.response.FunctionMessageResponse;
import com.agentsflex.memory.ChatMemory;
import com.agentsflex.memory.DefaultChatMemory;
import com.agentsflex.message.FunctionMessage;
import com.agentsflex.message.HumanMessage;
import com.agentsflex.message.Message;
@ -27,7 +27,7 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
public class FunctionPrompt extends Prompt<FunctionMessage> {
public class FunctionPrompt extends Prompt<FunctionMessageResponse> {
private final ChatMemory memory = new DefaultChatMemory();
private final List<Function> functions = new ArrayList<>();

View File

@ -15,14 +15,14 @@
*/
package com.agentsflex.prompt;
import com.agentsflex.llm.response.AiMessageResponse;
import com.agentsflex.memory.ChatMemory;
import com.agentsflex.memory.DefaultChatMemory;
import com.agentsflex.message.AiMessage;
import com.agentsflex.message.Message;
import java.util.List;
public class HistoriesPrompt extends Prompt<AiMessage> {
public class HistoriesPrompt extends Prompt<AiMessageResponse> {
private ChatMemory memory = new DefaultChatMemory();

View File

@ -15,13 +15,14 @@
*/
package com.agentsflex.prompt;
import com.agentsflex.llm.MessageResponse;
import com.agentsflex.message.Message;
import com.agentsflex.util.Metadata;
import java.util.List;
public abstract class Prompt<M extends Message> extends Metadata {
public abstract class Prompt<M extends MessageResponse<?>> extends Metadata {
public abstract List<Message> toMessages();

View File

@ -15,14 +15,14 @@
*/
package com.agentsflex.prompt;
import com.agentsflex.message.AiMessage;
import com.agentsflex.llm.response.AiMessageResponse;
import com.agentsflex.message.HumanMessage;
import com.agentsflex.message.Message;
import java.util.Collections;
import java.util.List;
public class TextPrompt extends Prompt<AiMessage> {
public class TextPrompt extends Prompt<AiMessageResponse> {
protected String content;

View File

@ -29,7 +29,6 @@ import com.agentsflex.llm.embedding.EmbeddingOptions;
import com.agentsflex.llm.response.AbstractBaseMessageResponse;
import com.agentsflex.llm.response.AiMessageResponse;
import com.agentsflex.llm.response.FunctionMessageResponse;
import com.agentsflex.message.AiMessage;
import com.agentsflex.parser.AiMessageParser;
import com.agentsflex.parser.FunctionMessageParser;
import com.agentsflex.prompt.FunctionPrompt;
@ -87,7 +86,7 @@ public class ChatglmLlm extends BaseLlm<ChatglmLlmConfig> {
@Override
public <R extends MessageResponse<M>, M extends AiMessage> R chat(Prompt<M> prompt, ChatOptions options) {
public <R extends MessageResponse<?>> R chat(Prompt<R> prompt, ChatOptions options) {
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
headers.put("Authorization", ChatglmLlmUtil.createAuthorizationToken(config));
@ -107,15 +106,13 @@ public class ChatglmLlm extends BaseLlm<ChatglmLlmConfig> {
JSONObject jsonObject = JSON.parseObject(response);
JSONObject error = jsonObject.getJSONObject("error");
AbstractBaseMessageResponse<M> messageResponse;
AbstractBaseMessageResponse<?> messageResponse;
if (prompt instanceof FunctionPrompt) {
//noinspection unchecked
messageResponse = (AbstractBaseMessageResponse<M>) new FunctionMessageResponse(((FunctionPrompt) prompt).getFunctions()
messageResponse = new FunctionMessageResponse(((FunctionPrompt) prompt).getFunctions()
, functionMessageParser.parse(jsonObject));
} else {
//noinspection unchecked
messageResponse = (AbstractBaseMessageResponse<M>) new AiMessageResponse(aiMessageParser.parse(jsonObject));
messageResponse = new AiMessageResponse(aiMessageParser.parse(jsonObject));
}
if (error != null && !error.isEmpty()) {
@ -131,7 +128,7 @@ public class ChatglmLlm extends BaseLlm<ChatglmLlmConfig> {
@Override
public <R extends MessageResponse<M>, M extends AiMessage> void chatStream(Prompt<M> prompt, StreamResponseListener<R, M> listener, ChatOptions options) {
public <R extends MessageResponse<?>> void chatStream(Prompt<R> prompt, StreamResponseListener<R> listener, ChatOptions options) {
LlmClient llmClient = new SseClient();
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");

View File

@ -28,7 +28,6 @@ import com.agentsflex.llm.client.impl.SseClient;
import com.agentsflex.llm.embedding.EmbeddingOptions;
import com.agentsflex.llm.response.AbstractBaseMessageResponse;
import com.agentsflex.llm.response.AiMessageResponse;
import com.agentsflex.message.AiMessage;
import com.agentsflex.parser.AiMessageParser;
import com.agentsflex.prompt.FunctionPrompt;
import com.agentsflex.prompt.Prompt;
@ -58,7 +57,7 @@ public class LlamaLlm extends BaseLlm<LlamaLlmConfig> {
@Override
public <R extends MessageResponse<M>, M extends AiMessage> R chat(Prompt<M> prompt, ChatOptions options) {
public <R extends MessageResponse<?>> R chat(Prompt<R> prompt, ChatOptions options) {
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
headers.put("Authorization", "Bearer " + config.getApiKey());
@ -77,13 +76,12 @@ public class LlamaLlm extends BaseLlm<LlamaLlmConfig> {
JSONObject jsonObject = JSON.parseObject(response);
JSONObject error = jsonObject.getJSONObject("error");
AbstractBaseMessageResponse<M> messageResponse;
AbstractBaseMessageResponse<?> messageResponse;
if (prompt instanceof FunctionPrompt) {
throw new IllegalStateException("Llama not support function calling");
} else {
//noinspection unchecked
messageResponse = (AbstractBaseMessageResponse<M>) new AiMessageResponse(aiMessageParser.parse(jsonObject));
messageResponse = new AiMessageResponse(aiMessageParser.parse(jsonObject));
}
if (error != null && !error.isEmpty()) {
@ -99,7 +97,7 @@ public class LlamaLlm extends BaseLlm<LlamaLlmConfig> {
@Override
public <R extends MessageResponse<M>, M extends AiMessage> void chatStream(Prompt<M> prompt, StreamResponseListener<R, M> listener, ChatOptions options) {
public <R extends MessageResponse<?>> void chatStream(Prompt<R> prompt, StreamResponseListener<R> listener, ChatOptions options) {
LlmClient llmClient = new SseClient();
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");

View File

@ -29,7 +29,6 @@ import com.agentsflex.llm.embedding.EmbeddingOptions;
import com.agentsflex.llm.response.AbstractBaseMessageResponse;
import com.agentsflex.llm.response.AiMessageResponse;
import com.agentsflex.llm.response.FunctionMessageResponse;
import com.agentsflex.message.AiMessage;
import com.agentsflex.parser.AiMessageParser;
import com.agentsflex.parser.FunctionMessageParser;
import com.agentsflex.prompt.FunctionPrompt;
@ -68,7 +67,7 @@ public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> {
}
@Override
public <R extends MessageResponse<M>, M extends AiMessage> R chat(Prompt<M> prompt, ChatOptions options) {
public <R extends MessageResponse<?>> R chat(Prompt<R> prompt, ChatOptions options) {
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
headers.put("Authorization", "Bearer " + getConfig().getApiKey());
@ -87,15 +86,13 @@ public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> {
JSONObject jsonObject = JSON.parseObject(response);
JSONObject error = jsonObject.getJSONObject("error");
AbstractBaseMessageResponse<M> messageResponse;
AbstractBaseMessageResponse<?> messageResponse;
if (prompt instanceof FunctionPrompt) {
//noinspection unchecked
messageResponse = (AbstractBaseMessageResponse<M>) new FunctionMessageResponse(((FunctionPrompt) prompt).getFunctions()
messageResponse = new FunctionMessageResponse(((FunctionPrompt) prompt).getFunctions()
, functionMessageParser.parse(jsonObject));
} else {
//noinspection unchecked
messageResponse = (AbstractBaseMessageResponse<M>) new AiMessageResponse(aiMessageParser.parse(jsonObject));
messageResponse = new AiMessageResponse(aiMessageParser.parse(jsonObject));
}
if (error != null && !error.isEmpty()) {
@ -111,7 +108,7 @@ public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> {
@Override
public <R extends MessageResponse<M>, M extends AiMessage> void chatStream(Prompt<M> prompt, StreamResponseListener<R, M> listener, ChatOptions options) {
public <R extends MessageResponse<?>> void chatStream(Prompt<R> prompt, StreamResponseListener<R> listener, ChatOptions options) {
LlmClient llmClient = new SseClient();
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");

View File

@ -5,7 +5,6 @@ import com.agentsflex.llm.Llm;
import com.agentsflex.llm.StreamResponseListener;
import com.agentsflex.llm.response.AiMessageResponse;
import com.agentsflex.llm.response.FunctionMessageResponse;
import com.agentsflex.message.AiMessage;
import com.agentsflex.prompt.FunctionPrompt;
import com.agentsflex.prompt.ImagePrompt;
import org.junit.Test;
@ -33,7 +32,7 @@ public class OpenAiLlmTest {
Llm llm = new OpenAiLlm(config);
// String response = llm.chat("请问你叫什么名字");
llm.chatStream("你叫什么名字", new StreamResponseListener<AiMessageResponse, AiMessage>() {
llm.chatStream("你叫什么名字", new StreamResponseListener<AiMessageResponse>() {
@Override
public void onMessage(ChatContext context, AiMessageResponse response) {
System.out.println(response.getMessage().getContent());
@ -48,8 +47,6 @@ public class OpenAiLlmTest {
}
@Test
public void testChatWithImage() {
OpenAiLlmConfig config = new OpenAiLlmConfig();

View File

@ -57,7 +57,7 @@ public class QwenLlm extends BaseLlm<QwenLlmConfig> {
@Override
public <R extends MessageResponse<M>, M extends AiMessage> R chat(Prompt<M> prompt, ChatOptions options) {
public <R extends MessageResponse<?>> R chat(Prompt<R> prompt, ChatOptions options) {
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
headers.put("Authorization", "Bearer " + getConfig().getApiKey());
@ -77,15 +77,13 @@ public class QwenLlm extends BaseLlm<QwenLlmConfig> {
JSONObject jsonObject = JSON.parseObject(response);
JSONObject error = jsonObject.getJSONObject("error");
AbstractBaseMessageResponse<M> messageResponse;
AbstractBaseMessageResponse<?> messageResponse;
if (prompt instanceof FunctionPrompt) {
//noinspection unchecked
messageResponse = (AbstractBaseMessageResponse<M>) new FunctionMessageResponse(((FunctionPrompt) prompt).getFunctions()
messageResponse = new FunctionMessageResponse(((FunctionPrompt) prompt).getFunctions()
, functionMessageParser.parse(jsonObject));
} else {
//noinspection unchecked
messageResponse = (AbstractBaseMessageResponse<M>) new AiMessageResponse(aiMessageParser.parse(jsonObject));
messageResponse = new AiMessageResponse(aiMessageParser.parse(jsonObject));
}
if (error != null && !error.isEmpty()) {
@ -101,7 +99,7 @@ public class QwenLlm extends BaseLlm<QwenLlmConfig> {
@Override
public <R extends MessageResponse<M>, M extends AiMessage> void chatStream(Prompt<M> prompt, StreamResponseListener<R, M> listener, ChatOptions options) {
public <R extends MessageResponse<?>> void chatStream(Prompt<R> prompt, StreamResponseListener<R> listener, ChatOptions options) {
LlmClient llmClient = new SseClient();
Map<String, String> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
@ -112,6 +110,7 @@ public class QwenLlm extends BaseLlm<QwenLlmConfig> {
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, new DefaultAiMessageParser() {
int prevMessageLength = 0;
@Override
public AiMessage parse(JSONObject content) {
AiMessage aiMessage = aiMessageParser.parse(content);

View File

@ -90,12 +90,12 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
@SuppressWarnings("unchecked")
@Override
public <R extends MessageResponse<M>, M extends AiMessage> R chat(Prompt<M> prompt, ChatOptions options) {
public <R extends MessageResponse<?>> R chat(Prompt<R> prompt, ChatOptions options) {
CountDownLatch latch = new CountDownLatch(1);
Message[] messages = new Message[1];
chatStream(prompt, new StreamResponseListener<MessageResponse<M>, M>() {
chatStream(prompt, new StreamResponseListener<R>() {
@Override
public void onMessage(ChatContext context, MessageResponse<M> response) {
public void onMessage(ChatContext context, R response) {
if (response.getMessage() instanceof FunctionMessage) {
messages[0] = response.getMessage();
} else {
@ -128,7 +128,7 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
@Override
public <R extends MessageResponse<M>, M extends AiMessage> void chatStream(Prompt<M> prompt, StreamResponseListener<R, M> listener, ChatOptions options) {
public <R extends MessageResponse<?>> void chatStream(Prompt<R> prompt, StreamResponseListener<R> listener, ChatOptions options) {
LlmClient llmClient = new WebSocketClient();
String url = SparkLlmUtil.createURL(config);