mirror of
https://gitee.com/agents-flex/agents-flex.git
synced 2024-11-29 18:38:17 +08:00
refactor: optimize llm methods
This commit is contained in:
parent
91d39001ed
commit
ef3cfec7c5
@ -33,25 +33,25 @@ public interface Llm extends EmbeddingModel {
|
||||
return chat != null && chat.getMessage() != null ? chat.getMessage().getContent() : null;
|
||||
}
|
||||
|
||||
default <R extends MessageResponse<M>, M extends AiMessage> R chat(Prompt<M> prompt) {
|
||||
default <R extends MessageResponse<?>> R chat(Prompt<R> prompt) {
|
||||
return chat(prompt, ChatOptions.DEFAULT);
|
||||
}
|
||||
|
||||
<R extends MessageResponse<M>, M extends AiMessage> R chat(Prompt<M> prompt, ChatOptions options);
|
||||
<R extends MessageResponse<?>> R chat(Prompt<R> prompt, ChatOptions options);
|
||||
|
||||
default void chatStream(String prompt, StreamResponseListener<AiMessageResponse, AiMessage> listener) {
|
||||
default void chatStream(String prompt, StreamResponseListener<AiMessageResponse> listener) {
|
||||
this.chatStream(new TextPrompt(prompt), listener, ChatOptions.DEFAULT);
|
||||
}
|
||||
|
||||
default void chatStream(String prompt, StreamResponseListener<AiMessageResponse, AiMessage> listener, ChatOptions options) {
|
||||
default void chatStream(String prompt, StreamResponseListener<AiMessageResponse> listener, ChatOptions options) {
|
||||
this.chatStream(new TextPrompt(prompt), listener, options);
|
||||
}
|
||||
|
||||
//chatStream
|
||||
default <R extends MessageResponse<M>, M extends AiMessage> void chatStream(Prompt<M> prompt, StreamResponseListener<R, M> listener) {
|
||||
default <R extends MessageResponse<?>> void chatStream(Prompt<R> prompt, StreamResponseListener<R> listener) {
|
||||
this.chatStream(prompt, listener, ChatOptions.DEFAULT);
|
||||
}
|
||||
|
||||
<R extends MessageResponse<M>, M extends AiMessage> void chatStream(Prompt<M> prompt, StreamResponseListener<R, M> listener, ChatOptions options);
|
||||
<R extends MessageResponse<?>> void chatStream(Prompt<R> prompt, StreamResponseListener<R> listener, ChatOptions options);
|
||||
|
||||
}
|
||||
|
@ -15,9 +15,7 @@
|
||||
*/
|
||||
package com.agentsflex.llm;
|
||||
|
||||
import com.agentsflex.message.AiMessage;
|
||||
|
||||
public interface StreamResponseListener<R extends MessageResponse<M>, M extends AiMessage> {
|
||||
public interface StreamResponseListener<R extends MessageResponse<?>> {
|
||||
|
||||
default void onStart(ChatContext context) {
|
||||
}
|
||||
|
@ -17,9 +17,9 @@ package com.agentsflex.prompt;
|
||||
|
||||
import com.agentsflex.functions.Function;
|
||||
import com.agentsflex.functions.Functions;
|
||||
import com.agentsflex.llm.response.FunctionMessageResponse;
|
||||
import com.agentsflex.memory.ChatMemory;
|
||||
import com.agentsflex.memory.DefaultChatMemory;
|
||||
import com.agentsflex.message.FunctionMessage;
|
||||
import com.agentsflex.message.HumanMessage;
|
||||
import com.agentsflex.message.Message;
|
||||
|
||||
@ -27,7 +27,7 @@ import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
public class FunctionPrompt extends Prompt<FunctionMessage> {
|
||||
public class FunctionPrompt extends Prompt<FunctionMessageResponse> {
|
||||
private final ChatMemory memory = new DefaultChatMemory();
|
||||
private final List<Function> functions = new ArrayList<>();
|
||||
|
||||
|
@ -15,14 +15,14 @@
|
||||
*/
|
||||
package com.agentsflex.prompt;
|
||||
|
||||
import com.agentsflex.llm.response.AiMessageResponse;
|
||||
import com.agentsflex.memory.ChatMemory;
|
||||
import com.agentsflex.memory.DefaultChatMemory;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.message.Message;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class HistoriesPrompt extends Prompt<AiMessage> {
|
||||
public class HistoriesPrompt extends Prompt<AiMessageResponse> {
|
||||
|
||||
private ChatMemory memory = new DefaultChatMemory();
|
||||
|
||||
|
@ -15,13 +15,14 @@
|
||||
*/
|
||||
package com.agentsflex.prompt;
|
||||
|
||||
import com.agentsflex.llm.MessageResponse;
|
||||
import com.agentsflex.message.Message;
|
||||
import com.agentsflex.util.Metadata;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
public abstract class Prompt<M extends Message> extends Metadata {
|
||||
public abstract class Prompt<M extends MessageResponse<?>> extends Metadata {
|
||||
|
||||
public abstract List<Message> toMessages();
|
||||
|
||||
|
@ -15,14 +15,14 @@
|
||||
*/
|
||||
package com.agentsflex.prompt;
|
||||
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.llm.response.AiMessageResponse;
|
||||
import com.agentsflex.message.HumanMessage;
|
||||
import com.agentsflex.message.Message;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class TextPrompt extends Prompt<AiMessage> {
|
||||
public class TextPrompt extends Prompt<AiMessageResponse> {
|
||||
|
||||
protected String content;
|
||||
|
||||
|
@ -29,7 +29,6 @@ import com.agentsflex.llm.embedding.EmbeddingOptions;
|
||||
import com.agentsflex.llm.response.AbstractBaseMessageResponse;
|
||||
import com.agentsflex.llm.response.AiMessageResponse;
|
||||
import com.agentsflex.llm.response.FunctionMessageResponse;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.parser.AiMessageParser;
|
||||
import com.agentsflex.parser.FunctionMessageParser;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
@ -87,7 +86,7 @@ public class ChatglmLlm extends BaseLlm<ChatglmLlmConfig> {
|
||||
|
||||
|
||||
@Override
|
||||
public <R extends MessageResponse<M>, M extends AiMessage> R chat(Prompt<M> prompt, ChatOptions options) {
|
||||
public <R extends MessageResponse<?>> R chat(Prompt<R> prompt, ChatOptions options) {
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
headers.put("Content-Type", "application/json");
|
||||
headers.put("Authorization", ChatglmLlmUtil.createAuthorizationToken(config));
|
||||
@ -107,15 +106,13 @@ public class ChatglmLlm extends BaseLlm<ChatglmLlmConfig> {
|
||||
JSONObject jsonObject = JSON.parseObject(response);
|
||||
JSONObject error = jsonObject.getJSONObject("error");
|
||||
|
||||
AbstractBaseMessageResponse<M> messageResponse;
|
||||
AbstractBaseMessageResponse<?> messageResponse;
|
||||
|
||||
if (prompt instanceof FunctionPrompt) {
|
||||
//noinspection unchecked
|
||||
messageResponse = (AbstractBaseMessageResponse<M>) new FunctionMessageResponse(((FunctionPrompt) prompt).getFunctions()
|
||||
messageResponse = new FunctionMessageResponse(((FunctionPrompt) prompt).getFunctions()
|
||||
, functionMessageParser.parse(jsonObject));
|
||||
} else {
|
||||
//noinspection unchecked
|
||||
messageResponse = (AbstractBaseMessageResponse<M>) new AiMessageResponse(aiMessageParser.parse(jsonObject));
|
||||
messageResponse = new AiMessageResponse(aiMessageParser.parse(jsonObject));
|
||||
}
|
||||
|
||||
if (error != null && !error.isEmpty()) {
|
||||
@ -131,7 +128,7 @@ public class ChatglmLlm extends BaseLlm<ChatglmLlmConfig> {
|
||||
|
||||
|
||||
@Override
|
||||
public <R extends MessageResponse<M>, M extends AiMessage> void chatStream(Prompt<M> prompt, StreamResponseListener<R, M> listener, ChatOptions options) {
|
||||
public <R extends MessageResponse<?>> void chatStream(Prompt<R> prompt, StreamResponseListener<R> listener, ChatOptions options) {
|
||||
LlmClient llmClient = new SseClient();
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
headers.put("Content-Type", "application/json");
|
||||
|
@ -28,7 +28,6 @@ import com.agentsflex.llm.client.impl.SseClient;
|
||||
import com.agentsflex.llm.embedding.EmbeddingOptions;
|
||||
import com.agentsflex.llm.response.AbstractBaseMessageResponse;
|
||||
import com.agentsflex.llm.response.AiMessageResponse;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.parser.AiMessageParser;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
@ -58,7 +57,7 @@ public class LlamaLlm extends BaseLlm<LlamaLlmConfig> {
|
||||
|
||||
|
||||
@Override
|
||||
public <R extends MessageResponse<M>, M extends AiMessage> R chat(Prompt<M> prompt, ChatOptions options) {
|
||||
public <R extends MessageResponse<?>> R chat(Prompt<R> prompt, ChatOptions options) {
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
headers.put("Content-Type", "application/json");
|
||||
headers.put("Authorization", "Bearer " + config.getApiKey());
|
||||
@ -77,13 +76,12 @@ public class LlamaLlm extends BaseLlm<LlamaLlmConfig> {
|
||||
JSONObject jsonObject = JSON.parseObject(response);
|
||||
JSONObject error = jsonObject.getJSONObject("error");
|
||||
|
||||
AbstractBaseMessageResponse<M> messageResponse;
|
||||
AbstractBaseMessageResponse<?> messageResponse;
|
||||
|
||||
if (prompt instanceof FunctionPrompt) {
|
||||
throw new IllegalStateException("Llama not support function calling");
|
||||
} else {
|
||||
//noinspection unchecked
|
||||
messageResponse = (AbstractBaseMessageResponse<M>) new AiMessageResponse(aiMessageParser.parse(jsonObject));
|
||||
messageResponse = new AiMessageResponse(aiMessageParser.parse(jsonObject));
|
||||
}
|
||||
|
||||
if (error != null && !error.isEmpty()) {
|
||||
@ -99,7 +97,7 @@ public class LlamaLlm extends BaseLlm<LlamaLlmConfig> {
|
||||
|
||||
|
||||
@Override
|
||||
public <R extends MessageResponse<M>, M extends AiMessage> void chatStream(Prompt<M> prompt, StreamResponseListener<R, M> listener, ChatOptions options) {
|
||||
public <R extends MessageResponse<?>> void chatStream(Prompt<R> prompt, StreamResponseListener<R> listener, ChatOptions options) {
|
||||
LlmClient llmClient = new SseClient();
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
headers.put("Content-Type", "application/json");
|
||||
|
@ -29,7 +29,6 @@ import com.agentsflex.llm.embedding.EmbeddingOptions;
|
||||
import com.agentsflex.llm.response.AbstractBaseMessageResponse;
|
||||
import com.agentsflex.llm.response.AiMessageResponse;
|
||||
import com.agentsflex.llm.response.FunctionMessageResponse;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.parser.AiMessageParser;
|
||||
import com.agentsflex.parser.FunctionMessageParser;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
@ -68,7 +67,7 @@ public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R extends MessageResponse<M>, M extends AiMessage> R chat(Prompt<M> prompt, ChatOptions options) {
|
||||
public <R extends MessageResponse<?>> R chat(Prompt<R> prompt, ChatOptions options) {
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
headers.put("Content-Type", "application/json");
|
||||
headers.put("Authorization", "Bearer " + getConfig().getApiKey());
|
||||
@ -87,15 +86,13 @@ public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> {
|
||||
JSONObject jsonObject = JSON.parseObject(response);
|
||||
JSONObject error = jsonObject.getJSONObject("error");
|
||||
|
||||
AbstractBaseMessageResponse<M> messageResponse;
|
||||
AbstractBaseMessageResponse<?> messageResponse;
|
||||
|
||||
if (prompt instanceof FunctionPrompt) {
|
||||
//noinspection unchecked
|
||||
messageResponse = (AbstractBaseMessageResponse<M>) new FunctionMessageResponse(((FunctionPrompt) prompt).getFunctions()
|
||||
messageResponse = new FunctionMessageResponse(((FunctionPrompt) prompt).getFunctions()
|
||||
, functionMessageParser.parse(jsonObject));
|
||||
} else {
|
||||
//noinspection unchecked
|
||||
messageResponse = (AbstractBaseMessageResponse<M>) new AiMessageResponse(aiMessageParser.parse(jsonObject));
|
||||
messageResponse = new AiMessageResponse(aiMessageParser.parse(jsonObject));
|
||||
}
|
||||
|
||||
if (error != null && !error.isEmpty()) {
|
||||
@ -111,7 +108,7 @@ public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> {
|
||||
|
||||
|
||||
@Override
|
||||
public <R extends MessageResponse<M>, M extends AiMessage> void chatStream(Prompt<M> prompt, StreamResponseListener<R, M> listener, ChatOptions options) {
|
||||
public <R extends MessageResponse<?>> void chatStream(Prompt<R> prompt, StreamResponseListener<R> listener, ChatOptions options) {
|
||||
LlmClient llmClient = new SseClient();
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
headers.put("Content-Type", "application/json");
|
||||
|
@ -5,7 +5,6 @@ import com.agentsflex.llm.Llm;
|
||||
import com.agentsflex.llm.StreamResponseListener;
|
||||
import com.agentsflex.llm.response.AiMessageResponse;
|
||||
import com.agentsflex.llm.response.FunctionMessageResponse;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
import com.agentsflex.prompt.ImagePrompt;
|
||||
import org.junit.Test;
|
||||
@ -33,7 +32,7 @@ public class OpenAiLlmTest {
|
||||
|
||||
Llm llm = new OpenAiLlm(config);
|
||||
// String response = llm.chat("请问你叫什么名字");
|
||||
llm.chatStream("你叫什么名字", new StreamResponseListener<AiMessageResponse, AiMessage>() {
|
||||
llm.chatStream("你叫什么名字", new StreamResponseListener<AiMessageResponse>() {
|
||||
@Override
|
||||
public void onMessage(ChatContext context, AiMessageResponse response) {
|
||||
System.out.println(response.getMessage().getContent());
|
||||
@ -48,8 +47,6 @@ public class OpenAiLlmTest {
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testChatWithImage() {
|
||||
OpenAiLlmConfig config = new OpenAiLlmConfig();
|
||||
|
@ -57,7 +57,7 @@ public class QwenLlm extends BaseLlm<QwenLlmConfig> {
|
||||
|
||||
|
||||
@Override
|
||||
public <R extends MessageResponse<M>, M extends AiMessage> R chat(Prompt<M> prompt, ChatOptions options) {
|
||||
public <R extends MessageResponse<?>> R chat(Prompt<R> prompt, ChatOptions options) {
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
headers.put("Content-Type", "application/json");
|
||||
headers.put("Authorization", "Bearer " + getConfig().getApiKey());
|
||||
@ -77,15 +77,13 @@ public class QwenLlm extends BaseLlm<QwenLlmConfig> {
|
||||
JSONObject jsonObject = JSON.parseObject(response);
|
||||
JSONObject error = jsonObject.getJSONObject("error");
|
||||
|
||||
AbstractBaseMessageResponse<M> messageResponse;
|
||||
AbstractBaseMessageResponse<?> messageResponse;
|
||||
|
||||
if (prompt instanceof FunctionPrompt) {
|
||||
//noinspection unchecked
|
||||
messageResponse = (AbstractBaseMessageResponse<M>) new FunctionMessageResponse(((FunctionPrompt) prompt).getFunctions()
|
||||
messageResponse = new FunctionMessageResponse(((FunctionPrompt) prompt).getFunctions()
|
||||
, functionMessageParser.parse(jsonObject));
|
||||
} else {
|
||||
//noinspection unchecked
|
||||
messageResponse = (AbstractBaseMessageResponse<M>) new AiMessageResponse(aiMessageParser.parse(jsonObject));
|
||||
messageResponse = new AiMessageResponse(aiMessageParser.parse(jsonObject));
|
||||
}
|
||||
|
||||
if (error != null && !error.isEmpty()) {
|
||||
@ -101,7 +99,7 @@ public class QwenLlm extends BaseLlm<QwenLlmConfig> {
|
||||
|
||||
|
||||
@Override
|
||||
public <R extends MessageResponse<M>, M extends AiMessage> void chatStream(Prompt<M> prompt, StreamResponseListener<R, M> listener, ChatOptions options) {
|
||||
public <R extends MessageResponse<?>> void chatStream(Prompt<R> prompt, StreamResponseListener<R> listener, ChatOptions options) {
|
||||
LlmClient llmClient = new SseClient();
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
headers.put("Content-Type", "application/json");
|
||||
@ -112,6 +110,7 @@ public class QwenLlm extends BaseLlm<QwenLlmConfig> {
|
||||
|
||||
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, new DefaultAiMessageParser() {
|
||||
int prevMessageLength = 0;
|
||||
|
||||
@Override
|
||||
public AiMessage parse(JSONObject content) {
|
||||
AiMessage aiMessage = aiMessageParser.parse(content);
|
||||
|
@ -90,12 +90,12 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public <R extends MessageResponse<M>, M extends AiMessage> R chat(Prompt<M> prompt, ChatOptions options) {
|
||||
public <R extends MessageResponse<?>> R chat(Prompt<R> prompt, ChatOptions options) {
|
||||
CountDownLatch latch = new CountDownLatch(1);
|
||||
Message[] messages = new Message[1];
|
||||
chatStream(prompt, new StreamResponseListener<MessageResponse<M>, M>() {
|
||||
chatStream(prompt, new StreamResponseListener<R>() {
|
||||
@Override
|
||||
public void onMessage(ChatContext context, MessageResponse<M> response) {
|
||||
public void onMessage(ChatContext context, R response) {
|
||||
if (response.getMessage() instanceof FunctionMessage) {
|
||||
messages[0] = response.getMessage();
|
||||
} else {
|
||||
@ -128,7 +128,7 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
|
||||
|
||||
@Override
|
||||
public <R extends MessageResponse<M>, M extends AiMessage> void chatStream(Prompt<M> prompt, StreamResponseListener<R, M> listener, ChatOptions options) {
|
||||
public <R extends MessageResponse<?>> void chatStream(Prompt<R> prompt, StreamResponseListener<R> listener, ChatOptions options) {
|
||||
LlmClient llmClient = new WebSocketClient();
|
||||
String url = SparkLlmUtil.createURL(config);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user