mirror of
https://gitee.com/agents-flex/agents-flex.git
synced 2024-12-02 03:48:11 +08:00
refactor: optimize and refactor LLM module
This commit is contained in:
parent
d1f7c05c54
commit
1fc00e4563
@ -26,15 +26,15 @@ import java.util.List;
|
||||
|
||||
public class Functions<T> extends ArrayList<Function<T>> {
|
||||
|
||||
public static <R> Functions<R> from(Class<?> clazz, Class<R> resultType, String... methodNames) {
|
||||
public static <R> Functions<R> from(Class<?> clazz, String... methodNames) {
|
||||
List<Method> methodList = ClassUtil.getAllMethods(clazz, method -> {
|
||||
if (Modifier.isStatic(method.getModifiers())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!resultType.isAssignableFrom(method.getReturnType())) {
|
||||
return false;
|
||||
}
|
||||
// if (!resultType.isAssignableFrom(method.getReturnType())) {
|
||||
// return false;
|
||||
// }
|
||||
|
||||
if (method.getAnnotation(FunctionDef.class) == null) {
|
||||
return false;
|
||||
|
@ -15,7 +15,7 @@
|
||||
*/
|
||||
package com.agentsflex.llm;
|
||||
|
||||
public abstract class BaseLlm<T extends LlmConfig> extends Llm{
|
||||
public abstract class BaseLlm<T extends LlmConfig> implements Llm {
|
||||
|
||||
protected T config;
|
||||
|
||||
|
@ -0,0 +1,44 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, Agents-Flex (fuhai999@gmail.com).
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.agentsflex.llm;
|
||||
|
||||
import com.agentsflex.llm.client.LlmClient;
|
||||
|
||||
public class ChatContext {
|
||||
private Llm llm;
|
||||
private LlmClient client;
|
||||
|
||||
public ChatContext(Llm llm, LlmClient client) {
|
||||
this.llm = llm;
|
||||
this.client = client;
|
||||
}
|
||||
|
||||
public Llm getLlm() {
|
||||
return llm;
|
||||
}
|
||||
|
||||
public void setLlm(Llm llm) {
|
||||
this.llm = llm;
|
||||
}
|
||||
|
||||
public LlmClient getClient() {
|
||||
return client;
|
||||
}
|
||||
|
||||
public void setClient(LlmClient client) {
|
||||
this.client = client;
|
||||
}
|
||||
}
|
@ -15,15 +15,16 @@
|
||||
*/
|
||||
package com.agentsflex.llm;
|
||||
|
||||
import com.agentsflex.message.AiMessage;
|
||||
|
||||
public interface ChatListener {
|
||||
|
||||
default void onStart(Llm llm){}
|
||||
default void onStart(ChatContext context) {
|
||||
}
|
||||
|
||||
void onMessage(Llm llm, AiMessage aiMessage);
|
||||
void onMessage(ChatContext context, ChatResponse<?> response);
|
||||
|
||||
default void onStop(Llm llm){}
|
||||
default void onStop(ChatContext context) {
|
||||
}
|
||||
|
||||
default void onFailure(Llm llm, Throwable throwable){}
|
||||
default void onFailure(ChatContext context, Throwable throwable) {
|
||||
}
|
||||
}
|
||||
|
@ -15,13 +15,8 @@
|
||||
*/
|
||||
package com.agentsflex.llm;
|
||||
|
||||
import com.agentsflex.functions.Function;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface FunctionCalling {
|
||||
|
||||
<R> R call(Prompt prompt, List<Function<R>> functions);
|
||||
import com.agentsflex.message.Message;
|
||||
|
||||
public interface ChatResponse<T extends Message> {
|
||||
T getMessage();
|
||||
}
|
@ -20,6 +20,6 @@ import com.agentsflex.vector.VectorData;
|
||||
|
||||
public interface Embeddings {
|
||||
|
||||
VectorData embeddings(Document prompt);
|
||||
VectorData embeddings(Document document);
|
||||
|
||||
}
|
||||
|
@ -15,16 +15,23 @@
|
||||
*/
|
||||
package com.agentsflex.llm;
|
||||
|
||||
import com.agentsflex.llm.client.LlmClient;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
import com.agentsflex.prompt.SimplePrompt;
|
||||
|
||||
public abstract class Llm implements Embeddings {
|
||||
public interface Llm extends Embeddings {
|
||||
|
||||
public LlmClient chat(String prompt, ChatListener listener) {
|
||||
return chat(new SimplePrompt(prompt), listener);
|
||||
default String chat(String prompt) {
|
||||
ChatResponse<?> chat = chat(new SimplePrompt(prompt));
|
||||
return chat != null ? chat.getMessage().getContent() : null;
|
||||
}
|
||||
|
||||
public abstract LlmClient chat(Prompt prompt, ChatListener listener);
|
||||
<T extends ChatResponse<?>> T chat(Prompt<T> prompt);
|
||||
|
||||
|
||||
default void chatAsync(String prompt, ChatListener listener) {
|
||||
this.chatAsync(new SimplePrompt(prompt), listener);
|
||||
}
|
||||
|
||||
void chatAsync(Prompt prompt, ChatListener listener);
|
||||
|
||||
}
|
||||
|
@ -15,66 +15,96 @@
|
||||
*/
|
||||
package com.agentsflex.llm.client;
|
||||
|
||||
import com.agentsflex.functions.Function;
|
||||
import com.agentsflex.llm.ChatContext;
|
||||
import com.agentsflex.llm.ChatListener;
|
||||
import com.agentsflex.llm.ChatResponse;
|
||||
import com.agentsflex.llm.Llm;
|
||||
import com.agentsflex.llm.response.FunctionResultResponse;
|
||||
import com.agentsflex.llm.response.MessageResponse;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.message.FunctionMessage;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
import com.agentsflex.prompt.HistoriesPrompt;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class BaseLlmClientListener implements LlmClientListener {
|
||||
|
||||
private final Llm llm;
|
||||
private final LlmClient client;
|
||||
private final ChatListener chatListener;
|
||||
|
||||
private final Prompt prompt;
|
||||
|
||||
private final MessageParser messageParser;
|
||||
|
||||
private final AiMessageParser messageParser;
|
||||
private final FunctionMessageParser functionInfoParser;
|
||||
private final StringBuilder fullMessage = new StringBuilder();
|
||||
|
||||
private AiMessage lastAiMessage;
|
||||
private boolean isFunctionCalling = false;
|
||||
|
||||
public BaseLlmClientListener(Llm llm, ChatListener chatListener, Prompt prompt, MessageParser messageParser) {
|
||||
public BaseLlmClientListener(Llm llm, LlmClient client, ChatListener chatListener, Prompt prompt
|
||||
, AiMessageParser messageParser
|
||||
, FunctionMessageParser functionInfoParser) {
|
||||
this.llm = llm;
|
||||
this.client = client;
|
||||
this.chatListener = chatListener;
|
||||
this.prompt = prompt;
|
||||
this.messageParser = messageParser;
|
||||
this.functionInfoParser = functionInfoParser;
|
||||
|
||||
if (prompt instanceof FunctionPrompt) {
|
||||
if (functionInfoParser == null) {
|
||||
throw new IllegalArgumentException("Can not support Function Calling");
|
||||
} else {
|
||||
isFunctionCalling = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void onStart(LlmClient client) {
|
||||
chatListener.onStart(llm);
|
||||
chatListener.onStart(new ChatContext(llm, client));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onMessage(LlmClient client, String response) {
|
||||
lastAiMessage = messageParser.parseMessage(response);
|
||||
fullMessage.append(lastAiMessage.getContent());
|
||||
chatListener.onMessage(llm, lastAiMessage);
|
||||
if (isFunctionCalling) {
|
||||
FunctionMessage functionInfo = functionInfoParser.parseMessage(response);
|
||||
List<Function<?>> functions = ((FunctionPrompt) prompt).getFunctions();
|
||||
ChatResponse<?> r = new FunctionResultResponse(functions, functionInfo);
|
||||
chatListener.onMessage(new ChatContext(llm, client), r);
|
||||
} else {
|
||||
lastAiMessage = messageParser.parseMessage(response);
|
||||
fullMessage.append(lastAiMessage.getContent());
|
||||
lastAiMessage.setFullContent(fullMessage.toString());
|
||||
ChatResponse<?> r = new MessageResponse(lastAiMessage);
|
||||
chatListener.onMessage(new ChatContext(llm, client), r);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onStop(LlmClient client) {
|
||||
if (lastAiMessage != null){
|
||||
|
||||
lastAiMessage.setFullContent(fullMessage.toString());
|
||||
|
||||
if (this.prompt instanceof HistoriesPrompt){
|
||||
if (lastAiMessage != null) {
|
||||
if (this.prompt instanceof HistoriesPrompt) {
|
||||
((HistoriesPrompt) this.prompt).addMessage(lastAiMessage);
|
||||
}
|
||||
}
|
||||
|
||||
chatListener.onStop(llm);
|
||||
chatListener.onStop(new ChatContext(llm, client));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(LlmClient client, Throwable throwable) {
|
||||
chatListener.onFailure(llm, throwable);
|
||||
chatListener.onFailure(new ChatContext(llm, client), throwable);
|
||||
}
|
||||
|
||||
|
||||
public interface MessageParser{
|
||||
public interface AiMessageParser {
|
||||
AiMessage parseMessage(String response);
|
||||
}
|
||||
|
||||
public interface FunctionMessageParser {
|
||||
FunctionMessage parseMessage(String response);
|
||||
}
|
||||
}
|
||||
|
@ -13,7 +13,7 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.agentsflex.util;
|
||||
package com.agentsflex.llm.client;
|
||||
|
||||
import okhttp3.*;
|
||||
import org.slf4j.Logger;
|
||||
@ -23,13 +23,13 @@ import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
public class OKHttpUtil {
|
||||
private static final Logger LOG = LoggerFactory.getLogger(OKHttpUtil.class);
|
||||
public class HttpClient {
|
||||
private static final Logger LOG = LoggerFactory.getLogger(HttpClient.class);
|
||||
private static final MediaType JSON_TYPE = MediaType.parse("application/json; charset=utf-8");
|
||||
|
||||
private final OkHttpClient okHttpClient;
|
||||
|
||||
public OKHttpUtil() {
|
||||
public HttpClient() {
|
||||
this.okHttpClient = new OkHttpClient.Builder()
|
||||
.connectTimeout(3, TimeUnit.MINUTES)
|
||||
.readTimeout(3, TimeUnit.MINUTES)
|
||||
@ -37,7 +37,7 @@ public class OKHttpUtil {
|
||||
}
|
||||
|
||||
|
||||
public OKHttpUtil(OkHttpClient okHttpClient) {
|
||||
public HttpClient(OkHttpClient okHttpClient) {
|
||||
this.okHttpClient = okHttpClient;
|
||||
}
|
||||
|
@ -24,7 +24,7 @@ import java.io.IOException;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
public class HttpClient implements LlmClient {
|
||||
public class AsyncHttpClient implements LlmClient {
|
||||
private static final MediaType JSON_TYPE = MediaType.parse("application/json; charset=utf-8");
|
||||
private OkHttpClient client;
|
||||
private LlmClientListener listener;
|
||||
@ -55,15 +55,15 @@ public class HttpClient implements LlmClient {
|
||||
this.client.newCall(rBuilder.build()).enqueue(new Callback() {
|
||||
@Override
|
||||
public void onFailure(@NotNull Call call, @NotNull IOException e) {
|
||||
HttpClient.this.listener.onFailure(HttpClient.this, e);
|
||||
AsyncHttpClient.this.listener.onFailure(AsyncHttpClient.this, e);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onResponse(@NotNull Call call, @NotNull Response response) throws IOException {
|
||||
HttpClient.this.listener.onMessage(HttpClient.this, response.message());
|
||||
AsyncHttpClient.this.listener.onMessage(AsyncHttpClient.this, response.message());
|
||||
if (!isStop) {
|
||||
HttpClient.this.isStop = true;
|
||||
HttpClient.this.listener.onStop(HttpClient.this);
|
||||
AsyncHttpClient.this.isStop = true;
|
||||
AsyncHttpClient.this.listener.onStop(AsyncHttpClient.this);
|
||||
}
|
||||
}
|
||||
});
|
@ -79,6 +79,10 @@ public class WebSocketClient extends WebSocketListener implements LlmClient {
|
||||
|
||||
@Override
|
||||
public void onClosing(WebSocket webSocket, int code, String reason) {
|
||||
if (!isStop) {
|
||||
this.isStop = true;
|
||||
this.listener.onStop(this);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -0,0 +1,47 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, Agents-Flex (fuhai999@gmail.com).
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.agentsflex.llm.response;
|
||||
|
||||
import com.agentsflex.functions.Function;
|
||||
import com.agentsflex.llm.ChatResponse;
|
||||
import com.agentsflex.message.FunctionMessage;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class FunctionResultResponse implements ChatResponse<FunctionMessage> {
|
||||
|
||||
private final List<Function<?>> functions;
|
||||
private final FunctionMessage functionMessage;
|
||||
|
||||
public FunctionResultResponse(List<Function<?>> functions, FunctionMessage functionMessage) {
|
||||
this.functions = functions;
|
||||
this.functionMessage = functionMessage;
|
||||
}
|
||||
|
||||
public Object invoke() {
|
||||
for (Function<?> function : functions) {
|
||||
if (function.getName().equals(functionMessage.getFunctionName())) {
|
||||
return function.invoke(functionMessage.getArgs());
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FunctionMessage getMessage() {
|
||||
return functionMessage;
|
||||
}
|
||||
}
|
@ -0,0 +1,32 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, Agents-Flex (fuhai999@gmail.com).
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.agentsflex.llm.response;
|
||||
|
||||
import com.agentsflex.llm.ChatResponse;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
|
||||
public class MessageResponse implements ChatResponse<AiMessage> {
|
||||
private final AiMessage aiMessage;
|
||||
|
||||
public MessageResponse(AiMessage aiMessage) {
|
||||
this.aiMessage = aiMessage;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AiMessage getMessage() {
|
||||
return aiMessage;
|
||||
}
|
||||
}
|
@ -17,10 +17,17 @@ package com.agentsflex.memory;
|
||||
|
||||
import com.agentsflex.message.Message;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
public interface MessageMemory extends Memory {
|
||||
List<Message> getMessages();
|
||||
|
||||
void addMessage(Message message);
|
||||
|
||||
default void addMessages(Collection<Message> messages){
|
||||
for (Message message : messages) {
|
||||
addMessage(message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -21,7 +21,6 @@ public class AiMessage extends Message {
|
||||
private Integer index;
|
||||
private MessageStatus status;
|
||||
private int totalTokens;
|
||||
|
||||
private String fullContent;
|
||||
|
||||
|
||||
|
@ -15,5 +15,26 @@
|
||||
*/
|
||||
package com.agentsflex.message;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public class FunctionMessage extends Message{
|
||||
|
||||
private String functionName;
|
||||
private Map<String,Object> args;
|
||||
|
||||
public String getFunctionName() {
|
||||
return functionName;
|
||||
}
|
||||
|
||||
public void setFunctionName(String functionName) {
|
||||
this.functionName = functionName;
|
||||
}
|
||||
|
||||
public Map<String, Object> getArgs() {
|
||||
return args;
|
||||
}
|
||||
|
||||
public void setArgs(Map<String, Object> args) {
|
||||
this.args = args;
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,53 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, Agents-Flex (fuhai999@gmail.com).
|
||||
* <p>
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
* <p>
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* <p>
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package com.agentsflex.prompt;
|
||||
|
||||
import com.agentsflex.functions.Function;
|
||||
import com.agentsflex.functions.Functions;
|
||||
import com.agentsflex.llm.response.FunctionResultResponse;
|
||||
import com.agentsflex.memory.DefaultMessageMemory;
|
||||
import com.agentsflex.memory.MessageMemory;
|
||||
import com.agentsflex.message.HumanMessage;
|
||||
import com.agentsflex.message.Message;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class FunctionPrompt extends Prompt<FunctionResultResponse> {
|
||||
private MessageMemory memory = new DefaultMessageMemory();
|
||||
|
||||
private List<Function<?>> functions = new ArrayList<>();
|
||||
|
||||
public FunctionPrompt(String prompt, Class<?> funcClass) {
|
||||
memory.addMessage(new HumanMessage(prompt));
|
||||
functions.addAll(Functions.from(funcClass));
|
||||
}
|
||||
|
||||
public FunctionPrompt(List<Message> messages, Class<?> funcClass) {
|
||||
memory.addMessages(messages);
|
||||
functions.addAll(Functions.from(funcClass));
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<Message> getMessages() {
|
||||
return memory.getMessages();
|
||||
}
|
||||
|
||||
public List<Function<?>> getFunctions() {
|
||||
return functions;
|
||||
}
|
||||
}
|
@ -15,13 +15,14 @@
|
||||
*/
|
||||
package com.agentsflex.prompt;
|
||||
|
||||
import com.agentsflex.llm.response.MessageResponse;
|
||||
import com.agentsflex.memory.DefaultMessageMemory;
|
||||
import com.agentsflex.memory.MessageMemory;
|
||||
import com.agentsflex.message.Message;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class HistoriesPrompt extends Prompt{
|
||||
public class HistoriesPrompt extends Prompt<MessageResponse>{
|
||||
|
||||
private MessageMemory memory = new DefaultMessageMemory();
|
||||
|
||||
|
@ -21,7 +21,7 @@ import com.agentsflex.util.Metadata;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
public abstract class Prompt extends Metadata {
|
||||
public abstract class Prompt<ChatResponse> extends Metadata {
|
||||
|
||||
public abstract List<Message> getMessages();
|
||||
|
||||
|
@ -15,13 +15,14 @@
|
||||
*/
|
||||
package com.agentsflex.prompt;
|
||||
|
||||
import com.agentsflex.llm.response.MessageResponse;
|
||||
import com.agentsflex.message.HumanMessage;
|
||||
import com.agentsflex.message.Message;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class SimplePrompt extends Prompt{
|
||||
public class SimplePrompt extends Prompt<MessageResponse> {
|
||||
|
||||
private final String content;
|
||||
|
||||
|
@ -15,17 +15,22 @@
|
||||
*/
|
||||
package com.agentsflex.llm.openai;
|
||||
|
||||
import com.agentsflex.llm.client.BaseLlmClientListener;
|
||||
import com.agentsflex.llm.client.LlmClient;
|
||||
import com.agentsflex.llm.client.LlmClientListener;
|
||||
import com.agentsflex.llm.client.impl.SseClient;
|
||||
import com.agentsflex.document.Document;
|
||||
import com.agentsflex.functions.Function;
|
||||
import com.agentsflex.llm.BaseLlm;
|
||||
import com.agentsflex.llm.ChatListener;
|
||||
import com.agentsflex.llm.FunctionCalling;
|
||||
import com.agentsflex.llm.ChatResponse;
|
||||
import com.agentsflex.llm.client.BaseLlmClientListener;
|
||||
import com.agentsflex.llm.client.HttpClient;
|
||||
import com.agentsflex.llm.client.LlmClient;
|
||||
import com.agentsflex.llm.client.LlmClientListener;
|
||||
import com.agentsflex.llm.client.impl.SseClient;
|
||||
import com.agentsflex.llm.response.FunctionResultResponse;
|
||||
import com.agentsflex.llm.response.MessageResponse;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.message.FunctionMessage;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
import com.agentsflex.document.Document;
|
||||
import com.agentsflex.util.OKHttpUtil;
|
||||
import com.agentsflex.util.StringUtil;
|
||||
import com.agentsflex.vector.VectorData;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
@ -36,17 +41,46 @@ import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> implements FunctionCalling {
|
||||
public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> {
|
||||
|
||||
private final OKHttpUtil httpUtil = new OKHttpUtil();
|
||||
private final HttpClient httpClient = new HttpClient();
|
||||
|
||||
public OpenAiLlm(OpenAiLlmConfig config) {
|
||||
super(config);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T extends ChatResponse<?>> T chat(Prompt<T> prompt){
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
headers.put("Content-Type", "application/json");
|
||||
headers.put("Authorization", "Bearer " + getConfig().getApiKey());
|
||||
|
||||
String payload = OpenAiLLmUtil.promptToPayload(prompt, config);
|
||||
String responseString = httpClient.post("https://api.openai.com/v1/chat/completions", headers, payload);
|
||||
if (StringUtil.noText(responseString)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (prompt instanceof FunctionPrompt) {
|
||||
List<Function<?>> functions = ((FunctionPrompt) prompt).getFunctions();
|
||||
|
||||
JSONObject jsonObject = JSON.parseObject(responseString);
|
||||
String callFunctionName = (String) JSONPath.eval(jsonObject, "$.choices[0].tool_calls[0].function.name");
|
||||
String callFunctionArgsString = (String) JSONPath.eval(jsonObject, "$.choices[0].tool_calls[0].function.arguments");
|
||||
JSONObject callFunctionArgs = JSON.parseObject(callFunctionArgsString);
|
||||
|
||||
FunctionMessage functionMessage = new FunctionMessage();
|
||||
functionMessage.setFunctionName(callFunctionName);
|
||||
functionMessage.setArgs(callFunctionArgs);
|
||||
return (T) new FunctionResultResponse(functions, functionMessage);
|
||||
} else {
|
||||
AiMessage aiMessage = OpenAiLLmUtil.parseAiMessage(responseString);
|
||||
return (T) new MessageResponse(aiMessage);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public LlmClient chat(Prompt prompt, ChatListener listener) {
|
||||
public void chatAsync(Prompt prompt, ChatListener listener) {
|
||||
LlmClient llmClient = new SseClient();
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
headers.put("Content-Type", "application/json");
|
||||
@ -54,22 +88,21 @@ public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> implements FunctionCalli
|
||||
|
||||
String payload = OpenAiLLmUtil.promptToPayload(prompt, config);
|
||||
|
||||
LlmClientListener clientListener = new BaseLlmClientListener(this, listener, prompt, OpenAiLLmUtil::parseAiMessage);
|
||||
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, OpenAiLLmUtil::parseAiMessage, null);
|
||||
llmClient.start("https://api.openai.com/v1/chat/completions", headers, payload, clientListener);
|
||||
return llmClient;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public VectorData embeddings(Document text) {
|
||||
public VectorData embeddings(Document document) {
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
headers.put("Content-Type", "application/json");
|
||||
headers.put("Authorization", "Bearer " + getConfig().getApiKey());
|
||||
|
||||
String payload = OpenAiLLmUtil.promptToEmbeddingsPayload(text);
|
||||
String payload = OpenAiLLmUtil.promptToEmbeddingsPayload(document);
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
String response = httpUtil.post("https://api.openai.com/v1/embeddings", headers, payload);
|
||||
String response = httpClient.post("https://api.openai.com/v1/embeddings", headers, payload);
|
||||
if (StringUtil.noText(response)) {
|
||||
return null;
|
||||
}
|
||||
@ -81,30 +114,5 @@ public class OpenAiLlm extends BaseLlm<OpenAiLlmConfig> implements FunctionCalli
|
||||
return vectorData;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R> R call(Prompt prompt, List<Function<R>> functions) {
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
headers.put("Content-Type", "application/json");
|
||||
headers.put("Authorization", "Bearer " + getConfig().getApiKey());
|
||||
|
||||
String payload = OpenAiLLmUtil.promptToFunctionCallingPayload(prompt, config, functions);
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
String response = httpUtil.post("https://api.openai.com/v1/embeddings", headers, payload);
|
||||
if (StringUtil.noText(response)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
JSONObject jsonObject = JSON.parseObject(response);
|
||||
String callFunctionName = (String) JSONPath.eval(jsonObject, "$.choices[0].tool_calls[0].function.name");
|
||||
String callFunctionArgsString = (String) JSONPath.eval(jsonObject, "$.choices[0].tool_calls[0].function.arguments");
|
||||
JSONObject callFunctionArgs = JSON.parseObject(callFunctionArgsString);
|
||||
|
||||
for (Function<R> function : functions) {
|
||||
if (function.getName().equals(callFunctionName)) {
|
||||
return function.invoke(callFunctionArgs);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
@ -1,43 +1,36 @@
|
||||
package com.agentsflex.llm.openai;
|
||||
|
||||
import com.agentsflex.functions.Functions;
|
||||
import com.agentsflex.llm.Llm;
|
||||
import com.agentsflex.prompt.SimplePrompt;
|
||||
import com.agentsflex.llm.response.FunctionResultResponse;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
import org.junit.Test;
|
||||
|
||||
public class OpenAiLlmTest {
|
||||
|
||||
@Test
|
||||
public void testChat() throws InterruptedException {
|
||||
|
||||
|
||||
public void testChat() {
|
||||
OpenAiLlmConfig config = new OpenAiLlmConfig();
|
||||
config.setApiKey("sk-rts5NF6n*******");
|
||||
|
||||
Llm llm = new OpenAiLlm(config);
|
||||
String response = llm.chat("请问你叫什么名字");
|
||||
|
||||
llm.chat(new SimplePrompt("请写一个小兔子战胜大灰狼的故事"), (instance, message) -> {
|
||||
System.out.println("--->" + message.getContent());
|
||||
});
|
||||
|
||||
Thread.sleep(10000);
|
||||
|
||||
System.out.println(response);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFunctionCalling() throws InterruptedException {
|
||||
|
||||
|
||||
public void testFunctionCalling() throws InterruptedException {
|
||||
OpenAiLlmConfig config = new OpenAiLlmConfig();
|
||||
config.setApiKey("sk-rts5NF6n*******");
|
||||
|
||||
OpenAiLlm llm = new OpenAiLlm(config);
|
||||
|
||||
Functions<String> functions = Functions.from(WeatherUtil.class, String.class);
|
||||
String result = llm.call(new SimplePrompt("今天的天气怎么样"), functions);
|
||||
FunctionPrompt prompt = new FunctionPrompt("今天北京的天气怎么样", WeatherUtil.class);
|
||||
FunctionResultResponse response = llm.chat(prompt);
|
||||
|
||||
Object result = response.invoke();
|
||||
|
||||
System.out.println(result);
|
||||
|
||||
Thread.sleep(10000);
|
||||
|
||||
// "Today it will be dull and overcast in 北京"
|
||||
}
|
||||
}
|
||||
|
@ -9,6 +9,6 @@ public class WeatherUtil {
|
||||
public static String getWeatherInfo(
|
||||
@FunctionParam(name = "city", description = "the city name") String name
|
||||
) {
|
||||
return "Today it will be dull and overcast in city " + name;
|
||||
return "Today it will be dull and overcast in " + name;
|
||||
}
|
||||
}
|
||||
|
@ -15,15 +15,20 @@
|
||||
*/
|
||||
package com.agentsflex.llm.qwen;
|
||||
|
||||
import com.agentsflex.llm.ChatResponse;
|
||||
import com.agentsflex.llm.client.BaseLlmClientListener;
|
||||
import com.agentsflex.llm.client.HttpClient;
|
||||
import com.agentsflex.llm.client.LlmClient;
|
||||
import com.agentsflex.llm.client.LlmClientListener;
|
||||
import com.agentsflex.llm.client.impl.SseClient;
|
||||
import com.agentsflex.llm.BaseLlm;
|
||||
import com.agentsflex.llm.ChatListener;
|
||||
import com.agentsflex.llm.response.MessageResponse;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.prompt.FunctionPrompt;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
import com.agentsflex.document.Document;
|
||||
import com.agentsflex.util.StringUtil;
|
||||
import com.agentsflex.vector.VectorData;
|
||||
|
||||
import java.util.HashMap;
|
||||
@ -35,9 +40,34 @@ public class QwenLlm extends BaseLlm<QwenLlmConfig> {
|
||||
super(config);
|
||||
}
|
||||
|
||||
HttpClient httpClient = new HttpClient();
|
||||
|
||||
|
||||
@Override
|
||||
public LlmClient chat(Prompt prompt, ChatListener listener) {
|
||||
public ChatResponse<?> chat(Prompt prompt) {
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
headers.put("Content-Type", "application/json");
|
||||
headers.put("Authorization", "Bearer " + getConfig().getApiKey());
|
||||
|
||||
|
||||
String payload = QwenLlmUtil.promptToPayload(prompt, config);
|
||||
String responseString = httpClient.post("https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation", headers, payload);
|
||||
if (StringUtil.noText(responseString)){
|
||||
return null;
|
||||
}
|
||||
|
||||
if (prompt instanceof FunctionPrompt){
|
||||
|
||||
}else {
|
||||
AiMessage aiMessage = QwenLlmUtil.parseAiMessage(responseString, 0);
|
||||
return new MessageResponse(aiMessage);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void chatAsync(Prompt prompt, ChatListener listener) {
|
||||
LlmClient llmClient = new SseClient();
|
||||
Map<String, String> headers = new HashMap<>();
|
||||
headers.put("Content-Type", "application/json");
|
||||
@ -45,24 +75,21 @@ public class QwenLlm extends BaseLlm<QwenLlmConfig> {
|
||||
|
||||
String payload = QwenLlmUtil.promptToPayload(prompt, config);
|
||||
|
||||
LlmClientListener clientListener = new BaseLlmClientListener(this, listener, prompt, new BaseLlmClientListener.MessageParser() {
|
||||
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient,listener, prompt, new BaseLlmClientListener.AiMessageParser() {
|
||||
int prevMessageLength = 0;
|
||||
|
||||
@Override
|
||||
public AiMessage parseMessage(String response) {
|
||||
AiMessage aiMessage = QwenLlmUtil.parseAiMessage(response, prevMessageLength);
|
||||
prevMessageLength += aiMessage.getContent().length();
|
||||
return aiMessage;
|
||||
}
|
||||
});
|
||||
},null);
|
||||
llmClient.start("https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation", headers, payload, clientListener);
|
||||
|
||||
return llmClient;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public VectorData embeddings(Document text) {
|
||||
public VectorData embeddings(Document document) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ public class QwenTest {
|
||||
config.setModel("qwen-turbo");
|
||||
|
||||
Llm llm = new QwenLlm(config);
|
||||
llm.chat(new SimplePrompt("请写一个小兔子战胜大灰狼的故事"), (llm1, aiMessage) -> {
|
||||
llm.chatAsync(new SimplePrompt("请写一个小兔子战胜大灰狼的故事"), (llm1, aiMessage) -> {
|
||||
System.out.println(">>>>" + aiMessage.getContent());
|
||||
});
|
||||
|
||||
|
@ -15,16 +15,22 @@
|
||||
*/
|
||||
package com.agentsflex.llm.spark;
|
||||
|
||||
import com.agentsflex.document.Document;
|
||||
import com.agentsflex.llm.BaseLlm;
|
||||
import com.agentsflex.llm.ChatContext;
|
||||
import com.agentsflex.llm.ChatListener;
|
||||
import com.agentsflex.llm.ChatResponse;
|
||||
import com.agentsflex.llm.client.BaseLlmClientListener;
|
||||
import com.agentsflex.llm.client.LlmClient;
|
||||
import com.agentsflex.llm.client.LlmClientListener;
|
||||
import com.agentsflex.llm.client.impl.WebSocketClient;
|
||||
import com.agentsflex.llm.BaseLlm;
|
||||
import com.agentsflex.llm.ChatListener;
|
||||
import com.agentsflex.llm.response.MessageResponse;
|
||||
import com.agentsflex.message.AiMessage;
|
||||
import com.agentsflex.prompt.Prompt;
|
||||
import com.agentsflex.document.Document;
|
||||
import com.agentsflex.vector.VectorData;
|
||||
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
|
||||
public SparkLlm(SparkLlmConfig config) {
|
||||
@ -32,21 +38,44 @@ public class SparkLlm extends BaseLlm<SparkLlmConfig> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public LlmClient chat(Prompt prompt, ChatListener listener) {
|
||||
public VectorData embeddings(Document document) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MessageResponse chat(Prompt prompt) {
|
||||
CountDownLatch latch = new CountDownLatch(1);
|
||||
AiMessage aiMessage = new AiMessage();
|
||||
chatAsync(prompt, new ChatListener() {
|
||||
@Override
|
||||
public void onMessage(ChatContext context, ChatResponse<?> response) {
|
||||
aiMessage.setContent(((AiMessage) response.getMessage()).getFullContent());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onStop(ChatContext context) {
|
||||
ChatListener.super.onStop(context);
|
||||
latch.countDown();
|
||||
}
|
||||
});
|
||||
try {
|
||||
latch.await();
|
||||
} catch (InterruptedException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
return new MessageResponse(aiMessage);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void chatAsync(Prompt prompt, ChatListener listener) {
|
||||
LlmClient llmClient = new WebSocketClient();
|
||||
String url = SparkLlmUtil.createURL(config);
|
||||
|
||||
String payload = SparkLlmUtil.promptToPayload(prompt, config);
|
||||
|
||||
LlmClientListener clientListener = new BaseLlmClientListener(this, listener, prompt, SparkLlmUtil::parseAiMessage);
|
||||
LlmClientListener clientListener = new BaseLlmClientListener(this, llmClient, listener, prompt, SparkLlmUtil::parseAiMessage, null);
|
||||
llmClient.start(url, null, payload, clientListener);
|
||||
|
||||
return llmClient;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public VectorData embeddings(Document text) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
@ -5,14 +5,26 @@ import com.agentsflex.llm.spark.SparkLlm;
|
||||
import com.agentsflex.llm.spark.SparkLlmConfig;
|
||||
import com.agentsflex.message.HumanMessage;
|
||||
import com.agentsflex.prompt.HistoriesPrompt;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.Scanner;
|
||||
|
||||
public class SparkLlmTest {
|
||||
|
||||
public static void main(String[] args) throws InterruptedException {
|
||||
@Test
|
||||
public void testSimple() {
|
||||
SparkLlmConfig config = new SparkLlmConfig();
|
||||
config.setAppId("****");
|
||||
config.setApiKey("****");
|
||||
config.setApiSecret("****");
|
||||
|
||||
Llm llm = new SparkLlm(config);
|
||||
String result = llm.chat("你好");
|
||||
System.out.println(result);
|
||||
}
|
||||
|
||||
|
||||
public static void main(String[] args) {
|
||||
SparkLlmConfig config = new SparkLlmConfig();
|
||||
config.setAppId("****");
|
||||
config.setApiKey("****");
|
||||
@ -26,18 +38,15 @@ public class SparkLlmTest {
|
||||
Scanner scanner = new Scanner(System.in);
|
||||
String userInput = scanner.nextLine();
|
||||
|
||||
while (userInput != null){
|
||||
while (userInput != null) {
|
||||
|
||||
prompt.addMessage(new HumanMessage(userInput));
|
||||
|
||||
llm.chat(prompt, (instance, message) -> {
|
||||
System.out.println(">>>> " + message.getContent());
|
||||
llm.chatAsync(prompt, (context, response) -> {
|
||||
System.out.println(">>>> " + response.getMessage().getContent());
|
||||
});
|
||||
|
||||
userInput = scanner.nextLine();
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -15,7 +15,7 @@
|
||||
*/
|
||||
package com.agentsflex.vector.aliyun;
|
||||
|
||||
import com.agentsflex.util.OKHttpUtil;
|
||||
import com.agentsflex.llm.client.HttpClient;
|
||||
import com.agentsflex.util.StringUtil;
|
||||
import com.agentsflex.vector.RetrieveWrapper;
|
||||
import com.agentsflex.vector.VectorDocument;
|
||||
@ -33,7 +33,7 @@ public class AliyunVectorStorage extends VectorStorage<VectorDocument> {
|
||||
|
||||
private AliyunVectorStorageConfig config;
|
||||
|
||||
private final OKHttpUtil httpUtil = new OKHttpUtil();
|
||||
private final HttpClient httpUtil = new HttpClient();
|
||||
|
||||
public AliyunVectorStorage(AliyunVectorStorageConfig config) {
|
||||
this.config = config;
|
||||
|
@ -15,7 +15,7 @@
|
||||
*/
|
||||
package com.agentsflex.vector.qcloud;
|
||||
|
||||
import com.agentsflex.util.OKHttpUtil;
|
||||
import com.agentsflex.llm.client.HttpClient;
|
||||
import com.agentsflex.util.StringUtil;
|
||||
import com.agentsflex.vector.RetrieveWrapper;
|
||||
import com.agentsflex.vector.VectorDocument;
|
||||
@ -33,7 +33,7 @@ public class QCloudVectorStorage extends VectorStorage<VectorDocument> {
|
||||
|
||||
private QCloudVectorStorageConfig config;
|
||||
|
||||
private final OKHttpUtil httpUtil = new OKHttpUtil();
|
||||
private final HttpClient httpUtil = new HttpClient();
|
||||
|
||||
|
||||
public QCloudVectorStorage(QCloudVectorStorageConfig config) {
|
||||
|
Loading…
Reference in New Issue
Block a user