mirror of
https://gitee.com/blackfox/geekai.git
synced 2024-12-04 05:09:12 +08:00
161 lines
4.3 KiB
Go
161 lines
4.3 KiB
Go
package mj
|
|
|
|
import (
|
|
"chatplus/core/types"
|
|
"chatplus/store"
|
|
"chatplus/store/model"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// Service MJ 绘画服务
|
|
type Service struct {
|
|
name string // service name
|
|
client *Client // MJ client
|
|
taskQueue *store.RedisQueue
|
|
notifyQueue *store.RedisQueue
|
|
db *gorm.DB
|
|
maxHandleTaskNum int32 // max task number current service can handle
|
|
handledTaskNum int32 // already handled task number
|
|
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
|
|
taskTimeout int64
|
|
}
|
|
|
|
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
|
|
return &Service{
|
|
name: name,
|
|
db: db,
|
|
taskQueue: taskQueue,
|
|
notifyQueue: notifyQueue,
|
|
client: client,
|
|
taskTimeout: timeout,
|
|
maxHandleTaskNum: maxTaskNum,
|
|
taskStartTimes: make(map[int]time.Time, 0),
|
|
}
|
|
}
|
|
|
|
func (s *Service) Run() {
|
|
logger.Infof("Starting MidJourney job consumer for %s", s.name)
|
|
for {
|
|
s.checkTasks()
|
|
if !s.canHandleTask() {
|
|
// current service is full, can not handle more task
|
|
// waiting for running task finish
|
|
time.Sleep(time.Second * 3)
|
|
continue
|
|
}
|
|
|
|
var task types.MjTask
|
|
err := s.taskQueue.LPop(&task)
|
|
if err != nil {
|
|
logger.Errorf("taking task with error: %v", err)
|
|
continue
|
|
}
|
|
|
|
// if it's reference message, check if it's this channel's message
|
|
if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId {
|
|
s.taskQueue.RPush(task)
|
|
time.Sleep(time.Second)
|
|
continue
|
|
}
|
|
|
|
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
|
|
switch task.Type {
|
|
case types.TaskImage:
|
|
err = s.client.Imagine(task.Prompt)
|
|
break
|
|
case types.TaskUpscale:
|
|
err = s.client.Upscale(task.Index, task.MessageId, task.MessageHash)
|
|
|
|
break
|
|
case types.TaskVariation:
|
|
err = s.client.Variation(task.Index, task.MessageId, task.MessageHash)
|
|
}
|
|
|
|
if err != nil {
|
|
logger.Error("绘画任务执行失败:", err)
|
|
// update the task progress
|
|
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
|
|
s.notifyQueue.RPush(task.UserId)
|
|
// restore img_call quota
|
|
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
|
|
continue
|
|
}
|
|
|
|
// lock the task until the execute timeout
|
|
s.taskStartTimes[task.Id] = time.Now()
|
|
atomic.AddInt32(&s.handledTaskNum, 1)
|
|
|
|
}
|
|
}
|
|
|
|
// check if current service instance can handle more task
|
|
func (s *Service) canHandleTask() bool {
|
|
handledNum := atomic.LoadInt32(&s.handledTaskNum)
|
|
return handledNum < s.maxHandleTaskNum
|
|
}
|
|
|
|
// remove the expired tasks
|
|
func (s *Service) checkTasks() {
|
|
for k, t := range s.taskStartTimes {
|
|
if time.Now().Unix()-t.Unix() > s.taskTimeout {
|
|
delete(s.taskStartTimes, k)
|
|
atomic.AddInt32(&s.handledTaskNum, -1)
|
|
// delete task from database
|
|
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Service) Notify(data CBReq) {
|
|
// extract the task ID
|
|
split := strings.Split(data.Prompt, " ")
|
|
var job model.MidJourneyJob
|
|
res := s.db.Where("message_id = ?", data.MessageId).First(&job)
|
|
if res.Error == nil && data.Status == Finished {
|
|
logger.Warn("重复消息:", data.MessageId)
|
|
return
|
|
}
|
|
|
|
tx := s.db.Session(&gorm.Session{}).Where("progress < ?", 100).Order("id ASC")
|
|
if data.ReferenceId != "" {
|
|
tx = tx.Where("reference_id = ?", data.ReferenceId)
|
|
} else {
|
|
tx = tx.Where("task_id = ?", split[0])
|
|
}
|
|
res = tx.First(&job)
|
|
if res.Error != nil {
|
|
logger.Warn("非法任务:", res.Error)
|
|
return
|
|
}
|
|
|
|
job.ChannelId = data.ChannelId
|
|
job.MessageId = data.MessageId
|
|
job.ReferenceId = data.ReferenceId
|
|
job.Progress = data.Progress
|
|
job.Prompt = data.Prompt
|
|
job.Hash = data.Image.Hash
|
|
job.OrgURL = data.Image.URL
|
|
if s.client.Config.UseCDN {
|
|
job.UseProxy = true
|
|
job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.imgCdnURL)
|
|
}
|
|
|
|
res = s.db.Updates(&job)
|
|
if res.Error != nil {
|
|
logger.Error("error with update job: ", res.Error)
|
|
return
|
|
}
|
|
|
|
if data.Status == Finished {
|
|
// release lock task
|
|
atomic.AddInt32(&s.handledTaskNum, -1)
|
|
}
|
|
|
|
s.notifyQueue.RPush(job.UserId)
|
|
|
|
}
|