diff --git a/api/core/types/chat.go b/api/core/types/chat.go index f9f20f1..95f46c0 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -2,9 +2,9 @@ package types // ApiRequest API 请求实体 type ApiRequest struct { - Model string `json:"model"` + Model string `json:"model,omitempty"` // 兼容百度文心一言 Temperature float32 `json:"temperature"` - MaxTokens int `json:"max_tokens"` + MaxTokens int `json:"max_tokens,omitempty"` // 兼容百度文心一言 Stream bool `json:"stream"` Messages []interface{} `json:"messages,omitempty"` Prompt []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM diff --git a/api/core/types/config.go b/api/core/types/config.go index 34fdbdc..9e4d957 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -90,6 +90,7 @@ type Platform string const OpenAI = Platform("OpenAI") const Azure = Platform("Azure") const ChatGLM = Platform("ChatGLM") +const Baidu = Platform("Baidu") // UserChatConfig 用户的聊天配置 type UserChatConfig struct { diff --git a/api/handler/baidu_handler.go b/api/handler/baidu_handler.go new file mode 100644 index 0000000..ba37674 --- /dev/null +++ b/api/handler/baidu_handler.go @@ -0,0 +1,235 @@ +package handler + +import ( + "bufio" + "chatplus/core/types" + "chatplus/store/model" + "chatplus/store/vo" + "chatplus/utils" + "context" + "encoding/json" + "fmt" + "github.com/golang-jwt/jwt/v5" + "gorm.io/gorm" + "io" + "strings" + "time" + "unicode/utf8" +) + +// 将消息发送给百度文心一言大模型 API 并获取结果,通过 WebSocket 推送到客户端 +func (h *ChatHandler) sendBaiduMessage( + chatCtx []interface{}, + req types.ApiRequest, + userVo vo.User, + ctx context.Context, + session *types.ChatSession, + role model.ChatRole, + prompt string, + ws *types.WsClient) error { + promptCreatedAt := time.Now() // 记录提问时间 + start := time.Now() + var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] + response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) + logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) + if err != nil { + if strings.Contains(err.Error(), "context canceled") { + logger.Info("用户取消了请求:", prompt) + return nil + } else if strings.Contains(err.Error(), "no available key") { + utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") + return nil + } else { + logger.Error(err) + } + + utils.ReplyMessage(ws, ErrorMsg) + utils.ReplyMessage(ws, "![](/images/wx.png)") + return err + } else { + defer response.Body.Close() + } + + contentType := response.Header.Get("Content-Type") + if strings.Contains(contentType, "text/event-stream") { + replyCreatedAt := time.Now() // 记录回复时间 + // 循环读取 Chunk 消息 + var message = types.Message{} + var contents = make([]string, 0) + var event, content string + scanner := bufio.NewScanner(response.Body) + for scanner.Scan() { + line := scanner.Text() + if len(line) < 5 || strings.HasPrefix(line, "id:") { + continue + } + if strings.HasPrefix(line, "event:") { + event = line[6:] + continue + } + + if strings.HasPrefix(line, "data:") { + content = line[5:] + } + switch event { + case "add": + if len(contents) == 0 { + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + } + utils.ReplyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: utils.InterfaceToString(content), + }) + contents = append(contents, content) + case "finish": + break + case "error": + utils.ReplyMessage(ws, fmt.Sprintf("**调用 ChatGLM API 出错:%s**", content)) + break + case "interrupted": + utils.ReplyMessage(ws, "**调用 ChatGLM API 出错,当前输出被中断!**") + } + + } // end for + + if err := scanner.Err(); err != nil { + if strings.Contains(err.Error(), "context canceled") { + logger.Info("用户取消了请求:", prompt) + } else { + logger.Error("信息读取出错:", err) + } + } + + // 消息发送成功 + if len(contents) > 0 { + // 更新用户的对话次数 + if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" { + h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1)) + } + + if message.Role == "" { + message.Role = "assistant" + } + message.Content = strings.Join(contents, "") + useMsg := types.Message{Role: "user", Content: prompt} + + // 更新上下文消息,如果是调用函数则不需要更新上下文 + if h.App.ChatConfig.EnableContext { + chatCtx = append(chatCtx, useMsg) // 提问消息 + chatCtx = append(chatCtx, message) // 回复消息 + h.App.ChatContexts.Put(session.ChatId, chatCtx) + } + + // 追加聊天记录 + if h.App.ChatConfig.EnableHistory { + // for prompt + promptToken, err := utils.CalcTokens(prompt, req.Model) + if err != nil { + logger.Error(err) + } + historyUserMsg := model.HistoryMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.PromptMsg, + Icon: userVo.Avatar, + Content: prompt, + Tokens: promptToken, + UseContext: true, + } + historyUserMsg.CreatedAt = promptCreatedAt + historyUserMsg.UpdatedAt = promptCreatedAt + res := h.db.Save(&historyUserMsg) + if res.Error != nil { + logger.Error("failed to save prompt history message: ", res.Error) + } + + // for reply + // 计算本次对话消耗的总 token 数量 + replyToken, _ := utils.CalcTokens(message.Content, req.Model) + totalTokens := replyToken + getTotalTokens(req) + historyReplyMsg := model.HistoryMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.ReplyMsg, + Icon: role.Icon, + Content: message.Content, + Tokens: totalTokens, + UseContext: true, + } + historyReplyMsg.CreatedAt = replyCreatedAt + historyReplyMsg.UpdatedAt = replyCreatedAt + res = h.db.Create(&historyReplyMsg) + if res.Error != nil { + logger.Error("failed to save reply history message: ", res.Error) + } + // 更新用户信息 + h.db.Model(&model.User{}).Where("id = ?", userVo.Id). + UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens)) + } + + // 保存当前会话 + var chatItem model.ChatItem + res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) + if res.Error != nil { + chatItem.ChatId = session.ChatId + chatItem.UserId = session.UserId + chatItem.RoleId = role.Id + chatItem.ModelId = session.Model.Id + if utf8.RuneCountInString(prompt) > 30 { + chatItem.Title = string([]rune(prompt)[:30]) + "..." + } else { + chatItem.Title = prompt + } + h.db.Create(&chatItem) + } + } + } else { + body, err := io.ReadAll(response.Body) + if err != nil { + return fmt.Errorf("error with reading response: %v", err) + } + + var res struct { + Code int `json:"code"` + Success bool `json:"success"` + Msg string `json:"msg"` + } + err = json.Unmarshal(body, &res) + if err != nil { + return fmt.Errorf("error with decode response: %v", err) + } + if !res.Success { + utils.ReplyMessage(ws, "请求 ChatGLM 失败:"+res.Msg) + } + } + + return nil +} + +func (h *ChatHandler) getBaiduToken(apiKey string) (string, error) { + ctx := context.Background() + tokenString, err := h.redis.Get(ctx, apiKey).Result() + if err == nil { + return tokenString, nil + } + + expr := time.Hour * 2 + key := strings.Split(apiKey, ".") + if len(key) != 2 { + return "", fmt.Errorf("invalid api key: %s", apiKey) + } + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "api_key": key[0], + "timestamp": time.Now().Unix(), + "exp": time.Now().Add(expr).Add(time.Second * 10).Unix(), + }) + token.Header["alg"] = "HS256" + token.Header["sign_type"] = "SIGN" + delete(token.Header, "typ") + // Sign and get the complete encoded token as a string using the secret + tokenString, err = token.SignedString([]byte(key[1])) + h.redis.Set(ctx, apiKey, tokenString, expr) + return tokenString, err +} diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go index b9241b1..c29cfe6 100644 --- a/api/handler/chat_handler.go +++ b/api/handler/chat_handler.go @@ -196,7 +196,10 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio req.Temperature = h.App.ChatConfig.ChatGML.Temperature req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens break - default: + case types.Baidu: + req.Temperature = h.App.ChatConfig.OpenAI.Temperature + // TODO: 目前只支持 ERNIE-Bot-turbo 模型,如果是 ERNIE-Bot 模型则需要增加函数支持 + case types.OpenAI: req.Temperature = h.App.ChatConfig.OpenAI.Temperature req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens var functions = make([]types.Function, 0) @@ -207,6 +210,10 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio functions = append(functions, f) } req.Functions = functions + default: + utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!") + utils.ReplyMessage(ws, "![](/images/wx.png)") + return nil } // 加载聊天上下文 diff --git a/web/src/views/admin/ApiKey.vue b/web/src/views/admin/ApiKey.vue index 768bedd..1ad7cd6 100644 --- a/web/src/views/admin/ApiKey.vue +++ b/web/src/views/admin/ApiKey.vue @@ -39,12 +39,16 @@ + - {{ item }} + {{ item.name }} @@ -82,7 +86,13 @@ const rules = reactive({ const loading = ref(true) const formRef = ref(null) const title = ref("") -const platforms = ref(["Azure", "OpenAI", "ChatGLM"]) +const platforms = ref([ + {name: "【清华智普】ChatGLM", value: "ChatGLM"}, + {name: "【百度】文心一言", value: "Baidu"}, + {name: "【微软】Azure", value: "Azure"}, + {name: "【OpenAI】ChatGPT", value: "OpenAI"}, + +]) // 获取数据 httpGet('/api/admin/apikey/list').then((res) => { diff --git a/web/src/views/admin/SysConfig.vue b/web/src/views/admin/SysConfig.vue index a273fe4..663153a 100644 --- a/web/src/views/admin/SysConfig.vue +++ b/web/src/views/admin/SysConfig.vue @@ -90,13 +90,25 @@ - +
值越大 AI 回答越发散,值越小回答越保守,建议保持默认值
+ 文心一言 + + + + + +
值越大 AI 回答越发散,值越小回答越保守,建议保持默认值
+
+ + + + 保存 @@ -116,7 +128,8 @@ const system = ref({models: []}) const chat = ref({ open_ai: {api_url: "", temperature: 1, max_tokens: 1024}, azure: {api_url: "", temperature: 1, max_tokens: 1024}, - chat_gml: {api_url: "", temperature: 1, max_tokens: 1024}, + chat_gml: {api_url: "", temperature: 0.95, max_tokens: 1024}, + baidu: {api_url: "", temperature: 0.95, max_tokens: 1024}, context_deep: 0, enable_context: true, enable_history: true,