diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 88f6233..40987c0 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -67,8 +67,10 @@ var ModelToTokens = map[string]int{ "gpt-3.5-turbo-16k": 16384, "gpt-4": 8192, "gpt-4-32k": 32768, - "chatglm_pro": 32768, + "chatglm_pro": 32768, // 清华智普 "chatglm_std": 16384, "chatglm_lite": 4096, "ernie_bot_turbo": 8192, // 文心一言 + "general": 8192, // 科大讯飞 + "general2": 8192, } diff --git a/api/core/types/client.go b/api/core/types/client.go index d1c80cd..6988a60 100644 --- a/api/core/types/client.go +++ b/api/core/types/client.go @@ -36,6 +36,16 @@ func (wc *WsClient) Send(message []byte) error { return wc.Conn.WriteMessage(wc.mt, message) } +func (wc *WsClient) SendJson(value interface{}) error { + wc.lock.Lock() + defer wc.lock.Unlock() + + if wc.Closed { + return ErrConClosed + } + return wc.Conn.WriteJSON(value) +} + func (wc *WsClient) Receive() (int, []byte, error) { if wc.Closed { return 0, nil, ErrConClosed diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 49b51b8..4defda9 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -39,7 +39,12 @@ type ChatHandler struct { } func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, redis *redis.Client, service *mj.Service) *ChatHandler { - h := ChatHandler{db: db, leveldb: levelDB, redis: redis, mjService: service} + h := ChatHandler{ + db: db, + leveldb: levelDB, + redis: redis, + mjService: service, + } h.App = app return &h } @@ -127,7 +132,11 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { logger.Error(err) client.Close() h.App.ChatClients.Delete(sessionId) - h.App.ReqCancelFunc.Delete(sessionId) + cancelFunc := h.App.ReqCancelFunc.Get(sessionId) + if cancelFunc != nil { + cancelFunc() + h.App.ReqCancelFunc.Delete(sessionId) + } return } @@ -217,6 +226,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio } req.Functions = functions } + case types.XunFei: + req.Temperature = h.App.ChatConfig.XunFei.Temperature + req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens default: utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!") utils.ReplyMessage(ws, "![](/images/wx.png)") @@ -291,6 +303,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) case types.Baidu: return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) + case types.XunFei: + return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) } utils.ReplyChunkMessage(ws, types.WsMessage{ diff --git a/api/handler/chatimpl/xunfei_handler.go b/api/handler/chatimpl/xunfei_handler.go index 3735dd6..156de4b 100644 --- a/api/handler/chatimpl/xunfei_handler.go +++ b/api/handler/chatimpl/xunfei_handler.go @@ -1,22 +1,54 @@ package chatimpl import ( - "bufio" "chatplus/core/types" "chatplus/store/model" "chatplus/store/vo" "chatplus/utils" "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" "encoding/json" "fmt" + "github.com/gorilla/websocket" "gorm.io/gorm" "io" "net/http" + "net/url" "strings" "time" "unicode/utf8" ) +type xunFeiResp struct { + Header struct { + Code int `json:"code"` + Message string `json:"message"` + Sid string `json:"sid"` + Status int `json:"status"` + } `json:"header"` + Payload struct { + Choices struct { + Status int `json:"status"` + Seq int `json:"seq"` + Text []struct { + Content string `json:"content"` + Role string `json:"role"` + Index int `json:"index"` + } `json:"text"` + } `json:"choices"` + Usage struct { + Text struct { + QuestionTokens int `json:"question_tokens"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"text"` + } `json:"usage"` + } `json:"payload"` +} + // 科大讯飞消息发送实现 func (h *ChatHandler) sendXunFeiMessage( @@ -29,229 +61,261 @@ func (h *ChatHandler) sendXunFeiMessage( prompt string, ws *types.WsClient) error { promptCreatedAt := time.Now() // 记录提问时间 - start := time.Now() var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] - response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) - logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) - if err != nil { - if strings.Contains(err.Error(), "context canceled") { - logger.Info("用户取消了请求:", prompt) - return nil - } else if strings.Contains(err.Error(), "no available key") { + if apiKey == "" { + var key model.ApiKey + res := h.db.Where("platform = ?", session.Model.Platform).Order("last_used_at ASC").First(&key) + if res.Error != nil { utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") return nil - } else { - logger.Error(err) } - - utils.ReplyMessage(ws, ErrorMsg) - utils.ReplyMessage(ws, "![](/images/wx.png)") - return err - } else { - defer response.Body.Close() + // 更新 API KEY 的最后使用时间 + h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix()) + apiKey = key.Value } - contentType := response.Header.Get("Content-Type") - if strings.Contains(contentType, "text/event-stream") { - replyCreatedAt := time.Now() // 记录回复时间 - // 循环读取 Chunk 消息 - var message = types.Message{} - var contents = make([]string, 0) - var content string - scanner := bufio.NewScanner(response.Body) - for scanner.Scan() { - line := scanner.Text() - if len(line) < 5 || strings.HasPrefix(line, "id:") { - continue - } + d := websocket.Dialer{ + HandshakeTimeout: 5 * time.Second, + } + key := strings.Split(apiKey, "|") + if len(key) != 3 { + utils.ReplyMessage(ws, "非法的 API KEY!") + return nil + } - if strings.HasPrefix(line, "data:") { - content = line[5:] - } - - var resp baiduResp - err := utils.JsonDecode(content, &resp) - if err != nil { - logger.Error("error with parse data line: ", err) - utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err)) - break - } - - if len(contents) == 0 { - utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) - } - utils.ReplyChunkMessage(ws, types.WsMessage{ - Type: types.WsMiddle, - Content: utils.InterfaceToString(resp.Result), - }) - contents = append(contents, resp.Result) - - if resp.IsTruncated { - utils.ReplyMessage(ws, "AI 输出异常中断") - break - } - - if resp.IsEnd { - break - } - - } // end for - - if err := scanner.Err(); err != nil { - if strings.Contains(err.Error(), "context canceled") { - logger.Info("用户取消了请求:", prompt) - } else { - logger.Error("信息读取出错:", err) - } - } - - // 消息发送成功 - if len(contents) > 0 { - // 更新用户的对话次数 - if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" { - h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1)) - } - - if message.Role == "" { - message.Role = "assistant" - } - message.Content = strings.Join(contents, "") - useMsg := types.Message{Role: "user", Content: prompt} - - // 更新上下文消息,如果是调用函数则不需要更新上下文 - if h.App.ChatConfig.EnableContext { - chatCtx = append(chatCtx, useMsg) // 提问消息 - chatCtx = append(chatCtx, message) // 回复消息 - h.App.ChatContexts.Put(session.ChatId, chatCtx) - } - - // 追加聊天记录 - if h.App.ChatConfig.EnableHistory { - // for prompt - promptToken, err := utils.CalcTokens(prompt, req.Model) - if err != nil { - logger.Error(err) - } - historyUserMsg := model.HistoryMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.PromptMsg, - Icon: userVo.Avatar, - Content: prompt, - Tokens: promptToken, - UseContext: true, - } - historyUserMsg.CreatedAt = promptCreatedAt - historyUserMsg.UpdatedAt = promptCreatedAt - res := h.db.Save(&historyUserMsg) - if res.Error != nil { - logger.Error("failed to save prompt history message: ", res.Error) - } - - // for reply - // 计算本次对话消耗的总 token 数量 - replyToken, _ := utils.CalcTokens(message.Content, req.Model) - totalTokens := replyToken + getTotalTokens(req) - historyReplyMsg := model.HistoryMessage{ - UserId: userVo.Id, - ChatId: session.ChatId, - RoleId: role.Id, - Type: types.ReplyMsg, - Icon: role.Icon, - Content: message.Content, - Tokens: totalTokens, - UseContext: true, - } - historyReplyMsg.CreatedAt = replyCreatedAt - historyReplyMsg.UpdatedAt = replyCreatedAt - res = h.db.Create(&historyReplyMsg) - if res.Error != nil { - logger.Error("failed to save reply history message: ", res.Error) - } - // 更新用户信息 - h.db.Model(&model.User{}).Where("id = ?", userVo.Id). - UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens)) - } - - // 保存当前会话 - var chatItem model.ChatItem - res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) - if res.Error != nil { - chatItem.ChatId = session.ChatId - chatItem.UserId = session.UserId - chatItem.RoleId = role.Id - chatItem.ModelId = session.Model.Id - if utf8.RuneCountInString(prompt) > 30 { - chatItem.Title = string([]rune(prompt)[:30]) + "..." - } else { - chatItem.Title = prompt - } - h.db.Create(&chatItem) - } - } + var apiURL string + if req.Model == "generalv2" { + apiURL = strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", "v2.1", 1) } else { - body, err := io.ReadAll(response.Body) + apiURL = strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", "v1.1", 1) + } + + wsURL, err := assembleAuthUrl(apiURL, key[1], key[2]) + //握手并建立websocket 连接 + conn, resp, err := d.Dial(wsURL, nil) + if err != nil { + logger.Error(readResp(resp) + err.Error()) + utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error()) + return nil + } else if resp.StatusCode != 101 { + utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error()) + return nil + } + + data := buildRequest(key[0], req) + fmt.Printf("%+v", data) + fmt.Println(apiURL) + err = conn.WriteJSON(data) + if err != nil { + utils.ReplyMessage(ws, "发送消息失败:"+err.Error()) + return nil + } + + replyCreatedAt := time.Now() // 记录回复时间 + // 循环读取 Chunk 消息 + var message = types.Message{} + var contents = make([]string, 0) + var content string + for { + _, msg, err := conn.ReadMessage() if err != nil { - return fmt.Errorf("error with reading response: %v", err) + logger.Error("error with read message:", err) + utils.ReplyMessage(ws, fmt.Sprintf("**数据读取失败:%s**", err)) + break } - var res struct { - Code int `json:"error_code"` - Msg string `json:"error_msg"` - } - err = json.Unmarshal(body, &res) + // 解析数据 + var result xunFeiResp + err = json.Unmarshal(msg, &result) if err != nil { - return fmt.Errorf("error with decode response: %v", err) + logger.Error("error with parsing JSON:", err) + utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err)) + return nil + } + + if result.Header.Code != 0 { + utils.ReplyMessage(ws, fmt.Sprintf("**请求 API 返回错误:%s**", result.Header.Message)) + return nil + } + + content = result.Payload.Choices.Text[0].Content + contents = append(contents, content) + // 第一个结果 + if result.Payload.Choices.Status == 0 { + utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart}) + } + utils.ReplyChunkMessage(ws, types.WsMessage{ + Type: types.WsMiddle, + Content: utils.InterfaceToString(content), + }) + + if result.Payload.Choices.Status == 2 { // 最终结果 + break + } + + select { + case <-ctx.Done(): + utils.ReplyMessage(ws, "**用户取消了生成指令!**") + return nil + default: + continue + } + + } + + // 消息发送成功 + if len(contents) > 0 { + // 更新用户的对话次数 + if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" { + h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1)) + } + + if message.Role == "" { + message.Role = "assistant" + } + message.Content = strings.Join(contents, "") + useMsg := types.Message{Role: "user", Content: prompt} + + // 更新上下文消息,如果是调用函数则不需要更新上下文 + if h.App.ChatConfig.EnableContext { + chatCtx = append(chatCtx, useMsg) // 提问消息 + chatCtx = append(chatCtx, message) // 回复消息 + h.App.ChatContexts.Put(session.ChatId, chatCtx) + } + + // 追加聊天记录 + if h.App.ChatConfig.EnableHistory { + // for prompt + promptToken, err := utils.CalcTokens(prompt, req.Model) + if err != nil { + logger.Error(err) + } + historyUserMsg := model.HistoryMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.PromptMsg, + Icon: userVo.Avatar, + Content: prompt, + Tokens: promptToken, + UseContext: true, + } + historyUserMsg.CreatedAt = promptCreatedAt + historyUserMsg.UpdatedAt = promptCreatedAt + res := h.db.Save(&historyUserMsg) + if res.Error != nil { + logger.Error("failed to save prompt history message: ", res.Error) + } + + // for reply + // 计算本次对话消耗的总 token 数量 + replyToken, _ := utils.CalcTokens(message.Content, req.Model) + totalTokens := replyToken + getTotalTokens(req) + historyReplyMsg := model.HistoryMessage{ + UserId: userVo.Id, + ChatId: session.ChatId, + RoleId: role.Id, + Type: types.ReplyMsg, + Icon: role.Icon, + Content: message.Content, + Tokens: totalTokens, + UseContext: true, + } + historyReplyMsg.CreatedAt = replyCreatedAt + historyReplyMsg.UpdatedAt = replyCreatedAt + res = h.db.Create(&historyReplyMsg) + if res.Error != nil { + logger.Error("failed to save reply history message: ", res.Error) + } + // 更新用户信息 + h.db.Model(&model.User{}).Where("id = ?", userVo.Id). + UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens)) + } + + // 保存当前会话 + var chatItem model.ChatItem + res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem) + if res.Error != nil { + chatItem.ChatId = session.ChatId + chatItem.UserId = session.UserId + chatItem.RoleId = role.Id + chatItem.ModelId = session.Model.Id + if utf8.RuneCountInString(prompt) > 30 { + chatItem.Title = string([]rune(prompt)[:30]) + "..." + } else { + chatItem.Title = prompt + } + h.db.Create(&chatItem) } - utils.ReplyMessage(ws, "请求百度文心大模型 API 失败:"+res.Msg) } return nil } -func (h *ChatHandler) getXunFeiToken(apiKey string) (string, error) { - ctx := context.Background() - tokenString, err := h.redis.Get(ctx, apiKey).Result() - if err == nil { - return tokenString, nil +// 构建 websocket 请求实体 +func buildRequest(appid string, req types.ApiRequest) map[string]interface{} { + return map[string]interface{}{ + "header": map[string]interface{}{ + "app_id": appid, + }, + "parameter": map[string]interface{}{ + "chat": map[string]interface{}{ + "domain": req.Model, + "temperature": float64(req.Temperature), + "top_k": int64(6), + "max_tokens": int64(req.MaxTokens), + "auditing": "default", + }, + }, + "payload": map[string]interface{}{ + "message": map[string]interface{}{ + "text": req.Messages, + }, + }, } +} - expr := time.Hour * 24 * 20 // access_token 有效期 - key := strings.Split(apiKey, "|") - if len(key) != 2 { - return "", fmt.Errorf("invalid api key: %s", apiKey) - } - url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", key[0], key[1]) - client := &http.Client{} - req, err := http.NewRequest("POST", url, nil) +// 创建鉴权 URL +func assembleAuthUrl(hostURL string, apiKey, apiSecret string) (string, error) { + ul, err := url.Parse(hostURL) if err != nil { return "", err } - req.Header.Add("Content-Type", "application/json") - req.Header.Add("Accept", "application/json") - res, err := client.Do(req) - if err != nil { - return "", fmt.Errorf("error with send request: %w", err) - } - defer res.Body.Close() + date := time.Now().UTC().Format(time.RFC1123) + signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"} + //拼接签名字符串 + signStr := strings.Join(signString, "\n") + sha := hmacWithSha256(signStr, apiSecret) - body, err := io.ReadAll(res.Body) - if err != nil { - return "", fmt.Errorf("error with read response: %w", err) - } - var r map[string]interface{} - err = json.Unmarshal(body, &r) - if err != nil { - return "", fmt.Errorf("error with parse response: %w", err) - } - - if r["error"] != nil { - return "", fmt.Errorf("error with api response: %s", r["error_description"]) - } - - tokenString = fmt.Sprintf("%s", r["access_token"]) - h.redis.Set(ctx, apiKey, tokenString, expr) - return tokenString, nil + authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, + "hmac-sha256", "host date request-line", sha) + //将请求参数使用base64编码 + authorization := base64.StdEncoding.EncodeToString([]byte(authUrl)) + v := url.Values{} + v.Add("host", ul.Host) + v.Add("date", date) + v.Add("authorization", authorization) + //将编码后的字符串url encode后添加到url后面 + return hostURL + "?" + v.Encode(), nil +} + +// 使用 sha256 签名 +func hmacWithSha256(data, key string) string { + mac := hmac.New(sha256.New, []byte(key)) + mac.Write([]byte(data)) + encodeData := mac.Sum(nil) + return base64.StdEncoding.EncodeToString(encodeData) +} + +// 读取响应 +func readResp(resp *http.Response) string { + if resp == nil { + return "" + } + b, err := io.ReadAll(resp.Body) + if err != nil { + panic(err) + } + return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b)) }