goploy/internal/model/model.go
2024-01-19 15:16:01 +08:00

191 lines
3.6 KiB
Go

// Copyright 2022 The Goploy Authors. All rights reserved.
// Use of this source code is governed by a GPLv3-style
// license that can be found in the LICENSE file.
package model
import (
"database/sql"
"errors"
"fmt"
"github.com/hashicorp/go-version"
"github.com/zhenorzz/goploy/config"
"github.com/zhenorzz/goploy/database"
"github.com/zhenorzz/goploy/internal/pkg"
"path"
"sort"
"strings"
)
// Pagination struct
type Pagination struct {
Page uint64 `json:"page" schema:"page"`
Rows uint64 `json:"rows" schema:"rows"`
Total uint64 `json:"total" schema:"total"`
}
// state type
const (
Fail = iota
Success
)
// state type
const (
Disable = iota
Enable
)
// review state type
const (
PENDING = iota
APPROVE
DENY
)
type SQLRunner struct {
*sql.DB
config.BaseObserver
}
func (db *SQLRunner) OnChange() error {
return connectDB()
}
// DB init when the program start
var DB = &SQLRunner{}
func Init() {
if err := connectDB(); err != nil {
panic(err)
}
config.GetEventBus().Subscribe(config.DBEventTopic, DB)
}
func connectDB() error {
if runner, err := Open(config.Toml.DB); err != nil {
return err
} else {
DB = runner
return nil
}
}
func Open(dbConfig config.DBConfig) (*SQLRunner, error) {
dbConn := fmt.Sprintf(
"%s:%s@(%s:%s)/%s?charset=utf8mb4,utf8",
dbConfig.User,
dbConfig.Password,
dbConfig.Host,
dbConfig.Port,
dbConfig.Database,
)
{
// @see https://github.com/go-sql-driver/mysql/wiki/Examples#a-word-on-sqlopen
var err error
db, err := sql.Open(dbConfig.Type, dbConn)
if err != nil {
return nil, err
}
// ping db to make sure the db has connected
if err = db.Ping(); err != nil {
return nil, err
}
return &SQLRunner{DB: db}, nil
}
}
func (db *SQLRunner) CreateDB(name string) error {
query := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s`", name)
_, err := db.Exec(query)
if err != nil {
return err
}
return nil
}
func (db *SQLRunner) UseDB(name string) error {
query := fmt.Sprintf("USE `%s`", name)
_, err := db.Exec(query)
if err != nil {
return err
}
return nil
}
func (db *SQLRunner) ImportSQL(sqlPath string) error {
sqlContent, err := database.File.ReadFile(sqlPath)
if err != nil {
return err
}
for _, query := range strings.Split(string(sqlContent), ";") {
query = pkg.ClearNewline(query)
if len(query) == 0 {
continue
}
_, err := db.Exec(query)
if err != nil {
return err
}
}
return nil
}
func Update(targetVerStr string) error {
systemConfig, err := SystemConfig{
Key: "version",
}.GetDataByKey()
if err != nil {
return err
}
if systemConfig.Value == "" {
systemConfig.Value = "0.0.1"
}
currentVer, err := version.NewVersion(systemConfig.Value)
if err != nil {
return err
}
targetVer, err := version.NewVersion(targetVerStr)
if err != nil {
return err
}
if ret := currentVer.Compare(targetVer); ret == 0 {
return nil
} else if ret == 1 {
return errors.New("currentVer greater than targetVer")
}
sqlEntries, err := database.File.ReadDir(".")
if err != nil {
return err
}
var vers []*version.Version
for _, entry := range sqlEntries {
filename := entry.Name()
ver, err := version.NewVersion(filename[0 : len(filename)-len(path.Ext(filename))])
if err != nil {
continue
}
vers = append(vers, ver)
}
sort.Sort(version.Collection(vers))
for _, ver := range vers {
if currentVer.LessThan(ver) && targetVer.GreaterThanOrEqual(ver) {
if err := DB.ImportSQL(ver.String() + database.FileExt); err != nil {
return err
}
}
}
println(`Update app success`)
systemConfig.Value = targetVerStr
return systemConfig.EditRowByKey()
}