feat: XunFei ai mode api implements is ready

This commit is contained in:
RockYang 2023-10-11 18:17:03 +08:00
parent 21c3a419a5
commit 9cbc6c91c4
4 changed files with 294 additions and 204 deletions

View File

@ -67,8 +67,10 @@ var ModelToTokens = map[string]int{
"gpt-3.5-turbo-16k": 16384, "gpt-3.5-turbo-16k": 16384,
"gpt-4": 8192, "gpt-4": 8192,
"gpt-4-32k": 32768, "gpt-4-32k": 32768,
"chatglm_pro": 32768, "chatglm_pro": 32768, // 清华智普
"chatglm_std": 16384, "chatglm_std": 16384,
"chatglm_lite": 4096, "chatglm_lite": 4096,
"ernie_bot_turbo": 8192, // 文心一言 "ernie_bot_turbo": 8192, // 文心一言
"general": 8192, // 科大讯飞
"general2": 8192,
} }

View File

@ -36,6 +36,16 @@ func (wc *WsClient) Send(message []byte) error {
return wc.Conn.WriteMessage(wc.mt, message) return wc.Conn.WriteMessage(wc.mt, message)
} }
func (wc *WsClient) SendJson(value interface{}) error {
wc.lock.Lock()
defer wc.lock.Unlock()
if wc.Closed {
return ErrConClosed
}
return wc.Conn.WriteJSON(value)
}
func (wc *WsClient) Receive() (int, []byte, error) { func (wc *WsClient) Receive() (int, []byte, error) {
if wc.Closed { if wc.Closed {
return 0, nil, ErrConClosed return 0, nil, ErrConClosed

View File

@ -39,7 +39,12 @@ type ChatHandler struct {
} }
func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, redis *redis.Client, service *mj.Service) *ChatHandler { func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, redis *redis.Client, service *mj.Service) *ChatHandler {
h := ChatHandler{db: db, leveldb: levelDB, redis: redis, mjService: service} h := ChatHandler{
db: db,
leveldb: levelDB,
redis: redis,
mjService: service,
}
h.App = app h.App = app
return &h return &h
} }
@ -127,7 +132,11 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
logger.Error(err) logger.Error(err)
client.Close() client.Close()
h.App.ChatClients.Delete(sessionId) h.App.ChatClients.Delete(sessionId)
h.App.ReqCancelFunc.Delete(sessionId) cancelFunc := h.App.ReqCancelFunc.Get(sessionId)
if cancelFunc != nil {
cancelFunc()
h.App.ReqCancelFunc.Delete(sessionId)
}
return return
} }
@ -217,6 +226,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
} }
req.Functions = functions req.Functions = functions
} }
case types.XunFei:
req.Temperature = h.App.ChatConfig.XunFei.Temperature
req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens
default: default:
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!") utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
utils.ReplyMessage(ws, "![](/images/wx.png)") utils.ReplyMessage(ws, "![](/images/wx.png)")
@ -291,6 +303,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.Baidu: case types.Baidu:
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
case types.XunFei:
return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
} }
utils.ReplyChunkMessage(ws, types.WsMessage{ utils.ReplyChunkMessage(ws, types.WsMessage{

View File

@ -1,22 +1,54 @@
package chatimpl package chatimpl
import ( import (
"bufio"
"chatplus/core/types" "chatplus/core/types"
"chatplus/store/model" "chatplus/store/model"
"chatplus/store/vo" "chatplus/store/vo"
"chatplus/utils" "chatplus/utils"
"context" "context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gorilla/websocket"
"gorm.io/gorm" "gorm.io/gorm"
"io" "io"
"net/http" "net/http"
"net/url"
"strings" "strings"
"time" "time"
"unicode/utf8" "unicode/utf8"
) )
type xunFeiResp struct {
Header struct {
Code int `json:"code"`
Message string `json:"message"`
Sid string `json:"sid"`
Status int `json:"status"`
} `json:"header"`
Payload struct {
Choices struct {
Status int `json:"status"`
Seq int `json:"seq"`
Text []struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
} `json:"text"`
} `json:"choices"`
Usage struct {
Text struct {
QuestionTokens int `json:"question_tokens"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"text"`
} `json:"usage"`
} `json:"payload"`
}
// 科大讯飞消息发送实现 // 科大讯飞消息发送实现
func (h *ChatHandler) sendXunFeiMessage( func (h *ChatHandler) sendXunFeiMessage(
@ -29,229 +61,261 @@ func (h *ChatHandler) sendXunFeiMessage(
prompt string, prompt string,
ws *types.WsClient) error { ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间 promptCreatedAt := time.Now() // 记录提问时间
start := time.Now()
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) if apiKey == "" {
logger.Info("HTTP请求完成耗时", time.Now().Sub(start)) var key model.ApiKey
if err != nil { res := h.db.Where("platform = ?", session.Model.Platform).Order("last_used_at ASC").First(&key)
if strings.Contains(err.Error(), "context canceled") { if res.Error != nil {
logger.Info("用户取消了请求:", prompt)
return nil
} else if strings.Contains(err.Error(), "no available key") {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员") utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil return nil
} else {
logger.Error(err)
} }
// 更新 API KEY 的最后使用时间
utils.ReplyMessage(ws, ErrorMsg) h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
utils.ReplyMessage(ws, "![](/images/wx.png)") apiKey = key.Value
return err
} else {
defer response.Body.Close()
} }
contentType := response.Header.Get("Content-Type") d := websocket.Dialer{
if strings.Contains(contentType, "text/event-stream") { HandshakeTimeout: 5 * time.Second,
replyCreatedAt := time.Now() // 记录回复时间 }
// 循环读取 Chunk 消息 key := strings.Split(apiKey, "|")
var message = types.Message{} if len(key) != 3 {
var contents = make([]string, 0) utils.ReplyMessage(ws, "非法的 API KEY")
var content string return nil
scanner := bufio.NewScanner(response.Body) }
for scanner.Scan() {
line := scanner.Text()
if len(line) < 5 || strings.HasPrefix(line, "id:") {
continue
}
if strings.HasPrefix(line, "data:") { var apiURL string
content = line[5:] if req.Model == "generalv2" {
} apiURL = strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", "v2.1", 1)
var resp baiduResp
err := utils.JsonDecode(content, &resp)
if err != nil {
logger.Error("error with parse data line: ", err)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
break
}
if len(contents) == 0 {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(resp.Result),
})
contents = append(contents, resp.Result)
if resp.IsTruncated {
utils.ReplyMessage(ws, "AI 输出异常中断")
break
}
if resp.IsEnd {
break
}
} // end for
if err := scanner.Err(); err != nil {
if strings.Contains(err.Error(), "context canceled") {
logger.Info("用户取消了请求:", prompt)
} else {
logger.Error("信息读取出错:", err)
}
}
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
}
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: prompt,
Tokens: promptToken,
UseContext: true,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
}
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
}
}
} else { } else {
body, err := io.ReadAll(response.Body) apiURL = strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", "v1.1", 1)
}
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
//握手并建立websocket 连接
conn, resp, err := d.Dial(wsURL, nil)
if err != nil {
logger.Error(readResp(resp) + err.Error())
utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
return nil
} else if resp.StatusCode != 101 {
utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
return nil
}
data := buildRequest(key[0], req)
fmt.Printf("%+v", data)
fmt.Println(apiURL)
err = conn.WriteJSON(data)
if err != nil {
utils.ReplyMessage(ws, "发送消息失败:"+err.Error())
return nil
}
replyCreatedAt := time.Now() // 记录回复时间
// 循环读取 Chunk 消息
var message = types.Message{}
var contents = make([]string, 0)
var content string
for {
_, msg, err := conn.ReadMessage()
if err != nil { if err != nil {
return fmt.Errorf("error with reading response: %v", err) logger.Error("error with read message:", err)
utils.ReplyMessage(ws, fmt.Sprintf("**数据读取失败:%s**", err))
break
} }
var res struct { // 解析数据
Code int `json:"error_code"` var result xunFeiResp
Msg string `json:"error_msg"` err = json.Unmarshal(msg, &result)
}
err = json.Unmarshal(body, &res)
if err != nil { if err != nil {
return fmt.Errorf("error with decode response: %v", err) logger.Error("error with parsing JSON:", err)
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
return nil
}
if result.Header.Code != 0 {
utils.ReplyMessage(ws, fmt.Sprintf("**请求 API 返回错误:%s**", result.Header.Message))
return nil
}
content = result.Payload.Choices.Text[0].Content
contents = append(contents, content)
// 第一个结果
if result.Payload.Choices.Status == 0 {
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
}
utils.ReplyChunkMessage(ws, types.WsMessage{
Type: types.WsMiddle,
Content: utils.InterfaceToString(content),
})
if result.Payload.Choices.Status == 2 { // 最终结果
break
}
select {
case <-ctx.Done():
utils.ReplyMessage(ws, "**用户取消了生成指令!**")
return nil
default:
continue
}
}
// 消息发送成功
if len(contents) > 0 {
// 更新用户的对话次数
if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1))
}
if message.Role == "" {
message.Role = "assistant"
}
message.Content = strings.Join(contents, "")
useMsg := types.Message{Role: "user", Content: prompt}
// 更新上下文消息,如果是调用函数则不需要更新上下文
if h.App.ChatConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
if h.App.ChatConfig.EnableHistory {
// for prompt
promptToken, err := utils.CalcTokens(prompt, req.Model)
if err != nil {
logger.Error(err)
}
historyUserMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.PromptMsg,
Icon: userVo.Avatar,
Content: prompt,
Tokens: promptToken,
UseContext: true,
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
// for reply
// 计算本次对话消耗的总 token 数量
replyToken, _ := utils.CalcTokens(message.Content, req.Model)
totalTokens := replyToken + getTotalTokens(req)
historyReplyMsg := model.HistoryMessage{
UserId: userVo.Id,
ChatId: session.ChatId,
RoleId: role.Id,
Type: types.ReplyMsg,
Icon: role.Icon,
Content: message.Content,
Tokens: totalTokens,
UseContext: true,
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
// 更新用户信息
h.db.Model(&model.User{}).Where("id = ?", userVo.Id).
UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens))
}
// 保存当前会话
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
chatItem.RoleId = role.Id
chatItem.ModelId = session.Model.Id
if utf8.RuneCountInString(prompt) > 30 {
chatItem.Title = string([]rune(prompt)[:30]) + "..."
} else {
chatItem.Title = prompt
}
h.db.Create(&chatItem)
} }
utils.ReplyMessage(ws, "请求百度文心大模型 API 失败:"+res.Msg)
} }
return nil return nil
} }
func (h *ChatHandler) getXunFeiToken(apiKey string) (string, error) { // 构建 websocket 请求实体
ctx := context.Background() func buildRequest(appid string, req types.ApiRequest) map[string]interface{} {
tokenString, err := h.redis.Get(ctx, apiKey).Result() return map[string]interface{}{
if err == nil { "header": map[string]interface{}{
return tokenString, nil "app_id": appid,
},
"parameter": map[string]interface{}{
"chat": map[string]interface{}{
"domain": req.Model,
"temperature": float64(req.Temperature),
"top_k": int64(6),
"max_tokens": int64(req.MaxTokens),
"auditing": "default",
},
},
"payload": map[string]interface{}{
"message": map[string]interface{}{
"text": req.Messages,
},
},
} }
}
expr := time.Hour * 24 * 20 // access_token 有效期 // 创建鉴权 URL
key := strings.Split(apiKey, "|") func assembleAuthUrl(hostURL string, apiKey, apiSecret string) (string, error) {
if len(key) != 2 { ul, err := url.Parse(hostURL)
return "", fmt.Errorf("invalid api key: %s", apiKey)
}
url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", key[0], key[1])
client := &http.Client{}
req, err := http.NewRequest("POST", url, nil)
if err != nil { if err != nil {
return "", err return "", err
} }
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := client.Do(req) date := time.Now().UTC().Format(time.RFC1123)
if err != nil { signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
return "", fmt.Errorf("error with send request: %w", err) //拼接签名字符串
} signStr := strings.Join(signString, "\n")
defer res.Body.Close() sha := hmacWithSha256(signStr, apiSecret)
body, err := io.ReadAll(res.Body) authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
if err != nil { "hmac-sha256", "host date request-line", sha)
return "", fmt.Errorf("error with read response: %w", err) //将请求参数使用base64编码
} authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
var r map[string]interface{} v := url.Values{}
err = json.Unmarshal(body, &r) v.Add("host", ul.Host)
if err != nil { v.Add("date", date)
return "", fmt.Errorf("error with parse response: %w", err) v.Add("authorization", authorization)
} //将编码后的字符串url encode后添加到url后面
return hostURL + "?" + v.Encode(), nil
if r["error"] != nil { }
return "", fmt.Errorf("error with api response: %s", r["error_description"])
} // 使用 sha256 签名
func hmacWithSha256(data, key string) string {
tokenString = fmt.Sprintf("%s", r["access_token"]) mac := hmac.New(sha256.New, []byte(key))
h.redis.Set(ctx, apiKey, tokenString, expr) mac.Write([]byte(data))
return tokenString, nil encodeData := mac.Sum(nil)
return base64.StdEncoding.EncodeToString(encodeData)
}
// 读取响应
func readResp(resp *http.Response) string {
if resp == nil {
return ""
}
b, err := io.ReadAll(resp.Body)
if err != nil {
panic(err)
}
return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b))
} }