mirror of
https://gitee.com/blackfox/geekai.git
synced 2024-12-05 05:37:41 +08:00
feat: XunFei ai mode api implements is ready
This commit is contained in:
parent
21c3a419a5
commit
9cbc6c91c4
@ -67,8 +67,10 @@ var ModelToTokens = map[string]int{
|
|||||||
"gpt-3.5-turbo-16k": 16384,
|
"gpt-3.5-turbo-16k": 16384,
|
||||||
"gpt-4": 8192,
|
"gpt-4": 8192,
|
||||||
"gpt-4-32k": 32768,
|
"gpt-4-32k": 32768,
|
||||||
"chatglm_pro": 32768,
|
"chatglm_pro": 32768, // 清华智普
|
||||||
"chatglm_std": 16384,
|
"chatglm_std": 16384,
|
||||||
"chatglm_lite": 4096,
|
"chatglm_lite": 4096,
|
||||||
"ernie_bot_turbo": 8192, // 文心一言
|
"ernie_bot_turbo": 8192, // 文心一言
|
||||||
|
"general": 8192, // 科大讯飞
|
||||||
|
"general2": 8192,
|
||||||
}
|
}
|
||||||
|
@ -36,6 +36,16 @@ func (wc *WsClient) Send(message []byte) error {
|
|||||||
return wc.Conn.WriteMessage(wc.mt, message)
|
return wc.Conn.WriteMessage(wc.mt, message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (wc *WsClient) SendJson(value interface{}) error {
|
||||||
|
wc.lock.Lock()
|
||||||
|
defer wc.lock.Unlock()
|
||||||
|
|
||||||
|
if wc.Closed {
|
||||||
|
return ErrConClosed
|
||||||
|
}
|
||||||
|
return wc.Conn.WriteJSON(value)
|
||||||
|
}
|
||||||
|
|
||||||
func (wc *WsClient) Receive() (int, []byte, error) {
|
func (wc *WsClient) Receive() (int, []byte, error) {
|
||||||
if wc.Closed {
|
if wc.Closed {
|
||||||
return 0, nil, ErrConClosed
|
return 0, nil, ErrConClosed
|
||||||
|
@ -39,7 +39,12 @@ type ChatHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, redis *redis.Client, service *mj.Service) *ChatHandler {
|
func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, redis *redis.Client, service *mj.Service) *ChatHandler {
|
||||||
h := ChatHandler{db: db, leveldb: levelDB, redis: redis, mjService: service}
|
h := ChatHandler{
|
||||||
|
db: db,
|
||||||
|
leveldb: levelDB,
|
||||||
|
redis: redis,
|
||||||
|
mjService: service,
|
||||||
|
}
|
||||||
h.App = app
|
h.App = app
|
||||||
return &h
|
return &h
|
||||||
}
|
}
|
||||||
@ -127,7 +132,11 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
client.Close()
|
client.Close()
|
||||||
h.App.ChatClients.Delete(sessionId)
|
h.App.ChatClients.Delete(sessionId)
|
||||||
|
cancelFunc := h.App.ReqCancelFunc.Get(sessionId)
|
||||||
|
if cancelFunc != nil {
|
||||||
|
cancelFunc()
|
||||||
h.App.ReqCancelFunc.Delete(sessionId)
|
h.App.ReqCancelFunc.Delete(sessionId)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -217,6 +226,9 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
}
|
}
|
||||||
req.Functions = functions
|
req.Functions = functions
|
||||||
}
|
}
|
||||||
|
case types.XunFei:
|
||||||
|
req.Temperature = h.App.ChatConfig.XunFei.Temperature
|
||||||
|
req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens
|
||||||
default:
|
default:
|
||||||
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
|
utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
|
||||||
utils.ReplyMessage(ws, "![](/images/wx.png)")
|
utils.ReplyMessage(ws, "![](/images/wx.png)")
|
||||||
@ -291,6 +303,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
|
|||||||
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||||
case types.Baidu:
|
case types.Baidu:
|
||||||
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||||
|
case types.XunFei:
|
||||||
|
return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
||||||
|
|
||||||
}
|
}
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||||
|
@ -1,22 +1,54 @@
|
|||||||
package chatimpl
|
package chatimpl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"chatplus/core/types"
|
"chatplus/core/types"
|
||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type xunFeiResp struct {
|
||||||
|
Header struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Sid string `json:"sid"`
|
||||||
|
Status int `json:"status"`
|
||||||
|
} `json:"header"`
|
||||||
|
Payload struct {
|
||||||
|
Choices struct {
|
||||||
|
Status int `json:"status"`
|
||||||
|
Seq int `json:"seq"`
|
||||||
|
Text []struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
} `json:"text"`
|
||||||
|
} `json:"choices"`
|
||||||
|
Usage struct {
|
||||||
|
Text struct {
|
||||||
|
QuestionTokens int `json:"question_tokens"`
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
} `json:"text"`
|
||||||
|
} `json:"usage"`
|
||||||
|
} `json:"payload"`
|
||||||
|
}
|
||||||
|
|
||||||
// 科大讯飞消息发送实现
|
// 科大讯飞消息发送实现
|
||||||
|
|
||||||
func (h *ChatHandler) sendXunFeiMessage(
|
func (h *ChatHandler) sendXunFeiMessage(
|
||||||
@ -29,80 +61,106 @@ func (h *ChatHandler) sendXunFeiMessage(
|
|||||||
prompt string,
|
prompt string,
|
||||||
ws *types.WsClient) error {
|
ws *types.WsClient) error {
|
||||||
promptCreatedAt := time.Now() // 记录提问时间
|
promptCreatedAt := time.Now() // 记录提问时间
|
||||||
start := time.Now()
|
|
||||||
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
|
var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
|
||||||
response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
|
if apiKey == "" {
|
||||||
logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
|
var key model.ApiKey
|
||||||
if err != nil {
|
res := h.db.Where("platform = ?", session.Model.Platform).Order("last_used_at ASC").First(&key)
|
||||||
if strings.Contains(err.Error(), "context canceled") {
|
if res.Error != nil {
|
||||||
logger.Info("用户取消了请求:", prompt)
|
|
||||||
return nil
|
|
||||||
} else if strings.Contains(err.Error(), "no available key") {
|
|
||||||
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
|
||||||
return nil
|
return nil
|
||||||
} else {
|
}
|
||||||
logger.Error(err)
|
// 更新 API KEY 的最后使用时间
|
||||||
|
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
|
||||||
|
apiKey = key.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.ReplyMessage(ws, ErrorMsg)
|
d := websocket.Dialer{
|
||||||
utils.ReplyMessage(ws, "![](/images/wx.png)")
|
HandshakeTimeout: 5 * time.Second,
|
||||||
return err
|
}
|
||||||
} else {
|
key := strings.Split(apiKey, "|")
|
||||||
defer response.Body.Close()
|
if len(key) != 3 {
|
||||||
|
utils.ReplyMessage(ws, "非法的 API KEY!")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var apiURL string
|
||||||
|
if req.Model == "generalv2" {
|
||||||
|
apiURL = strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", "v2.1", 1)
|
||||||
|
} else {
|
||||||
|
apiURL = strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", "v1.1", 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
|
||||||
|
//握手并建立websocket 连接
|
||||||
|
conn, resp, err := d.Dial(wsURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(readResp(resp) + err.Error())
|
||||||
|
utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
|
||||||
|
return nil
|
||||||
|
} else if resp.StatusCode != 101 {
|
||||||
|
utils.ReplyMessage(ws, "请求讯飞星火模型 API 失败:"+readResp(resp)+err.Error())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
data := buildRequest(key[0], req)
|
||||||
|
fmt.Printf("%+v", data)
|
||||||
|
fmt.Println(apiURL)
|
||||||
|
err = conn.WriteJSON(data)
|
||||||
|
if err != nil {
|
||||||
|
utils.ReplyMessage(ws, "发送消息失败:"+err.Error())
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
contentType := response.Header.Get("Content-Type")
|
|
||||||
if strings.Contains(contentType, "text/event-stream") {
|
|
||||||
replyCreatedAt := time.Now() // 记录回复时间
|
replyCreatedAt := time.Now() // 记录回复时间
|
||||||
// 循环读取 Chunk 消息
|
// 循环读取 Chunk 消息
|
||||||
var message = types.Message{}
|
var message = types.Message{}
|
||||||
var contents = make([]string, 0)
|
var contents = make([]string, 0)
|
||||||
var content string
|
var content string
|
||||||
scanner := bufio.NewScanner(response.Body)
|
for {
|
||||||
for scanner.Scan() {
|
_, msg, err := conn.ReadMessage()
|
||||||
line := scanner.Text()
|
|
||||||
if len(line) < 5 || strings.HasPrefix(line, "id:") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(line, "data:") {
|
|
||||||
content = line[5:]
|
|
||||||
}
|
|
||||||
|
|
||||||
var resp baiduResp
|
|
||||||
err := utils.JsonDecode(content, &resp)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("error with parse data line: ", err)
|
logger.Error("error with read message:", err)
|
||||||
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
|
utils.ReplyMessage(ws, fmt.Sprintf("**数据读取失败:%s**", err))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(contents) == 0 {
|
// 解析数据
|
||||||
|
var result xunFeiResp
|
||||||
|
err = json.Unmarshal(msg, &result)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("error with parsing JSON:", err)
|
||||||
|
utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Header.Code != 0 {
|
||||||
|
utils.ReplyMessage(ws, fmt.Sprintf("**请求 API 返回错误:%s**", result.Header.Message))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
content = result.Payload.Choices.Text[0].Content
|
||||||
|
contents = append(contents, content)
|
||||||
|
// 第一个结果
|
||||||
|
if result.Payload.Choices.Status == 0 {
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
|
||||||
}
|
}
|
||||||
utils.ReplyChunkMessage(ws, types.WsMessage{
|
utils.ReplyChunkMessage(ws, types.WsMessage{
|
||||||
Type: types.WsMiddle,
|
Type: types.WsMiddle,
|
||||||
Content: utils.InterfaceToString(resp.Result),
|
Content: utils.InterfaceToString(content),
|
||||||
})
|
})
|
||||||
contents = append(contents, resp.Result)
|
|
||||||
|
|
||||||
if resp.IsTruncated {
|
if result.Payload.Choices.Status == 2 { // 最终结果
|
||||||
utils.ReplyMessage(ws, "AI 输出异常中断")
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.IsEnd {
|
select {
|
||||||
break
|
case <-ctx.Done():
|
||||||
|
utils.ReplyMessage(ws, "**用户取消了生成指令!**")
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end for
|
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
if strings.Contains(err.Error(), "context canceled") {
|
|
||||||
logger.Info("用户取消了请求:", prompt)
|
|
||||||
} else {
|
|
||||||
logger.Error("信息读取出错:", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 消息发送成功
|
// 消息发送成功
|
||||||
@ -190,68 +248,74 @@ func (h *ChatHandler) sendXunFeiMessage(
|
|||||||
h.db.Create(&chatItem)
|
h.db.Create(&chatItem)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
body, err := io.ReadAll(response.Body)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error with reading response: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var res struct {
|
|
||||||
Code int `json:"error_code"`
|
|
||||||
Msg string `json:"error_msg"`
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(body, &res)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error with decode response: %v", err)
|
|
||||||
}
|
|
||||||
utils.ReplyMessage(ws, "请求百度文心大模型 API 失败:"+res.Msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *ChatHandler) getXunFeiToken(apiKey string) (string, error) {
|
// 构建 websocket 请求实体
|
||||||
ctx := context.Background()
|
func buildRequest(appid string, req types.ApiRequest) map[string]interface{} {
|
||||||
tokenString, err := h.redis.Get(ctx, apiKey).Result()
|
return map[string]interface{}{
|
||||||
if err == nil {
|
"header": map[string]interface{}{
|
||||||
return tokenString, nil
|
"app_id": appid,
|
||||||
|
},
|
||||||
|
"parameter": map[string]interface{}{
|
||||||
|
"chat": map[string]interface{}{
|
||||||
|
"domain": req.Model,
|
||||||
|
"temperature": float64(req.Temperature),
|
||||||
|
"top_k": int64(6),
|
||||||
|
"max_tokens": int64(req.MaxTokens),
|
||||||
|
"auditing": "default",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"payload": map[string]interface{}{
|
||||||
|
"message": map[string]interface{}{
|
||||||
|
"text": req.Messages,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
expr := time.Hour * 24 * 20 // access_token 有效期
|
// 创建鉴权 URL
|
||||||
key := strings.Split(apiKey, "|")
|
func assembleAuthUrl(hostURL string, apiKey, apiSecret string) (string, error) {
|
||||||
if len(key) != 2 {
|
ul, err := url.Parse(hostURL)
|
||||||
return "", fmt.Errorf("invalid api key: %s", apiKey)
|
|
||||||
}
|
|
||||||
url := fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?client_id=%s&client_secret=%s&grant_type=client_credentials", key[0], key[1])
|
|
||||||
client := &http.Client{}
|
|
||||||
req, err := http.NewRequest("POST", url, nil)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
req.Header.Add("Content-Type", "application/json")
|
|
||||||
req.Header.Add("Accept", "application/json")
|
|
||||||
|
|
||||||
res, err := client.Do(req)
|
date := time.Now().UTC().Format(time.RFC1123)
|
||||||
|
signString := []string{"host: " + ul.Host, "date: " + date, "GET " + ul.Path + " HTTP/1.1"}
|
||||||
|
//拼接签名字符串
|
||||||
|
signStr := strings.Join(signString, "\n")
|
||||||
|
sha := hmacWithSha256(signStr, apiSecret)
|
||||||
|
|
||||||
|
authUrl := fmt.Sprintf("hmac username=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey,
|
||||||
|
"hmac-sha256", "host date request-line", sha)
|
||||||
|
//将请求参数使用base64编码
|
||||||
|
authorization := base64.StdEncoding.EncodeToString([]byte(authUrl))
|
||||||
|
v := url.Values{}
|
||||||
|
v.Add("host", ul.Host)
|
||||||
|
v.Add("date", date)
|
||||||
|
v.Add("authorization", authorization)
|
||||||
|
//将编码后的字符串url encode后添加到url后面
|
||||||
|
return hostURL + "?" + v.Encode(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 sha256 签名
|
||||||
|
func hmacWithSha256(data, key string) string {
|
||||||
|
mac := hmac.New(sha256.New, []byte(key))
|
||||||
|
mac.Write([]byte(data))
|
||||||
|
encodeData := mac.Sum(nil)
|
||||||
|
return base64.StdEncoding.EncodeToString(encodeData)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 读取响应
|
||||||
|
func readResp(resp *http.Response) string {
|
||||||
|
if resp == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
b, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error with send request: %w", err)
|
panic(err)
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
return fmt.Sprintf("code=%d,body=%s", resp.StatusCode, string(b))
|
||||||
|
|
||||||
body, err := io.ReadAll(res.Body)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("error with read response: %w", err)
|
|
||||||
}
|
|
||||||
var r map[string]interface{}
|
|
||||||
err = json.Unmarshal(body, &r)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("error with parse response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if r["error"] != nil {
|
|
||||||
return "", fmt.Errorf("error with api response: %s", r["error_description"])
|
|
||||||
}
|
|
||||||
|
|
||||||
tokenString = fmt.Sprintf("%s", r["access_token"])
|
|
||||||
h.redis.Set(ctx, apiKey, tokenString, expr)
|
|
||||||
return tokenString, nil
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user