mirror of
https://gitee.com/blackfox/geekai.git
synced 2024-12-04 05:09:12 +08:00
440 lines
11 KiB
Go
440 lines
11 KiB
Go
package handler
|
||
|
||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||
// * Use of this source code is governed by a Apache-2.0 license
|
||
// * that can be found in the LICENSE file.
|
||
// * @Author yangjian102621@163.com
|
||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||
|
||
import (
|
||
"fmt"
|
||
"geekai/core"
|
||
"geekai/core/types"
|
||
"geekai/service"
|
||
"geekai/store/model"
|
||
"geekai/store/vo"
|
||
"geekai/utils"
|
||
"geekai/utils/resp"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/go-redis/redis/v8"
|
||
"github.com/golang-jwt/jwt/v5"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type UserHandler struct {
|
||
BaseHandler
|
||
searcher *xdb.Searcher
|
||
redis *redis.Client
|
||
licenseService *service.LicenseService
|
||
}
|
||
|
||
func NewUserHandler(
|
||
app *core.AppServer,
|
||
db *gorm.DB,
|
||
searcher *xdb.Searcher,
|
||
client *redis.Client,
|
||
licenseService *service.LicenseService) *UserHandler {
|
||
return &UserHandler{
|
||
BaseHandler: BaseHandler{DB: db, App: app},
|
||
searcher: searcher,
|
||
redis: client,
|
||
licenseService: licenseService,
|
||
}
|
||
}
|
||
|
||
// Register user register
|
||
func (h *UserHandler) Register(c *gin.Context) {
|
||
// parameters process
|
||
var data struct {
|
||
RegWay string `json:"reg_way"`
|
||
Username string `json:"username"`
|
||
Password string `json:"password"`
|
||
Code string `json:"code"`
|
||
InviteCode string `json:"invite_code"`
|
||
}
|
||
if err := c.ShouldBindJSON(&data); err != nil {
|
||
resp.ERROR(c, types.InvalidArgs)
|
||
return
|
||
}
|
||
data.Password = strings.TrimSpace(data.Password)
|
||
if len(data.Password) < 8 {
|
||
resp.ERROR(c, "密码长度不能少于8个字符")
|
||
return
|
||
}
|
||
|
||
// 检测最大注册人数
|
||
var totalUser int64
|
||
h.DB.Model(&model.User{}).Count(&totalUser)
|
||
if h.licenseService.GetLicense().UserNum > 0 && int(totalUser) >= h.licenseService.GetLicense().UserNum {
|
||
resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
|
||
return
|
||
}
|
||
|
||
// 检查验证码
|
||
var key string
|
||
if data.RegWay == "email" || data.RegWay == "mobile" {
|
||
key = CodeStorePrefix + data.Username
|
||
code, err := h.redis.Get(c, key).Result()
|
||
if err != nil || code != data.Code {
|
||
resp.ERROR(c, "验证码错误")
|
||
return
|
||
}
|
||
}
|
||
|
||
// 验证邀请码
|
||
inviteCode := model.InviteCode{}
|
||
if data.InviteCode != "" {
|
||
res := h.DB.Where("code = ?", data.InviteCode).First(&inviteCode)
|
||
if res.Error != nil {
|
||
resp.ERROR(c, "无效的邀请码")
|
||
return
|
||
}
|
||
}
|
||
|
||
// check if the username is exists
|
||
var item model.User
|
||
res := h.DB.Where("username = ?", data.Username).First(&item)
|
||
if item.Id > 0 {
|
||
resp.ERROR(c, "该用户名已经被注册")
|
||
return
|
||
}
|
||
|
||
salt := utils.RandString(8)
|
||
user := model.User{
|
||
Username: data.Username,
|
||
Password: utils.GenPassword(data.Password, salt),
|
||
Nickname: fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)),
|
||
Avatar: "/images/avatar/user.png",
|
||
Salt: salt,
|
||
Status: true,
|
||
ChatRoles: utils.JsonEncode([]string{"gpt"}), // 默认只订阅通用助手角色
|
||
ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
|
||
Power: h.App.SysConfig.InitPower,
|
||
}
|
||
|
||
res = h.DB.Create(&user)
|
||
if res.Error != nil {
|
||
resp.ERROR(c, "保存数据失败")
|
||
logger.Error(res.Error)
|
||
return
|
||
}
|
||
|
||
// 记录邀请关系
|
||
if data.InviteCode != "" {
|
||
// 增加邀请数量
|
||
h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
|
||
if h.App.SysConfig.InvitePower > 0 {
|
||
h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower))
|
||
// 记录邀请算力充值日志
|
||
var inviter model.User
|
||
h.DB.Where("id", inviteCode.UserId).First(&inviter)
|
||
h.DB.Create(&model.PowerLog{
|
||
UserId: inviter.Id,
|
||
Username: inviter.Username,
|
||
Type: types.PowerInvite,
|
||
Amount: h.App.SysConfig.InvitePower,
|
||
Balance: inviter.Power,
|
||
Mark: types.PowerAdd,
|
||
Model: "",
|
||
Remark: fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
|
||
CreatedAt: time.Now(),
|
||
})
|
||
}
|
||
|
||
// 添加邀请记录
|
||
h.DB.Create(&model.InviteLog{
|
||
InviterId: inviteCode.UserId,
|
||
UserId: user.Id,
|
||
Username: user.Username,
|
||
InviteCode: inviteCode.Code,
|
||
Remark: fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
|
||
})
|
||
}
|
||
|
||
_ = h.redis.Del(c, key) // 注册成功,删除短信验证码
|
||
|
||
// 自动登录创建 token
|
||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||
"user_id": user.Id,
|
||
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
|
||
})
|
||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
||
if err != nil {
|
||
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
||
return
|
||
}
|
||
// 保存到 redis
|
||
key = fmt.Sprintf("users/%d", user.Id)
|
||
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
|
||
resp.ERROR(c, "error with save token: "+err.Error())
|
||
return
|
||
}
|
||
resp.SUCCESS(c, tokenString)
|
||
}
|
||
|
||
// Login 用户登录
|
||
func (h *UserHandler) Login(c *gin.Context) {
|
||
var data struct {
|
||
Username string `json:"username"`
|
||
Password string `json:"password"`
|
||
}
|
||
if err := c.ShouldBindJSON(&data); err != nil {
|
||
resp.ERROR(c, types.InvalidArgs)
|
||
return
|
||
}
|
||
var user model.User
|
||
res := h.DB.Where("username = ?", data.Username).First(&user)
|
||
if res.Error != nil {
|
||
resp.ERROR(c, "用户名不存在")
|
||
return
|
||
}
|
||
|
||
password := utils.GenPassword(data.Password, user.Salt)
|
||
if password != user.Password {
|
||
resp.ERROR(c, "用户名或密码错误")
|
||
return
|
||
}
|
||
|
||
if user.Status == false {
|
||
resp.ERROR(c, "该用户已被禁止登录,请联系管理员")
|
||
return
|
||
}
|
||
|
||
// 更新最后登录时间和IP
|
||
user.LastLoginIp = c.ClientIP()
|
||
user.LastLoginAt = time.Now().Unix()
|
||
h.DB.Model(&user).Updates(user)
|
||
|
||
h.DB.Create(&model.UserLoginLog{
|
||
UserId: user.Id,
|
||
Username: user.Username,
|
||
LoginIp: c.ClientIP(),
|
||
LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
|
||
})
|
||
|
||
// 创建 token
|
||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||
"user_id": user.Id,
|
||
"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
|
||
})
|
||
tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
|
||
if err != nil {
|
||
resp.ERROR(c, "Failed to generate token, "+err.Error())
|
||
return
|
||
}
|
||
// 保存到 redis
|
||
key := fmt.Sprintf("users/%d", user.Id)
|
||
if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
|
||
resp.ERROR(c, "error with save token: "+err.Error())
|
||
return
|
||
}
|
||
resp.SUCCESS(c, tokenString)
|
||
}
|
||
|
||
// Logout 注 销
|
||
func (h *UserHandler) Logout(c *gin.Context) {
|
||
key := h.GetUserKey(c)
|
||
if _, err := h.redis.Del(c, key).Result(); err != nil {
|
||
logger.Error("error with delete session: ", err)
|
||
}
|
||
resp.SUCCESS(c)
|
||
}
|
||
|
||
// Session 获取/验证会话
|
||
func (h *UserHandler) Session(c *gin.Context) {
|
||
user, err := h.GetLoginUser(c)
|
||
if err == nil {
|
||
var userVo vo.User
|
||
err := utils.CopyObject(user, &userVo)
|
||
if err != nil {
|
||
resp.ERROR(c)
|
||
}
|
||
userVo.Id = user.Id
|
||
resp.SUCCESS(c, userVo)
|
||
} else {
|
||
resp.NotAuth(c)
|
||
}
|
||
|
||
}
|
||
|
||
type userProfile struct {
|
||
Id uint `json:"id"`
|
||
Nickname string `json:"nickname"`
|
||
Username string `json:"username"`
|
||
Avatar string `json:"avatar"`
|
||
Power int `json:"power"`
|
||
ExpiredTime int64 `json:"expired_time"`
|
||
Vip bool `json:"vip"`
|
||
}
|
||
|
||
func (h *UserHandler) Profile(c *gin.Context) {
|
||
user, err := h.GetLoginUser(c)
|
||
if err != nil {
|
||
resp.NotAuth(c)
|
||
return
|
||
}
|
||
|
||
h.DB.First(&user, user.Id)
|
||
var profile userProfile
|
||
err = utils.CopyObject(user, &profile)
|
||
if err != nil {
|
||
logger.Error("对象拷贝失败:", err.Error())
|
||
resp.ERROR(c, "获取用户信息失败")
|
||
return
|
||
}
|
||
|
||
profile.Id = user.Id
|
||
resp.SUCCESS(c, profile)
|
||
}
|
||
|
||
func (h *UserHandler) ProfileUpdate(c *gin.Context) {
|
||
var data userProfile
|
||
if err := c.ShouldBindJSON(&data); err != nil {
|
||
resp.ERROR(c, types.InvalidArgs)
|
||
return
|
||
}
|
||
|
||
user, err := h.GetLoginUser(c)
|
||
if err != nil {
|
||
resp.NotAuth(c)
|
||
return
|
||
}
|
||
h.DB.First(&user, user.Id)
|
||
user.Avatar = data.Avatar
|
||
user.Nickname = data.Nickname
|
||
res := h.DB.Updates(&user)
|
||
if res.Error != nil {
|
||
resp.ERROR(c, "更新用户信息失败")
|
||
return
|
||
}
|
||
|
||
resp.SUCCESS(c)
|
||
}
|
||
|
||
// UpdatePass 更新密码
|
||
func (h *UserHandler) UpdatePass(c *gin.Context) {
|
||
var data struct {
|
||
OldPass string `json:"old_pass"`
|
||
Password string `json:"password"`
|
||
}
|
||
if err := c.ShouldBindJSON(&data); err != nil {
|
||
resp.ERROR(c, types.InvalidArgs)
|
||
return
|
||
}
|
||
|
||
if len(data.Password) < 8 {
|
||
resp.ERROR(c, "密码长度不能少于8个字符")
|
||
return
|
||
}
|
||
|
||
user, err := h.GetLoginUser(c)
|
||
if err != nil {
|
||
resp.NotAuth(c)
|
||
return
|
||
}
|
||
|
||
password := utils.GenPassword(data.OldPass, user.Salt)
|
||
logger.Debugf(user.Salt, ",", user.Password, ",", password, ",", data.OldPass)
|
||
if password != user.Password {
|
||
resp.ERROR(c, "原密码错误")
|
||
return
|
||
}
|
||
|
||
newPass := utils.GenPassword(data.Password, user.Salt)
|
||
res := h.DB.Model(&user).UpdateColumn("password", newPass)
|
||
if res.Error != nil {
|
||
logger.Error("更新数据库失败: ", res.Error)
|
||
resp.ERROR(c, "更新数据库失败")
|
||
return
|
||
}
|
||
|
||
resp.SUCCESS(c)
|
||
}
|
||
|
||
// ResetPass 重置密码
|
||
func (h *UserHandler) ResetPass(c *gin.Context) {
|
||
var data struct {
|
||
Username string `json:"username"`
|
||
Code string `json:"code"` // 验证码
|
||
Password string `json:"password"` // 新密码
|
||
}
|
||
if err := c.ShouldBindJSON(&data); err != nil {
|
||
resp.ERROR(c, types.InvalidArgs)
|
||
return
|
||
}
|
||
|
||
var user model.User
|
||
res := h.DB.Where("username", data.Username).First(&user)
|
||
if res.Error != nil {
|
||
resp.ERROR(c, "用户不存在!")
|
||
return
|
||
}
|
||
|
||
// 检查验证码
|
||
key := CodeStorePrefix + data.Username
|
||
code, err := h.redis.Get(c, key).Result()
|
||
if err != nil || code != data.Code {
|
||
resp.ERROR(c, "短信验证码错误")
|
||
return
|
||
}
|
||
|
||
password := utils.GenPassword(data.Password, user.Salt)
|
||
user.Password = password
|
||
res = h.DB.Updates(&user)
|
||
if res.Error != nil {
|
||
resp.ERROR(c)
|
||
} else {
|
||
h.redis.Del(c, key)
|
||
resp.SUCCESS(c)
|
||
}
|
||
}
|
||
|
||
// BindUsername 重置账号
|
||
func (h *UserHandler) BindUsername(c *gin.Context) {
|
||
var data struct {
|
||
Username string `json:"username"`
|
||
Code string `json:"code"`
|
||
}
|
||
if err := c.ShouldBindJSON(&data); err != nil {
|
||
resp.ERROR(c, types.InvalidArgs)
|
||
return
|
||
}
|
||
|
||
// 检查验证码
|
||
key := CodeStorePrefix + data.Username
|
||
code, err := h.redis.Get(c, key).Result()
|
||
if err != nil || code != data.Code {
|
||
resp.ERROR(c, "验证码错误")
|
||
return
|
||
}
|
||
|
||
// 检查手机号是否被其他账号绑定
|
||
var item model.User
|
||
res := h.DB.Where("username = ?", data.Username).First(&item)
|
||
if res.Error == nil {
|
||
resp.ERROR(c, "该账号已经被其他账号绑定")
|
||
return
|
||
}
|
||
|
||
user, err := h.GetLoginUser(c)
|
||
if err != nil {
|
||
resp.NotAuth(c)
|
||
return
|
||
}
|
||
|
||
res = h.DB.Model(&user).UpdateColumn("username", data.Username)
|
||
if res.Error != nil {
|
||
resp.ERROR(c, "更新数据库失败")
|
||
return
|
||
}
|
||
|
||
_ = h.redis.Del(c, key) // 删除短信验证码
|
||
resp.SUCCESS(c)
|
||
}
|