feat: midjourney plus service is ready

This commit is contained in:
RockYang 2024-01-11 18:16:48 +08:00
parent e8fff55c42
commit d70035ff0c
13 changed files with 671 additions and 71 deletions

View File

@ -152,6 +152,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
c.Request.URL.Path == "/api/role/list" ||
c.Request.URL.Path == "/api/mj/jobs" ||
c.Request.URL.Path == "/api/mj/client" ||
c.Request.URL.Path == "/api/mj/notify" ||
c.Request.URL.Path == "/api/invite/hits" ||
c.Request.URL.Path == "/api/sd/jobs" ||
c.Request.URL.Path == "/api/upload" ||

View File

@ -5,21 +5,23 @@ import (
)
type AppConfig struct {
Path string `toml:"-"`
Listen string
Session Session
ProxyURL string
MysqlDns string // mysql 连接地址
Manager Manager // 后台管理员账户信息
StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
SmsConfig AliYunSmsConfig // AliYun send message service config
OSS OSSConfig // OSS config
MjConfigs []MidJourneyConfig // mj AI draw service pool
WeChatBot bool // 是否启用微信机器人
SdConfigs []StableDiffusionConfig // sd AI draw service pool
Path string `toml:"-"`
Listen string
Session Session
ProxyURL string
MysqlDns string // mysql 连接地址
Manager Manager // 后台管理员账户信息
StaticDir string // 静态资源目录
StaticUrl string // 静态资源 URL
Redis RedisConfig // redis 连接信息
ApiConfig ChatPlusApiConfig // ChatPlus API authorization configs
SmsConfig AliYunSmsConfig // AliYun send message service config
OSS OSSConfig // OSS config
MjConfigs []MidJourneyConfig // mj AI draw service pool
MjPlusConfigs []MidJourneyPlusConfig // MJ plus config
ImgCdnURL string // 图片反代加速地址
WeChatBot bool // 是否启用微信机器人
SdConfigs []StableDiffusionConfig // sd AI draw service pool
XXLConfig XXLConfig
AlipayConfig AlipayConfig
@ -60,7 +62,6 @@ type MidJourneyConfig struct {
ChanelId string // Chanel ID
UseCDN bool
DiscordAPI string
DiscordCDN string
DiscordGateway string
}
@ -71,6 +72,14 @@ type StableDiffusionConfig struct {
Txt2ImgJsonPath string
}
type MidJourneyPlusConfig struct {
Enabled bool // 如果启用了 MidJourney Plus将会自动禁用原生的MidJourney服务
Name string // 服务名称,保持唯一
ApiURL string
ApiKey string
NotifyURL string // 任务进度更新回调地址
}
type AliYunSmsConfig struct {
AccessKey string
AccessSecret string

View File

@ -5,6 +5,7 @@ import (
"chatplus/core/types"
"chatplus/service"
"chatplus/service/mj"
"chatplus/service/mj/plus"
"chatplus/service/oss"
"chatplus/store/model"
"chatplus/store/vo"
@ -203,7 +204,6 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
}
idValue, _ := c.Get(types.LoginUserID)
jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
taskId, _ := h.snowflake.Next(true)
job := model.MidJourneyJob{
@ -221,7 +221,7 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
}
h.pool.PushTask(types.MjTask{
Id: jobId,
Id: int(job.Id),
SessionId: data.SessionId,
Type: types.TaskUpscale,
Prompt: data.Prompt,
@ -251,7 +251,6 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
}
idValue, _ := c.Get(types.LoginUserID)
jobId := 0
userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
taskId, _ := h.snowflake.Next(true)
job := model.MidJourneyJob{
@ -270,7 +269,7 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
}
h.pool.PushTask(types.MjTask{
Id: jobId,
Id: int(job.Id),
SessionId: data.SessionId,
Type: types.TaskVariation,
Prompt: data.Prompt,
@ -340,9 +339,13 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
// 正在运行中任务使用代理访问图片
if item.ImgURL == "" && item.OrgURL != "" {
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
if h.App.Config.ImgCdnURL != "" {
job.ImgURL = strings.ReplaceAll(job.OrgURL, "https://cdn.discordapp.com", h.App.Config.ImgCdnURL)
} else {
image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
if err == nil {
job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
}
}
}
}
@ -382,3 +385,24 @@ func (h *MidJourneyHandler) Remove(c *gin.Context) {
resp.SUCCESS(c)
}
// Notify MidJourney Plus 服务任务回调处理
func (h *MidJourneyHandler) Notify(c *gin.Context) {
var data plus.CBReq
if err := c.ShouldBindJSON(&data); err != nil {
logger.Error("非法任务回调:%+v", err)
return
}
err := h.pool.Notify(data)
if err != nil {
logger.Error(err)
} else {
userId := h.GetLoginUserId(c)
client := h.pool.Clients.Get(userId)
if client != nil {
_ = client.Send([]byte("Task Updated"))
}
}
resp.SUCCESS(c)
}

View File

@ -7,8 +7,8 @@ import (
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"github.com/gin-gonic/gin"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
@ -22,23 +22,176 @@ func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS
return &TestHandler{db: db, snowflake: snowflake, js: js}
}
func (h *TestHandler) Test(c *gin.Context) {
//h.initUserNickname(c)
//h.initMjTaskId(c)
type reqBody struct {
BotType string `json:"botType"`
Prompt string `json:"prompt"`
Base64Array []interface{} `json:"base64Array,omitempty"`
AccountFilter struct {
InstanceId string `json:"instanceId"`
Modes []interface{} `json:"modes"`
Remix bool `json:"remix"`
RemixAutoConsidered bool `json:"remixAutoConsidered"`
} `json:"accountFilter,omitempty"`
NotifyHook string `json:"notifyHook"`
State string `json:"state,omitempty"`
}
orderId, _ := h.snowflake.Next(false)
params := payment.JPayReq{
TotalFee: 12345,
OutTradeNo: orderId,
Subject: "支付测试",
type resBody struct {
Code int `json:"code"`
Description string `json:"description"`
Properties struct {
} `json:"properties"`
Result string `json:"result"`
}
func (h *TestHandler) Test(c *gin.Context) {
query(c)
}
func upscale(c *gin.Context) {
apiURL := "https://api.openai1s.cn/mj/submit/action"
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
body := map[string]string{
"customId": "MJ::JOB::upsample::1::c80a8eb1-f2d1-4f40-8785-97eb99b7ba0a",
"taskId": "1704880156226095",
"notifyHook": "http://r9it.com:6004/api/test/mj",
}
r := h.js.Pay(params)
if !r.IsOK() {
resp.ERROR(c, r.ReturnMsg)
var res resBody
var resErr errRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+token).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&resErr).
Post(apiURL)
if err != nil {
resp.ERROR(c, "请求出错:"+err.Error())
return
}
resp.SUCCESS(c, r)
if r.IsErrorState() {
resp.ERROR(c, "返回错误状态:"+resErr.Error.Message)
return
}
resp.SUCCESS(c, res)
}
type queryRes struct {
Action string `json:"action"`
Buttons []struct {
CustomId string `json:"customId"`
Emoji string `json:"emoji"`
Label string `json:"label"`
Style int `json:"style"`
Type int `json:"type"`
} `json:"buttons"`
Description string `json:"description"`
FailReason string `json:"failReason"`
FinishTime int `json:"finishTime"`
Id string `json:"id"`
ImageUrl string `json:"imageUrl"`
Progress string `json:"progress"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Properties struct {
} `json:"properties"`
StartTime int `json:"startTime"`
State string `json:"state"`
Status string `json:"status"`
SubmitTime int `json:"submitTime"`
}
func query(c *gin.Context) {
apiURL := "https://api.openai1s.cn/mj/task/1704960661008372/fetch"
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
var res queryRes
r, err := req.C().R().SetHeader("Authorization", "Bearer "+token).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
resp.ERROR(c, "请求出错:"+err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "返回错误状态:"+r.Status)
return
}
resp.SUCCESS(c, res)
}
type errRes struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
func image(c *gin.Context) {
apiURL := "https://api.openai1s.cn/mj-fast/mj/submit/imagine"
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
body := reqBody{
BotType: "MID_JOURNEY",
Prompt: "一个中国美女,手上拿着一桶爆米花,脸上带着迷人的微笑,白色衣服 --s 750 --v 6",
NotifyHook: "http://r9it.com:6004/api/test/mj",
}
var res resBody
var resErr errRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+token).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&resErr).
Post(apiURL)
if err != nil {
resp.ERROR(c, "请求出错:"+err.Error())
return
}
if r.IsErrorState() {
resp.ERROR(c, "返回错误状态:"+resErr.Error.Message)
return
}
resp.SUCCESS(c, res)
}
type cbReq struct {
Id string `json:"id"`
Action string `json:"action"`
Status string `json:"status"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
Progress string `json:"progress"`
ImageUrl string `json:"imageUrl"`
FailReason interface{} `json:"failReason"`
Properties struct {
FinalPrompt string `json:"finalPrompt"`
} `json:"properties"`
}
func (h *TestHandler) Mj(c *gin.Context) {
var data cbReq
if err := c.ShouldBindJSON(&data); err != nil {
logger.Error(err)
}
logger.Debugf("任务ID%s,任务进度:%s,图片地址:%s, 最终提示词:%s", data.Id, data.Progress, data.ImageUrl, data.Properties.FinalPrompt)
apiURL := "https://api.openai1s.cn/mj/task/" + data.Id + "/fetch"
token := "sk-QpBaQn9Z5vngsjJaFdDfC9Db90C845EaB5E764578a7d292a"
var res queryRes
_, _ = req.C().R().SetHeader("Authorization", "Bearer "+token).
SetSuccessResult(&res).
Get(apiURL)
fmt.Println(res.State, ",", res.ImageUrl, ",", res.Progress)
}
func (h *TestHandler) initUserNickname(c *gin.Context) {

View File

@ -235,6 +235,7 @@ func main() {
group.POST("variation", h.Variation)
group.GET("jobs", h.JobList)
group.POST("remove", h.Remove)
group.POST("notify", h.Notify)
}),
fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
group := s.Engine.Group("/api/sd")
@ -367,6 +368,7 @@ func main() {
fx.Provide(handler.NewTestHandler),
fx.Invoke(func(s *core.AppServer, h *handler.TestHandler) {
s.Engine.GET("/api/test", h.Test)
s.Engine.POST("/api/test/mj", h.Mj)
}),
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
err := s.Run(db)

View File

@ -33,7 +33,7 @@ func NewBot(name string, proxy string, config types.MidJourneyConfig, service *S
// use CDN reverse proxy
if config.UseCDN {
discordgo.SetEndpointDiscord(config.DiscordAPI)
discordgo.SetEndpointCDN(config.DiscordCDN)
discordgo.SetEndpointCDN("https://cdn.discordapp.com")
discordgo.SetEndpointStatus(config.DiscordAPI + "/api/v2/")
bot.MjGateway = config.DiscordGateway + "/"
} else { // use proxy

View File

@ -11,12 +11,13 @@ import (
// MidJourney client
type Client struct {
client *req.Client
Config types.MidJourneyConfig
apiURL string
client *req.Client
Config types.MidJourneyConfig
imgCdnURL string
apiURL string
}
func NewClient(config types.MidJourneyConfig, proxy string) *Client {
func NewClient(config types.MidJourneyConfig, proxy string, imgCdnURL string) *Client {
client := req.C().SetTimeout(10 * time.Second)
var apiURL string
// set proxy URL
@ -29,7 +30,7 @@ func NewClient(config types.MidJourneyConfig, proxy string) *Client {
}
}
return &Client{client: client, Config: config, apiURL: apiURL}
return &Client{client: client, Config: config, apiURL: apiURL, imgCdnURL: imgCdnURL}
}
func (c *Client) Imagine(prompt string) error {

View File

@ -0,0 +1,171 @@
package plus
import (
"chatplus/core/types"
logger2 "chatplus/logger"
"errors"
"fmt"
"github.com/imroc/req/v3"
)
var logger = logger2.GetLogger()
// Client MidJourney Plus Client
type Client struct {
Config types.MidJourneyPlusConfig
}
func NewClient(config types.MidJourneyPlusConfig) *Client {
return &Client{Config: config}
}
type ImageReq struct {
BotType string `json:"botType"`
Prompt string `json:"prompt"`
Base64Array []interface{} `json:"base64Array,omitempty"`
AccountFilter struct {
InstanceId string `json:"instanceId"`
Modes []interface{} `json:"modes"`
Remix bool `json:"remix"`
RemixAutoConsidered bool `json:"remixAutoConsidered"`
} `json:"accountFilter,omitempty"`
NotifyHook string `json:"notifyHook"`
State string `json:"state,omitempty"`
}
type ImageRes struct {
Code int `json:"code"`
Description string `json:"description"`
Properties struct {
} `json:"properties"`
Result string `json:"result"`
}
type ErrRes struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
}
func (c *Client) Imagine(prompt string) (ImageRes, error) {
apiURL := fmt.Sprintf("%s/mj-fast/mj/submit/imagine", c.Config.ApiURL)
body := ImageReq{
BotType: "MID_JOURNEY",
Prompt: prompt,
NotifyHook: c.Config.NotifyURL,
}
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Upscale 放大指定的图片
func (c *Client) Upscale(index int, messageId string, hash string) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash),
"taskId": messageId,
"notifyHook": c.Config.NotifyURL,
}
apiURL := fmt.Sprintf("%s/mj/submit/action", c.Config.ApiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
// Variation 以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
func (c *Client) Variation(index int, messageId string, hash string) (ImageRes, error) {
body := map[string]string{
"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash),
"taskId": messageId,
"notifyHook": c.Config.NotifyURL,
}
apiURL := fmt.Sprintf("%s/mj/submit/action", c.Config.ApiURL)
var res ImageRes
var errRes ErrRes
r, err := req.C().R().
SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetBody(body).
SetSuccessResult(&res).
SetErrorResult(&errRes).
Post(apiURL)
if err != nil {
return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
}
if r.IsErrorState() {
return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
}
return res, nil
}
type QueryRes struct {
Action string `json:"action"`
Buttons []struct {
CustomId string `json:"customId"`
Emoji string `json:"emoji"`
Label string `json:"label"`
Style int `json:"style"`
Type int `json:"type"`
} `json:"buttons"`
Description string `json:"description"`
FailReason string `json:"failReason"`
FinishTime int `json:"finishTime"`
Id string `json:"id"`
ImageUrl string `json:"imageUrl"`
Progress string `json:"progress"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Properties struct {
} `json:"properties"`
StartTime int `json:"startTime"`
State string `json:"state"`
Status string `json:"status"`
SubmitTime int `json:"submitTime"`
}
func (c *Client) QueryTask(taskId string) (QueryRes, error) {
apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.Config.ApiURL, taskId)
var res QueryRes
r, err := req.C().R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
SetSuccessResult(&res).
Get(apiURL)
if err != nil {
return QueryRes{}, err
}
if r.IsErrorState() {
return QueryRes{}, errors.New("error status:" + r.Status)
}
return res, nil
}

View File

@ -0,0 +1,164 @@
package plus
import (
"chatplus/core/types"
"chatplus/store"
"chatplus/store/model"
"chatplus/utils"
"fmt"
"strings"
"sync/atomic"
"time"
"gorm.io/gorm"
)
// Service MJ 绘画服务
type Service struct {
name string // service name
Client *Client // MJ Client
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
maxHandleTaskNum int32 // max task number current service can handle
handledTaskNum int32 // already handled task number
taskStartTimes map[int]time.Time // task start time, to check if the task is timeout
taskTimeout int64
}
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, maxTaskNum int32, timeout int64, db *gorm.DB, client *Client) *Service {
return &Service{
name: name,
db: db,
taskQueue: taskQueue,
notifyQueue: notifyQueue,
Client: client,
taskTimeout: timeout,
maxHandleTaskNum: maxTaskNum,
taskStartTimes: make(map[int]time.Time, 0),
}
}
func (s *Service) Run() {
logger.Infof("Starting MidJourney job consumer for %s", s.name)
for {
s.checkTasks()
if !s.canHandleTask() {
// current service is full, can not handle more task
// waiting for running task finish
time.Sleep(time.Second * 3)
continue
}
var task types.MjTask
err := s.taskQueue.LPop(&task)
if err != nil {
logger.Errorf("taking task with error: %v", err)
continue
}
// if it's reference message, check if it's this channel's message
if task.ChannelId != "" && task.ChannelId != s.Client.Config.Name {
s.taskQueue.RPush(task)
time.Sleep(time.Second)
continue
}
logger.Infof("%s handle a new MidJourney task: %+v", s.name, task)
var res ImageRes
switch task.Type {
case types.TaskImage:
index := strings.Index(task.Prompt, " ")
res, err = s.Client.Imagine(task.Prompt[index+1:])
break
case types.TaskUpscale:
res, err = s.Client.Upscale(task.Index, task.MessageId, task.MessageHash)
break
case types.TaskVariation:
res, err = s.Client.Variation(task.Index, task.MessageId, task.MessageHash)
}
if err != nil || (res.Code != 1 && res.Code != 22) {
logger.Error("绘画任务执行失败:", err)
// update the task progress
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
// 任务失败,通知前端
s.notifyQueue.RPush(task.UserId)
// restore img_call quota
s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls + ?", 1))
// TODO: 任务提交失败,加入队列重试
continue
}
logger.Infof("任务提交成功:%+v", res)
// lock the task until the execute timeout
s.taskStartTimes[task.Id] = time.Now()
atomic.AddInt32(&s.handledTaskNum, 1)
// 更新任务 ID/频道
s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumns(map[string]interface{}{
"task_id": res.Result,
"channel_id": s.Client.Config.Name,
})
}
}
// check if current service instance can handle more task
func (s *Service) canHandleTask() bool {
handledNum := atomic.LoadInt32(&s.handledTaskNum)
return handledNum < s.maxHandleTaskNum
}
// remove the expired tasks
func (s *Service) checkTasks() {
for k, t := range s.taskStartTimes {
if time.Now().Unix()-t.Unix() > s.taskTimeout {
delete(s.taskStartTimes, k)
atomic.AddInt32(&s.handledTaskNum, -1)
// delete task from database
s.db.Delete(&model.MidJourneyJob{Id: uint(k)}, "progress < 100")
}
}
}
type CBReq struct {
Id string `json:"id"`
Action string `json:"action"`
Status string `json:"status"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
Progress string `json:"progress"`
ImageUrl string `json:"imageUrl"`
FailReason interface{} `json:"failReason"`
Properties struct {
FinalPrompt string `json:"finalPrompt"`
} `json:"properties"`
}
func (s *Service) Notify(data CBReq, job model.MidJourneyJob) error {
job.Progress = utils.IntValue(strings.Replace(data.Progress, "%", "", 1), 0)
job.Prompt = data.Properties.FinalPrompt
if data.ImageUrl != "" {
job.OrgURL = data.ImageUrl
}
job.UseProxy = true
job.MessageId = data.Id
logger.Debugf("JOB: %+v", job)
res := s.db.Updates(&job)
if res.Error != nil {
return fmt.Errorf("error with update job: %v", res.Error)
}
if data.Status == "SUCCESS" {
// release lock task
atomic.AddInt32(&s.handledTaskNum, -1)
}
s.notifyQueue.RPush(job.UserId)
return nil
}

View File

@ -2,11 +2,13 @@ package mj
import (
"chatplus/core/types"
"chatplus/service/mj/plus"
"chatplus/service/oss"
"chatplus/store"
"chatplus/store/model"
"fmt"
"github.com/go-redis/redis/v8"
"strings"
"time"
"gorm.io/gorm"
@ -14,7 +16,7 @@ import (
// ServicePool Mj service pool
type ServicePool struct {
services []*Service
services []interface{}
taskQueue *store.RedisQueue
notifyQueue *store.RedisQueue
db *gorm.DB
@ -23,37 +25,53 @@ type ServicePool struct {
}
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager, appConfig *types.AppConfig) *ServicePool {
services := make([]*Service, 0)
services := make([]interface{}, 0)
taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
// create mj client and service
for k, config := range appConfig.MjConfigs {
for k, config := range appConfig.MjPlusConfigs {
if config.Enabled == false {
continue
}
// create mj client
client := NewClient(config, appConfig.ProxyURL)
name := fmt.Sprintf("MjService-%d", k)
// create mj service
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
botName := fmt.Sprintf("MjBot-%d", k)
bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
if err != nil {
continue
}
err = bot.Run()
if err != nil {
continue
}
// run mj service
client := plus.NewClient(config)
name := fmt.Sprintf("MidJourney Plus Service-%d", k)
servicePlus := plus.NewService(name, taskQueue, notifyQueue, 10, 600, db, client)
go func() {
service.Run()
servicePlus.Run()
}()
services = append(services, servicePlus)
}
services = append(services, service)
if len(services) == 0 {
// create mj client and service
for k, config := range appConfig.MjConfigs {
if config.Enabled == false {
continue
}
// create mj client
client := NewClient(config, appConfig.ProxyURL, appConfig.ImgCdnURL)
name := fmt.Sprintf("MjService-%d", k)
// create mj service
service := NewService(name, taskQueue, notifyQueue, 4, 600, db, client)
botName := fmt.Sprintf("MjBot-%d", k)
bot, err := NewBot(botName, appConfig.ProxyURL, config, service)
if err != nil {
continue
}
err = bot.Run()
if err != nil {
continue
}
// run mj service
go func() {
service.Run()
}()
services = append(services, service)
}
}
return &ServicePool{
@ -94,7 +112,24 @@ func (p *ServicePool) DownloadImages() {
// download images
for _, v := range items {
imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
if v.OrgURL == "" {
continue
}
var imgURL string
var err error
if v.UseProxy {
if servicePlus := p.getServicePlus(v.ChannelId); servicePlus != nil {
task, _ := servicePlus.Client.QueryTask(v.TaskId)
if task.ImageUrl != "" {
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(task.ImageUrl, false)
}
if len(task.Buttons) > 0 {
v.Hash = getImageHash(task.Buttons[0].CustomId)
}
}
} else {
imgURL, err = p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, true)
}
if err != nil {
logger.Error("error with download image: ", err)
continue
@ -125,3 +160,37 @@ func (p *ServicePool) PushTask(task types.MjTask) {
func (p *ServicePool) HasAvailableService() bool {
return len(p.services) > 0
}
func (p *ServicePool) Notify(data plus.CBReq) error {
logger.Infof("收到任务回调:%+v", data)
var job model.MidJourneyJob
res := p.db.Where("task_id = ?", data.Id).First(&job)
if res.Error != nil {
return fmt.Errorf("非法任务:%s", data.Id)
}
if servicePlus := p.getServicePlus(job.ChannelId); servicePlus != nil {
return servicePlus.Notify(data, job)
}
return nil
}
func (p *ServicePool) getServicePlus(name string) *plus.Service {
for _, s := range p.services {
if servicePlus, ok := s.(*plus.Service); ok {
if servicePlus.Client.Config.Name == name {
return servicePlus
}
}
}
return nil
}
func getImageHash(action string) string {
split := strings.Split(action, "::")
if len(split) > 5 {
return split[4]
}
return split[len(split)-1]
}

View File

@ -58,8 +58,6 @@ func (s *Service) Run() {
// if it's reference message, check if it's this channel's message
if task.ChannelId != "" && task.ChannelId != s.client.Config.ChanelId {
s.taskQueue.RPush(task)
s.db.Model(&model.MidJourneyJob{Id: uint(task.Id)}).UpdateColumn("progress", -1)
s.notifyQueue.RPush(task.UserId)
time.Sleep(time.Second)
continue
}
@ -143,7 +141,7 @@ func (s *Service) Notify(data CBReq) {
job.OrgURL = data.Image.URL
if s.client.Config.UseCDN {
job.UseProxy = true
job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.Config.DiscordCDN)
job.ImgURL = strings.ReplaceAll(data.Image.URL, "https://cdn.discordapp.com", s.client.imgCdnURL)
}
res = s.db.Updates(&job)

View File

@ -1,5 +1,12 @@
package main
func main() {
import (
"fmt"
"strings"
)
func main() {
str := "7151109597841850368 一个漂亮的中国女孩,手上拿着一桶爆米花,脸上带着迷人的微笑,电影效果"
index := strings.Index(str, " ")
fmt.Println(str[index+1:])
}

View File

@ -700,6 +700,7 @@ const generate = () => {
//
const upscale = (index, item) => {
console.log(item)
send('/api/mj/upscale', index, item)
}