优化聊天会话管理,支持 websocket 断开重连之后能继续连接会话上下文

This commit is contained in:
RockYang 2023-03-24 14:24:49 +08:00
parent 2067aa3f83
commit bb019f3552
9 changed files with 244 additions and 317 deletions

View File

@ -16,6 +16,7 @@ var logger = logger2.GetLogger()
//go:embed dist //go:embed dist
var webRoot embed.FS var webRoot embed.FS
var configFile string var configFile string
var debugMode bool
func main() { func main() {
defer func() { defer func() {
@ -49,12 +50,13 @@ func main() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
s.Run(webRoot, "dist") s.Run(webRoot, "dist", debugMode)
} }
func init() { func init() {
flag.StringVar(&configFile, "config", "", "Config file path (default: ~/.config/chat-gpt/config.toml)") flag.StringVar(&configFile, "config", "", "Config file path (default: ~/.config/chat-gpt/config.toml)")
flag.BoolVar(&debugMode, "debug", true, "Enable debug mode (default: true, recommend to set false in production env)")
flag.Usage = usage flag.Usage = usage
flag.Parse() flag.Parse()
} }
@ -67,7 +69,7 @@ OPTIONS:
`, os.Args[0]) `, os.Args[0])
flagSet := flag.CommandLine flagSet := flag.CommandLine
order := []string{"config"} order := []string{"config", "debug"}
for _, name := range order { for _, name := range order {
f := flagSet.Lookup(name) f := flagSet.Lookup(name)
fmt.Printf(" --%s => %s\n", f.Name, f.Usage) fmt.Printf(" --%s => %s\n", f.Name, f.Usage)

View File

@ -10,6 +10,7 @@ import (
"io" "io"
"math/rand" "math/rand"
"net/http" "net/http"
"net/url"
"openai/types" "openai/types"
"strings" "strings"
"time" "time"
@ -68,6 +69,19 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
return err return err
} }
// 创建 HttpClient 请求对象
var client *http.Client
if s.Config.ProxyURL == "" {
client = &http.Client{}
} else { // 使用代理
uri := url.URL{}
proxy, _ := uri.Parse(s.Config.ProxyURL)
client = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
}
request, err := http.NewRequest(http.MethodPost, s.Config.Chat.ApiURL, bytes.NewBuffer(requestBody)) request, err := http.NewRequest(http.MethodPost, s.Config.Chat.ApiURL, bytes.NewBuffer(requestBody))
if err != nil { if err != nil {
return err return err
@ -86,7 +100,7 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
} }
logger.Infof("Use API KEY: %s", apiKey) logger.Infof("Use API KEY: %s", apiKey)
request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
response, err = s.Client.Do(request) response, err = client.Do(request)
if err == nil { if err == nil {
break break
} else { } else {

View File

@ -11,12 +11,6 @@ import (
// ConfigSetHandle set configs // ConfigSetHandle set configs
func (s *Server) ConfigSetHandle(c *gin.Context) { func (s *Server) ConfigSetHandle(c *gin.Context) {
token := c.Query("token")
if token != "RockYang" {
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg})
return
}
var data map[string]string var data map[string]string
err := json.NewDecoder(c.Request.Body).Decode(&data) err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil { if err != nil {
@ -24,10 +18,6 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
c.JSON(http.StatusBadRequest, nil) c.JSON(http.StatusBadRequest, nil)
return return
} }
// API key
if key, ok := data["api_key"]; ok && len(key) > 20 {
s.Config.Chat.ApiKeys = append(s.Config.Chat.ApiKeys, key)
}
// proxy URL // proxy URL
if proxy, ok := data["proxy"]; ok { if proxy, ok := data["proxy"]; ok {
@ -91,12 +81,6 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
s.Config.EnableAuth = v s.Config.EnableAuth = v
} }
if token, ok := data["token"]; ok {
if !utils.ContainsItem(s.Config.Tokens, token) {
s.Config.Tokens = append(s.Config.Tokens, token)
}
}
// 保存配置文件 // 保存配置文件
err = types.SaveConfig(s.Config, s.ConfigPath) err = types.SaveConfig(s.Config, s.ConfigPath)
if err != nil { if err != nil {
@ -106,3 +90,62 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg}) c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
} }
func (s *Server) AddToken(c *gin.Context) {
var data map[string]string
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {
logger.Errorf("Error decode json data: %s", err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
if token, ok := data["token"]; ok {
if !utils.ContainsItem(s.Config.Tokens, token) {
s.Config.Tokens = append(s.Config.Tokens, token)
}
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Tokens})
}
func (s *Server) RemoveToken(c *gin.Context) {
var data map[string]string
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {
logger.Errorf("Error decode json data: %s", err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
if token, ok := data["token"]; ok {
for i, v := range s.Config.Tokens {
if v == token {
s.Config.Tokens = append(s.Config.Tokens[:i], s.Config.Tokens[i+1:]...)
break
}
}
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Tokens})
}
func (s *Server) AddApiKey(c *gin.Context) {
var data map[string]string
err := json.NewDecoder(c.Request.Body).Decode(&data)
if err != nil {
logger.Errorf("Error decode json data: %s", err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
if key, ok := data["api_key"]; ok && len(key) > 20 {
s.Config.Chat.ApiKeys = append(s.Config.Chat.ApiKeys, key)
}
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys})
}
func (s *Server) ListApiKeys(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys})
}

View File

@ -9,7 +9,6 @@ import (
"io/fs" "io/fs"
"log" "log"
"net/http" "net/http"
"net/url"
logger2 "openai/logger" logger2 "openai/logger"
"openai/types" "openai/types"
"openai/utils" "openai/utils"
@ -34,7 +33,6 @@ func (s StaticFile) Open(name string) (fs.File, error) {
type Server struct { type Server struct {
Config *types.Config Config *types.Config
ConfigPath string ConfigPath string
Client *http.Client
History map[string][]types.Message History map[string][]types.Message
// 保存 Websocket 会话 Token, 每个 Token 只能连接一次 // 保存 Websocket 会话 Token, 每个 Token 只能连接一次
@ -50,16 +48,8 @@ func NewServer(configPath string) (*Server, error) {
return nil, err return nil, err
} }
uri := url.URL{}
proxy, _ := uri.Parse(config.ProxyURL)
client := &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxy),
},
}
return &Server{ return &Server{
Config: config, Config: config,
Client: client,
ConfigPath: configPath, ConfigPath: configPath,
History: make(map[string][]types.Message, 16), History: make(map[string][]types.Message, 16),
WsSession: make(map[string]string), WsSession: make(map[string]string),
@ -67,11 +57,13 @@ func NewServer(configPath string) (*Server, error) {
}, nil }, nil
} }
func (s *Server) Run(webRoot embed.FS, path string) { func (s *Server) Run(webRoot embed.FS, path string, debug bool) {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
engine := gin.Default() engine := gin.Default()
if debug {
engine.Use(corsMiddleware())
}
engine.Use(sessionMiddleware(s.Config)) engine.Use(sessionMiddleware(s.Config))
engine.Use(corsMiddleware())
engine.Use(AuthorizeMiddleware(s)) engine.Use(AuthorizeMiddleware(s))
engine.GET("/hello", Hello) engine.GET("/hello", Hello)
@ -79,6 +71,10 @@ func (s *Server) Run(webRoot embed.FS, path string) {
engine.POST("/api/login", s.LoginHandle) engine.POST("/api/login", s.LoginHandle)
engine.Any("/api/chat", s.ChatHandle) engine.Any("/api/chat", s.ChatHandle)
engine.POST("/api/config/set", s.ConfigSetHandle) engine.POST("/api/config/set", s.ConfigSetHandle)
engine.POST("api/config/token/add", s.AddToken)
engine.POST("api/config/token/remove", s.RemoveToken)
engine.POST("api/config/apikey/add", s.AddApiKey)
engine.POST("api/config/apikey/list", s.ListApiKeys)
engine.NoRoute(func(c *gin.Context) { engine.NoRoute(func(c *gin.Context) {
if c.Request.URL.Path == "/favicon.ico" { if c.Request.URL.Path == "/favicon.ico" {
@ -123,7 +119,7 @@ func corsMiddleware() gin.HandlerFunc {
origin := c.Request.Header.Get("Origin") origin := c.Request.Header.Get("Origin")
if origin != "" { if origin != "" {
// 设置允许的请求源 // 设置允许的请求源
c.Writer.Header().Set("Access-Control-Allow-Origin", origin) c.Header("Access-Control-Allow-Origin", origin)
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE") c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
//允许跨域设置可以返回其他子段,可以自定义字段 //允许跨域设置可以返回其他子段,可以自定义字段
c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, ChatGPT-Token") c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, ChatGPT-Token")
@ -154,18 +150,28 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if !s.Config.EnableAuth || if !s.Config.EnableAuth ||
c.Request.URL.Path == "/api/login" || c.Request.URL.Path == "/api/login" ||
c.Request.URL.Path == "/api/config/set" ||
!strings.HasPrefix(c.Request.URL.Path, "/api") { !strings.HasPrefix(c.Request.URL.Path, "/api") {
c.Next() c.Next()
return return
} }
if strings.HasPrefix(c.Request.URL.Path, "/api/config") {
accessKey := c.Query("access_key")
if accessKey != "RockYang" {
c.Abort()
c.JSON(http.StatusOK, types.BizVo{Code: types.NotAuthorized, Message: "No Permissions"})
} else {
c.Next()
}
return
}
// WebSocket 连接请求验证 // WebSocket 连接请求验证
if c.Request.URL.Path == "/api/chat" { if c.Request.URL.Path == "/api/chat" {
tokenName := c.Query("token") tokenName := c.Query("token")
if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() { if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() {
// 每个令牌只能连接一次 // 每个令牌只能连接一次
delete(s.WsSession, tokenName) //delete(s.WsSession, tokenName)
c.Next() c.Next()
} else { } else {
c.Abort() c.Abort()
@ -190,7 +196,16 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
} }
func (s *Server) GetSessionHandle(c *gin.Context) { func (s *Server) GetSessionHandle(c *gin.Context) {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success}) tokenName := c.GetHeader(types.TokenName)
if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() {
c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: addr})
} else {
c.JSON(http.StatusOK, types.BizVo{
Code: types.NotAuthorized,
Message: "Not Authorized",
})
}
} }
func (s *Server) LoginHandle(c *gin.Context) { func (s *Server) LoginHandle(c *gin.Context) {

View File

@ -52,7 +52,7 @@ func NewDefaultConfig() *Config {
MaxAge: 86400, MaxAge: 86400,
Secure: true, Secure: true,
HttpOnly: false, HttpOnly: false,
SameSite: http.SameSiteNoneMode, SameSite: http.SameSiteLaxMode,
}, },
Chat: Chat{ Chat: Chat{
ApiURL: "https://api.openai.com/v1/chat/completions", ApiURL: "https://api.openai.com/v1/chat/completions",

View File

@ -1,2 +1,2 @@
VUE_APP_API_HOST=http://chat.r9it.com:6789 VUE_APP_API_HOST=https://ai.r9it.com
VUE_APP_WS_HOST=ws://chat.r9it.com:6789 VUE_APP_WS_HOST=wss://ai.r9it.com

341
web/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@ -2,15 +2,14 @@
/** /**
* storage handler * storage handler
*/ */
import Storage from 'good-storage'
const SessionIdKey = 'ChatGPT_SESSION_ID'; const SessionIdKey = 'ChatGPT_SESSION_ID';
export const Global = {} export const Global = {}
export function getSessionId() { export function getSessionId() {
return Storage.get(SessionIdKey) return sessionStorage.getItem(SessionIdKey)
} }
export function setSessionId(value) { export function setSessionId(value) {
Storage.set(SessionIdKey, value) sessionStorage.setItem(SessionIdKey, value)
} }

View File

@ -114,8 +114,6 @@ export default defineComponent({
this.chatBoxHeight = window.innerHeight - this.toolBoxHeight; this.chatBoxHeight = window.innerHeight - this.toolBoxHeight;
}) })
this.checkSession();
// for (let i = 0; i < 10; i++) { // for (let i = 0; i < 10; i++) {
// this.chatData.push({ // this.chatData.push({
// type: "prompt", // type: "prompt",
@ -175,44 +173,11 @@ export default defineComponent({
this.chatBoxHeight = window.innerHeight - this.toolBoxHeight; this.chatBoxHeight = window.innerHeight - this.toolBoxHeight;
}); });
this.connect();
}, },
methods: { methods: {
//
checkSession: function () {
httpPost("/api/session/get").then(() => {
if (this.socket == null) {
this.connect();
}
//
//setTimeout(() => this.checkSession(), 5000);
}).catch((res) => {
if (res.code === 400) {
this.showLoginDialog = true;
} else {
this.connectingMessageBox = ElMessageBox.confirm(
'^_^ 会话发生异常,您已经从服务器断开连接!',
'注意:',
{
confirmButtonText: '重连会话',
cancelButtonText: '不聊了',
type: 'warning',
showClose: false,
closeOnClickModal: false
}
).then(() => {
this.connect();
}).catch(() => {
ElMessage({
type: 'info',
message: '您关闭了会话',
})
})
}
})
},
connect: function () { connect: function () {
// WebSocket // WebSocket
const token = getSessionId(); const token = getSessionId();
@ -264,8 +229,32 @@ export default defineComponent({
}); });
socket.addEventListener('close', () => { socket.addEventListener('close', () => {
// //
this.checkSession(); httpPost("/api/session/get").then(() => {
this.connectingMessageBox = ElMessageBox.confirm(
'^_^ 会话发生异常,您已经从服务器断开连接!',
'注意:',
{
confirmButtonText: '重连会话',
cancelButtonText: '不聊了',
type: 'warning',
showClose: false,
closeOnClickModal: false
}
).then(() => {
this.connect();
}).catch(() => {
ElMessage({
type: 'info',
message: '您关闭了会话',
})
})
}).catch((res) => {
if (res.code === 400) {
this.showLoginDialog = true;
}
})
}); });
this.socket = socket; this.socket = socket;