refactor: optimize Maps

This commit is contained in:
Michael Yang 2024-11-13 11:43:34 +08:00
parent 9f2ab83e06
commit f8817a1455
18 changed files with 127 additions and 121 deletions

View File

@ -73,8 +73,8 @@ public class ImagePrompt extends TextPrompt {
@Override
public Object getMessageContent() {
List<Map<String, Object>> messageContent = new ArrayList<>();
messageContent.add(Maps.of("type", "text").put("text", prompt.content));
messageContent.add(Maps.of("type", "image_url").put("image_url", Maps.of("url", prompt.imageUrl)));
messageContent.add(Maps.of("type", "text").set("text", prompt.content));
messageContent.add(Maps.of("type", "image_url").set("image_url", Maps.of("url", prompt.imageUrl)));
return messageContent;
}
}

View File

@ -36,19 +36,24 @@ public class Maps extends HashMap<String, Object> {
}
public static Maps ofNotNull(String key, Object value) {
return new Maps().putIfNotNull(key, value);
return new Maps().setIfNotNull(key, value);
}
public static Maps ofNotEmpty(String key, Object value) {
return new Maps().putIfNotEmpty(key, value);
return new Maps().setIfNotEmpty(key, value);
}
public static Maps ofNotEmpty(String key, Maps value) {
return new Maps().putIfNotEmpty(key, value);
return new Maps().setIfNotEmpty(key, value);
}
public Maps put(String key, Object value) {
public Maps set(String key, Object value) {
super.put(key, value);
return this;
}
public Maps setChild(String key, Object value) {
if (key.contains(".")) {
String[] keys = key.split("\\.");
Map<String, Object> currentMap = this;
@ -71,37 +76,57 @@ public class Maps extends HashMap<String, Object> {
return this;
}
public Maps putOrDefault(String key, Object value, Object orDefault) {
public Maps setOrDefault(String key, Object value, Object orDefault) {
if (isNullOrEmpty(value)) {
return this.put(key, orDefault);
return this.set(key, orDefault);
} else {
return this.put(key, value);
return this.set(key, value);
}
}
public Maps putIf(boolean condition, String key, Object value) {
public Maps setIf(boolean condition, String key, Object value) {
if (condition) put(key, value);
return this;
}
public Maps putIf(Function<Maps, Boolean> func, String key, Object value) {
public Maps setIf(Function<Maps, Boolean> func, String key, Object value) {
if (func.apply(this)) put(key, value);
return this;
}
public Maps putIfNotNull(String key, Object value) {
public Maps setIfNotNull(String key, Object value) {
if (value != null) put(key, value);
return this;
}
public Maps putIfNotEmpty(String key, Object value) {
public Maps setIfNotEmpty(String key, Object value) {
if (!isNullOrEmpty(value)) {
put(key, value);
}
return this;
}
private boolean isNullOrEmpty(Object value) {
public Maps setIfContainsKey(String checkKey, String key, Object value) {
if (this.containsKey(checkKey)) {
this.put(key, value);
}
return this;
}
public Maps setIfNotContainsKey(String checkKey, String key, Object value) {
if (!this.containsKey(checkKey)) {
this.put(key, value);
}
return this;
}
public String toJSON() {
return JSON.toJSONString(this);
}
private static boolean isNullOrEmpty(Object value) {
if (value == null) {
return true;
}
@ -125,22 +150,4 @@ public class Maps extends HashMap<String, Object> {
}
public Maps putIfContainsKey(String checkKey, String key, Object value) {
if (this.containsKey(checkKey)) {
this.put(key, value);
}
return this;
}
public Maps putIfNotContainsKey(String checkKey, String key, Object value) {
if (!this.containsKey(checkKey)) {
this.put(key, value);
}
return this;
}
public String toJSON() {
return JSON.toJSONString(this);
}
}

View File

@ -11,7 +11,7 @@ public class MapsTest {
@Test
public void testMaps() {
Map<String, Object> map1 = Maps.of("key", "value")
.putChild("options.aaa", 1);
.setChild("options.aaa", 1);
Assert.assertEquals(1, ((Map<?, ?>) map1.get("options")).get("aaa"));
System.out.println(map1);

View File

@ -37,8 +37,8 @@ public class GiteeImageModel implements ImageModel {
headers.put("Authorization", "Bearer " + config.getApiKey());
String payload = Maps.of("inputs", request.getPrompt())
.putIfNotNull("width", request.getWidth())
.putIfNotNull("height", request.getHeight())
.setIfNotNull("width", request.getWidth())
.setIfNotNull("height", request.getHeight())
.toJSON();
String url = config.getEndpoint() + "/api/serverless/" + config.getModel() + "/text-to-image";

View File

@ -42,9 +42,9 @@ public class OpenAiImageModel implements ImageModel {
headers.put("Authorization", "Bearer " + config.getApiKey());
String payload = Maps.of("model", config.getModel())
.put("prompt", request.getPrompt())
.putIfNotNull("n", request.getN())
.put("size", request.getSize())
.set("prompt", request.getPrompt())
.setIfNotNull("n", request.getN())
.set("size", request.getSize())
.toJSON();

View File

@ -41,11 +41,11 @@ public class SiliconImageModel implements ImageModel {
headers.put("Authorization", "Bearer " + config.getApiKey());
String payload = Maps.of("prompt", request.getPrompt())
.putIfNotEmpty("negative_prompt", request.getNegativePrompt())
.putOrDefault("image_size", request.getSize(), config.getImageSize())
.putOrDefault("batch_size", request.getN(), 1)
.putOrDefault("num_inference_steps", request.getOption("num_inference_steps"), config.getNumInferenceSteps())
.putOrDefault("guidance_scale", request.getOption("guidance_scale"), config.getGuidanceScale())
.setIfNotEmpty("negative_prompt", request.getNegativePrompt())
.setOrDefault("image_size", request.getSize(), config.getImageSize())
.setOrDefault("batch_size", request.getN(), 1)
.setOrDefault("num_inference_steps", request.getOption("num_inference_steps"), config.getNumInferenceSteps())
.setOrDefault("guidance_scale", request.getOption("guidance_scale"), config.getGuidanceScale())
.toJSON();
String url = config.getEndpoint() + SiliconflowImageModels.getPath(config.getModel());

View File

@ -25,12 +25,12 @@ public class SiliconflowImageModels {
private static Map<String, Object> modelsPathMapping = Maps
.of(flux_1_schnell, "/v1/black-forest-labs/FLUX.1-schnell/text-to-image")
.put(Stable_Diffusion_3, "/v1/stabilityai/stable-diffusion-3-medium/text-to-image")
.put(Stable_Diffusion_XL, "/v1/stabilityai/stable-diffusion-xl-base-1.0/text-to-image")
.put(Stable_Diffusion_2_1, "/v1/stabilityai/stable-diffusion-2-1/text-to-image")
.put(Stable_Diffusion_Turbo, "/v1/stabilityai/sd-turbo/text-to-image")
.put(Stable_Diffusion_XL_Turbo, "/v1/stabilityai/sdxl-turbo/text-to-image")
.put(Stable_Diffusion_XL_Lighting, "/v1/ByteDance/SDXL-Lightning/text-to-image")
.set(Stable_Diffusion_3, "/v1/stabilityai/stable-diffusion-3-medium/text-to-image")
.set(Stable_Diffusion_XL, "/v1/stabilityai/stable-diffusion-xl-base-1.0/text-to-image")
.set(Stable_Diffusion_2_1, "/v1/stabilityai/stable-diffusion-2-1/text-to-image")
.set(Stable_Diffusion_Turbo, "/v1/stabilityai/sd-turbo/text-to-image")
.set(Stable_Diffusion_XL_Turbo, "/v1/stabilityai/sdxl-turbo/text-to-image")
.set(Stable_Diffusion_XL_Lighting, "/v1/ByteDance/SDXL-Lightning/text-to-image")
;
public static String getPath(String model) {

View File

@ -43,7 +43,7 @@ public class StabilityImageModel implements ImageModel {
headers.put("Authorization", "Bearer " + config.getApiKey());
Map<String, Object> payload = Maps.of("prompt", request.getPrompt())
.putIfNotNull("output_format", "jpeg");
.setIfNotNull("output_format", "jpeg");
String url = config.getEndpoint() + "/v2beta/stable-image/generate/sd3";

View File

@ -62,7 +62,7 @@ public class ChatglmLlm extends BaseLlm<ChatglmLlmConfig> {
String endpoint = config.getEndpoint();
String payload = Maps.of("model", "embedding-2")
.put("input", document.getContent())
.set("input", document.getContent())
.toJSON();
String response = httpClient.post(endpoint + "/api/paas/v4/embeddings", headers, payload);

View File

@ -95,14 +95,14 @@ public class ChatglmLlmUtil {
List<Message> messages = prompt.toMessages();
HumanMessage humanMessage = (HumanMessage) CollectionUtil.lastItem(messages);
return Maps.of("model", config.getModel())
.put("messages", promptFormat.toMessagesJsonObject(messages))
.putIf(withStream, "stream", true)
.putIfNotEmpty("tools", promptFormat.toFunctionsJsonObject(humanMessage))
.putIfContainsKey("tools", "tool_choice", humanMessage.getToolChoice())
.putIfNotNull("top_p", options.getTopP())
.putIfNotEmpty("stop", options.getStop())
.putIf(map -> !map.containsKey("tools") && options.getTemperature() > 0, "temperature", options.getTemperature())
.putIf(map -> !map.containsKey("tools") && options.getMaxTokens() != null, "max_tokens", options.getMaxTokens())
.set("messages", promptFormat.toMessagesJsonObject(messages))
.setIf(withStream, "stream", true)
.setIfNotEmpty("tools", promptFormat.toFunctionsJsonObject(humanMessage))
.setIfContainsKey("tools", "tool_choice", humanMessage.getToolChoice())
.setIfNotNull("top_p", options.getTopP())
.setIfNotEmpty("stop", options.getStop())
.setIf(map -> !map.containsKey("tools") && options.getTemperature() > 0, "temperature", options.getTemperature())
.setIf(map -> !map.containsKey("tools") && options.getMaxTokens() != null, "max_tokens", options.getMaxTokens())
.toJSON();
}

View File

@ -62,12 +62,12 @@ public class CozeLlmUtil {
public static String promptToPayload(Prompt prompt, String botId, String userId, Map<String, String> customVariables, boolean stream) {
List<Message> messages = prompt.toMessages();
return Maps.of()
.put("bot_id", botId)
.put("user_id", userId)
.put("auto_save_history", true)
.put("additional_messages", promptFormat.toMessagesJsonObject(messages))
.put("stream", stream)
.putIf(customVariables != null, "custom_variables", customVariables)
.set("bot_id", botId)
.set("user_id", userId)
.set("auto_save_history", true)
.set("additional_messages", promptFormat.toMessagesJsonObject(messages))
.set("stream", stream)
.setIf(customVariables != null, "custom_variables", customVariables)
.toJSON();
}

View File

@ -40,17 +40,17 @@ public class GiteeAiLLmUtil {
List<Message> messages = prompt.toMessages();
HumanMessage humanMessage = (HumanMessage) CollectionUtil.lastItem(messages);
return Maps.of()
.put("messages", promptFormat.toMessagesJsonObject(messages))
.putIf(withStream, "stream", withStream)
.putIfNotNull("max_tokens", options.getMaxTokens())
.putIfNotNull("temperature", options.getTemperature())
.putIfNotEmpty("tools", promptFormat.toFunctionsJsonObject(humanMessage))
.putIfContainsKey("tools", "tool_choice", humanMessage.getToolChoice())
.putIfNotNull("top_p", options.getTopP())
.putIfNotNull("top_k", options.getTopK())
.putIfNotEmpty("stop", options.getStop())
.putIf(map -> !map.containsKey("tools") && options.getTemperature() > 0, "temperature", options.getTemperature())
.putIf(map -> !map.containsKey("tools") && options.getMaxTokens() != null, "max_tokens", options.getMaxTokens())
.set("messages", promptFormat.toMessagesJsonObject(messages))
.setIf(withStream, "stream", withStream)
.setIfNotNull("max_tokens", options.getMaxTokens())
.setIfNotNull("temperature", options.getTemperature())
.setIfNotEmpty("tools", promptFormat.toFunctionsJsonObject(humanMessage))
.setIfContainsKey("tools", "tool_choice", humanMessage.getToolChoice())
.setIfNotNull("top_p", options.getTopP())
.setIfNotNull("top_k", options.getTopK())
.setIfNotEmpty("stop", options.getStop())
.setIf(map -> !map.containsKey("tools") && options.getTemperature() > 0, "temperature", options.getTemperature())
.setIf(map -> !map.containsKey("tools") && options.getMaxTokens() != null, "max_tokens", options.getMaxTokens())
.toJSON();
}

View File

@ -32,23 +32,22 @@ public class MoonshotLlmUtil {
}
/**
* 将给定的Prompt转换为特定的payload格式用于与语言模型进行交互
*
* @param prompt 需要转换为payload的Prompt对象包含了对话的具体内容
* @param config 用于配置Moonshot LLM行为的配置对象例如指定使用的模型
* @param isStream 指示payload是否应该以流的形式进行处理
* @param prompt 需要转换为payload的Prompt对象包含了对话的具体内容
* @param config 用于配置Moonshot LLM行为的配置对象例如指定使用的模型
* @param isStream 指示payload是否应该以流的形式进行处理
* @param chatOptions 包含了对话选项的配置如温度和最大令牌数等
* @return 返回一个字符串形式的payload供进一步的处理或发送给语言模型
*/
public static String promptToPayload(Prompt prompt, MoonshotLlmConfig config, Boolean isStream, ChatOptions chatOptions) {
// 构建payload的根结构包括模型信息流式处理标志对话选项和格式化后的prompt消息
return Maps.of("model", config.getModel())
.put("stream", isStream)
.put("temperature", chatOptions.getTemperature())
.put("max_tokens", chatOptions.getMaxTokens())
.put("messages", promptFormat.toMessagesJsonObject(prompt.toMessages()))
.set("stream", isStream)
.set("temperature", chatOptions.getTemperature())
.set("max_tokens", chatOptions.getMaxTokens())
.set("messages", promptFormat.toMessagesJsonObject(prompt.toMessages()))
.toJSON();
}
}

View File

@ -61,7 +61,7 @@ public class OllamaLlm extends BaseLlm<OllamaLlmConfig> {
}
String payload = Maps.of("model", options.getModelOrDefault(config.getModel()))
.put("input", document.getContent())
.set("input", document.getContent())
.toJSON();
String endpoint = config.getEndpoint();

View File

@ -104,14 +104,14 @@ public class OllamaLlmUtil {
public static String promptToPayload(Prompt prompt, OllamaLlmConfig config, ChatOptions options, boolean stream) {
List<Message> messages = prompt.toMessages();
return Maps.of("model", config.getModel())
.put("messages", promptFormat.toMessagesJsonObject(messages))
.putIf(!stream, "stream", stream)
.putIfNotEmpty("tools", promptFormat.toFunctionsJsonObject(messages.get(messages.size() - 1)))
.putIfNotEmpty("options.seed", options.getSeed())
.putIfNotEmpty("options.top_k", options.getTopK())
.putIfNotEmpty("options.top_p", options.getTopP())
.putIfNotEmpty("options.temperature", options.getTemperature())
.putIfNotEmpty("options.stop", options.getStop())
.set("messages", promptFormat.toMessagesJsonObject(messages))
.setIf(!stream, "stream", stream)
.setIfNotEmpty("tools", promptFormat.toFunctionsJsonObject(messages.get(messages.size() - 1)))
.setIfNotEmpty("options.seed", options.getSeed())
.setIfNotEmpty("options.top_k", options.getTopK())
.setIfNotEmpty("options.top_p", options.getTopP())
.setIfNotEmpty("options.temperature", options.getTemperature())
.setIfNotEmpty("options.stop", options.getStop())
.toJSON();
}

View File

@ -42,8 +42,8 @@ public class OpenAiLLmUtil {
public static String promptToEmbeddingsPayload(Document text, EmbeddingOptions options, OpenAiLlmConfig config) {
// https://platform.openai.com/docs/api-reference/making-requests
return Maps.of("model", options.getModelOrDefault(config.getDefaultEmbeddingModel()))
.put("encoding_format", "float")
.put("input", text.getContent())
.set("encoding_format", "float")
.set("input", text.getContent())
.toJSON();
}
@ -52,14 +52,14 @@ public class OpenAiLLmUtil {
List<Message> messages = prompt.toMessages();
HumanMessage humanMessage = (HumanMessage) CollectionUtil.lastItem(messages);
return Maps.of("model", config.getModel())
.put("messages", promptFormat.toMessagesJsonObject(messages))
.putIf(withStream, "stream", true)
.putIfNotEmpty("tools", promptFormat.toFunctionsJsonObject(humanMessage))
.putIfContainsKey("tools", "tool_choice", humanMessage.getToolChoice())
.putIfNotNull("top_p", options.getTopP())
.putIfNotEmpty("stop", options.getStop())
.putIf(map -> !map.containsKey("tools") && options.getTemperature() > 0, "temperature", options.getTemperature())
.putIf(map -> !map.containsKey("tools") && options.getMaxTokens() != null, "max_tokens", options.getMaxTokens())
.set("messages", promptFormat.toMessagesJsonObject(messages))
.setIf(withStream, "stream", true)
.setIfNotEmpty("tools", promptFormat.toFunctionsJsonObject(humanMessage))
.setIfContainsKey("tools", "tool_choice", humanMessage.getToolChoice())
.setIfNotNull("top_p", options.getTopP())
.setIfNotEmpty("stop", options.getStop())
.setIf(map -> !map.containsKey("tools") && options.getTemperature() > 0, "temperature", options.getTemperature())
.setIf(map -> !map.containsKey("tools") && options.getMaxTokens() != null, "max_tokens", options.getMaxTokens())
.toJSON();
}

View File

@ -92,14 +92,14 @@ public class QwenLlmUtil {
List<Message> messages = prompt.toMessages();
return Maps.of("model", config.getModel())
.put("input", Maps.of("messages", promptFormat.toMessagesJsonObject(messages)))
.put("parameters", Maps.of("result_format", "message")
.putIfNotEmpty("tools", promptFormat.toFunctionsJsonObject(messages.get(messages.size() - 1)))
.putIf(map -> !map.containsKey("tools") && options.getTemperature() > 0, "temperature", options.getTemperature())
.putIf(map -> !map.containsKey("tools") && options.getMaxTokens() != null, "max_tokens", options.getMaxTokens())
.putIfNotNull("top_p", options.getTopP())
.putIfNotNull("top_k", options.getTopK())
.putIfNotEmpty("stop", options.getStop())
.set("input", Maps.of("messages", promptFormat.toMessagesJsonObject(messages)))
.set("parameters", Maps.of("result_format", "message")
.setIfNotEmpty("tools", promptFormat.toFunctionsJsonObject(messages.get(messages.size() - 1)))
.setIf(map -> !map.containsKey("tools") && options.getTemperature() > 0, "temperature", options.getTemperature())
.setIf(map -> !map.containsKey("tools") && options.getMaxTokens() != null, "max_tokens", options.getMaxTokens())
.setIfNotNull("top_p", options.getTopP())
.setIfNotNull("top_k", options.getTopK())
.setIfNotEmpty("stop", options.getStop())
).toJSON();
}
@ -108,7 +108,7 @@ public class QwenLlmUtil {
List<String> list = new ArrayList<>();
list.add(text.getContent());
return Maps.of("model", config.getModel())
.put("input", Maps.of("texts", list))
.set("input", Maps.of("texts", list))
.toJSON();
}

View File

@ -57,8 +57,8 @@ public class SparkLlmUtil {
}
Maps builder = Maps.of("name", function.getName())
.put("description", function.getDescription())
.put("parameters", Maps.of("type", "object").put("properties", propertiesMap).put("required", requiredProperties));
.set("description", function.getDescription())
.set("parameters", Maps.of("type", "object").set("properties", propertiesMap).set("required", requiredProperties));
functionsJsonArray.add(builder);
}
}
@ -132,13 +132,13 @@ public class SparkLlmUtil {
List<Message> messages = prompt.toMessages();
Maps root = Maps.of("header", Maps.of("app_id", config.getAppId()).put("uid", UUID.randomUUID()));
root.put("parameter", Maps.of("chat", Maps.of("domain", getDomain(config.getVersion()))
.putIf(options.getTemperature() > 0, "temperature", options.getTemperature())
.putIf(options.getMaxTokens() != null, "max_tokens", options.getMaxTokens())
.putIfNotNull("top_k", options.getTopK())
.setIf(options.getTemperature() > 0, "temperature", options.getTemperature())
.setIf(options.getMaxTokens() != null, "max_tokens", options.getMaxTokens())
.setIfNotNull("top_k", options.getTopK())
)
);
root.put("payload", Maps.of("message", Maps.of("text", promptFormat.toMessagesJsonObject(messages)))
.putIfNotEmpty("functions", Maps.ofNotNull("text", promptFormat.toFunctionsJsonObject(messages.get(messages.size() - 1))))
.setIfNotEmpty("functions", Maps.ofNotNull("text", promptFormat.toFunctionsJsonObject(messages.get(messages.size() - 1))))
);
return JSON.toJSONString(root);
}
@ -192,9 +192,9 @@ public class SparkLlmUtil {
String text = Maps.of("messages", Collections.singletonList(Maps.of("content", document.getContent()).put("role", "user"))).toJSON();
String textBase64 = Base64.getEncoder().encodeToString(text.getBytes());
return Maps.of("header", Maps.of("app_id", config.getAppId()).put("uid", UUID.randomUUID()).put("status", 3))
.put("parameter", Maps.of("emb", Maps.of("domain", "para").put("feature", Maps.of("encoding", "utf8").put("compress", "raw").put("format", "plain"))))
.put("payload", Maps.of("messages", Maps.of("encoding", "utf8").put("compress", "raw").put("format", "json").put("status", 3).put("text", textBase64)))
return Maps.of("header", Maps.of("app_id", config.getAppId()).set("uid", UUID.randomUUID()).set("status", 3))
.set("parameter", Maps.of("emb", Maps.of("domain", "para").put("feature", Maps.of("encoding", "utf8").set("compress", "raw").set("format", "plain"))))
.set("payload", Maps.of("messages", Maps.of("encoding", "utf8").set("compress", "raw").set("format", "json").set("status", 3).set("text", textBase64)))
.toJSON();
}