2023-06-15 09:41:30 +08:00
|
|
|
|
package handler
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"bytes"
|
|
|
|
|
"chatplus/core"
|
|
|
|
|
"chatplus/core/types"
|
2023-08-14 07:09:52 +08:00
|
|
|
|
"chatplus/store"
|
2023-06-15 09:41:30 +08:00
|
|
|
|
"chatplus/store/model"
|
|
|
|
|
"chatplus/store/vo"
|
|
|
|
|
"chatplus/utils"
|
|
|
|
|
"chatplus/utils/resp"
|
|
|
|
|
"context"
|
|
|
|
|
"encoding/json"
|
|
|
|
|
"errors"
|
|
|
|
|
"fmt"
|
2023-09-04 06:43:15 +08:00
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
|
"github.com/go-redis/redis/v8"
|
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
|
|
"gorm.io/gorm"
|
2023-06-15 09:41:30 +08:00
|
|
|
|
"net/http"
|
|
|
|
|
"net/url"
|
|
|
|
|
"strings"
|
|
|
|
|
"time"
|
|
|
|
|
)
|
|
|
|
|
|
2023-06-25 11:34:55 +08:00
|
|
|
|
const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。"
|
2023-06-15 09:41:30 +08:00
|
|
|
|
|
|
|
|
|
type ChatHandler struct {
|
|
|
|
|
BaseHandler
|
2023-08-14 07:09:52 +08:00
|
|
|
|
db *gorm.DB
|
|
|
|
|
leveldb *store.LevelDB
|
2023-09-04 06:43:15 +08:00
|
|
|
|
redis *redis.Client
|
2023-06-15 09:41:30 +08:00
|
|
|
|
}
|
|
|
|
|
|
2023-09-04 06:43:15 +08:00
|
|
|
|
func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, redis *redis.Client) *ChatHandler {
|
|
|
|
|
handler := ChatHandler{db: db, leveldb: levelDB, redis: redis}
|
2023-06-19 07:06:59 +08:00
|
|
|
|
handler.App = app
|
2023-06-15 09:41:30 +08:00
|
|
|
|
return &handler
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-15 18:37:25 +08:00
|
|
|
|
var chatConfig types.ChatConfig
|
|
|
|
|
|
2023-06-15 09:41:30 +08:00
|
|
|
|
// ChatHandle 处理聊天 WebSocket 请求
|
|
|
|
|
func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|
|
|
|
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
|
|
|
|
if err != nil {
|
|
|
|
|
logger.Error(err)
|
|
|
|
|
return
|
|
|
|
|
}
|
2023-07-24 18:18:09 +08:00
|
|
|
|
|
2023-06-16 15:32:11 +08:00
|
|
|
|
sessionId := c.Query("session_id")
|
2023-06-19 07:06:59 +08:00
|
|
|
|
roleId := h.GetInt(c, "role_id", 0)
|
2023-06-16 15:32:11 +08:00
|
|
|
|
chatId := c.Query("chat_id")
|
2023-09-04 06:43:15 +08:00
|
|
|
|
modelId := h.GetInt(c, "model_id", 0)
|
|
|
|
|
|
|
|
|
|
client := types.NewWsClient(ws)
|
|
|
|
|
// get model info
|
|
|
|
|
var chatModel model.ChatModel
|
|
|
|
|
res := h.db.First(&chatModel, modelId)
|
|
|
|
|
if res.Error != nil || chatModel.Enabled == false {
|
|
|
|
|
utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!")
|
|
|
|
|
c.Abort()
|
|
|
|
|
return
|
|
|
|
|
}
|
2023-06-16 15:32:11 +08:00
|
|
|
|
|
2023-06-19 07:06:59 +08:00
|
|
|
|
session := h.App.ChatSession.Get(sessionId)
|
2023-08-17 14:20:16 +08:00
|
|
|
|
if session == nil {
|
2023-06-26 16:39:00 +08:00
|
|
|
|
user, err := utils.GetLoginUser(c, h.db)
|
|
|
|
|
if err != nil {
|
|
|
|
|
logger.Info("用户未登录")
|
|
|
|
|
c.Abort()
|
|
|
|
|
return
|
|
|
|
|
}
|
2023-08-17 14:20:16 +08:00
|
|
|
|
session = &types.ChatSession{
|
2023-06-26 16:39:00 +08:00
|
|
|
|
SessionId: sessionId,
|
|
|
|
|
ClientIP: c.ClientIP(),
|
2023-09-04 06:43:15 +08:00
|
|
|
|
Username: user.Mobile,
|
2023-06-26 16:39:00 +08:00
|
|
|
|
UserId: user.Id,
|
|
|
|
|
}
|
|
|
|
|
h.App.ChatSession.Put(sessionId, session)
|
2023-06-15 09:41:30 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// use old chat data override the chat model and role ID
|
|
|
|
|
var chat model.ChatItem
|
2023-09-04 06:43:15 +08:00
|
|
|
|
res = h.db.Where("chat_id=?", chatId).First(&chat)
|
2023-06-15 09:41:30 +08:00
|
|
|
|
if res.Error == nil {
|
2023-09-04 06:43:15 +08:00
|
|
|
|
chatModel.Id = chat.ModelId
|
2023-06-15 09:41:30 +08:00
|
|
|
|
roleId = int(chat.RoleId)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
session.ChatId = chatId
|
2023-09-04 06:43:15 +08:00
|
|
|
|
session.Model = types.ChatModel{
|
|
|
|
|
Id: chatModel.Id,
|
|
|
|
|
Value: chatModel.Value,
|
|
|
|
|
Platform: types.Platform(chatModel.Platform)}
|
2023-09-08 18:12:18 +08:00
|
|
|
|
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
|
2023-06-15 09:41:30 +08:00
|
|
|
|
var chatRole model.ChatRole
|
|
|
|
|
res = h.db.First(&chatRole, roleId)
|
|
|
|
|
if res.Error != nil || !chatRole.Enable {
|
2023-08-11 18:46:56 +08:00
|
|
|
|
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
2023-07-15 18:37:25 +08:00
|
|
|
|
c.Abort()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 初始化聊天配置
|
|
|
|
|
var config model.Config
|
|
|
|
|
h.db.Where("marker", "chat").First(&config)
|
|
|
|
|
err = utils.JsonDecode(config.Config, &chatConfig)
|
|
|
|
|
if err != nil {
|
2023-08-11 18:46:56 +08:00
|
|
|
|
utils.ReplyMessage(client, "加载系统配置失败,连接已关闭!!!")
|
2023-06-15 09:41:30 +08:00
|
|
|
|
c.Abort()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 保存会话连接
|
2023-06-19 07:06:59 +08:00
|
|
|
|
h.App.ChatClients.Put(sessionId, client)
|
2023-06-15 09:41:30 +08:00
|
|
|
|
go func() {
|
|
|
|
|
for {
|
2023-08-15 18:29:53 +08:00
|
|
|
|
_, msg, err := client.Receive()
|
2023-06-15 09:41:30 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
logger.Error(err)
|
|
|
|
|
client.Close()
|
2023-06-19 07:06:59 +08:00
|
|
|
|
h.App.ChatClients.Delete(sessionId)
|
|
|
|
|
h.App.ReqCancelFunc.Delete(sessionId)
|
2023-06-15 09:41:30 +08:00
|
|
|
|
return
|
|
|
|
|
}
|
2023-08-15 18:29:53 +08:00
|
|
|
|
|
|
|
|
|
message := string(msg)
|
|
|
|
|
logger.Info("Receive a message: ", message)
|
2023-08-11 18:46:56 +08:00
|
|
|
|
//utils.ReplyMessage(client, "这是一条测试消息!")
|
2023-06-15 09:41:30 +08:00
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
2023-06-19 07:06:59 +08:00
|
|
|
|
h.App.ReqCancelFunc.Put(sessionId, cancel)
|
2023-06-15 09:41:30 +08:00
|
|
|
|
// 回复消息
|
2023-08-15 18:29:53 +08:00
|
|
|
|
err = h.sendMessage(ctx, session, chatRole, message, client)
|
2023-06-15 09:41:30 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
logger.Error(err)
|
|
|
|
|
} else {
|
2023-08-11 18:46:56 +08:00
|
|
|
|
utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
|
2023-06-15 09:41:30 +08:00
|
|
|
|
logger.Info("回答完毕: " + string(message))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
}()
|
|
|
|
|
}
|
|
|
|
|
|
2023-08-17 14:20:16 +08:00
|
|
|
|
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
|
2023-09-04 06:43:15 +08:00
|
|
|
|
defer func() {
|
|
|
|
|
if r := recover(); r != nil {
|
|
|
|
|
logger.Error("Recover message from error: ", r)
|
|
|
|
|
}
|
|
|
|
|
}()
|
2023-06-15 09:41:30 +08:00
|
|
|
|
|
|
|
|
|
var user model.User
|
|
|
|
|
res := h.db.Model(&model.User{}).First(&user, session.UserId)
|
|
|
|
|
if res.Error != nil {
|
2023-08-11 18:46:56 +08:00
|
|
|
|
utils.ReplyMessage(ws, "非法用户,请联系管理员!")
|
2023-06-15 09:41:30 +08:00
|
|
|
|
return res.Error
|
|
|
|
|
}
|
|
|
|
|
var userVo vo.User
|
|
|
|
|
err := utils.CopyObject(user, &userVo)
|
|
|
|
|
userVo.Id = user.Id
|
|
|
|
|
if err != nil {
|
|
|
|
|
return errors.New("User 对象转换失败," + err.Error())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if userVo.Status == false {
|
2023-08-11 18:46:56 +08:00
|
|
|
|
utils.ReplyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!")
|
|
|
|
|
utils.ReplyMessage(ws, "![](/images/wx.png)")
|
2023-06-15 09:41:30 +08:00
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
2023-09-04 17:34:29 +08:00
|
|
|
|
if userVo.Calls <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
|
2023-08-11 18:46:56 +08:00
|
|
|
|
utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者点击左下角菜单加入众筹获得100次对话!")
|
|
|
|
|
utils.ReplyMessage(ws, "![](/images/wx.png)")
|
2023-06-15 09:41:30 +08:00
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
|
2023-08-11 18:46:56 +08:00
|
|
|
|
utils.ReplyMessage(ws, "您的账号已经过期,请联系管理员!")
|
|
|
|
|
utils.ReplyMessage(ws, "![](/images/wx.png)")
|
2023-06-15 09:41:30 +08:00
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
var req = types.ApiRequest{
|
2023-09-04 06:43:15 +08:00
|
|
|
|
Model: session.Model.Value,
|
|
|
|
|
Stream: true,
|
|
|
|
|
}
|
|
|
|
|
switch session.Model.Platform {
|
|
|
|
|
case types.Azure:
|
|
|
|
|
req.Temperature = h.App.ChatConfig.Azure.Temperature
|
|
|
|
|
req.MaxTokens = h.App.ChatConfig.Azure.MaxTokens
|
|
|
|
|
break
|
2023-09-04 17:34:29 +08:00
|
|
|
|
case types.ChatGLM:
|
2023-09-04 06:43:15 +08:00
|
|
|
|
req.Temperature = h.App.ChatConfig.ChatGML.Temperature
|
|
|
|
|
req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens
|
|
|
|
|
break
|
|
|
|
|
default:
|
|
|
|
|
req.Temperature = h.App.ChatConfig.OpenAI.Temperature
|
|
|
|
|
req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
|
|
|
|
|
var functions = make([]types.Function, 0)
|
|
|
|
|
for _, f := range types.InnerFunctions {
|
|
|
|
|
if !h.App.SysConfig.EnabledDraw && f.Name == types.FuncMidJourney {
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
functions = append(functions, f)
|
|
|
|
|
}
|
|
|
|
|
req.Functions = functions
|
2023-06-15 09:41:30 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 加载聊天上下文
|
2023-07-15 18:00:40 +08:00
|
|
|
|
var chatCtx []interface{}
|
2023-09-04 17:34:29 +08:00
|
|
|
|
if h.App.ChatConfig.EnableContext {
|
2023-06-19 07:06:59 +08:00
|
|
|
|
if h.App.ChatContexts.Has(session.ChatId) {
|
|
|
|
|
chatCtx = h.App.ChatContexts.Get(session.ChatId)
|
2023-06-15 09:41:30 +08:00
|
|
|
|
} else {
|
2023-08-01 17:58:03 +08:00
|
|
|
|
// calculate the tokens of current request, to prevent to exceeding the max tokens num
|
|
|
|
|
tokens := req.MaxTokens
|
|
|
|
|
for _, f := range types.InnerFunctions {
|
|
|
|
|
tks, _ := utils.CalcTokens(utils.JsonEncode(f), req.Model)
|
|
|
|
|
tokens += tks
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// loading the role context
|
2023-06-15 09:41:30 +08:00
|
|
|
|
var messages []types.Message
|
|
|
|
|
err := utils.JsonDecode(role.Context, &messages)
|
|
|
|
|
if err == nil {
|
2023-07-15 18:00:40 +08:00
|
|
|
|
for _, v := range messages {
|
2023-08-01 17:58:03 +08:00
|
|
|
|
tks, _ := utils.CalcTokens(v.Content, req.Model)
|
|
|
|
|
if tokens+tks >= types.ModelToTokens[req.Model] {
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
tokens += tks
|
2023-07-15 18:00:40 +08:00
|
|
|
|
chatCtx = append(chatCtx, v)
|
|
|
|
|
}
|
2023-06-15 09:41:30 +08:00
|
|
|
|
}
|
2023-07-15 18:37:25 +08:00
|
|
|
|
|
2023-08-01 17:58:03 +08:00
|
|
|
|
// loading recent chat history as chat context
|
2023-07-15 18:37:25 +08:00
|
|
|
|
if chatConfig.ContextDeep > 0 {
|
|
|
|
|
var historyMessages []model.HistoryMessage
|
2023-08-01 17:58:03 +08:00
|
|
|
|
res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("created_at desc").Find(&historyMessages)
|
2023-07-15 18:37:25 +08:00
|
|
|
|
if res.Error == nil {
|
|
|
|
|
for _, msg := range historyMessages {
|
2023-09-04 06:43:15 +08:00
|
|
|
|
if tokens+msg.Tokens >= types.ModelToTokens[session.Model.Value] {
|
2023-08-01 17:58:03 +08:00
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
tokens += msg.Tokens
|
2023-07-15 18:37:25 +08:00
|
|
|
|
ms := types.Message{Role: "user", Content: msg.Content}
|
|
|
|
|
if msg.Type == types.ReplyMsg {
|
|
|
|
|
ms.Role = "assistant"
|
|
|
|
|
}
|
|
|
|
|
chatCtx = append(chatCtx, ms)
|
2023-06-15 09:41:30 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2023-07-31 08:34:11 +08:00
|
|
|
|
logger.Debugf("聊天上下文:%+v", chatCtx)
|
2023-06-15 09:41:30 +08:00
|
|
|
|
}
|
2023-07-15 18:00:40 +08:00
|
|
|
|
reqMgs := make([]interface{}, 0)
|
|
|
|
|
for _, m := range chatCtx {
|
|
|
|
|
reqMgs = append(reqMgs, m)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
req.Messages = append(reqMgs, map[string]interface{}{
|
|
|
|
|
"role": "user",
|
|
|
|
|
"content": prompt,
|
2023-06-15 09:41:30 +08:00
|
|
|
|
})
|
|
|
|
|
|
2023-09-04 06:43:15 +08:00
|
|
|
|
switch session.Model.Platform {
|
|
|
|
|
case types.Azure:
|
|
|
|
|
return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
|
|
|
|
case types.OpenAI:
|
|
|
|
|
return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
2023-09-04 17:34:29 +08:00
|
|
|
|
case types.ChatGLM:
|
2023-09-04 06:43:15 +08:00
|
|
|
|
return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
|
2023-06-15 09:41:30 +08:00
|
|
|
|
}
|
2023-09-04 17:34:29 +08:00
|
|
|
|
return fmt.Errorf("not supported platform: %s", session.Model.Platform)
|
2023-06-15 09:41:30 +08:00
|
|
|
|
}
|
|
|
|
|
|
2023-09-04 06:43:15 +08:00
|
|
|
|
// Tokens 统计 token 数量
|
|
|
|
|
func (h *ChatHandler) Tokens(c *gin.Context) {
|
|
|
|
|
var data struct {
|
|
|
|
|
Text string `json:"text"`
|
|
|
|
|
Model string `json:"model"`
|
2023-06-15 09:41:30 +08:00
|
|
|
|
}
|
2023-09-04 06:43:15 +08:00
|
|
|
|
if err := c.ShouldBindJSON(&data); err != nil {
|
|
|
|
|
resp.ERROR(c, types.InvalidArgs)
|
|
|
|
|
return
|
2023-06-15 09:41:30 +08:00
|
|
|
|
}
|
|
|
|
|
|
2023-09-04 06:43:15 +08:00
|
|
|
|
tokens, err := utils.CalcTokens(data.Text, data.Model)
|
2023-06-15 09:41:30 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
resp.ERROR(c, err.Error())
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
resp.SUCCESS(c, tokens)
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-15 18:00:40 +08:00
|
|
|
|
func getTotalTokens(req types.ApiRequest) int {
|
|
|
|
|
encode := utils.JsonEncode(req.Messages)
|
|
|
|
|
var items []map[string]interface{}
|
|
|
|
|
err := utils.JsonDecode(encode, &items)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return 0
|
|
|
|
|
}
|
|
|
|
|
tokens := 0
|
|
|
|
|
for _, item := range items {
|
|
|
|
|
content, ok := item["content"]
|
|
|
|
|
if ok && !utils.IsEmptyValue(content) {
|
|
|
|
|
t, err := utils.CalcTokens(utils.InterfaceToString(content), req.Model)
|
|
|
|
|
if err == nil {
|
|
|
|
|
tokens += t
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return tokens
|
|
|
|
|
}
|
|
|
|
|
|
2023-06-15 09:41:30 +08:00
|
|
|
|
// StopGenerate 停止生成
|
|
|
|
|
func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
2023-06-16 15:32:11 +08:00
|
|
|
|
sessionId := c.Query("session_id")
|
2023-06-19 07:06:59 +08:00
|
|
|
|
if h.App.ReqCancelFunc.Has(sessionId) {
|
|
|
|
|
h.App.ReqCancelFunc.Get(sessionId)()
|
|
|
|
|
h.App.ReqCancelFunc.Delete(sessionId)
|
2023-06-15 09:41:30 +08:00
|
|
|
|
}
|
|
|
|
|
resp.SUCCESS(c, types.OkMsg)
|
|
|
|
|
}
|
2023-09-04 06:43:15 +08:00
|
|
|
|
|
|
|
|
|
// 发送请求到 OpenAI 服务器
|
|
|
|
|
// useOwnApiKey: 是否使用了用户自己的 API KEY
|
|
|
|
|
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *string) (*http.Response, error) {
|
|
|
|
|
|
|
|
|
|
var apiURL string
|
|
|
|
|
switch platform {
|
|
|
|
|
case types.Azure:
|
|
|
|
|
md := strings.Replace(req.Model, ".", "", 1)
|
|
|
|
|
apiURL = strings.Replace(h.App.ChatConfig.Azure.ApiURL, "{model}", md, 1)
|
|
|
|
|
break
|
2023-09-04 17:34:29 +08:00
|
|
|
|
case types.ChatGLM:
|
2023-09-04 06:43:15 +08:00
|
|
|
|
apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1)
|
|
|
|
|
req.Prompt = req.Messages
|
|
|
|
|
req.Messages = nil
|
|
|
|
|
break
|
|
|
|
|
default:
|
|
|
|
|
apiURL = h.App.ChatConfig.OpenAI.ApiURL
|
|
|
|
|
}
|
|
|
|
|
// 创建 HttpClient 请求对象
|
|
|
|
|
var client *http.Client
|
|
|
|
|
requestBody, err := json.Marshal(req)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
request = request.WithContext(ctx)
|
|
|
|
|
request.Header.Set("Content-Type", "application/json")
|
|
|
|
|
proxyURL := h.App.Config.ProxyURL
|
|
|
|
|
if proxyURL != "" && platform == types.OpenAI { // 使用代理
|
|
|
|
|
proxy, _ := url.Parse(proxyURL)
|
|
|
|
|
client = &http.Client{
|
|
|
|
|
Transport: &http.Transport{
|
|
|
|
|
Proxy: http.ProxyURL(proxy),
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
client = http.DefaultClient
|
|
|
|
|
}
|
2023-09-04 17:34:29 +08:00
|
|
|
|
if *apiKey == "" {
|
|
|
|
|
var key model.ApiKey
|
|
|
|
|
res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key)
|
|
|
|
|
if res.Error != nil {
|
|
|
|
|
return nil, errors.New("no available key, please import key")
|
|
|
|
|
}
|
|
|
|
|
// 更新 API KEY 的最后使用时间
|
|
|
|
|
h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
|
|
|
|
|
*apiKey = key.Value
|
2023-09-04 06:43:15 +08:00
|
|
|
|
}
|
|
|
|
|
|
2023-09-04 17:34:29 +08:00
|
|
|
|
logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model)
|
2023-09-04 06:43:15 +08:00
|
|
|
|
switch platform {
|
|
|
|
|
case types.Azure:
|
2023-09-04 17:34:29 +08:00
|
|
|
|
request.Header.Set("api-key", *apiKey)
|
2023-09-04 06:43:15 +08:00
|
|
|
|
break
|
2023-09-04 17:34:29 +08:00
|
|
|
|
case types.ChatGLM:
|
|
|
|
|
token, err := h.getChatGLMToken(*apiKey)
|
2023-09-04 06:43:15 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
logger.Info(token)
|
|
|
|
|
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
|
|
|
|
|
break
|
|
|
|
|
default:
|
2023-09-04 17:34:29 +08:00
|
|
|
|
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
|
2023-09-04 06:43:15 +08:00
|
|
|
|
}
|
|
|
|
|
return client.Do(request)
|
|
|
|
|
}
|